import os
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.optim import SGD, Adam

from utils.utils import *

class BaseModel(nn.Module):
    def __init__(self, args):
        super(BaseModel, self).__init__()
        self.scores, self.scores_md, self.scores_total, self.out_score = [], [], [], []
        self.num_cls_per_task = args.num_cls_per_task # only makes sense if task id is known during training
        
        self.buffer = None

        self.statistics = {'mu': [], 'eigvec': [], 'eigval': []}

        self.args = args
        self.criterion = args.criterion # NLL
        self.net = args.net

        if self.args.optim_type == 'sgd':
            self.optimizer = SGD(self.net.parameters(), lr=self.args.lr)
        elif self.args.optim_type == 'adam':
            self.optimizer = Adam(self.net.parameters(), lr=self.args.lr)

        self.model_clip = args.model_clip

        # The actual label id from DataLoader.
        # This carries the label ids we've seen throughout training
        self.seen_ids = []
        # The actual label name from DataLoader.
        self.seen_names = []

        self.correct, self.til_correct, self.total, self.total_loss = 0., 0., 0., 0.
        self.cal_correct = 0.
        self.true_lab, self.pred_lab = [], []
        self.output_list, self.feature_list, self.label_list = [], [], []

        self.saving_buffer = {}

        self.margin, self.indices = [], []

    def observe(self, inputs, labels, not_aug_inputs=None, f_y=None, **kwargs):
        pass

    def evaluate(self, inputs, labels, task_id=None, **kwargs):
        self.net.eval()
        # labels = self.map_labels(labels)
        with torch.no_grad():
            out = self.net(inputs, self.args.normalize)

        scores, pred = F.softmax(out, dim=1).max(1)
        self.scores_total.append(scores.detach().cpu().numpy())
        self.correct += pred.eq(labels).sum().item()
        self.total += len(labels)

        if task_id is not None:
            normalized_labels = labels % self.num_cls_per_task
            til_pred = out[:, task_id * self.num_cls_per_task:(task_id + 1) * self.num_cls_per_task]
            scores, til_pred = F.softmax(til_pred, dim=1).max(1)
            self.scores.append(scores.detach().cpu().numpy())
            # til_pred = til_pred.argmax(1)
            self.til_correct += til_pred.eq(normalized_labels).sum().item()

        self.net.train()

        if self.args.confusion:
            self.true_lab.append(labels.cpu().numpy())
            self.pred_lab.append(pred.cpu().numpy())

        if self.args.save_output:
            self.output_list.append(out.data.cpu().numpy())
            self.label_list.append(labels.data.cpu().numpy())

    def map_labels(self, labels):
        # labels: tensor
        relabel = []
        for y_ in labels:
            relabel.append(self.seen_ids.index(y_))
        return torch.tensor(relabel).to(self.args.device)

    def acc(self, reset=True):
        metrics = {}
        metrics['cil_acc'] = self.correct / self.total * 100
        metrics['til_acc'] = self.til_correct / self.total * 100
        metrics['cal_cil_acc'] = self.cal_correct / self.total * 100
        if len(self.scores_total) > 0: metrics['scores_total'] = np.concatenate(self.scores_total)
        if len(self.scores) > 0: metrics['scores'] = np.concatenate(self.scores)
        if len(self.scores_md) > 0: metrics['scores_md'] = np.concatenate(self.scores_md)

        if len(self.margin) > 0: metrics['margin'] = np.concatenate(self.margin)
        if len(self.indices) > 0: metrics['indices'] = np.concatenate(self.indices)
        if reset: self.reset_eval()
        return metrics

    def reset_eval(self):
        self.correct, self.til_correct, self.total, self.total_loss = 0., 0., 0., 0.
        self.cal_correct = 0.
        self.true_lab, self.pred_lab = [], []
        self.output_list, self.label_list = [], []
        self.scores, self.scores_md, self.scores_total = [], [], []
        self.feature_list, self.label_list = [], []
        self.margin, self.indices = [], []

    # def acc(self, reset=True):
    #     metrics = {}
    #     metrics['cil_acc'] = self.correct / self.total * 100
    #     metrics['til_acc'] = self.til_correct / self.total * 100
    #     if reset: self.reset_eval()
    #     return metrics

    # def reset_eval(self):
    #     self.correct, self.til_correct, self.total, self.total_loss = 0., 0., 0., 0.
    #     self.true_lab, self.pred_lab = [], []
    #     self.output_list, self.label_list = [], []

    def save(self, **kwargs):
        """
            Save model specific elements required for resuming training
            kwargs: e.g. model state_dict, optimizer state_dict, epochs, etc.
        """
        raise NotImplementedError()

    def load(self, **kwargs):
        raise NotImplementedError()

    def compute_stats(self, task_id, loader):
        self.args.logger.print(f"Compute Stats for task {task_id}")
        self.reset_eval()
        self.net.eval() # NO NEED .EVAL FOR CLIP AND VIT?

        for _, inputs, labels, _, _, _ in loader:
            inputs, labels = inputs.to(self.args.device), labels.to(self.args.device)

            with torch.no_grad():
                if self.args.model_clip:
                    inputs = self.args.model_clip.encode_image(inputs).type(torch.FloatTensor).to(self.args.device)
                elif self.args.model_vit:
                    inputs = self.args.model_vit.forward_features(inputs)

            if self.args.model_clip or self.args.model_vit:
                self.feature_list.append(inputs.data.cpu().numpy())
                self.label_list.append(labels.data.cpu().numpy())
            else:
                self.evaluate(inputs, labels, task_id, report_cil=False,
                    total_learned_task_id=task_id, ensemble=self.args.pass_ensemble)

        self.feature_list = np.concatenate(self.feature_list)
        self.label_list = np.concatenate(self.label_list)


        torch.save(self.feature_list,
                    self.args.logger.dir() + f'/feature_task_{task_id}')
        torch.save(self.label_list,
                    self.args.logger.dir() + f'/label_task_{task_id}')

        cov_list = []
        ys = list(sorted(set(self.label_list)))
        # Compute/save the statistics for MD
        for y in ys:
            idx = np.where(self.label_list == y)[0]
            f = self.feature_list[idx]
            cov = np.cov(f.T)
            cov_list.append(cov)
            mean = np.mean(f, 0)
            np.save(os.path.join(self.args.logger.dir(),
                                f'{self.args.mean_label_name}_{y}'),
                    mean)
            # np.save(self.args.logger.dir() + f'mean_label_{y}', mean)
            self.args.mean[y] = mean
        cov = np.array(cov_list).mean(0)
        np.save(os.path.join(self.args.logger.dir(),
                            f'{self.args.cov_task_name}_{task_id}'),
                cov)
        # np.save(self.args.logger.dir() + f'cov_task_{task_id}', cov)
        self.args.cov[task_id] = cov
        self.args.cov_inv[task_id] = np.linalg.inv(0.8 * cov + 0.2 * np.eye(len(cov)))

        if self.args.noise:
            mean = np.mean(self.feature_list, axis=0)
            cov = np.cov(self.feature_list.T)
            np.save(os.path.join(self.args.logger.dir(),
                                f'{self.args.mean_task_name}_{task_id}'),
                mean)
            # np.save(self.args.logger.dir() + f'mean_task_{task_id}', mean)
            np.save(os.path.join(self.args.logger.dir(),
                                f'{self.args.cov_task_noise_name}_{task_id}'),
                cov)
            # np.save(self.args.logger.dir() + f'cov_task_noise_{task_id}', cov)
            self.args.mean_task[task_id] = mean
            self.args.cov_noise[task_id] = cov
            self.args.cov_inv_noise[task_id] = np.linalg.inv(cov)

        self.net.train()
        self.reset_eval()

    def compute_md_by_task(self, net_id, features):
        """
            Compute Mahalanobis distance of features to the Gaussian distribution of task == net_id
            return: scores_md of np.array of ndim == (B, 1) if cov_inv is available
                    None if cov_inv is not available (e.g. task=0 or cov_inv is not yet computed)
        """
        md_list, dist_list = [], []
        if len(self.args.cov_inv) > 0:
            for y in range(net_id * self.num_cls_per_task, (net_id + 1) * self.num_cls_per_task):
                mean, cov_inv = self.mean_cov(y, net_id)
                dist = md(features, mean, cov_inv, inverse=True)

                if self.args.noise:
                    cov_inv_noise = self.args.cov_inv_noise[net_id]
                    dist = dist - 0.7 * md(features, mean, cov_inv_noise, inverse=True)

                scores_md = 1 / dist
                md_list.append(scores_md)
                dist_list.append(-dist)

            scores_md = np.concatenate(md_list, axis=1)
            dist_list = np.concatenate(dist_list, axis=1)
            scores_md = scores_md.max(1, keepdims=True)
            dist_list = dist_list.max(1)
            return scores_md, dist_list
        return None, None

    def mean_cov(self, y, net_id, inverse=True):
        if inverse:
            cov = self.args.cov_inv[net_id]
        else:
            cov = self.args.cov[net_id]
        return self.args.mean[y], cov