import torch
import numpy as np
from apprs.basemodel import BaseModel
from collections import Counter

class LSP(BaseModel):
    """
    This provides a base module for LSP methods and its variants such as iLSP.
    It mainly provides sampling method
    """
    def __init__(self, args):
        super(LSP, self).__init__(args)
        self.statistics = {'mu': [], 'eigvec': [], 'eigval': []}
        self.n_seen_samples = []

    def sampling(self, ys, batch_size=None):
        """
        Sample features from the latent feature space
        ys: a list of labels not appear in current batch, but has computed its mean, and eigenpairs
        NOTE. the eigenpairs are saved for each labels except for ipca_single. For technical reason,
        statistics['eigval'] and statistics['eigvec'] are arrays. Therefore, shouldn't access by [y_].
        """
        sample_data, sample_label = [], []
        if batch_size is not None:
            selected_cls = np.random.choice(ys, size=batch_size, replace=True)
        else:
            selected_cls = np.random.choice(ys, size=self.args.minibatch_size, replace=True)

        cls_n_samples_pair = Counter(selected_cls)
        for y_, sz in cls_n_samples_pair.items():
            if self.statistics['mu'][y_] is not None:
                if 'ipca_single' in self.args.model:
                    shape = self.statistics['eigval'].shape[0]
                    eigvec = self.statistics['eigvec'].T
                    eigval_sqrt = np.sqrt(self.statistics['eigval'])
                else:
                    shape = self.statistics['eigval'][y_].shape[0]
                    eigvec = self.statistics['eigvec'][y_].T
                    eigval_sqrt = np.sqrt(self.statistics['eigval'][y_])
                rand_samples = np.random.standard_normal(size=(sz, shape))
                temp = eigvec * eigval_sqrt
                sz_r, sz_c = temp.shape
                temp = np.dot(rand_samples, temp.T)
                rand_samples = self.statistics['mu'][y_] + temp
                sample_data.append(rand_samples)
                sample_label.append(torch.zeros(sz, dtype=torch.long) + y_.item())
        
        if len(sample_data) > 0:
            sample_data = np.concatenate(sample_data)
            sample_data = torch.from_numpy(sample_data)
            sample_label = torch.cat(sample_label)
    
            idx = np.random.permutation(len(sample_data))[:self.args.minibatch_size]
            sample_data = sample_data[idx].type(torch.FloatTensor).to(self.args.device)
            sample_label = sample_label[idx]
            sample_label = sample_label.to(self.args.device)

        return sample_data, sample_label

