import copy
import logging
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from utils.toolkit import tensor2numpy, accuracy
from scipy.spatial.distance import cdist
import os
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil
from sklearn.metrics import f1_score, recall_score, precision_score
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"

EPSILON = 1e-8
batch_size = 64


class BaseLearner(object):
    def __init__(self, args):
        self.args = args
        self.curr_acc = None
        self._cur_task = -1
        self._known_classes = 0
        self._total_classes = 0
        self._network = None
        self._old_network = None
        self._data_memory, self._targets_memory = np.array([]), np.array([])
        self.topk = 3
        self._memory_size = args["memory_size"]
        self._memory_per_class = args.get("memory_per_class", None)
        self._fixed_memory = args.get("fixed_memory", False)
        self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
        #self._device = args["device"][0]
        self._multiple_gpus = args["device"]
        self._predictions_folder = os.path.join('predictions', args["summary_file"])
        
        if os.path.exists(self._predictions_folder):

            shutil.rmtree(self._predictions_folder)
        
        os.mkdir(self._predictions_folder)
        
        

    @property
    def exemplar_size(self):
        assert len(self._data_memory) == len(
            self._targets_memory
        ), "Exemplar size error."
        return len(self._targets_memory)

    @property
    def samples_per_class(self):
        if self._fixed_memory:
            return self._memory_per_class
        else:
            assert self._total_classes != 0, "Total classes is 0"
            return self._memory_size // self._total_classes

    @property
    def feature_dim(self):
        if isinstance(self._network, nn.DataParallel):
            return self._network.module.feature_dim
        else:
            return self._network.feature_dim

    def build_rehearsal_memory(self, data_manager, per_class):
        if self._fixed_memory:
            self._construct_exemplar_unified(data_manager, per_class)
        else:
            self._reduce_exemplar(data_manager, per_class)
            self._construct_exemplar(data_manager, per_class)

    def save_checkpoint(self):
        
        torch.save(self._network.state_dict(), os.path.join(self._predictions_folder, "state_dict_task_{}.pth".format(self._cur_task)))

    def save_model(self):
        self._network.cpu()
        model_scripted = torch.jit.script(self._network)
        model_scripted.save(os.path.join(self._predictions_folder, 'model_task_{}.pt'.format(self._cur_task)))
        
    def after_task(self):
        pass

    def _evaluate(self, y_pred, y_true):
        ret = {}
        grouped = accuracy(y_pred.T[0], y_true, self._known_classes)
        ret["grouped"] = grouped
        ret["top1"] = grouped["total"]
        ret["top{}".format(self.topk)] = np.around(
            (y_pred.T == np.tile(y_true, (self.topk, 1))).sum() * 100 / len(y_true),
            decimals=2,
        )

        return ret

    def eval_task(self, save_conf=False):
        y_pred, y_true = self._eval_cnn(self.test_loader)
        cnn_accy = self._evaluate(y_pred, y_true)

        if hasattr(self, "_class_means"):
            y_pred, y_true = self._eval_nme(self.test_loader, self._class_means)
            nme_accy = self._evaluate(y_pred, y_true)
        else:
            nme_accy = None

        if save_conf:
            _pred = y_pred.T[0]
            _pred_path = os.path.join(self.args['logfilename'], "pred.npy")
            _target_path = os.path.join(self.args['logfilename'], "target.npy")
            np.save(_pred_path, _pred)
            np.save(_target_path, y_true)

            _save_dir = os.path.join(f"./results/conf_matrix/{self.args['prefix']}")
            os.makedirs(_save_dir, exist_ok=True)
            _save_path = os.path.join(_save_dir, f"{self.args['csv_name']}.csv")
            with open(_save_path, "a+") as f:
                f.write(f"{self.args['time_str']},{self.args['model_name']},{_pred_path},{_target_path} \n")

        return cnn_accy, nme_accy

    def incremental_train(self):
        pass

    def _train(self):
        pass

    def _get_memory(self):
        if len(self._data_memory) == 0:
            return None
        else:
            return (self._data_memory, self._targets_memory)

    def _compute_accuracy(self, model, loader):
        model.eval()
        y_pred = [] # save predction
        y_true = [] # save ground truth

        correct, total = 0, 0
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = model(inputs)["logits"]
            predicts = torch.max(outputs, dim=1)[1]
            
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)

            y_pred.extend(predicts.data.cpu().numpy())  # save prediction
            y_true.extend(targets.data.cpu().numpy())  # save ground truth

        true_targets = np.array(y_true)
        pred_targets = np.array(y_pred)

        uq_true, count_true = np.unique(true_targets, return_counts=True)
        uq_pred, count_pred = np.unique(pred_targets, return_counts=True)
        logging.debug('Task: {}\tUnique True samples after each epoch: {}\tCounts: {}'.format(self._cur_task, uq_true, count_true))
        logging.debug('Task: {}\tUnique Predicted samples after each epoch: {}\tCounts: {}'.format(self._cur_task, uq_pred, count_pred))

        correct_new = (true_targets == pred_targets).sum()
        
        logging.info('#### Task: {} Test ACC after each epoch: {} ####'.format(self._cur_task, 
                                                               np.around(correct_new * 100 / len(y_pred), decimals=2)))
        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)

    def _createConfusionMatrix(self, model, loader):
        """
        Creates and visualizes the confusion matrix for evaluating the model's performance.

        Parameters:
            model (torch.nn.Module): The trained model for which the confusion matrix is being generated.
            loader (torch.utils.data.DataLoader): The data loader for the dataset to be evaluated.

        Returns:
            matplotlib.figure.Figure: A figure containing the confusion matrix visualization.

        Computes various evaluation metrics including accuracy, F1-score, precision, and recall. 
        Additionally, it saves the predictions and embeddings for further analysis.
        """
        model.eval()
        y_pred = [] # save predction
        y_true = [] # save ground truth
        embeddings = []
        correct, total = 0, 0
        # iterate over data
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = model(inputs)["logits"]
                embedding = model(inputs)["features"]
            predicts = torch.max(outputs, dim=1)[1]
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)

            y_pred.extend(predicts.data.cpu().numpy())  # save prediction
            y_true.extend(targets.data.cpu().numpy())  # save ground truth
            embeddings.extend(embedding.data.cpu().numpy()) # save the embeddings
        np.save(os.path.join(self._predictions_folder,
                              'predictions_task_{}_tr_pr.npy'.format(self._cur_task)), np.array(list(zip(y_true, y_pred))))
        assert np.stack(embeddings).shape[0] == len(embeddings), "embeddings are not of the desired shape"
        np.save(os.path.join(self._predictions_folder,
                              'embeddings_task_{}.npy'.format(self._cur_task)), np.stack(embeddings))
        self.save_checkpoint()
        self.curr_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
        
        logging.info('Task: {}\tTest ACC on Inference data ----> {}'.format(self._cur_task, self.curr_acc))
        
        # constant for classes
        classes = tuple(map(str, np.unique(y_true)))
        f1_scores = f1_score(y_true, y_pred, average=None)
        recall_scores = recall_score(y_true, y_pred, average=None)
        precision_scores = precision_score(y_true, y_pred, average=None)
        mat_size = len(classes)
        f1_matrix = np.zeros((mat_size, mat_size), float) 
        precision_matrix = np.zeros((mat_size, mat_size), float) 
        recall_matrix = np.zeros((mat_size, mat_size), float) 
        
        np.fill_diagonal(f1_matrix, f1_scores)
        np.fill_diagonal(precision_matrix, precision_scores) 
        np.fill_diagonal(recall_matrix, recall_scores)

        # print the count of each testing label 
        uq_lbls, lbls_counts = np.unique(y_true, return_counts=True)
        logging.info('Testing on the classes {}\tCount of each class {}'.format(uq_lbls, lbls_counts))

        # Build confusion matrix
        cf_matrix = confusion_matrix(y_true, y_pred)
        
        group_counts = ['{0:0.0f}'.format(value) for value in cf_matrix.flatten()]
        group_f1 = ['{0:0.4f}'.format(value) for value in f1_matrix.flatten()]
        group_precision = ['{0:0.4f}'.format(value) for value in precision_matrix.flatten()]
        group_recall = ['{0:0.4f}'.format(value) for value in recall_matrix.flatten()]

        percentages = cf_matrix / np.sum(cf_matrix, axis=1)[:, None]
        group_percentages = ['{0:.2%}'.format(value) for value in percentages.flatten()]
        
        labels = ['{}\n{}\n{} {}'.format(v2, v3, v4, v5) for v2, v3, v4, v5 in zip(group_counts, group_f1, group_precision, group_recall)]        
        labels = np.asarray(labels).reshape(len(classes), len(classes))
        
        df_cm = pd.DataFrame(cf_matrix , index=[i for i in classes],
                            columns=[i for i in classes])
        
        plt.figure(figsize=(15, 15))    
        return sn.heatmap(df_cm, annot=labels, fmt='', 
                            cmap="rainbow", 
                            norm=matplotlib.colors.LogNorm(), 
                            linewidths=0.3,
                            linecolor='black', 
                            annot_kws={"color":"black"}
                            ).get_figure()


    def _eval_cnn(self, loader):
        self._network.eval()
        y_pred, y_true = [], []
        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            with torch.no_grad():
                outputs = self._network(inputs)["logits"]
            predicts = torch.topk(
                outputs, k=self.topk, dim=1, largest=True, sorted=True
            )[
                1
            ]  # [bs, topk]
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return np.concatenate(y_pred), np.concatenate(y_true)  # [N, topk]
    
    def _eval_nme(self, loader, class_means):
        self._network.eval()
        vectors, y_true = self._extract_vectors(loader)
        vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T

        dists = cdist(class_means, vectors, "sqeuclidean")  # [nb_classes, N]
        scores = dists.T  # [N, nb_classes], choose the one with the smallest distance

        return np.argsort(scores, axis=1)[:, : self.topk], y_true  # [N, topk]

    def _extract_vectors(self, loader):
        self._network.eval()
        vectors, targets = [], []
        for _, _inputs, _targets in loader:
            _targets = _targets.numpy()
            if isinstance(self._network, nn.DataParallel):
                _vectors = tensor2numpy(
                    self._network.module.extract_vector(_inputs.to(self._device))
                )
            else:
                _vectors = tensor2numpy(
                    self._network.extract_vector(_inputs.to(self._device))
                )

            vectors.append(_vectors)
            targets.append(_targets)

        return np.concatenate(vectors), np.concatenate(targets)

    def _reduce_exemplar(self, data_manager, m):
        logging.info("Reducing exemplars...({} per classes)".format(m))
        dummy_data, dummy_targets = copy.deepcopy(self._data_memory), copy.deepcopy(
            self._targets_memory
        )
        self._class_means = np.zeros((self._total_classes, self.feature_dim))
        self._data_memory, self._targets_memory = np.array([]), np.array([])

        for class_idx in range(self._known_classes):
            mask = np.where(dummy_targets == class_idx)[0]
            dd, dt = dummy_data[mask][:m], dummy_targets[mask][:m]
            self._data_memory = (
                np.concatenate((self._data_memory, dd))
                if len(self._data_memory) != 0
                else dd
            )
            self._targets_memory = (
                np.concatenate((self._targets_memory, dt))
                if len(self._targets_memory) != 0
                else dt
            )

            # Exemplar mean
            idx_dataset = data_manager.get_dataset(
                [], source="train", mode="test", appendent=(dd, dt)
            )
            idx_loader = DataLoader(
                idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
            )
            vectors, _ = self._extract_vectors(idx_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            mean = np.mean(vectors, axis=0)
            mean = mean / np.linalg.norm(mean)

            self._class_means[class_idx, :] = mean

    def _construct_exemplar(self, data_manager, m):
        logging.info("Constructing exemplars...({} per classes)".format(m))
        for class_idx in range(self._known_classes, self._total_classes):
            data, targets, idx_dataset = data_manager.get_dataset(
                np.arange(class_idx, class_idx + 1),
                source="train",
                mode="test",
                ret_data=True,
            )
            idx_loader = DataLoader(
                idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
            )
            vectors, _ = self._extract_vectors(idx_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            class_mean = np.mean(vectors, axis=0)

            # Select
            selected_exemplars = []
            exemplar_vectors = []  # [n, feature_dim]
            for k in range(1, m + 1):
                S = np.sum(
                    exemplar_vectors, axis=0
                )  # [feature_dim] sum of selected exemplars vectors
                mu_p = (vectors + S) / k  # [n, feature_dim] sum to all vectors
                i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
                selected_exemplars.append(
                    np.array(data[i])
                )  # New object to avoid passing by inference
                exemplar_vectors.append(
                    np.array(vectors[i])
                )  # New object to avoid passing by inference

                vectors = np.delete(
                    vectors, i, axis=0
                )  # Remove it to avoid duplicative selection
                data = np.delete(
                    data, i, axis=0
                )  # Remove it to avoid duplicative selection

            # uniques = np.unique(selected_exemplars, axis=0)
            # print('Unique elements: {}'.format(len(uniques)))
            selected_exemplars = np.array(selected_exemplars)
            exemplar_targets = np.full(m, class_idx)
            self._data_memory = (
                np.concatenate((self._data_memory, selected_exemplars))
                if len(self._data_memory) != 0
                else selected_exemplars
            )
            self._targets_memory = (
                np.concatenate((self._targets_memory, exemplar_targets))
                if len(self._targets_memory) != 0
                else exemplar_targets
            )

            # Exemplar mean
            idx_dataset = data_manager.get_dataset(
                [],
                source="train",
                mode="test",
                appendent=(selected_exemplars, exemplar_targets),
            )
            idx_loader = DataLoader(
                idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
            )
            vectors, _ = self._extract_vectors(idx_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            mean = np.mean(vectors, axis=0)
            mean = mean / np.linalg.norm(mean)

            self._class_means[class_idx, :] = mean

    def _construct_exemplar_unified(self, data_manager, m):
        logging.info(
            "Constructing exemplars for new classes...({} per classes)".format(m)
        )
        _class_means = np.zeros((self._total_classes, self.feature_dim))

        # Calculate the means of old classes with newly trained network
        for class_idx in range(self._known_classes):
            mask = np.where(self._targets_memory == class_idx)[0]
            class_data, class_targets = (
                self._data_memory[mask],
                self._targets_memory[mask],
            )

            class_dset = data_manager.get_dataset(
                [], source="train", mode="test", appendent=(class_data, class_targets)
            )
            class_loader = DataLoader(
                class_dset, batch_size=batch_size, shuffle=False, num_workers=4
            )
            vectors, _ = self._extract_vectors(class_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            mean = np.mean(vectors, axis=0)
            mean = mean / np.linalg.norm(mean)

            _class_means[class_idx, :] = mean

        # Construct exemplars for new classes and calculate the means
        for class_idx in range(self._known_classes, self._total_classes):
            data, targets, class_dset = data_manager.get_dataset(
                np.arange(class_idx, class_idx + 1),
                source="train",
                mode="test",
                ret_data=True,
            )
            class_loader = DataLoader(
                class_dset, batch_size=batch_size, shuffle=False, num_workers=4
            )

            vectors, _ = self._extract_vectors(class_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            class_mean = np.mean(vectors, axis=0)

            # Select
            selected_exemplars = []
            exemplar_vectors = []
            for k in range(1, m + 1):
                S = np.sum(
                    exemplar_vectors, axis=0
                )  # [feature_dim] sum of selected exemplars vectors
                mu_p = (vectors + S) / k  # [n, feature_dim] sum to all vectors
                i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))

                selected_exemplars.append(
                    np.array(data[i])
                )  # New object to avoid passing by inference
                exemplar_vectors.append(
                    np.array(vectors[i])
                )  # New object to avoid passing by inference

                vectors = np.delete(
                    vectors, i, axis=0
                )  # Remove it to avoid duplicative selection
                data = np.delete(
                    data, i, axis=0
                )  # Remove it to avoid duplicative selection

            selected_exemplars = np.array(selected_exemplars)
            exemplar_targets = np.full(m, class_idx)
            self._data_memory = (
                np.concatenate((self._data_memory, selected_exemplars))
                if len(self._data_memory) != 0
                else selected_exemplars
            )
            self._targets_memory = (
                np.concatenate((self._targets_memory, exemplar_targets))
                if len(self._targets_memory) != 0
                else exemplar_targets
            )

            # Exemplar mean
            exemplar_dset = data_manager.get_dataset(
                [],
                source="train",
                mode="test",
                appendent=(selected_exemplars, exemplar_targets),
            )
            exemplar_loader = DataLoader(
                exemplar_dset, batch_size=batch_size, shuffle=False, num_workers=4
            )
            vectors, _ = self._extract_vectors(exemplar_loader)
            vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
            mean = np.mean(vectors, axis=0)
            mean = mean / np.linalg.norm(mean)

            _class_means[class_idx, :] = mean

        self._class_means = _class_means
