from enum import Enum, auto
from typing import *

import numpy as np
from pyclustering.cluster.center_initializer import kmeans_plusplus_initializer
from pyclustering.cluster.xmeans import xmeans
from torch.utils.data import DataLoader

from approaches.abst_appr import AbstractAppr
from approaches.hat.model_hat import ModelHAT
from approaches.prmwo2so.constants import ConstantsPRM
from utils import myprint as print


class Ablation(Enum):
    S = auto()
    D = auto()
    T = auto()


# endclass


class Appr(AbstractAppr):
    def __init__(self, device: str, list__ncls: List[int], inputsize: Tuple[int, ...],
                 lr: float, lr_factor: float, lr_min: float, epochs_max: int, patience_max: int,
                 smax: float, lamb: float, seed_pt: int, ablation: str,
                 drop1: float, drop2: float,
                 ):
        super().__init__(device=device, list__ncls=list__ncls, inputsize=inputsize,
                         lr=lr, lr_factor=lr_factor, lr_min=lr_min,
                         epochs_max=epochs_max, patience_max=patience_max,
                         smax=smax, lamb=lamb)
        self.seed_pt = seed_pt
        self.model = ModelHAT(list__ncls=list__ncls, inputsize=inputsize,
                              smax=smax, hat_enabled=True,
                              drop1=drop1, drop2=drop2).to(self.device)

        # for ablation study
        if ablation is None:
            self.abl = None
        elif ablation.lower() == 'allsimilar':
            self.abl = Ablation.S
        elif ablation.lower() == 'alldissimilar':
            self.abl = Ablation.D
        elif ablation.lower() == 'typegiven':
            self.abl = Ablation.T
        else:
            raise ValueError(ablation)
        # endif
        print(f'Ablation: {self.abl}')
    # enddef

    def xmeans(self, x: np.ndarray, amount_centers: int) -> List[int]:
        xmeans_init_center = kmeans_plusplus_initializer(x,
                                                         random_state=self.seed_pt,
                                                         amount_centers=amount_centers).initialize()
        model_xmeans = xmeans(x, xmeans_init_center, ccore=False, random_state=self.seed_pt)
        model_xmeans.process()
        labels = []
        for i in range(x.shape[0]):
            for label, group in enumerate(model_xmeans.get_clusters()):
                if i in group:
                    labels.append(label)
                    break
                # endfor
            # endfor
        # endfor

        labels = [i - labels[0] for i in labels]
        labels = [max(labels) - i if i < 0 else i for i in labels]

        return labels
    # enddef

    def find_similars(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
                      args_on_forward: Dict[str, Any],
                      args_on_after_backward: Dict[str, Any]) -> Dict[int, List[int]]:
        dict__idx_hat__list__same_cluster = {0: [], 1: []}
        if idx_task == 0:
            return dict__idx_hat__list__same_cluster
        # endif

        # Learn with HAT
        model_backup = self.copy_model()
        args = args_on_after_backward.copy()
        args[ConstantsPRM.KEY_DISSIMILARS] = {0: list(range(idx_task)), 1: list(range(idx_task))}
        super().train(idx_task=idx_task,
                      dl_train=dl_train, dl_val=dl_val,
                      args_on_forward=args_on_forward,
                      args_on_after_backward=args,
                      )

        for idx_hat, hat in enumerate(self.model.feature.model.hats):
            emb_all = hat.emb.weight.clone().detach().cpu().numpy()[:idx_task + 1]
            labels = self.xmeans(emb_all, amount_centers=2)
            label_curr = labels[idx_task]

            dict__idx_hat__list__same_cluster[idx_hat] = [t for t in range(idx_task) if labels[t] == label_curr]
        # endfor

        # Load backup
        self.load_model(model_backup)

        return dict__idx_hat__list__same_cluster
    # enddef

    dict__idx_task__idx_layer__dissimilars = {}

    def train(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader,
              args_on_forward: Dict[str, Any],
              args_on_after_backward: Dict[str, Any],
              ) -> float:
        # find similars
        if self.abl is None:
            dict__idx_layer__similars = self.find_similars(idx_task=idx_task,
                                                           dl_train=dl_train, dl_val=dl_val,
                                                           args_on_forward=args_on_forward,
                                                           args_on_after_backward=args_on_after_backward)
        elif self.abl == Ablation.S:
            l = list(range(idx_task))
            dict__idx_layer__similars = {0: l, 1: l}
        elif self.abl == Ablation.D:
            dict__idx_layer__similars = {0: [], 1: []}
        elif self.abl == Ablation.T:
            l = [i for i in range(idx_task) if self.list__ncls[i] == self.list__ncls[idx_task]]
            dict__idx_layer__similars = {0: l, 1: l}
        else:
            raise NotImplementedError
        # endif
        self.save_object_as_artifact(dict__idx_layer__similars, f'dict__{idx_task}__similars.txt')

        # save into dict/dissimilars
        self.dict__idx_task__idx_layer__dissimilars[idx_task] = {
            0: [t for t in range(idx_task) if t not in dict__idx_layer__similars[0]],
            1: [t for t in range(idx_task) if t not in dict__idx_layer__similars[1]],
            }

        # main learning
        args_fw = args_on_forward.copy()
        args_bw = args_on_after_backward.copy()
        args_fw[ConstantsPRM.KEY_DISSIMILARS] = self.dict__idx_task__idx_layer__dissimilars[idx_task]
        args_bw[ConstantsPRM.KEY_DISSIMILARS] = self.dict__idx_task__idx_layer__dissimilars[idx_task]
        return super().train(idx_task, dl_train, dl_val,
                             args_on_forward=args_fw,
                             args_on_after_backward=args_bw)
    # enddef

    def complete_learning(self, idx_task: int) -> None:
        self.save_object_as_artifact(self.dict__idx_task__idx_layer__dissimilars,
                                     'dict__idx_task__dissimilars.txt')
        self.model.freeze_masks(idx_task)
    # enddef

# endclass
