"""Implements core function nearest_neighbours used for AMD and PDD calculations."""

import collections
from typing import Iterable
from itertools import product

import numba
import numpy as np
from scipy.spatial import KDTree


def nearest_neighbours(
        motif: np.ndarray,
        cell: np.ndarray,
        x: np.ndarray,
        k: int):
    """
    Given a periodic set represented by (motif, cell) and an integer k, find
    the k nearest neighbours in the periodic set to points in x.

    Parameters
    ----------
    motif : numpy.ndarray
        Orthogonal (Cartesian) coords of the motif, shape (no points, dims).
    cell : numpy.ndarray
        Orthogonal (Cartesian) coords of the unit cell, shape (dims, dims).
    x : numpy.ndarray
        Array of points to query for neighbours. For invariants of crystals
        this is the asymmetric unit.
    k : int
        Number of nearest neighbours to find for each point in x.

    Returns
    -------
    pdd : numpy.ndarray
        Array shape (motif.shape[0], k) of distances from points in x
        to their k nearest neighbours in the periodic set, in order.
        E.g. pdd[m][n] is the distance from x[m] to its n-th nearest
        neighbour in the periodic set.
    cloud : numpy.ndarray
        Collection of points in the periodic set that was generated
        during the nearest neighbour search.
    inds : numpy.ndarray
        Array shape (motif.shape[0], k) containing the indices of
        nearest neighbours in cloud. E.g. the n-th nearest neighbour to
        the m-th motif point is cloud[inds[m][n]].
    """

    cloud_generator = generate_concentric_cloud(motif, cell)
    n_points = 0
    cloud = []
    while n_points <= k:
        l = next(cloud_generator)
        n_points += l.shape[0]
        cloud.append(l)
    cloud.append(next(cloud_generator))
    cloud = np.concatenate(cloud)

    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
    pdd_, inds = tree.query(x, k=k+1, workers=-1)
    pdd = np.zeros_like(pdd_)

    while not np.allclose(pdd, pdd_, atol=1e-12, rtol=0):
        pdd = pdd_
        cloud = np.vstack((cloud, next(cloud_generator)))
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
        pdd_, inds = tree.query(x, k=k+1, workers=-1)

    return pdd_[:, 1:], cloud, inds[:, 1:]


def nearest_neighbours_minval(motif, cell, min_val):
    """The same as nearest_neighbours except a value is given instead of an
    integer k and the result has at least enough columns so all values in 
    the last column are at least the given value."""

    cloud_generator = generate_concentric_cloud(motif, cell)

    cloud = []
    for _ in range(3):
        cloud.append(next(cloud_generator))

    cloud = np.concatenate(cloud)
    tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
    pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)
    pdd = np.zeros_like(pdd_)

    while True:
        if np.all(pdd[:, -1] >= min_val):
            col_where = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]
            if np.array_equal(pdd[:, :col_where+1], pdd_[:, :col_where+1]):
                break

        pdd = pdd_
        cloud = np.vstack((cloud, next(cloud_generator)))
        tree = KDTree(cloud, compact_nodes=False, balanced_tree=False)
        pdd_, _ = tree.query(motif, k=cloud.shape[0], workers=-1)

    k = np.argwhere(np.all(pdd >= min_val, axis=0))[0][0]

    return pdd[:, 1:k+1]


def generate_concentric_cloud(
        motif: np.ndarray,
        cell: np.ndarray
) -> Iterable[np.ndarray]:
    """
    Generates batches of points from a periodic set given by (motif, cell)
    which get successively further away from the origin.

    Each yield gives all points (that have not already been yielded) which
    lie in a unit cell whose corner lattice point was generated by
    generate_integer_lattice(motif.shape[1]).

    Parameters
    ----------
    motif : ndarray
        Cartesian representation of the motif, shape (no points, dims).
    cell : ndarray
        Cartesian representation of the unit cell, shape (dims, dims).

    Yields
    -------
    ndarray
        Yields arrays of points from the periodic set.
    """

    m = len(motif)
    int_lattice_generator = generate_integer_lattice(cell.shape[0])

    while True:
        int_lattice = next(int_lattice_generator) @ cell
        layer = np.empty((m * len(int_lattice), cell.shape[0]))

        for i, translation in enumerate(int_lattice):
            layer[m*i:m*(i+1)] = motif + translation

        yield layer


def generate_integer_lattice(dims: int) -> Iterable[np.ndarray]:
    """Generates batches of integer lattice points.

    Each yield gives all points (that have not already been yielded)
    inside a sphere centered at the origin with radius d. d starts at 0
    and increments by 1 on each loop.

    Parameters
    ----------
    dims : int
        The dimension of Euclidean space the lattice is in.

    Yields
    -------
    ndarray
        Yields arrays of integer points in dims dimensional Euclidean space.
    """

    ymax = collections.defaultdict(int)
    d = 0

    if dims == 1:
        yield np.array([[0]])
        while True:
            d += 1
            yield np.array([[-d], [d]])

    while True:
        positive_int_lattice = []
        while True:
            batch = []
            for xy in product(range(d+1), repeat=dims-1):
                if _dist(xy, ymax[xy]) <= d**2:
                    batch.append((*xy, ymax[xy]))
                    ymax[xy] += 1
            if not batch:
                break
            positive_int_lattice += batch

        positive_int_lattice = np.array(positive_int_lattice)
        batches = _reflect_positive_lattice(positive_int_lattice)
        yield np.array(np.concatenate(batches))
        d += 1


@numba.njit()
def _reflect_positive_lattice(positive_int_lattice):
    """Reflect a set of points in the +ve quadrant in all axes."""
    dims = positive_int_lattice.shape[-1]
    batches = [positive_int_lattice]

    for n_reflections in range(1, dims + 1):

        indices = np.arange(n_reflections)
        batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
        batch[:, indices] *= -1
        batches.append(batch)

        while True:
            i = n_reflections - 1
            for _ in range(n_reflections):
                if indices[i] != i + dims - n_reflections:
                    break
                i -= 1
            else:
                break
            indices[i] += 1
            for j in range(i+1, n_reflections):
                indices[j] = indices[j-1] + 1
            
            batch = positive_int_lattice[(positive_int_lattice[:, indices] == 0).sum(axis=-1) == 0]
            batch[:, indices] *= -1
            batches.append(batch)

    return batches


@numba.njit()
def _dist(xy, z):
    s = z ** 2
    for val in xy:
        s += val ** 2
    return s


# # @numba.njit()
# def cartesian_product(n, repeat):
#     arrays = [np.arange(n)] * repeat
#     arr = np.empty(tuple([n] * repeat + [repeat]), dtype=np.int64)
#     for i, a in enumerate(np.ix_(*arrays)):
#         arr[..., i] = a
#     return arr.reshape(-1, repeat)
