from typing import Callable, Literal, Optional, cast
import itertools
import math

from joblib import Parallel, delayed
import torch
import pandas as pd
import numpy as np

from vis_models.metrics.representation_similarity.cka import MinibatchCKA
from .representation_distances import (
    RepDistanceMetric,
    compute_representation_distances,
)


AggregationMode = Literal["intra_group", "inter_group", "across_all"]

RandomSubsetGetter = Callable[[torch.Tensor, Optional[int]], torch.Tensor]

DIST_METRIC = "linear_cka"

class MultiGroupDistTracker:

    ACROSS_ALL_KEY = "across_all"

    def __init__(
        self,
        # generator: torch.Generator,
        # tracking_groups: list[tuple[str, str]],
        dist_metric: RepDistanceMetric = "linear_cka",
    ) -> None:
        # if dist_metric == "linear_cka":
        #     self.tracker_constructor = MinibatchCKA
        # elif dist_metric == "l2":
        #     self.tracker_constructor = EagerTracker
        # else:
        #     raise ValueError()

        self.dist_metric: RepDistanceMetric = dist_metric
        self.group_trackers = {}

        # def get_random_subset(
        #     t: torch.Tensor, n_samples: Optional[int]
        # ) -> torch.Tensor:
        #     return _random_subset(t, n_samples, generator)
        # self.get_random_subset = get_random_subset

    def track(
        self,
        tracking_group: tuple[str, str],
        g1_samples: dict[str, torch.Tensor],
        g2_samples: dict[str, torch.Tensor],
    ) -> None:
        assert g1_samples.keys() == g2_samples.keys()
        layer_trackers = self.group_trackers.setdefault(tracking_group, {})
        for layer_name, g1_layer_sample in g1_samples.items():
            g2_layer_sample = g2_samples[layer_name]
            layer_tracker = layer_trackers.setdefault(
                # layer_name, self.tracker_constructor()
                layer_name, EagerTracker(self.dist_metric)
            )
            layer_tracker.add_minibatch(X=g1_layer_sample, Y=g2_layer_sample)

    def compute_mean_dist(
        self,
        aggregation_mode: AggregationMode,
    ) -> pd.Series:
        if aggregation_mode == "intra_group":
            aggregation_groups = [
                group for group in self.group_trackers.keys()
                if group[0] == group[1]
            ]
        elif aggregation_mode == "inter_group":
            aggregation_groups = [
                group for group in self.group_trackers.keys()
                if group[0] != group[1]
            ]
        else:
            raise ValueError()
        # print("agg mode:", aggregation_mode)
        # print("agg groups", aggregation_groups)
        dists = self._compute_layer_dists(aggregation_groups)
        # print("dists:", dists)
        mean_dists = cast(pd.Series, dists.mean(axis=1))
        mean_dists.index.set_names("layer", inplace=True)
        # print("mean dists:", mean_dists)
        return mean_dists

    def _compute_layer_dists(
        self, groups: list[tuple[str, str]]
    ) -> pd.DataFrame:
        dists = pd.DataFrame(dtype=float)
        for group in groups:
            layer_trackers = self.group_trackers[group]
            group_name = f"{group[0]}-{group[1]}"
            for layer_name, layer_tracker in layer_trackers.items():
                layer_dist = layer_tracker.value()
                dists.loc[layer_name, group_name] = layer_dist
        return dists

    # def add_rep_samples(
    #     self, group_name: str, reps: dict[str, torch.Tensor]
    # ) -> None:
    #     group_samples = self.rep_samples.setdefault(group_name, {})
    #     for layer_name, layer_reps in reps.items():
    #         layer_batches = group_samples.setdefault(layer_name, [])
    #         # Use twice the amount of samples because they will be split
    #         # later to create distance pairs
    #         rep_sample = self.get_random_subset(
    #             layer_reps, 2 * self.n_samples,
    #         )
    #         layer_batches.append(rep_sample)

    # def compute_mean_dist(
    #     self,
    #     aggregation_mode: AggregationMode,
    # ) -> pd.Series:
    #     if aggregation_mode == "intra_group":
    #         dists = self._compute_across_all_dists(
    #             shuffle_all=False,
    #             dist_metric=dist_metric,
    #         )
    #     elif aggregation_mode == "inter_group":
    #         dists = self._compute_inter_group_dists(
    #             dist_metric=dist_metric,
    #         )
    #     elif aggregation_mode == "across_all":
    #         dists = self._compute_across_all_dists(
    #             shuffle_all=True,
    #             dist_metric=dist_metric,
    #         )
    #     else:
    #         raise ValueError()
    #     dists.index.set_names("layer", inplace=True)
    #     return dists

    # def _compute_inter_group_dists(
    #     self,
    #     dist_metric: RepDistanceMetric,
    # ) -> pd.Series:
    #     # layer_samples = {}
    #     layers = list(self.rep_samples.values())[0].keys()
    #     dists = pd.Series(0, index=layers, dtype=float)
    #     # Subsample each group, else the n^2 complexity of this operation
    #     # blows up
    #     # n_group_samples = (
    #     #     2 * math.ceil(self.n_samples / (len(self.rep_samples) - 1))
    #     #     if len(self.rep_samples) > 1
    #     #     else self.n_samples
    #     # )
    #     n_group_samples = self.n_samples
    #     n_group_pairs = 0
    #     for group_1, group_1_samples in self.rep_samples.items():
    #         for group_2, group_2_samples in self.rep_samples.items():
    #             if group_1 >= group_2:
    #                 # Only count each combination once
    #                 continue
    #             assert group_1_samples.keys() == group_2_samples.keys()
    #             n_group_pairs += 1
    #             # print("computing distance for group pair", group_1, group_2)
    #             layer_samples = {}
    #             for layer in group_1_samples.keys():
    #                 layer_1_samples = self.get_random_subset(
    #                     _concatenate_batches(group_1_samples[layer]),
    #                     n_group_samples,
    #                 )
    #                 layer_2_samples = self.get_random_subset(
    #                     _concatenate_batches(group_2_samples[layer]),
    #                     n_group_samples,
    #                 )
    #                 layer_1_batches, layer_2_batches = layer_samples.setdefault(
    #                     layer, ([], [])
    #                 )
    #                 layer_1_batches.append(layer_1_samples)
    #                 layer_2_batches.append(layer_2_samples)
    #             group_pair_dists = self._compute_layerwise_dists(
    #                 layer_samples,
    #                 shuffle=False,
    #                 dist_metric=dist_metric,
    #             )
    #             # print("group pair dists:", group_pair_dists)
    #             dists += group_pair_dists
    #     dists /= n_group_pairs
    #     return dists
    #     # return self._compute_layerwise_dists(
    #     #     layer_samples,
    #     #     shuffle=False,
    #     # )

    # def _compute_inter_group_dists(self) -> pd.DataFrame:
    #     dist_results = Parallel(n_jobs=-1, backend="threading")(
    #         delayed(RepresentationTracker._inter_group_kernel)(
    #             group_1,
    #             group_2,
    #             # self.n_distance_samples,
    #             # self.generator.seed(),
    #         )
    #         for group_1, group_2 in itertools.product(
    #             self.rep_samples.items(), self.rep_samples.items()
    #         )
    #     )
    #     if dist_results is None:
    #         raise ValueError()
    #     return pd.concat(
    #         [dist_res for dist_res in dist_results if dist_res is not None],
    #         axis=1,
    #     )

    # def _compute_across_all_dists(
    #     self,
    #     shuffle_all: bool,
    #     dist_metric: RepDistanceMetric,
    # ) -> pd.Series:
    #     layer_samples = {}
    #     for group_samples in self.rep_samples.values():
    #         for layer, layer_batches in group_samples.items():
    #             full_reps = _concatenate_batches(layer_batches)
    #             reps1, reps2 = _split_samples(full_reps, self.get_random_subset)
    #             layer_1_batches, layer_2_batches = layer_samples.setdefault(
    #                 layer, ([], [])
    #             )
    #             layer_1_batches.append(reps1)
    #             layer_2_batches.append(reps2)

    #     return self._compute_layerwise_dists(
    #         layer_samples,
    #         shuffle=shuffle_all,
    #         dist_metric=dist_metric,
    #     )

    # def _compute_layerwise_dists(
    #     self,
    #     layer_samples: dict[str, tuple[list[torch.Tensor], list[torch.Tensor]]],
    #     shuffle: bool,
    #     dist_metric: RepDistanceMetric,
    # ) -> pd.Series:
    #     dists = pd.Series(dtype=float)
    #     for layer, (layer_1_batches, layer_2_batches) in layer_samples.items():
    #         full_reps_1 = _concatenate_batches(layer_1_batches)
    #         print("num samples:", len(full_reps_1))
    #         full_reps_2 = _concatenate_batches(layer_2_batches)
    #         assert len(full_reps_1) == len(full_reps_2)
    #         if shuffle:
    #             full_reps_1 = self.get_random_subset(
    #                 full_reps_1, None,
    #             )
    #             full_reps_2 = self.get_random_subset(
    #                 full_reps_2, None,
    #             )
    #         dists[layer] = (
    #             compute_representation_distances(
    #                 full_reps_1, full_reps_2, dist_metric,
    #             ).mean().item()
    #         )
    #     return dists

    # @staticmethod
    # def _inter_group_kernel(
    #     group_1: tuple[str, dict],
    #     group_2: tuple[str, dict],
    #     # n_samples: Optional[int],
    #     # seed: int,
    #     # rng: torch.Generator,
    # ) -> Optional[pd.DataFrame]:
    #     # print("parallel worker")
    #     group_1_name, group_1_samples = group_1
    #     group_2_name, group_2_samples = group_2
    #     # rng = torch.Generator()
    #     # rng.manual_seed(seed)
    #     if group_1_name >= group_2_name:
    #         # inter-group is only meant for comparing groups
    #         # that are not the same
    #         # We also only need to do one direction of comparison, since
    #         # the distances are symmetric
    #         return None
    #     assert group_1_samples.keys() == group_2_samples.keys()

    #     dists = pd.DataFrame()
    #     inter_group_name = f"{group_1_name}-{group_2_name}"
    #     for (layer_1, layer_1_batches), (_, layer_2_batches) \
    #         in zip(group_1_samples.items(), group_2_samples.items()) \
    #     :
    #         full_1_reps = _concatenate_batches(layer_1_batches)
    #         full_2_reps = _concatenate_batches(layer_2_batches)
    #         assert len(full_1_reps) == len(full_2_reps), (
    #             f"Length {len(full_1_reps)} != {len(full_2_reps)}"
    #         )
    #         reps1, reps2 = full_1_reps, full_2_reps
    #         # reps1 = random_subset(full_1_reps, n_samples, rng)
    #         # reps2 = random_subset(full_2_reps, n_samples, rng)
    #         layer_dists = compute_representation_distances(
    #             reps1, reps2, DIST_METRIC
    #         )
    #         dists.loc[layer_1, inter_group_name] = layer_dists.mean().item()
    #     return dists


# def _concatenate_batches(batches: list[torch.Tensor]) -> torch.Tensor:
#     return torch.cat(batches, 0)

# def _split_samples(
#     samples: torch.Tensor, get_random_subset: RandomSubsetGetter,
# ) -> tuple[torch.Tensor, torch.Tensor]:
#     subsampled_reps = get_random_subset(samples, None)
#     n_samples = len(subsampled_reps)
#     assert n_samples % 2 == 0
#     n_half_samples = n_samples // 2
#     reps1 = subsampled_reps[:n_half_samples]
#     reps2 = subsampled_reps[n_half_samples:]
#     return (reps1, reps2)

# def _random_subset(
#     array: torch.Tensor, n_samples: Optional[int], rng: torch.Generator,
# ) -> torch.Tensor:
#     n_samples = n_samples if n_samples is not None else len(array)
#     """Random subsampling without replacement"""
#     return array[torch.randperm(len(array), generator=rng)[:n_samples]]

# def _random_choice(
#     array: torch.Tensor, n_samples: int, rng: torch.Generator,
# ) -> torch.Tensor:
#     """Random subsampling with replacement"""
#     return array[torch.randint(len(array), (n_samples,), generator=rng)]


class EagerTracker:

    def __init__(self, dist_metric: RepDistanceMetric) -> None:
        self.dist_metric: RepDistanceMetric = dist_metric
        self.reset()
    
    def reset(self) -> None:
        self.dists = []

    def add_minibatch(self, X: torch.Tensor, Y: torch.Tensor) -> None:
        mean_dist = compute_representation_distances(X, Y, self.dist_metric)
        self.dists.append(mean_dist)

    def value(self) -> float:
        # all_dists = torch.cat(self.dists)
        return np.mean(self.dists).item()
