"""
risk/network/plot/utils/color
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""

from typing import Any, Dict, List, Tuple, Union

import matplotlib
import matplotlib.colors as mcolors
import numpy as np

from risk.network.graph import NetworkGraph
from risk.network.plot.utils.layout import calculate_centroids


def get_annotated_domain_colors(
    graph: NetworkGraph,
    cmap: str = "gist_rainbow",
    color: Union[str, List, Tuple, np.ndarray, None] = None,
    blend_colors: bool = False,
    blend_gamma: float = 2.2,
    min_scale: float = 0.8,
    max_scale: float = 1.0,
    scale_factor: float = 1.0,
    random_seed: int = 888,
) -> np.ndarray:
    """Get colors for the domains based on node annotations, or use a specified color.

    Args:
        graph (NetworkGraph): The network data and attributes to be visualized.
        cmap (str, optional): Colormap to use for generating domain colors. Defaults to "gist_rainbow".
        color (str, List, Tuple, np.ndarray, or None, optional): Color to use for the domains. Can be a single color or an array of colors.
            If None, the colormap will be used. Defaults to None.
        blend_colors (bool, optional): Whether to blend colors for nodes with multiple domains. Defaults to False.
        blend_gamma (float, optional): Gamma correction factor for perceptual color blending. Defaults to 2.2.
        min_scale (float, optional): Minimum scale for color intensity when generating domain colors. Defaults to 0.8.
        max_scale (float, optional): Maximum scale for color intensity when generating domain colors. Defaults to 1.0.
        scale_factor (float, optional): Factor for adjusting the contrast in the colors generated based on significance. Higher values
            increase the contrast. Defaults to 1.0.
        random_seed (int, optional): Seed for random number generation to ensure reproducibility. Defaults to 888.

    Returns:
        np.ndarray: Array of RGBA colors for each domain.
    """
    # Generate domain colors based on the significance data
    node_colors = get_domain_colors(
        graph=graph,
        cmap=cmap,
        color=color,
        blend_colors=blend_colors,
        blend_gamma=blend_gamma,
        min_scale=min_scale,
        max_scale=max_scale,
        scale_factor=scale_factor,
        random_seed=random_seed,
    )
    annotated_colors = []
    for _, node_ids in graph.domain_id_to_node_ids_map.items():
        if len(node_ids) > 1:
            # For multi-node domains, choose the brightest color based on RGB sum
            domain_colors = np.array([node_colors[node] for node in node_ids])
            brightest_color = domain_colors[
                np.argmax(domain_colors[:, :3].sum(axis=1))  # Sum the RGB values
            ]
            annotated_colors.append(brightest_color)
        else:
            # Single-node domains default to white (RGBA)
            default_color = np.array([1.0, 1.0, 1.0, 1.0])
            annotated_colors.append(default_color)

    return np.array(annotated_colors)


def get_domain_colors(
    graph: NetworkGraph,
    cmap: str = "gist_rainbow",
    color: Union[str, List, Tuple, np.ndarray, None] = None,
    blend_colors: bool = False,
    blend_gamma: float = 2.2,
    min_scale: float = 0.8,
    max_scale: float = 1.0,
    scale_factor: float = 1.0,
    random_seed: int = 888,
) -> np.ndarray:
    """Generate composite colors for domains based on significance or specified colors.

    Args:
        graph (NetworkGraph): The network data and attributes to be visualized.
        cmap (str, optional): Name of the colormap to use for generating domain colors. Defaults to "gist_rainbow".
        color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for all domains.
            If None, the colormap will be used. Defaults to None.
        blend_colors (bool, optional): Whether to blend colors for nodes with multiple domains. Defaults to False.
        blend_gamma (float, optional): Gamma correction factor for perceptual color blending. Defaults to 2.2.
        min_scale (float, optional): Minimum intensity scale for the colors generated by the colormap. Controls the dimmest colors.
            Defaults to 0.8.
        max_scale (float, optional): Maximum intensity scale for the colors generated by the colormap. Controls the brightest colors.
            Defaults to 1.0.
        scale_factor (float, optional): Exponent for adjusting the color scaling based on significance scores. Higher values increase
            contrast by dimming lower scores more. Defaults to 1.0.
        random_seed (int, optional): Seed for random number generation to ensure reproducibility of color assignments. Defaults to 888.

    Returns:
        np.ndarray: Array of RGBA colors generated for each domain, based on significance or the specified color.
    """
    # Get colors for each domain
    domain_colors = _get_domain_colors(graph=graph, cmap=cmap, color=color, random_seed=random_seed)
    # Generate composite colors for nodes
    node_colors = _get_composite_node_colors(
        graph=graph, domain_colors=domain_colors, blend_colors=blend_colors, blend_gamma=blend_gamma
    )
    # Transform colors to ensure proper alpha values and intensity
    transformed_colors = _transform_colors(
        node_colors,
        graph.node_significance_sums,
        min_scale=min_scale,
        max_scale=max_scale,
        scale_factor=scale_factor,
    )
    return transformed_colors


def _get_domain_colors(
    graph: NetworkGraph,
    cmap: str = "gist_rainbow",
    color: Union[str, List, Tuple, np.ndarray, None] = None,
    random_seed: int = 888,
) -> Dict[int, Any]:
    """Get colors for each domain.

    Args:
        graph (NetworkGraph): The network data and attributes to be visualized.
        cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
        color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
            If None, the colormap will be used. Defaults to None.
        random_seed (int, optional): Seed for random number generation. Defaults to 888.

    Returns:
        Dict[int, Any]: A dictionary mapping domain keys to their corresponding RGBA colors.
    """
    # Get colors for each domain based on node positions
    domain_colors = _get_colors(
        graph.network,
        graph.domain_id_to_node_ids_map,
        cmap=cmap,
        color=color,
        random_seed=random_seed,
    )
    return dict(zip(graph.domain_id_to_node_ids_map.keys(), domain_colors))


def _get_composite_node_colors(
    graph, domain_colors: np.ndarray, blend_colors: bool = False, blend_gamma: float = 2.2
) -> np.ndarray:
    """Generate composite colors for nodes based on domain colors and significance values, with optional color blending.

    Args:
        graph (NetworkGraph): The network data and attributes to be visualized.
        domain_colors (np.ndarray): Array or list of RGBA colors corresponding to each domain.
        blend_colors (bool): Whether to blend colors for nodes with multiple domains. Defaults to False.
        blend_gamma (float, optional): Gamma correction factor to be used for perceptual color blending.
            This parameter is only relevant if blend_colors is True. Defaults to 2.2.

    Returns:
        np.ndarray: Array of composite colors for each node.
    """
    # Determine the number of nodes
    num_nodes = len(graph.node_coordinates)
    # Initialize composite colors array with shape (number of nodes, 4) for RGBA
    composite_colors = np.zeros((num_nodes, 4))

    # If blending is not required, directly assign domain colors to nodes
    if not blend_colors:
        for domain_id, nodes in graph.domain_id_to_node_ids_map.items():
            color = domain_colors[domain_id]
            for node in nodes:
                composite_colors[node] = color

    # If blending is required
    else:
        for node, node_info in graph.node_id_to_domain_ids_and_significance_map.items():
            domains = node_info["domains"]  # List of domain IDs
            significances = node_info["significances"]  # List of significance values
            # Filter domains and significances to keep only those with corresponding colors in domain_colors
            filtered_domains_significances = [
                (domain_id, significance)
                for domain_id, significance in zip(domains, significances)
                if domain_id in domain_colors
            ]
            # If no valid domains exist, skip this node
            if not filtered_domains_significances:
                continue

            # Unpack filtered domains and significances
            filtered_domains, filtered_significances = zip(*filtered_domains_significances)
            # Get the colors corresponding to the valid filtered domains
            colors = [domain_colors[domain_id] for domain_id in filtered_domains]
            # Blend the colors using the given gamma (default is 2.2 if None)
            gamma = blend_gamma if blend_gamma is not None else 2.2
            composite_color = _blend_colors_perceptually(colors, filtered_significances, gamma)
            # Assign the composite color to the node
            composite_colors[node] = composite_color

    return composite_colors


def _get_colors(
    network,
    domain_id_to_node_ids_map,
    cmap: str = "gist_rainbow",
    color: Union[str, List, Tuple, np.ndarray, None] = None,
    random_seed: int = 888,
) -> List[Tuple]:
    """Generate a list of RGBA colors based on domain centroids, ensuring that domains
    close in space get maximally separated colors, while keeping some randomness.

    Args:
        network (NetworkX graph): The graph representing the network.
        domain_id_to_node_ids_map (Dict[int, Any]): Mapping from domain IDs to lists of node IDs.
        cmap (str, optional): The name of the colormap to use. Defaults to "gist_rainbow".
        color (str, List, Tuple, np.ndarray, or None, optional): A specific color or array of colors to use for the domains.
            If None, the colormap will be used. Defaults to None.
        random_seed (int, optional): Seed for random number generation. Defaults to 888.

    Returns:
        List[Tuple]: List of RGBA colors.
    """
    # Set random seed for reproducibility
    np.random.seed(random_seed)
    # Determine the number of colors to generate based on the number of domains
    num_colors_to_generate = len(domain_id_to_node_ids_map)
    if color:
        # Generate all colors as the same specified color
        rgba = to_rgba(color, num_repeats=num_colors_to_generate)
        return rgba

    # Load colormap
    colormap = matplotlib.colormaps.get_cmap(cmap)
    # Step 1: Calculate centroids for each domain
    centroids = calculate_centroids(network, domain_id_to_node_ids_map)
    # Step 2: Calculate pairwise distances between centroids
    centroid_array = np.array(centroids)
    dist_matrix = np.linalg.norm(centroid_array[:, None] - centroid_array, axis=-1)
    # Step 3: Assign distant colors to close centroids
    color_positions = _assign_distant_colors(dist_matrix, num_colors_to_generate)
    # Step 4: Randomly shift the entire color palette while maintaining relative distances
    global_shift = np.random.uniform(-0.1, 0.1)  # Small global shift to change the overall palette
    color_positions = (color_positions + global_shift) % 1  # Wrap around to keep within [0, 1]
    # Step 5: Ensure that all positions remain between 0 and 1
    color_positions = np.clip(color_positions, 0, 1)

    # Step 6: Generate RGBA colors based on positions
    return [colormap(pos) for pos in color_positions]


def _assign_distant_colors(dist_matrix, num_colors_to_generate):
    """Assign colors to centroids that are close in space, ensuring stark color differences.

    Args:
        dist_matrix (ndarray): Matrix of pairwise centroid distances.
        num_colors_to_generate (int): Number of colors to generate.

    Returns:
        np.array: Array of color positions in the range [0, 1].
    """
    color_positions = np.zeros(num_colors_to_generate)
    # Step 1: Sort indices by centroid proximity (based on sum of distances to others)
    proximity_order = sorted(
        range(num_colors_to_generate), key=lambda idx: np.sum(dist_matrix[idx])
    )
    # Step 2: Assign colors starting with the most distant points in proximity order
    for i, idx in enumerate(proximity_order):
        color_positions[idx] = i / num_colors_to_generate

    # Step 3: Adjust colors so that centroids close to one another are maximally distant on the color spectrum
    half_spectrum = int(num_colors_to_generate / 2)
    for i in range(half_spectrum):
        # Split the spectrum so that close centroids are assigned distant colors
        color_positions[proximity_order[i]] = (i * 2) / num_colors_to_generate
        color_positions[proximity_order[-(i + 1)]] = ((i * 2) + 1) / num_colors_to_generate

    return color_positions


def _blend_colors_perceptually(
    colors: Union[List, Tuple, np.ndarray], significances: List[float], gamma: float = 2.2
) -> Tuple[float, float, float, float]:
    """Blends a list of RGBA colors using gamma correction for perceptually uniform color mixing.

    Args:
        colors (List, Tuple, np.ndarray): List of RGBA colors. Can be a list, tuple, or NumPy array of RGBA values.
        significances (List[float]): Corresponding list of significance values.
        gamma (float, optional): Gamma correction factor, default is 2.2 (typical for perceptual blending).

    Returns:
        Tuple[float, float, float, float]: The blended RGBA color.
    """
    # Normalize significances so they sum up to 1 (proportions)
    total_significance = sum(significances)
    proportions = [significance / total_significance for significance in significances]
    # Convert colors to gamma-corrected space (apply gamma correction to RGB channels)
    gamma_corrected_colors = [[channel**gamma for channel in color[:3]] for color in colors]
    # Blend the colors in gamma-corrected space
    blended_color = np.dot(proportions, gamma_corrected_colors)
    # Convert back from gamma-corrected space to linear space (by applying inverse gamma correction)
    blended_color = [channel ** (1 / gamma) for channel in blended_color]
    # Average the alpha channel separately (no gamma correction on alpha)
    alpha = np.dot(proportions, [color[3] for color in colors])
    return tuple(blended_color + [alpha])


def _transform_colors(
    colors: np.ndarray,
    significance_sums: np.ndarray,
    min_scale: float = 0.8,
    max_scale: float = 1.0,
    scale_factor: float = 1.0,
) -> np.ndarray:
    """Transform colors using power scaling to emphasize high significance sums more. Black colors are replaced with
    very dark grey to avoid issues with color scaling (rgb(0.1, 0.1, 0.1)).

    Args:
        colors (np.ndarray): An array of RGBA colors.
        significance_sums (np.ndarray): An array of significance sums corresponding to the colors.
        min_scale (float, optional): Minimum scale for color intensity. Defaults to 0.8.
        max_scale (float, optional): Maximum scale for color intensity. Defaults to 1.0.
        scale_factor (float, optional): Exponent for scaling, where values > 1 increase contrast by dimming small
            values more. Defaults to 1.0.

    Returns:
        np.ndarray: The transformed array of RGBA colors with adjusted intensities.
    """
    # Ensure that min_scale is less than max_scale
    if min_scale == max_scale:
        min_scale = max_scale - 10e-6  # Avoid division by zero

    # Replace black colors (#000000) with very dark grey (#1A1A1A)
    black_color = np.array([0.0, 0.0, 0.0])  # Pure black RGB
    dark_grey = np.array([0.1, 0.1, 0.1])  # Very dark grey RGB (#1A1A1A)
    # Check where colors are black (very close to [0, 0, 0]) and replace with dark grey
    is_black = np.all(colors[:, :3] == black_color, axis=1)
    colors[is_black, :3] = dark_grey

    # Normalize the significance sums to the range [0, 1]
    normalized_sums = significance_sums / np.max(significance_sums)
    # Apply power scaling to dim lower values and emphasize higher values
    scaled_sums = normalized_sums**scale_factor
    # Linearly scale the normalized sums to the range [min_scale, max_scale]
    scaled_sums = min_scale + (max_scale - min_scale) * scaled_sums
    # Adjust RGB values based on scaled sums
    for i in range(3):  # Only adjust RGB values
        colors[:, i] = scaled_sums * colors[:, i]

    return colors


def to_rgba(
    color: Union[str, List, Tuple, np.ndarray],
    alpha: Union[float, None] = None,
    num_repeats: Union[int, None] = None,
) -> np.ndarray:
    """Convert color(s) to RGBA format, applying alpha and repeating as needed.

    Args:
        color (str, List, Tuple, np.ndarray): The color(s) to convert. Can be a string (e.g., 'red'), a list or tuple of RGB/RGBA values,
            or an `np.ndarray` of colors.
        alpha (float, None, optional): Alpha value (transparency) to apply. If provided, it overrides any existing alpha values found
            in color.
        num_repeats (int, None, optional): If provided, the color(s) will be repeated this many times. Defaults to None.

    Returns:
        np.ndarray: Array of RGBA colors repeated `num_repeats` times, if applicable.
    """

    def convert_to_rgba(c: Union[str, List, Tuple, np.ndarray]) -> np.ndarray:
        """Convert a single color to RGBA format, handling strings, hex, and RGB/RGBA lists."""
        # Note: if no alpha is provided, the default alpha value is 1.0 by mcolors.to_rgba
        if isinstance(c, str):
            # Convert color names or hex values (e.g., 'red', '#FF5733') to RGBA
            rgba = np.array(mcolors.to_rgba(c))
        elif isinstance(c, (list, tuple, np.ndarray)) and len(c) in [3, 4]:
            # Convert RGB (3) or RGBA (4) values to RGBA format
            rgba = np.array(mcolors.to_rgba(c))
        else:
            raise ValueError(
                f"Invalid color format: {c}. Must be a valid string or RGB/RGBA sequence."
            )

        if alpha is not None:  # Override alpha if provided
            rgba[3] = alpha
        return rgba

    # If color is a 2D array of RGBA values, convert it to a list of lists
    if isinstance(color, np.ndarray) and color.ndim == 2 and color.shape[1] == 4:
        color = [list(c) for c in color]

    # Handle a single color (string or RGB/RGBA list/tuple)
    if (
        isinstance(color, str)
        or isinstance(color, (list, tuple, np.ndarray))
        and not any(isinstance(c, (str, list, tuple, np.ndarray)) for c in color)
    ):
        rgba_color = convert_to_rgba(color)
        if num_repeats:
            return np.tile(
                rgba_color, (num_repeats, 1)
            )  # Repeat the color if num_repeats is provided
        return np.array([rgba_color])  # Return a single color wrapped in a numpy array

    # Handle a list/array of colors
    elif isinstance(color, (list, tuple, np.ndarray)):
        rgba_colors = np.array(
            [convert_to_rgba(c) for c in color]
        )  # Convert each color in the list to RGBA
        # Handle repetition if num_repeats is provided
        if num_repeats:
            repeated_colors = np.array(
                [rgba_colors[i % len(rgba_colors)] for i in range(num_repeats)]
            )
            return repeated_colors

        return rgba_colors

    else:
        raise ValueError("Color must be a valid RGB/RGBA or array of RGB/RGBA colors.")
