Skip to main content

The K-Median Problem

The k-median problem is the problem of clustering data points into k clusters, aiming to minimize the sum of distances between points belonging to a particular cluster and the data point that is the center of the cluster. This can be considered a variant of k-means clustering. For k-means clustering, we determine the mean value of each cluster whereas for k-median we use the median value. This problem is known as NP-hard. We describe how to implement the mathematical model of k-median problem with JijModeling and solve it with JijZept.

Mathematical Model

Let us consider a mathematical model for k-median problem.

Decisioin variables

We denote xi,jx_{i, j} to be a binary variable which is 1 if ii-th data point belongs to the jj-th median data point and 0 otherwise. We also use a binary variable yjy_j which is 1 if jj-th data point is the median and 0 otherwise.

xi,j={1, Node i is covered by median j0, Otherwisex_{i,j} = \begin{cases} 1,~\text{Node $i$ is covered by median $j$}\\ 0,~\text{Otherwise} \end{cases}
yi={1, Node i is a median.0, Otherwisey_i = \begin{cases} 1,~\text{Node $i$ is a median.}\\ 0,~\text{Otherwise} \end{cases}

Mathematical Model

Our goal is to find a solution that minimizes the sum of the distance between ii-th data point and jj-th median point. We also set three constraints:

  1. A data point must belong to some single median data point,
  2. the number of median points is kk,
  3. the data points must belong to a median point.

These can be expressed in a mathematical model as follows.

minxijdi,jxi,js.t. jxi,j=1, ijyj=kxi,jyj, i,jxi,j{0,1} i,jyj{0,1} j(1)\begin{aligned} \min_x &\sum_{i}\sum_j d_{i,j}x_{i,j} \\ \mathrm{s.t.}~&\sum_{j} x_{i,j} = 1,~\forall i\\ &\sum_j y_j = k\\ &x_{i,j} \leq y_j, ~\forall i, j\\ &x_{i,j} \in \{0, 1\} ~\forall i, j\\ &y_j \in \{0, 1\}~\forall j \end{aligned} \tag{1}

Modeling by JijModeling

Here, we show an implementation using JijModeling. We first define variables for the mathematical model described above.

import jijmodeling as jm

d = jm.Placeholder("d", ndim=2)
N = d.len_at(0, latex="N")
J = jm.Placeholder("J", ndim=1)
k = jm.Placeholder("k")
i = jm.Element("i", belong_to=(0, N))
j = jm.Element("j", belong_to=J)
x = jm.BinaryVar("x", shape=(N, J.shape[0]))
y = jm.BinaryVar("y", shape=(J.shape[0],))

d is a two-dimensional array representing the distance between each data point and the median point. The number of data points N is extracted from the number of elements in d. J is a one-dimensional array representing the candidate indices of the median point. k is the number of median points. i and j denote the indices used in the binary variables, respectively. Finally, we define the binary variables x and y.

Then, we implement (1).

problem = jm.Problem("k-median")
problem += jm.sum([i, j], d[i, j]*x[i, j])
problem += jm.Constraint("onehot", x[i, :].sum() == 1, forall=i)
problem += jm.Constraint("k-median", y[:].sum() == k)
problem += jm.Constraint("cover", x[i, j] <= y[j], forall=[i, j])

With jm.Constraint("onehot", x[i, :].sum() == 1, forall=i), we insert as a constraint that jxi,j=1\sum_j x_{i, j} = 1 for all ii. jm.Constraint("k-median", y[:].sum() == k) represents jyj=k\sum_j y_j = k. jm.Constraint("cover", x[i, j] <= y[j], forall=[i, j]) requires that xi,jyjx_{i, j} \leq y_j must be for all i,ji, j.

We can check the implementation of the mathematical model on Jupyter Notebook.


Problem:k-medianmini=0N1jJdi,jxi,js.t.coverxi,jyji{0,,N1}jJk-median0=0len(J,0)1y0=konehot1=0len(J,0)1xi,1=1i{0,,N1}wherey1-dim binary variablex2-dim binary variable\begin{array}{cccc}\text{Problem:} & \text{k-median} & & \\& & \min \quad \displaystyle \sum_{i = 0}^{N - 1} \sum_{j \in J} d_{i, j} \cdot x_{i, j} & \\\text{{s.t.}} & & & \\ & \text{cover} & \displaystyle x_{i, j} \leq y_{j} & \forall i \in \left\{0,\ldots,N - 1\right\} \forall j \in J \\ & \text{k-median} & \displaystyle \sum_{\ast_{0} = 0}^{\mathrm{len}\left(J, 0\right) - 1} y_{\ast_{0}} = k & \\ & \text{onehot} & \displaystyle \sum_{\ast_{1} = 0}^{\mathrm{len}\left(J, 0\right) - 1} x_{i, \ast_{1}} = 1 & \forall i \in \left\{0,\ldots,N - 1\right\} \\\text{{where}} & & & \\& y & 1\text{-dim binary variable}\\& x & 2\text{-dim binary variable}\\\end{array}

Prepare instance

We prepare and visualize data points.

import matplotlib.pyplot as plt
import numpy as np

num_nodes = 30
X, Y = np.random.uniform(0, 1, (2, num_nodes))

plt.plot(X, Y, "o")
[<matplotlib.lines.Line2D at 0x116635420>]


We need to calculate the distance between each data point.

XX, XX_T = np.meshgrid(X, X)
YY, YY_T = np.meshgrid(Y, Y)
distance = np.sqrt((XX - XX_T)**2 + (YY - YY_T)**2)

ph_value = {
"d": distance,
"J": np.arange(0, num_nodes),
"k": 4

Solve with JijZept

We solve the problems implemented so far using JijZept JijSASampler. We also use the parameter search function by setting search=True here.

import jijzept as jz

sampler = jz.JijSASampler(config="config.toml")
results = sampler.sample_model(problem, ph_value, search=True)

From the results, we extract the feasible solutions and show the smallest value of the objective function among them.

# extract feasible solutions
feasibles = results.feasible()
# get objective values from feasibles
objectives = np.array(feasibles.evaluation.objective)
# get lowest index from objective values
lowest_index = np.argmin(objectives)
print(lowest_index, objectives[lowest_index])
0 7.455011036647106

From the binary variable yiy_i, which data point is used as the median, we get the indices of the data points that are actually used as the median.

# check solution
nonzero_indices, nonzero_values, shape = feasibles.record.solution["y"][lowest_index]
print("indices: ", nonzero_indices)
print("values: ", nonzero_values)
indices:  ([0, 8, 15, 23],)
values: [1.0, 1.0, 1.0, 1.0]

This information allows us to visualize how data points are clustered.

median_indices = np.array(nonzero_indices[0])

median_X, median_Y = X[median_indices], Y[median_indices]

d_from_m = []
for m in median_indices:
d_from_m.append(np.sqrt((X - X[m])**2 + (Y - Y[m])**2))

cover_median = median_indices[np.argmin(d_from_m, axis=0)]

plt.plot(X, Y, "o")
plt.plot(X[median_indices], Y[median_indices], "o", markersize=10)
for index in range(len(X)):
j_value = cover_median[index]
i_value = index
plt.plot(X[[i_value, j_value]], Y[[i_value, j_value]], c="gray")


Orange and blue points show the median and other data points respectively. The gray line connects the median and the data points belonging to that cluster. This figure shows how they are into clusters.