import torch
import numpy as np
from my_ipca import MyIPCA as IPCA
from torch.optim import SGD, Adam
from basemodel import BaseModel
from collections import Counter

class PLS(BaseModel):
    # Mahalanobis Distnace. No memory used
    def __init__(self, args):
        super(PLS, self).__init__(args)
        self.mu_list, self.eigvec_list, self.eigval_list = [], [], []
        self.pca_list = []

        self.n_components = args.n_components
        self.ff = args.ff

    def observe(self, inputs, labels, names, not_aug_inputs=None, f_y=None, **kwargs):
        self.net.train()
        try:
            text_embedding = kwargs['text_embedding']
        except KeyError:
            text_embedding = None

        original_labels = labels.clone()
        n_samples = len(labels)

        self.optimizer.zero_grad()

        for y_, name_ in zip(labels, names):
            if y_ not in self.seen_ids:
                self.append_model_heads(y_, name_)

        labels = self.map_labels(labels)
        ys = list(sorted(set(labels.data.cpu().numpy())))

        outputs = self.net(inputs, self.args.normalize)

        if len(self.mu_list) > 0:
            relabeled_seen_ids = self.map_labels(self.seen_ids)
            ys_not_in_batch = [y_.item() for y_ in self.seen_ids if y_ not in original_labels]

            sample_data, sample_label = [], []
            if len(ys_not_in_batch) > 0:
                selected_cls = np.random.choice(ys_not_in_batch, size=self.args.minibatch_size, replace=True)

                cls_n_samples_pair = Counter(selected_cls)
                for y_, sz in cls_n_samples_pair.items():
                    # Sampling
                    if self.mu_list[y_] is not None:
                        rand_samples = np.random.standard_normal(size=(sz, self.eigval_list[y_].shape[0]))
                        temp = self.eigvec_list[y_].T * np.sqrt(self.eigval_list[y_])
                        temp = np.dot(rand_samples, temp.T)
                        rand_samples = self.mu_list[y_] + temp
                        sample_data.append(rand_samples)
                        y_ = self.map_labels(torch.tensor([y_]))
                        sample_label.append(torch.zeros(sz, dtype=torch.long) + y_.item())

                if len(sample_data) > 0:
                    sample_data = np.concatenate(sample_data)
                    sample_data = torch.from_numpy(sample_data)
                    sample_label = torch.cat(sample_label)
                
                    idx = np.random.permutation(len(sample_data))[:self.args.minibatch_size]
                    sample_label = sample_label[idx]
                    sample_label = sample_label.to(self.args.device)
                    sample_outputs = self.net(sample_data[idx].type(torch.FloatTensor).to(self.args.device),
                                            self.args.normalize) # sample_data is already normalized to some extend since the mu and cov are computed with normalized features

                    outputs = torch.cat([outputs, sample_outputs])
                    labels = torch.cat([labels, sample_label])

        loss = self.criterion(outputs, labels)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.total_loss = loss.item()
        self.correct += outputs[:n_samples].argmax(1).eq(labels[:n_samples]).sum().item()
        self.total += n_samples

        return loss.item()

    def append_model_heads(self, y, name):
        """ Append a new head to the model """
        self.seen_names.append(name)
        self.seen_ids.append(y)

        # Append a new head to the classifier
        self.net.make_head(1, name, clip_init=self.args.model_clip if self.args.clip_init else None)

        # Make a new optimizer as we added new params
        if self.args.optim_type == 'adam':
            self.optimizer = Adam(self.net.parameters(), lr=self.args.lr)
        elif self.args.optim_type == 'sgd':
            self.optimizer = SGD(self.net.parameters(), lr=self.args.lr)

