import math
from typing import Callable, Iterator, Literal, Optional, cast

import torch
import pandas as pd
from tqdm import tqdm
import numpy as np

from datasets.objects_2d import OBJECTS as OBJ2D_OBJECTS
from datasets.objects_2d import TRANSFORMS as OBJ2D_TRANSFORMS
from utils.representations.representation_distances import (
    compute_representations,
    ModelConfig,
)
from utils.representations.rep_dist_tracker import MultiGroupDistTracker
from utils import persistence
from .data import RIDataConfig, create_single_object_dataset


# @dataclass
# class InvarianceEvaluationDistances:
#     withinClasses: pd.DataFrame
#     acrossClasses: pd.DataFrame
#     acrossTransforms: pd.DataFrame

OBJECT_TYPES = ["training", "holdout"]#, "random"]
# ACROSS_ALL_KEY = "across_all"

@persistence.cached_result
def evaluate_invariance(
    models: dict[str, torch.nn.Module],
    model_config: ModelConfig,
    # training_transform: str,
    training_objects: list[int],
    data_config: RIDataConfig,
    n_distance_samples: int,
    n_object_samples: int,
    dist_metric: str,
) -> Iterator[pd.DataFrame]:
    layer_names = list(_get_monitored_layers(model_config, None).keys())
    transforms = data_config.transforms
    if transforms is None:
        raise ValueError("Specify the list of transforms explicitly")
    index = pd.MultiIndex.from_product(
        [
            models.keys(),
            transforms,
            layer_names,
        ],
        names=("model", "transform", "layer"),
    )
    dists = pd.DataFrame(
        # columns=OBJECT_TYPES,
        index=index,
    )

    # rng = torch.Generator()
    # rng.manual_seed(data_config.sampling_seed + 52)

    for model_name, model in models.items():
        print(f"------\nEvaluating invariance for model {model_name}\n------")

        # if model_name == "untrained":
        #     continue

        for transform_name in transforms:
            print(f"Evaluating transform {transform_name}")

            for object_type in OBJECT_TYPES:
                print(f"Evaluating object type {object_type}")
                objects = _get_object_type_objects(
                    object_type,
                    training_objects,
                    data_config,
                    n_object_samples=n_object_samples,
                )

                # dist_tracker = MultiGroupDistTracker("linear_cka")
                # dist_tracker = MultiGroupDistTracker("l2")
                dist_tracker = MultiGroupDistTracker(dist_metric)
                for object_1 in tqdm(objects):
                    reps_base = _compute_representations(
                        model,
                        model_config,
                        transform_name,
                        object_1,
                        random_object=object_type == "random",
                        data_config=data_config,
                        n_distance_samples=2 * n_distance_samples,
                    )

                    for object_2 in objects:
                        if object_1 > object_2:
                            continue
                        elif object_1 == object_2:
                            print("estimating", object_1, object_2)
                            reps_2 = {
                                layer_name: reps[
                                    n_distance_samples:2 * n_distance_samples
                                ]
                                # layer_name: -0.5 * reps[:n_distance_samples]
                                for layer_name, reps in reps_base.items()
                            }
                        else:
                            print("estimating", object_1, object_2)
                            reps_2 = _compute_representations(
                                model,
                                model_config,
                                transform_name,
                                object_2,
                                random_object=object_type == "random",
                                data_config=data_config,
                                n_distance_samples=n_distance_samples,
                            )
                        reps_1 = {
                            layer_name: reps[:n_distance_samples]
                            for layer_name, reps in reps_base.items()
                        }

                        dist_tracker.track(
                            (str(object_1), str(object_2)), reps_1, reps_2
                        )

                # Add the mean distance between objects of that type
                intra_class_dists = cast(
                    pd.Series,
                    dist_tracker.compute_mean_dist("intra_group")
                    # .mean(axis=1)
                ).values
                dists.loc[
                    (model_name, transform_name), f"within_{object_type}"
                ] = intra_class_dists
                inter_class_dists = cast(
                    pd.Series,
                    dist_tracker.compute_mean_dist("inter_group")
                    # .mean(axis=1)
                ).values
                dists.loc[
                    (model_name, transform_name), f"between_{object_type}"
                ] = inter_class_dists

                # across_class_dists = (
                #     rep_tracker.compute_mean_dist("across_all")
                #     .values
                # )
                # dists.loc[
                #     (model_name, transform_name), f"across_{object_type}"
                # ] = across_class_dists
                print("dists:", dists)
                yield dists

    print("dists:", dists)
    return dists

# def get_transforms(data_config: ITDataConfig) -> list[str]:
#     if data_config.dataset == "obj2d":
#         return list(OBJ2D_TRANSFORMS.keys())
#     else:
#         raise ValueError(f"Invalid data type '{data_config.dataset}'")

def _get_object_type_objects(
    object_type: str,
    training_objects: list[int],
    data_config: RIDataConfig,
    n_object_samples: int,
) -> list[int]:
    rng = np.random.default_rng(data_config.config_seed + 942)
    if object_type == "training":
        selected_training_ojbects = rng.permutation(training_objects)[
            :n_object_samples
        ]
        return selected_training_ojbects.tolist()
    elif object_type == "holdout":
        if data_config.dataset == "obj2d":
            remaining_objects = list(
                set(OBJ2D_OBJECTS) - set(training_objects)
            )
            holdout_objects = rng.permutation(remaining_objects)[
                :n_object_samples
            ]
            return holdout_objects.tolist()
        else:
            raise ValueError(f"Invalid data type '{data_config.dataset}'")
    elif object_type == "random":
        return list(range(n_object_samples))
    else:
        raise ValueError(f"Invalid object type '{object_type}'")

def _compute_representations(
    model: torch.nn.Module,
    model_config: ModelConfig,
    transform_name: str,
    object_id: int,
    random_object: bool,
    data_config: RIDataConfig,
    n_distance_samples: int,
) -> dict[str, torch.Tensor]:
    so_dataset = create_single_object_dataset(
        data_config,
        transform_name,
        object_id,
        random_object,
        n_distance_samples,
    )
    monitored_layers = _get_monitored_layers(model_config, model)
    reps = compute_representations(
        model, model_config, so_dataset, list(monitored_layers.items()),
    )
    return reps

CIFAR_RESNET_18_LAYERS = [
    ("conv1", lambda model: model.conv1),
    ("layer1", lambda model: model.layer1),
    ("layer2", lambda model: model.layer2),
    ("layer3", lambda model: model.layer3),
    ("layer4", lambda model: model.layer4),
    ("avgpool", lambda model: model.avgpool),
]

def _get_monitored_layers(
    config: ModelConfig, model: Optional[torch.nn.Module],
) -> dict[str, Optional[torch.nn.Module]]:
    if config.domain == "cifar" and config.type == "resnet-18":
        return {
            layer_name: layer_getter(model) if model is not None else None
            for layer_name, layer_getter in CIFAR_RESNET_18_LAYERS
        }
    else:
        raise NotImplementedError()
