import hashlib
import json
import math
from typing import List, Tuple, Dict

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.colors import to_hex
from torch import Tensor

from compute_result.typing import (
    BudgetRun,
    ProblemSpace,
    StatisticalAnalysisOptions,
    BudgetStepsRun,
    MetricManipulation,
    Point,
    RunResult,
    AxisAnalysis,
    AxisAnalysisData,
)
from problems.types import Suites

DEFAULT_COLOR_MAP = plt.get_cmap("tab10")


def torch_wrapper(func):
    def wrapper(values):
        if isinstance(values, Tensor):
            return func(values)
        elif isinstance(values, np.ndarray):
            return func(torch.from_numpy(values)).numpy()
        else:
            return func(torch.tensor(values)).tolist()

    return wrapper


STATISTICAL_ANALYSIS_MAPPING = {
    StatisticalAnalysisOptions.MEAN: torch_wrapper(lambda values: values.mean(dim=-1)),
    StatisticalAnalysisOptions.MEDIAN: torch_wrapper(
        lambda values: values.median(dim=-1)
    ),
    StatisticalAnalysisOptions.QUANTILE_90: torch_wrapper(
        lambda values: values.quantile(0.9, dim=-1)
    ),
    StatisticalAnalysisOptions.QUANTILE_10: torch_wrapper(
        lambda values: values.quantile(0.1, dim=-1)
    ),
    StatisticalAnalysisOptions.STD: torch_wrapper(lambda values: values.std(dim=-1)),
}

METRIC_MANIPULATION = {
    MetricManipulation.NONE: lambda metric: np.array(
        [val for val in metric.iloc[:, 2]]
    ),
    MetricManipulation.COS: lambda metric: np.array(
        [-math.log(math.fabs(val)) for val in metric.iloc[:, 2]]
    ),
    MetricManipulation.COS_CUMSUM: lambda metric: np.array(
        [math.fabs(val) for val in metric.iloc[:, 2]]
    ).cumsum(),
    MetricManipulation.NORM: lambda metric: np.linalg.norm(
        np.array([json.loads(val) for val in metric.iloc[:, 2]]), axis=1
    ),
    MetricManipulation.NORM_CUMSUM: lambda metric: np.linalg.norm(
        np.array([json.loads(val) for val in metric.iloc[:, 2]]), axis=1
    ).cumsum(),
}
AXIS_ANALYSIS = {
    AxisAnalysis.MEAN: AxisAnalysisData(
        lambda norm_results, **kwargs: norm_results.mean(axis=1),
        name=lambda **kwargs: "Mean",
    ),
    AxisAnalysis.STD: AxisAnalysisData(
        lambda norm_results, **kwargs: norm_results.std(axis=1),
        name=lambda **kwargs: "Standard Deviation",
    ),
    AxisAnalysis.SOLVED: AxisAnalysisData(
        lambda norm_results, offset=0.01, **kwargs: (norm_results <= offset).sum(axis=1)
        / norm_results.shape[1],
        name=lambda **kwargs: "Solved Functions",
    ),
    AxisAnalysis.BUDGET_FOR_PROGRESS: AxisAnalysisData(
        lambda results, percentile, **kwargs: how_much_budget_to_reach_percentile(
            results, percentile
        ).mean(axis=1),
        name=lambda percentile, **kwargs: f"Budget to reach {percentile} percentile",
    ),
}


def max_budget_of_runs(runs_result: List[BudgetStepsRun]):
    return max([max(result, key=lambda x: x[0])[0] for result in runs_result])


def from_mapped_result_to_steps(
    steps: BudgetStepsRun, max_steps: int = None
) -> List[float]:
    """
    This function receive a mapping from x axis to y axis and convert it to a list of steps with equal length
    """
    step_mapping = {step[0]: step[1] for step in steps}
    max_steps = max_steps or max(step_mapping)
    max_value = step_mapping[min(step_mapping)]

    total_alg_results = []
    for i in range(max_steps + 1):
        if i in step_mapping:
            total_alg_results += [step_mapping[i]]
        else:
            total_alg_results += [
                (total_alg_results[i - 1] if total_alg_results else max_value)
            ]
    return total_alg_results


def full_steps_for_budget_run(
    results: List[BudgetRun], max_steps: int = None
) -> List[List[float]]:
    max_used_budget = max_steps or max_budget_of_runs(results)
    steps_normalized_results = [
        from_mapped_result_to_steps(normalized_result, max_used_budget)
        for normalized_result in results
    ]
    return steps_normalized_results


def measure_statistic_result_over_multiple_runs(
    results: List[BudgetRun],
    statistical_analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
):
    steps_normalized_results = full_steps_for_budget_run(results)

    return [
        statistical_analyzer(normalized_values)
        for normalized_values in zip(*steps_normalized_results)
        if (
            statistical_analyzer := STATISTICAL_ANALYSIS_MAPPING.get(
                statistical_analysis
            )
        )
    ]


def problems_with_the_same_func(
    suite: Suites, func_id: int, problems: List[ProblemSpace]
):
    return [
        (curr_suite, curr_func_id, dim, instance)
        for curr_suite, curr_func_id, dim, instance in problems
        if curr_suite == suite
        if curr_func_id == func_id
    ]


def problems_with_the_same_dim(dim: int, problems: List[ProblemSpace]):
    return [
        (curr_suite, func_id, curr_dim, instance)
        for curr_suite, func_id, curr_dim, instance in problems
        if curr_dim == dim
    ]


def add_to_range(ranges, start, end, array_size):
    mid_point = int(array_size / 2)
    if end - start <= mid_point:
        ranges.append((start, end))
    else:
        ranges.append((mid_point, end))
        ranges.append((start, mid_point - 1))
    return ranges


def get_index_ranges(indexes: List[int], array_size: int) -> List[Tuple[int, int]]:
    indexes = sorted(indexes)
    parts = []
    ranges = []
    start = indexes[0]
    end = indexes[0]

    if len(indexes) == 1:
        ranges.append((indexes[0], indexes[0]))

    for i in range(1, len(indexes)):
        index = indexes[i]
        # Check if this is the end of the chunk
        if index - indexes[i - 1] == 1:
            end = index
        # Add chunk to ranges
        else:
            ranges = add_to_range(ranges, start, end, array_size)
            start = index
            end = index
    if start != end:
        ranges = add_to_range(ranges, start, end, array_size)

    for range_start, range_end in sorted(
        ranges, key=lambda x: x[1] - x[0], reverse=True
    ):
        # Check if the range is contained in previous ranges
        for chunk_number, num_of_chunks in parts:
            chunk_size = math.floor(array_size / num_of_chunks)
            part_start = chunk_number * chunk_size
            part_end = part_start + chunk_size
            if part_start <= range_start <= range_end < part_end:
                continue

        # Create new part
        new_part, chunk_size = find_fractile_for_range(
            array_size, range_start, range_end
        )
        parts += new_part

    # Sort by chunk size
    parts = sorted(parts, key=lambda x: x[1], reverse=True)

    # Convert the ranges to arrays or lists as needed
    return parts


def find_fractile_for_range(array_size: int, start: int, end: int):
    ranges = []
    range_size = end - start + 1
    num_of_chunks = math.floor(array_size / range_size)
    chunk_size = math.floor(array_size / num_of_chunks)
    for chunk_index in range(num_of_chunks):
        if chunk_index * chunk_size <= start < (chunk_index + 1) * chunk_size:
            ranges.append((chunk_index, num_of_chunks))
        elif chunk_index * chunk_size <= end < (chunk_index + 1) * chunk_size:
            ranges.append((chunk_index, num_of_chunks))
    return ranges, chunk_size


def normalize_result(
    func_value: float, algorithm_min: float, algorithm_max: float
) -> float:
    return (func_value - algorithm_min) / (algorithm_max - algorithm_min)


def points_from_run(results: RunResult) -> List[Point]:
    return [point for _, _, point, *__ in results]


def np_array_from_results(
    results: List[List[RunResult]], index_to_take=1, max_length=None
) -> np.ndarray:
    max_inner_length = max_length or max(
        max(len(inner) for inner in middle) for middle in results
    )
    padded_results = [
        [inner + [inner[-1]] * (max_inner_length - len(inner)) for inner in middle]
        for middle in results
    ]
    values = np.array(
        [
            [[step[index_to_take] for step in result] for result in run_results]
            for run_results in padded_results
        ]
    )
    return values


def np_array_from_res_split_by_dim(
    results: Dict[int, List[List[RunResult]]], index_to_take=1
):
    max_length = max(
        max(max(len(inner) for inner in middle) for middle in res)
        for res in results.values()
    )
    return {
        dim: np_array_from_results(results[dim], index_to_take, max_length)
        for dim in results
    }


def how_much_budget_to_reach_percentile(
    results: List[List[RunResult]], finish_percentile: float
):
    values = np_array_from_results(results)
    budgets = np_array_from_results(results, 0)
    max_for_run_problem = values.max(axis=2)
    min_for_run_problem = values.min(axis=2)
    distance = max_for_run_problem - min_for_run_problem
    percentile_value = distance * finish_percentile
    percentile_idx = values < np.expand_dims(
        percentile_value + min_for_run_problem, axis=2
    )
    budget_in_percentile = np.min(np.where(percentile_idx, budgets, np.inf), axis=2)
    alg_failed_to_progress = budget_in_percentile == np.inf
    budget_in_percentile[alg_failed_to_progress] = budgets[:, :, -1][
        alg_failed_to_progress
    ]
    return budget_in_percentile


def rolling_average(data, window_size=10):
    padding = (window_size - 1) // 2
    kernel = torch.ones(window_size).to(device=data.device) / window_size
    smoothed_data = torch.nn.functional.conv1d(
        data.unsqueeze(0).unsqueeze(0),
        kernel.unsqueeze(0).unsqueeze(0),
        padding=padding,
    )
    return smoothed_data.squeeze()


def get_color(name: str, colormap):
    hash_object = hashlib.md5(name.encode())
    hash_int = int(hash_object.hexdigest(), 16)
    color_index = (hash_int % colormap.N) / colormap.N
    return colormap(color_index)


def get_label_color_map(labels, cmap):
    unique_labels = sorted(set(labels))
    n_labels = len(unique_labels)
    colors = []

    # Determine how many cycles we need through the base colormap
    cycles_needed = (n_labels // 10) + 1

    for cycle in range(cycles_needed):
        for i in range(10):
            if len(colors) >= n_labels:
                break
            # Get the base color from the colormap
            base_color = np.array(cmap(i))
            # Adjust the brightness to generate new colors
            factor = 1 - (cycle * 0.1)
            adjusted_color = base_color * factor
            adjusted_color = np.clip(adjusted_color, 0, 1)
            colors.append(to_hex(adjusted_color))

    # Map each unique label to a color
    label_color_map = dict(zip(unique_labels, colors))
    return label_color_map


def create_plot_with_statistics(x, mean_results, lower, upper, plot_names, colormap):
    fig, ax = plt.subplots()
    colormap = get_label_color_map(plot_names, colormap)
    for i, plot_name in enumerate(plot_names):
        ax.plot(x, mean_results[i], label=plot_name, color=colormap[plot_name])
        ax.fill_between(
            x,
            lower[i],
            upper[i],
            alpha=0.3,
            color=colormap[plot_name],
        )
    return fig, ax
