import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim import Adam
from utils.sgd_hat import SGD_hat as SGD
from utils.my_ipca import MyIPCA as IPCA
# from apprs.basemodel import BaseModel
# from apprs.maha_ipca import Maha
from apprs.vitadapter import ViTAdapter
from copy import deepcopy
from utils.utils import *
from torch.utils.data import DataLoader
from collections import Counter

import types

device = "cuda" if torch.cuda.is_available() else "cpu"

class ViTAdapterEWT(ViTAdapter):
    def __init__(self, args):
        super(ViTAdapterEWT, self).__init__(args)
        self.scores, self.scores_md, self.scores_total, self.out_score = [], [], [], []
        self.feature_list, self.label_list = [], []
        self.p_mask, self.mask_back = None, None
        # self.net_list = []
        self.last_task_id = -1 # task id of lastly learned task
        self.cal_correct = 0.
        self.w, self.b = None, None
        self.cil_acc_mat_test = np.zeros((args.n_tasks + 1, args.n_tasks + 1)) - 100

        self.pca_list_MD = {}
        self.statistics_MD = {'mu': {}, 'eigvec': {}, 'eigval': {}}

        if args.distillation:
            self.criterion_distill = nn.KLDivLoss()

        if args.use_buffer:
            assert args.buffer_size
            if self.args.dataset in ['imagenet', 'timgnet']:
                self.buffer_dataset = Memory_ImageFolder(args)
            else:
                self.buffer_dataset = Memory(args)
        else:
            self.buffer_dataset = None

        self.ent = ComputeEnt(self.args)

    def update_stats_for_MD(self, inputs, labels, y_, n_samples):
        idx = labels[:n_samples] == y_

        selected_inputs = inputs[:n_samples][idx].view(-1, inputs.size(1))

        self.pca_list_MD[y_].partial_fit(selected_inputs.cpu().numpy())
        self.statistics_MD['mu'][y_] = self.pca_list_MD[y_].mean_
        self.statistics_MD['eigvec'][y_] = self.pca_list_MD[y_].components_
        self.statistics_MD['eigval'][y_] = self.pca_list_MD[y_].explained_variance_

    def compute_stats(self, task_id, loader):
        for y, val in self.statistics_MD['mu'].items():
            self.args.mean[y] = val

        cov = 0
        for (y, eigvec), (_, eigval) in zip(self.statistics_MD['eigvec'].items(), 
                                            self.statistics_MD['eigval'].items()):
            cov += np.dot(eigvec.T * eigval, eigvec)
        cov /= len(self.statistics_MD['eigvec'])
        self.args.cov[task_id] = cov
        self.args.cov_inv[task_id] = np.linalg.inv(0.8 * cov + 0.2 * np.eye(len(cov)))

    def observe(self, inputs, labels, names, not_aug_inputs=None, f_y=None, **kwargs):
        task_id = kwargs['task_id']
        b = kwargs['b']
        B = kwargs['B']
        s = self.update_s(b, B)

        n_samples = len(inputs)
        normalized_labels = labels % self.num_cls_per_task

        ys = list(sorted(set(labels.data.cpu().numpy())))

        if self.buffer:
            raise NotImplementedError()
            try:
                inputs_bf, labels_bf = next(self.buffer_iter)
            except StopIteration:
                del self.buffer_iter
                self.buffer = DataLoader(self.buffer_dataset,
                                        batch_size=self.args.batch_size,
                                        sampler=self.sampler,
                                        num_workers=5,
                                        pin_memory=self.args.pin_memory)
                self.buffer_iter = iter(self.buffer)
                # self.buffer_iter = iter(self.buffer)
                inputs_bf, labels_bf = next(self.buffer_iter)

            inputs_bf = inputs_bf.to(device)
            # single dummy head
            labels_bf = torch.zeros_like(labels_bf).to(device) + self.num_cls_per_task
            normalized_labels_bf = labels_bf
            inputs = torch.cat([inputs, inputs_bf])
            labels = torch.cat([labels, labels_bf])
            normalized_labels = torch.cat([normalized_labels, normalized_labels_bf])

        features, masks = self.net.forward_features(task_id, inputs, s=s)
        outputs = self.net.forward_classifier(task_id, features, normalize=self.args.normalize)
        # outputs = outputs[:, task_id * self.num_cls_per_task:(task_id + 1) * self.num_cls_per_task]

        if self.args.train_joint_clf:
            outputs_list, labels_list = [outputs], [labels]
            for p_task_id in range(task_id):
                y_choices = np.arange((p_task_id + 1) * self.args.num_cls_per_task)
                sample_data, sample_label = self.sampling(y_choices)

                p_outputs = self.net.forward_classifier(p_task_id, sample_data, normalize=self.args.normalize)
                outputs_list.append(p_outputs)
                labels_list.append(sample_label)
            outputs_list = torch.cat(outputs_list, dim=1)
            labels_list = torch.cat(labels_list)

            loss = self.criterion(outputs_list, labels_list)
        else:
            loss = self.criterion(outputs, normalized_labels)

        if self.args.distillation:
            self.teacher.eval()
            with torch.no_grad():
                outputs_t = self.teacher(inputs)
                outputs_t = outputs_t[:, task_id * self.num_cls_per_task:(task_id + 1) * self.num_cls_per_task]
            loss += self.criterion_distill(F.log_softmax(outputs / self.args.T, dim=1), 
                    F.softmax(outputs_t / self.args.T, dim=1)) * self.args.T * self.args.T * self.args.distill_lambda * self.num_cls_per_task

        loss += self.hat_reg(self.p_mask, masks)
        self.optimizer.zero_grad()
        loss.backward()
        self.compensation(self.net, self.args.thres_cosh, s=s)

        hat = False
        if self.last_task_id >= 0:
            hat = True
        self.optimizer.step(hat=hat)
        self.compensation_clamp(self.net, self.args.thres_emb)

        for y_ in ys:
            self.update_stats_for_MD(features.data.cpu(), labels, y_, n_samples)
            self.update_stats(features.data.cpu(), labels, y_, n_samples)

        # Start sampling
        loss_head = torch.tensor([0])
        if len(self.statistics['mu']) > 0:
            y_choices = np.arange(task_id * self.args.num_cls_per_task,
                                (task_id + 1) * self.args.num_cls_per_task)

            sample_data, sample_label = self.sampling(y_choices)
            sample_label = sample_label % self.num_cls_per_task
            sample_outputs = self.net.forward_classifier(task_id, sample_data, self.args.normalize) # sample_data is already normalized to some extend since the mu and cov are computed with normalized features

            loss_head = self.criterion(sample_outputs, sample_label)
            self.optimizer_head.zero_grad()
            loss_head.backward()
            self.optimizer_head.step()

        self.total_loss += loss.item() + loss_head.item()
        outputs = outputs[:n_samples]
        scores, pred = outputs.max(1)
        self.scores.append(scores.detach().cpu().numpy())
        self.correct += pred.eq(normalized_labels[:n_samples]).sum().item()
        self.total += n_samples

        return loss.item()

    def save(self, task_id, **kwargs):
        """
            Save model specific elements required for resuming training
            kwargs: e.g. model state_dict, optimizer state_dict, epochs, etc.
        """
        self.saving_buffer['buffer_dataset'] = self.buffer_dataset
        self.saving_buffer['w'] = self.w
        self.saving_buffer['b'] = self.b
        self.saving_buffer['p_mask'] = self.p_mask
        self.saving_buffer['mask_back'] = self.mask_back
        self.saving_buffer['statistics'] = self.statistics
        self.saving_buffer['statistics_MD'] = self.statistics_MD

        for key in kwargs:
            self.saving_buffer[key] = kwargs[key]

        torch.save(self.saving_buffer,
                    os.path.join(self.args.logger.dir(), f'saving_buffer_{task_id}'))

    def preprocess_task(self, **kwargs):
        task_id = kwargs['task_id']
        # Add new embeddings for HAT
        self.net.append_embedddings()

        # Put label names in seen_names
        targets, names = zip(*sorted(zip(kwargs['loader'].dataset.targets,
                                         kwargs['loader'].dataset.names)))
        targets, names = list(targets), list(names)
        _, idx = np.unique(targets, return_index=True)
        for i in idx:
            # self.seen_names.append(names[i])
            self.append_model_heads(names[i], targets[i])


        # Reset optimizer as there might be some leftover in optimizer
        if self.args.optim_type == 'sgd':
            self.optimizer = SGD(self.net.only_adapter_parameters() + list(self.net.head[task_id].parameters()), lr=self.args.lr, momentum=self.args.momentum)
            self.optimizer_head = SGD(self.net.head[task_id].parameters(), lr=self.args.lr, momentum=self.args.momentum)
        elif self.args.optim_type == 'adam':
            raise NotImplementedError("HAT for Adam is not implemented")
            self.optimizer = Adam(self.net.only_adapter_parameters() + list(self.net.head[task_id].parameters()), lr=self.args.lr)
            self.optimizer_head = Adam(self.net.head[task_id].parameters(), lr=self.args.lr)

        # Prepare mask values for proper gradient update
        for n, p in self.net.named_parameters():
            p.grad = None
            if self.mask_back is not None:
                if n in self.mask_back.keys():
                    p.hat = self.mask_back[n]
                else:
                    p.hat = None
            else:
                p.hat = None

        # Prepare memory loader if memory data exist
        if self.args.use_buffer:
            if len(self.buffer_dataset.data) > 0:
                self.sampler = MySampler(len(self.buffer_dataset), len(kwargs['loader'].dataset))
                # We don't use minibatch. Use upsampling.
                self.buffer = DataLoader(self.buffer_dataset,
                                        batch_size=self.args.batch_size,
                                        sampler=self.sampler,
                                        num_workers=15,
                                        pin_memory=self.args.pin_memory)
                self.buffer_iter = iter(self.buffer)

    def append_model_heads(self, name, y):
        """
        Append a new head to the model
        name, y: string of a label name, and label
        """
        self.seen_names.append(name)
        self.seen_ids.append(y)

        if self.args.dynamic is not None:
            new_components = min(self.args.dynamic // len(self.seen_ids), self.args.in_dim)
            self.args.logger.print(f"Save {self.n_components} per class -> Save {new_components} per class")
            self.n_components = new_components

        self.n_seen_samples.append(0)
        self.statistics['mu'].append(None)
        self.statistics['eigvec'].append(None)
        self.statistics['eigval'].append(None)
        self.pca_list.append(IPCA(n_components=self.n_components, ff=self.ff, max_size=self.args.in_dim))
        self.left_samples.append(torch.tensor([]))

        self.pca_list_MD[y] = IPCA(n_components=self.args.in_dim, ff=self.ff, max_size=self.args.in_dim)

        if self.args.dynamic is not None:
            for y_ in range(len(self.pca_list)):
                if self.statistics['eigval'][y_] is not None:
                    self.pca_list[y_].n_components = self.n_components
                    self.statistics['eigval'][y_] = self.statistics['eigval'][y_][:self.n_components]
                    self.statistics['eigvec'][y_] = self.statistics['eigvec'][y_][:self.n_components]

    def end_task(self, **kwargs):
        # loaders = kwargs['cal_loaders']
        # test_loaders = kwargs['test_loaders']
        task_id = kwargs['task_id']

        self.last_task_id += 1

        # Update masks for HAT
        self.p_mask = self.cum_mask(self.last_task_id, self.p_mask)
        self.mask_back = self.freeze_mask(self.last_task_id, self.p_mask)

        for y, val in self.statistics_MD['mu'].items():
            self.args.mean[y] = val
            np.save(self.args.logger.dir() + f'mean_label_{y}', val)

        cov = 0
        for (y, eigvec), (_, eigval) in zip(self.statistics_MD['eigvec'].items(), 
                                            self.statistics_MD['eigval'].items()):
            cov += np.dot(eigvec.T * eigval, eigvec)
        cov /= len(self.statistics_MD['eigvec'])
        self.args.cov[task_id] = cov
        self.args.cov_inv[task_id] = np.linalg.inv(0.8 * cov + 0.2 * np.eye(len(cov)))
        np.save(self.args.logger.dir() + f'cov_task_{task_id}', cov)

        self.statistics_MD = {'mu': {}, 'eigvec': {}, 'eigval': {}}
        self.pca_list_MD = {}

        # Update memory if used
        if self.args.use_buffer and not self.args.train_clf:
            # if isinstance(kwargs['train_loader'], list):
            #     loader = loaders[-1]
            self.buffer_dataset.update(kwargs['train_loader'].dataset)

            self.args.logger.print(Counter(self.buffer_dataset.targets))

            if os.path.exists(self.args.logger.dir() + f'/memory_{self.last_task_id}'):
                self.args.logger.print("Memory exists. Not saving memory...")
                # save = None
                # while save != 'y' and save != 'n':
                #     save = input("Memory exists, replace it? (y/n) ")
                # if save == 'y':
                #     print("Replacing existing memory...")
                #     torch.save([deepcopy(self.buffer_dataset.data),
                #                 deepcopy(self.buffer_dataset.targets)],
                #                self.args.logger.dir() + f'/memory_{self.last_task_id}')
            else:
                self.args.logger.print("Saving memory...")
                torch.save([deepcopy(self.buffer_dataset.data),
                            deepcopy(self.buffer_dataset.targets)],
                           self.args.logger.dir() + f'/memory_{self.last_task_id}')

