import torch
import numpy as np
# from numpy.random import multivariate_normal
# from numpy.random.default_rng import multivariate_normal
# from sklearn.decomposition import IncrementalPCA as IPCA
from utils.my_ipca import MyIPCA as IPCA
from torch.optim import SGD, Adam
from torch.autograd import Variable
from apprs.lsp_base import LSP

class Maha(LSP):
    # Mahalanobis Distnace. No memory used
    def __init__(self, args):
        super(Maha, self).__init__(args)
        self.pca_list = []
        self.statistics = {'mu': [], 'eigvec': [], 'eigval': []}
        self.sample_data = torch.tensor([])
        self.sample_label = torch.tensor([])

        self.n_components = args.n_components
        self.ff = args.ff

        self.left_samples = []

    def observe(self, inputs, labels, names, not_aug_inputs=None, f_y=None, **kwargs):
        self.net.train()
        try:
            text_embedding = kwargs['text_embedding']
            # print('text_embedding', text_embedding.sum(1))
        except KeyError:
            text_embedding = None

        original_labels = labels.clone()
        n_samples = len(labels)
        ys = list(sorted(set(labels.data.cpu().numpy())))

        self.optimizer.zero_grad()

        ys = list(sorted(set(labels.data.cpu().numpy())))

        outputs = self.net(inputs, self.args.normalize)

        # Start sampling
        if len(self.statistics['mu']) > 0:
            ys_not_in_batch = [y_.item() for y_ in self.seen_ids if y_ not in original_labels]

            if len(ys_not_in_batch) > 0:
                sample_data, sample_label = self.sampling(ys_not_in_batch)
                # If some data are drawn
                if len(sample_data) > 0:
                    sample_outputs = self.net(sample_data, self.args.normalize) # sample_data is already normalized to some extend since the mu and cov are computed with normalized features

                    outputs = torch.cat([outputs, sample_outputs])
                    labels = torch.cat([labels, sample_label])
        # End sampling

        loss = self.criterion(outputs, labels)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.total_loss = loss.item()
        self.correct += outputs[:n_samples].argmax(1).eq(labels[:n_samples]).sum().item()
        self.total += n_samples

        # update mean and eigenvectors for iLSP
        for y_ in ys:
            self.update_stats(inputs, labels, y_, n_samples)
            # inputs = inputs / torch.norm(inputs, dim=-1, keepdim=True)
            # idx = labels[:n_samples] == y_

            # selected_inputs = inputs[:n_samples][idx].view(-1, inputs.size(1))

            # if len(selected_inputs) > 0:
            #     self.pca_list[y_].partial_fit(selected_inputs.cpu().numpy())
            #     self.statistics['mu'][y_] = self.pca_list[y_].mean_
            #     self.statistics['eigvec'][y_] = self.pca_list[y_].components_
            #     self.statistics['eigval'][y_] = self.pca_list[y_].explained_variance_
            # else:
            #     self.left_samples[y_] = torch.cat((self.left_samples[y_],
            #                                         selected_inputs.cpu()))
            #     if len(self.left_samples[y_]) > 0:
            #         self.pca_list[y_].partial_fit(self.left_samples[y_].numpy())
            #         self.left_samples[y_] = torch.tensor([])
            #         self.statistics['mu'][y_] = self.pca_list[y_].mean_
            #         self.statistics['eigvec'][y_] = self.pca_list[y_].components_
            #         self.statistics['eigval'][y_] = self.pca_list[y_].explained_variance_

        return loss.item()

    def preprocess_task(self, **kwargs):
        """
        Preprocess things before learning a task
        names: list of string
        labels: sorted list
        """
        names = kwargs['names']
        labels = kwargs['labels']

        for name, y in zip(names, labels):
            self.append_model_heads(name, y)

    def append_model_heads(self, name, y):
        """
        Append a new head to the model
        name, y: string of a label name, and label
        """
        self.seen_names.append(name)
        self.seen_ids.append(y)

        if self.args.dynamic is not None:
            new_components = min(self.args.dynamic // len(self.seen_ids), self.args.in_dim)
            self.args.logger.print(f"Save {self.n_components} per class -> Save {new_components} per class")
            self.n_components = new_components

        self.n_seen_samples.append(0)
        self.statistics['mu'].append(None)
        self.statistics['eigvec'].append(None)
        self.statistics['eigval'].append(None)
        self.pca_list.append(IPCA(n_components=self.n_components, ff=self.ff, max_size=self.args.in_dim))
        self.left_samples.append(torch.tensor([]))

        # Append a new head to the classifier
        self.net.make_head(1, name, clip_init=self.args.model_clip if self.args.clip_init else None)

        # Make a new optimizer as we added new params
        if self.args.optim_type == 'adam':
            self.optimizer = Adam(self.net.parameters(), lr=self.args.lr)
        elif self.args.optim_type == 'sgd':
            self.optimizer = SGD(self.net.parameters(), lr=self.args.lr)

        if self.args.dynamic is not None:
            for y_ in range(len(self.pca_list)):
                if self.statistics['eigval'][y_] is not None:
                    self.pca_list[y_].n_components = self.n_components
                    self.statistics['eigval'][y_] = self.statistics['eigval'][y_][:self.n_components]
                    self.statistics['eigvec'][y_] = self.statistics['eigvec'][y_][:self.n_components]

    def append_calibration_heads(self):
        """ Add new calibration parameters for new heads """
        temp_w = torch.rand(len(self.seen_ids), device=self.args.device)
        temp_b = torch.rand(len(self.seen_ids), device=self.args.device)

        # Carry over the learned parameters
        temp_w.data[:-1] = self.w.data
        temp_b.data[:-1] = self.b.data

        self.w = Variable(temp_w.data, requires_grad=True)
        self.b = Variable(temp_b.data, requires_grad=True)

        del temp_w, temp_b

    def save(self, **kwargs):
        self.saving_buffer['statistics'] = self.statistics
        self.saving_buffer['pca_list'] = self.pca_list
        self.saving_buffer['seen_names'] = self.seen_names
        self.saving_buffer['seen_ids'] = self.seen_ids
        self.saving_buffer['n_components'] = self.args.n_components

        for key in kwargs:
            self.saving_buffer[key] = kwargs[key]

        torch.save(self.saving_buffer, self.args.logger.dir() + 'saving_buffer')

    def load(self, **kwargs):
        for key, val in kwargs.items():
            if hasattr(self, key):
                self.args.logger.print(f"** {self.__class__.__name__}: Update {key} values **")
                setattr(self, key, val)
            else:
                self.args.logger.print(f"** WARNING: {self.__class__.__name__}: {key} values are not updated **")

    def update_stats(self, inputs, labels, y_, n_samples):
        if self.args.normalize:
            inputs = inputs / torch.norm(inputs, dim=-1, keepdim=True)
        idx = labels[:n_samples] == y_

        selected_inputs = inputs[:n_samples][idx].view(-1, inputs.size(1))

        if len(selected_inputs) > 0:
            self.pca_list[y_].partial_fit(selected_inputs.cpu().numpy())
            self.statistics['mu'][y_] = self.pca_list[y_].mean_
            self.statistics['eigvec'][y_] = self.pca_list[y_].components_
            self.statistics['eigval'][y_] = self.pca_list[y_].explained_variance_
        else:
            self.left_samples[y_] = torch.cat((self.left_samples[y_],
                                                selected_inputs.cpu()))
            if len(self.left_samples[y_]) > 0:
                self.pca_list[y_].partial_fit(self.left_samples[y_].numpy())
                self.left_samples[y_] = torch.tensor([])
                self.statistics['mu'][y_] = self.pca_list[y_].mean_
                self.statistics['eigvec'][y_] = self.pca_list[y_].components_
                self.statistics['eigval'][y_] = self.pca_list[y_].explained_variance_