import re

from golearn.core import AlgorithmBase
from golearn.core.utils import ALGORITHMS
import matplotlib.pyplot as plt
import torch
import random
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.nn.functional import cosine_similarity
import torch.nn.functional as F
from golearn.algorithms.utils.count import count_memory, refresh_memory
from golearn.algorithms.utils.memory import reservoir_update
import os
import contextlib
import numpy as np
from inspect import signature
from collections import OrderedDict
from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from golearn.algorithms.hooks import PseudoLabelingHook, FixedThresholdingHook
from golearn.algorithms.utils import SSL_Argument
from numpy import trapz
from golearn.core.criterions.sup_contrastive import SupConLoss


@ALGORITHMS.register('nce')
class NCE(AlgorithmBase):
    """
        Train a semi-supervised model using labeled data, basic pseudo label tricks and reservoir memory.
        This serves as a baseline for comparison.

        Args:
            - args (`argparse`):
                algorithm arguments
            - net_builder (`callable`):
                network loading function
            - tb_log (`TBLog`):
                tensorboard logger
            - logger (`logging.Logger`):
                logger to use
        """

    def __init__(self, args, net_builder, tb_log=None, logger=None):
        super().__init__(args, net_builder, tb_log, logger)
        self.buffer = torch.tensor([]).cuda(self.gpu)
        self.label = torch.tensor([]).cuda(self.gpu)
        self.current_class_num = 0
        self.old_class_num = 0
        self.class_id = torch.tensor([]).cuda(self.gpu)
        self.classifier = None
        self.label_mapping = None
        self.p_cutoff = args.p_cutoff
        self.prototypes = {}
        self.acc_list = []
        self.num_eval_iter = args.num_eval_iter
        self.momentum = args.proto_mom
        self.proto_sim_matrix = None
        self.proto_norm = None
        self.well_flag = False
        self.loss_list = []
        self.lambda_reg = args.lambda_reg
        self.new_class_id = None
        self.train_preds = []
        self.train_y = []
        self.sparsity = []
        self.new_classnum = 0
        self.idea_sim_matrix = None
        self.threshold = args.threshold

    def set_hooks(self):
        self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook")
        self.register_hook(FixedThresholdingHook(), "MaskingHook")
        super().set_hooks()

    def unique(self, x):
        x = x.tolist()
        unique_elements = list(dict.fromkeys(x))
        x = torch.tensor(unique_elements).cuda(self.gpu)
        return x

    def plot1(self):
        plt.plot(np.arange(len(self.loss_list)), np.array(self.loss_list))
        plt.savefig('./fig1.pdf')

    def plot2(self):
        plt.plot(np.arange(len(self.acc_list)), np.array(self.acc_list))
        plt.savefig('./fig2.pdf')

    def update_memory(self, x_lb_w, x_lb_s, y_lb):
        reservoir_update(self, x_lb_w, y_lb)

    def update_prototypes(self, feats, y):
        with torch.no_grad():
            for i, feat in enumerate(feats):
                label = int(y[i])
                if label not in self.prototypes:
                    self.prototypes[label] = feat
                else:
                    self.prototypes[label] = self.momentum * self.prototypes[label] + (1 - self.momentum) * feat
            vectors = torch.stack(list(self.prototypes.values()))
            self.proto_norm = torch.norm(vectors, dim=1, keepdim=True)
            normalized_vectors = vectors / self.proto_norm
            # print(normalized_vectors)
            self.proto_sim_matrix = cosine_similarity(normalized_vectors.unsqueeze(1), normalized_vectors.unsqueeze(0), dim=2)

    def proto_classify(self, data):
        class_prototypes = torch.stack(list(self.prototypes.values()))
        sim = F.cosine_similarity(data.unsqueeze(1), class_prototypes.unsqueeze(0), dim=2)
        preds = torch.argmax(sim, dim=1)
        return preds

    def update_status(self, num, y):
        if self.current_class_num == 0:
            self.current_class_num = len(torch.unique(y))
            self.class_id = self.unique(torch.cat([self.class_id, self.unique(y)])).to(torch.long)
            self.label_mapping = {str(self.class_id[i]): i for i in range(len(self.class_id))}
        else:
            # update the class id
            self.new_class_id = self.unique(y).to(torch.long)
            new_class_id = self.unique(torch.cat([self.class_id, self.unique(y)])).to(torch.long)
            new_class_num = len(new_class_id) - len(self.class_id)
            # print(y)
            self.class_id = new_class_id
            if new_class_num != 0:
                # print(new_class_id)
                self.new_classnum = new_class_num
                self.well_flag = False
                self.label_mapping = {str(self.class_id[i]): i for i in range(len(self.class_id))}
                self.old_class_num = self.current_class_num
                self.current_class_num += new_class_num
                # print(self.label_mapping)
        # self.idea_sim_matrix = (2 * torch.eye(self.current_class_num) - 1 / (self.current_class_num-1)).cuda(self.gpu)
        # print(y)

    def max_sep_loss(self, features, y):
        features = features.unsqueeze(1)
        if len(self.prototypes) <= 1:
            return 0
        vectors = torch.stack(list(self.prototypes.values())).unsqueeze(0)
        similarity = cosine_similarity(features, vectors, dim=2)
        proto_y = torch.arange(len(self.prototypes)).cuda(self.gpu)
        eq_matrix = torch.eq(y.unsqueeze(1), proto_y.unsqueeze(0)).bool().cuda(self.gpu)
        sim_label = torch.where(eq_matrix, torch.tensor(1, dtype=torch.float).cuda(self.gpu),
                                torch.tensor(-1/(len(self.prototypes)-1), dtype=torch.float).cuda(self.gpu)).cuda(self.gpu)
        # print(similarity); print(sim_label)
        return torch.cdist(sim_label, similarity).sum()

    def label_transfer(self, y):
        label = torch.tensor([self.label_mapping.get(str(value)) for value in y]).cuda(self.gpu)
        return label

    def train_step(self, x_lb_w, x_lb_s, y_lb, x_ulb_w):
        # inference and calculate sup/unsup losses
        with self.amp_cm():
            logits_x_w, feats_x_w = self.model(x_lb_w)
            self.update_status(len(feats_x_w[0]), y_lb)
            # y = self.label_transfer(y_lb)
            y = y_lb
            sup_loss = self.ce_loss(logits_x_w, y, reduction='mean')
            head = self.model.module.head.weight
            if self.old_class_num != 0:
                reg = - torch.norm(head[0:self.old_class_num], 1) / torch.norm(head[0:self.old_class_num], 2) / len(feats_x_w[0])
            else:
                reg = torch.norm(head, 1) / torch.norm(head, 2) / len(feats_x_w[0])
            sep_loss = self.max_sep_loss(features=feats_x_w, y=y)
            # print(torch.norm(head, p=1) / torch.norm(head, float('inf')) / len(feats_x_w[0]))
            total_loss = sup_loss + self.lambda_reg * (reg + sep_loss)
            # total_loss = sup_loss
        out_dict = self.process_out_dict(loss=total_loss)
        log_dict = self.process_log_dict(sup_loss=sup_loss.item(),
                                         total_loss=total_loss.item())
        self.loss_list.append(log_dict['train/sup_loss'])
        self.update_prototypes(feats=feats_x_w, y=y)
        return out_dict, log_dict

    def reverse_lookup(self, value):
        # print(value)
        # print(self.label_mapping)
        for key, val in self.label_mapping.items():
            if val == value:
                return key

    def train(self):
        # lb: labeled, ulb: unlabeled
        self.model.train()
        self.epoch = 1
        self.call_hook("before_run")
        self.call_hook("before_train_epoch")
        for data_lb, data_ulb in zip(self.loader_dict['train_lb'],
                                     self.loader_dict['train_ulb']):
            # prevent the training iterations exceed args.num_train_iter
            self.it += 1
            self.update_memory(**self.process_batch(**data_lb))
            self.call_hook("before_train_step")
            self.out_dict, self.log_dict = self.train_step(**self.process_batch(**data_lb, **data_ulb))
            self.call_hook("after_train_step")
            # experience replay
            if self.it % self.er_frequency == self.er_frequency - 5:
                # print(self.it)
                self.replay()
        self.replay()
        self.call_hook("evaluate")
        self.call_hook("after_run")
        self.plot1()
        self.plot2()

    def replay(self):
        self.it -= 2
        supervised_memory = TensorDataset(self.buffer, self.label)
        memory_loader = DataLoader(dataset=supervised_memory, batch_size=self.args.batch_size, shuffle=True)
        y_true = []
        y_pred = []
        for data in memory_loader:
            self.call_hook("before_train_step")
            x, y = data
            logits, _ = self.model(x)
            # y = self.label_transfer(y)
            sup_loss = self.ce_loss(logits, y, reduction='mean')
            y_pred.extend(torch.max(logits[:, :self.current_class_num], dim=-1)[1].cpu().tolist())
            # print(y_pred)
            y_true.extend(y.cpu().tolist())
            self.out_dict = self.process_out_dict(loss=sup_loss)
            self.log_dict = self.process_log_dict(loss=sup_loss.item())
            self.call_hook("after_train_step")
        cf_mat = confusion_matrix(y_true, y_pred, normalize='true')
        # print(cf_mat)
        indices = np.argwhere((cf_mat>self.threshold)&(1-np.eye(cf_mat.shape[0], dtype=bool)))
        if len(indices) != 0:
            combined_indices = np.column_stack(indices)
            confused_classes = np.unique(combined_indices)
            indexes = {}
            for cl in confused_classes:
                c = cl
                # c_name = self.reverse_lookup(cl)
                # num = re.findall(r'\d+', c_name)
                # c = int(num[0])
                index = torch.nonzero(self.label == c)
                # print(index)
                if len(index) == 0 or len(index) == 1:
                    indexes[c] = torch.tensor([[0], [1]]).cuda(self.gpu).squeeze()
                else:
                    indexes[c] = index.squeeze()
                # print(c, len(index))
            for i in range(len(indices)):
                c1 = indices[i][0]; c2 = indices[i][1]
                # print(c1, c2)
                # c1_name = self.reverse_lookup(c1)
                # num = re.findall(r'\d+', c1_name)
                # c1 = int(num[0])
                #
                # c2_name = self.reverse_lookup(c2)
                # num = re.findall(r'\d+', c2_name)
                # c2 = int(num[0])
                # print(indexes)
                class_id1 = indexes[c1]; class_id2 = indexes[c2]
                idxes = torch.flatten(torch.cat([class_id1, class_id2], dim=0))
                confused_data = self.buffer[idxes]; confused_label = self.label[idxes]
                confused_set = TensorDataset(confused_data, confused_label)
                confused_loader = DataLoader(dataset=confused_set, batch_size=self.args.batch_size, shuffle=True)
                for data in confused_loader:
                    self.call_hook("before_train_step")
                    x, y = data
                    logits, _ = self.model(x)
                    # y = self.label_transfer(y)
                    # c1 = torch.as_tensor(c1).unsqueeze(dim=0); c2 = torch.as_tensor(c2).unsqueeze(dim=0)
                    # c = torch.cat((c1, c2)).cuda(self.gpu)
                    # print(c)
                    # print(self.label_mapping)
                    # c = self.label_transfer(c)
                    # c1 = self.label_transfer(c1); c2 = self.label_transfer(c2)
                    # c1 = c[0]; c2 = c[1]
                    mask = torch.zeros(len(logits[0]))
                    # print(len(logits),c1,c2)
                    mask[c1] = 1; mask[c2] = 1
                    mask = mask.cuda(self.gpu)
                    sup_loss = self.ce_loss(logits, y, reduction='none', mask=mask)
                    self.out_dict = self.process_out_dict(loss=sup_loss)
                    self.log_dict = self.process_log_dict(loss=sup_loss.item())
                    self.call_hook("after_train_step")
        for data in memory_loader:
            self.call_hook("before_train_step")
            x, y = data
            # feats = self.model(x)['feat']
            # logits = self.classifier(feats)
            logits, _ = self.model(x)
            # y = self.label_transfer(y)
            sup_loss = self.ce_loss(logits, y, reduction='mean')
            self.out_dict = self.process_out_dict(loss=sup_loss)
            self.log_dict = self.process_log_dict(loss=sup_loss.item())
            self.call_hook("after_train_step")
        self.it += 2

    def evaluate(self, eval_dest='eval', out_key='logits', return_logits=False):
        """
        evaluation function
        """
        self.model.eval()
        self.ema.apply_shadow()

        eval_loader = self.loader_dict[eval_dest]
        total_loss = 0.0
        total_num = 0.0
        y_true = []
        y_pred = []
        # y_probs = []
        y_logits = []
        with torch.no_grad():
            for data in eval_loader:
                x = data['x_lb_w']
                y = data['y_lb']
                y = y.cuda(self.gpu)
                matches = torch.eq(y.unsqueeze(1), self.class_id.unsqueeze(0))
                index = torch.nonzero(torch.any(matches, dim=1), as_tuple=True)[0]
                x = x[index]
                y = y[index]
                if isinstance(x, dict):
                    x = {k: v.cuda(self.gpu) for k, v in x.items()}
                else:
                    x = x.cuda(self.gpu)

                # print(self.label_mapping)
                # y = self.label_transfer(y)
                num_batch = y.shape[0]
                total_num += num_batch
                logits, feats = self.model(x)
                preds = self.proto_classify(feats).cpu().tolist()
                loss = F.cross_entropy(logits, y, reduction='mean', ignore_index=-1)
                y_true.extend(y.cpu().tolist())
                y_pred.extend(torch.max(logits[:, :self.current_class_num], dim=-1)[1].cpu().tolist())
                # y_pred.extend(preds)
                y_logits.append(logits.cpu().numpy())
                total_loss += loss.item() * num_batch
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        y_logits = np.concatenate(y_logits)
        top1 = accuracy_score(y_true, y_pred)
        balanced_top1 = balanced_accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        F1 = f1_score(y_true, y_pred, average='macro')

        cf_mat = confusion_matrix(y_true, y_pred, normalize='true')
        np.set_printoptions(precision=2, threshold=100)
        self.print_fn('confusion matrix:\n' + np.array_str(cf_mat))
        self.ema.restore()
        self.model.train()
        self.acc_list.append(top1)
        acc_list = np.array(self.acc_list)
        # print(acc_list)
        aauc = trapz(acc_list, dx=self.num_eval_iter) / (len(acc_list) * self.num_eval_iter)
        eval_dict = {eval_dest + '/loss': total_loss / total_num, eval_dest + '/top-1-acc': top1,
                     eval_dest + '/balanced_acc': balanced_top1, eval_dest + '/precision': precision,
                     eval_dest + '/recall': recall, eval_dest + '/F1': F1,
                     eval_dest + '/aauc': aauc,}
        if return_logits:
            eval_dict[eval_dest + '/logits'] = y_logits

        return eval_dict

    @staticmethod
    def get_argument():
        return [
            SSL_Argument('--p_cutoff', float, 0.95),
            SSL_Argument('--proto_mom', float, 0.95),
        ]

ALGORITHMS['nce'] = NCE