from typing import Generator, Dict, Tuple, Any, Callable, List
import logging
from copy import deepcopy

import torch.nn
from torch.utils.data import DataLoader
import numpy as np

from .analysis_method import _SingleAnalysisMethod, ResultGeneratorType
from path_learning.utils.result import TaskResult
from path_learning.loss.domain_distance_metrics.clean_fid import fid_from_feats
from path_learning.learning.learning_intervention import intervention
from path_learning.utils.log import get_logger
from path_learning.models.resnet import ResNet

logger = get_logger("invariance_analysis")


class SingleTaskInvarianceAnalysis(_SingleAnalysisMethod):
    name = "task_invariance_analysis"

    def __init__(self, *args, **kwargs):
        try:
            self.n_batches = kwargs.pop("n_batches")
            self.reset_bias = kwargs.pop("reset_bias")
            self.kwargs_intervention = kwargs.pop("kwargs_intervention", None)
            """"
            Kwargs are used for the intervention: e.g.:
            {"make_identity_w_symmetry_breaking": True,
            "identity_scaling": 0.05,
            "rand_scaling": 0.01}"""
        except KeyError:
            logging.warning(f"invalid kwargs {self.name}, received : {args} and {kwargs}")
            raise
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable, model_inputs: Tuple[torch.tensor, torch.tensor],
                  model_outputs: Tuple[torch.tensor, torch.tensor], target: torch.tensor) -> float:
        raise NotImplementedError("MultiTaskSensitivityAnalysis is abstract")

    def extract_feature_dict(self, model: torch.nn.Module, dataloader: DataLoader,
                             num_classes=10, with_intervention=False, len_min_set=None) -> Tuple[Dict, int]:
        # ---- HOOK ------
        # Temp storage of hook data
        activation = {}

        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.detach().view(output.size(0), -1).cpu()

            return hook

        # ---- HOOK END ------
        assert issubclass(type(model), ResNet)

        model = model.to(self.device)
        model.train()
        with torch.no_grad():
            # Note: specifically coded for RESNET
            model.avgpool.register_forward_hook(get_activation('avgpool'))

            if with_intervention is True:
                model = intervention(None, model=model, logger=logger, device=self.device,
                                     kwargs=self.kwargs_intervention)
            features_dict = {}
            for i in range(10):
                features_dict[i] = None

            # COLLECT FEATURES IN CLASS "BUCKETS"
            for batch_idx, (data, target) in enumerate(dataloader):
                model.zero_grad()
                data.requires_grad = True
                data, target = data.to(0), target.to(0)
                _ = model(data)
                for i in range(num_classes):
                    if features_dict[i] is None:
                        features_dict[i] = activation['avgpool'][target == i].cpu().numpy()
                    else:
                        features_dict[i] = np.concatenate((features_dict[i],
                                                           activation['avgpool'][target == i].cpu().numpy()),
                                                          axis=0)
                # Break after n_batches
                if batch_idx + 1 >= self.n_batches:
                    break

            # Ensure that the number of samples is equal for all classes
            len_min = np.Inf
            for i in range(num_classes):
                length = len(features_dict[i])
                if length < len_min:
                    len_min = length

            # make min length even
            len_min = len_min - len_min % 2
            if len_min_set is not None:
                if len_min > len_min_set:
                    len_min = len_min_set

            for i in range(num_classes):
                features_dict[i] = np.array(features_dict[i][:len_min][:])
        return features_dict, len_min

    def analyze_model(self, task_result: TaskResult,
                      model: torch.nn.Module) -> ResultGeneratorType:

        dataloader = self.generate_dataloader(task_result)

        model = deepcopy(model)
        # Currently specialized only to 10 classes
        num_classes = 10

        for with_intervention in [False, True]:
            # Extract features
            features_dict, len_min = self.extract_feature_dict(model, dataloader, num_classes=num_classes,
                                                               with_intervention=with_intervention)
            # --------------------------------------
            # Frechet distance using model features:
            # --------------------------------------

            # COMPARE WITHIN BUCKET
            for class_i in features_dict:
                length = len(features_dict[class_i])
                fid_within = fid_from_feats(features_dict[class_i][:length // 2],  # first half
                                            features_dict[class_i][length // 2:])  # second half
                yield f"{mode}_FID_within_class{class_i}", fid_within

            # COMPARE BETWEEN BUCKETS
            fid_between: float = 0
            count: int = 0
            for i in range(len(features_dict)-1):
                # distance is symmetric
                for j in range(i+1, len(features_dict)):
                    count += 1
                    fid_between_ij = fid_from_feats(features_dict[i],
                                                    features_dict[j])
                    fid_between += fid_between_ij
                    yield f"{mode}_FID_between_classes_{i}_{j}", fid_between_ij
            yield f"{mode}_FID_avg_total_between_classes", fid_between / count


class MultiDataloaderTaskInvarianceAnalysis(SingleTaskInvarianceAnalysis):
    name = "multi_dataloader_task_invariance_analysis"

    def __init__(self, *args, **kwargs):
        try:
            self.n_batches = kwargs.pop("n_batches")
            self.reset_bias = kwargs.pop("reset_bias")
            self.kwargs_intervention = kwargs.pop("kwargs_intervention", None)
            """"
            Kwargs are used for the intervention: e.g.:
            {"make_identity_w_symmetry_breaking": True,
            "identity_scaling": 0.05,
            "rand_scaling": 0.01}"""
        except KeyError:
            logging.warning(f"invalid kwargs {self.name}, received : {args} and {kwargs}")
            raise
        super().__init__(*args, **kwargs)

    def criterion(self, model: torch.nn.Module, loss_function: Callable,
                  model_inputs: Tuple[torch.tensor, torch.tensor],
                  model_outputs: Tuple[torch.tensor, torch.tensor], target: torch.tensor) -> float:
        raise NotImplementedError("MultiTaskSensitivityAnalysis is abstract")

    def analyze_model(self, task_result: TaskResult,
                      model: torch.nn.Module) -> ResultGeneratorType:

        dataloader = self.generate_dataloader(task_result)
        dataloader_unlabeled = self.generate_second_dataloader(task_result)

        model = deepcopy(model)
        # Currently specialized only to 10 classes
        num_classes = 10

        for with_intervention in [False, True]:
            # Extract features
            features_dict, len_min_set = self.extract_feature_dict(model, dataloader, num_classes=num_classes,
                                                                   with_intervention=with_intervention)
            features_dict2, len_min2 = self.extract_feature_dict(model, dataloader_unlabeled,
                                                                 num_classes=num_classes,
                                                                 with_intervention=with_intervention,
                                                                 len_min_set=len_min_set)
            # --------------------------------------
            # Frechet distance using model features:
            # --------------------------------------

            # COMPARE WITHIN BUCKET BUT ACROSS DATALOADERS
            for class_i in features_dict:
                fid_within = fid_from_feats(features_dict[class_i],  # first dataloader
                                            features_dict2[class_i])  # second dataloader
                yield f"{mode}_FID_within_class{class_i}", fid_within

            # COMPARE BETWEEN BUCKETS - FIRST DATALOADER
            fid_between: float = 0
            count: int = 0
            for i in range(len(features_dict) - 1):
                # distance is symmetric
                for j in range(i + 1, len(features_dict)):
                    count += 1
                    fid_between_ij = fid_from_feats(features_dict[i],
                                                    features_dict[j])
                    fid_between += fid_between_ij
                    yield f"{mode}_FID_between_classes_{i}_{j}", fid_between_ij
            yield f"{mode}_FID_avg_total_between_classes", fid_between / count

            # COMPARE BETWEEN BUCKETS - SECOND "UNLABELED" DATALOADER
            fid_between: float = 0
            count: int = 0
            for i in range(len(features_dict2) - 1):
                # distance is symmetric
                for j in range(i + 1, len(features_dict2)):
                    count += 1
                    fid_between_ij = fid_from_feats(features_dict2[i],
                                                    features_dict2[j])
                    fid_between += fid_between_ij
                    yield f"{mode}_FID_between_classes_{i}_{j}", fid_between_ij
            yield f"{mode}_FID_avg_total_between_classes", fid_between / count
