API Reference#

mlx_cluster.random_walk(rowptr: mlx.core.array, col: mlx.core.array, start: mlx.core.array, rand: mlx.core.array, walk_length: int, *, stream: object | None = None) tuple#

Uniform random walks.

Parameters:
  • rowptr (mlx.core.array) -- rowptr of graph in csr format.

  • col (mlx.core.array) -- edges(col) in csr format.

  • start_indices (mlx.core.array) -- starting nodes of graph from which sampling will be performed.

  • random_values (mlx.core.array) -- random values (between 0 to 1)

  • walk_length (int) -- walk length of random graph

Returns:

tuple (mlx.core.array, mlx.core.array)

mlx_cluster.rejection_sampling(rowptr: mlx.core.array, col: mlx.core.array, start: mlx.core.array, walk_length: int, p: float, q: float, *, stream: object | None = None) tuple#

Sample nodes from the graph by sampling neighbors based on probablity p and q

Parameters:
  • rowptr (mlx.core.array) -- rowptr of graph in csr format.

  • col (mlx.core.array) -- edges in csr format.

  • start (mlx.core.array) -- starting node of graph from which biased sampling will be performed.

  • walk_length (int) -- walk length of random graph

  • p (float) -- Likelihood of immediately revisiting a node in the walk.

  • q (float) -- Control parameter to interpolate between breadth-first strategy and depth-first strategy

Returns:

tuple (mlx.core.array, mlx.core.array)

mlx_cluster.neighbor_sample(colptr: mlx.core.array, row: mlx.core.array, input_node: mlx.core.array, num_neighbors: collections.abc.Sequence[int], replace: bool = False, directed: bool = True) tuple#

Simple neighbor sampling without primitives.

Parameters:
  • colptr (mlx.core.array) -- Column pointers (CSC format)

  • row (mlx.core.array) -- Row indices (CSC format)

  • input_node (mlx.core.array) -- Input nodes to sample from

  • num_neighbors (list[int]) -- List containing how many neighbors to sample in each hop

  • replace (bool) -- Sample with replacement (Default to False)

  • directed (bool) -- Directed graph (Default to True)

Returns:

(samples, rows, cols, edges)

Return type:

tuple (mlx.core.array, mlx.core.array, mlx.core.array, mlx.core.array)