from typing import Union, List, Optional

from scipy.stats import ortho_group
import torch
import torch.nn as nn
from torch.nn import functional as F

from src.utils import jl_transform


class BaseCDVAE(nn.Module):

    def __init__(
            self, 
            n_classes: Optional[Union[int, List[int]]],
            latent_dim: Optional[int], 
            latent_dim_class: Optional[Union[int, List[int]]],
            beta_kl: float=1.,
            mean_norm: float=1.,
            var_norm: float=1.,
            n_samples: int=1,
        ):
        nn.Module.__init__(self)

        self.n_classes = [n_classes] if isinstance(n_classes, int) else n_classes
        self.latent_dim = latent_dim
        self.latent_dim_class = [latent_dim_class] if isinstance(latent_dim_class, int) else latent_dim_class

        assert len(self.n_classes)==len(self.latent_dim_class), \
            "lengths of n_classes and latent_dim_class must be the same"

        self.beta_kl = beta_kl
        if mean_norm:
            self.register_buffer("mean_norm", torch.Tensor([mean_norm]))
        else:
            self.mean_norm = nn.Parameter(0.1*torch.randn(1))
        self.register_buffer("var_norm", torch.Tensor([var_norm]))

        self._init_distribution_classes()
        self.n_samples = n_samples
        

    #==========Initialize target latent space====================
    def _init_distribution_classes(self):
        self.means = [create_onedim_base(
            n, self.latent_dim_class[i]
        ) for i, n in enumerate(self.n_classes)]


    #==========Forward methods====================
    def encode(self):
        raise NotImplementedError
    

    def decode(self):
        raise NotImplementedError
    

    def reparameterize(self, mean, logvar, n_samples):
        if n_samples==0:
            return mean
        std = torch.exp(0.5 * logvar)
        eps = torch.randn((tuple([n_samples])+std.shape)).to(mean.device)
        z = mean + eps * std
        z = z.swapaxes(0,1).reshape(-1, z.shape[2])
        return z
    

    def forward(self, x, n_samples=None):
        n_samples = self.n_samples if n_samples is None else n_samples
        latent_mean, latent_logvar = self.encode(x)
        z = self.reparameterize(latent_mean, latent_logvar, n_samples=n_samples)
        x_hat = self.decode(z)
        if n_samples>1:
            x_hat = x_hat.reshape(-1, n_samples, *x_hat.shape[1:]).mean(dim=1)
        return {'x_hat': x_hat, 'mean': latent_mean, 'logvar': latent_logvar}
    

    #==========Calculate Losses====================
    def kl_loss(self, mean, logvar, labels):
        mean_y = torch.zeros((mean.shape[0], self.latent_dim)).to(mean.get_device())
        var_y = self.var_norm * torch.ones((mean.shape[0], self.latent_dim)).to(mean.get_device())
        dim_ini = 0
        labels = [labels] if len(self.n_classes)==1 else labels
        for i, n in enumerate(self.n_classes):
            dim_fin = dim_ini + self.latent_dim_class[i]
            labels_i = F.one_hot(labels[i], n).to(dtype=torch.float) if \
                len(labels[i].shape)==1 else labels[i]
            mean_y[:,dim_ini:dim_fin] = self.mean_norm * labels_i @ self.means[i]
            dim_ini = dim_fin
            
        kl_loss = -0.5 * torch.sum(
            1 + \
            logvar - torch.log(var_y) - \
            (mean - mean_y)**2 / var_y - \
            logvar.exp() / var_y, 
            axis=1
        )
        
        if len(mean.shape)==3:
            return kl_loss.mean(axis=1)
        else:
            return kl_loss


    def recons_loss(self, x, x_hat):
        return torch.mean((x-x_hat)**2, axis=(1,2,3))


    def calc_loss(self, x, x_hat, mean, logvar, labels):
        kl_loss = self.kl_loss(mean, logvar, labels)
        recons_loss = self.recons_loss(x, x_hat)
        total_loss = recons_loss + self.beta_kl*kl_loss
        return {
            'kl_loss':      kl_loss,
            'recons_loss':  recons_loss,
            'total_loss':   total_loss
        }


    #==========Generate Samples====================
    def generate_samples(self, num_samples, device):
        z = torch.randn(num_samples, self.latent_dim).to(device)
        samples = self.decode(z)
        return samples
    

    #==========Reconstruct====================
    def reconstruct(self, x):
        mean, _ = self.encode(x)
        x_rec = self.decode(mean)
        return x_rec
    

    #==========Device management methods====================
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        for i in range(len(self.n_classes)):
            self.means[i] = (self.means[i]).to(*args, **kwargs)
        return self


    #==========Load & Save model==========
    def state_dict(self, *args, **kwargs):
        state_dict = super().state_dict(*args, **kwargs)
        for i in range(len(self.n_classes)):
            state_dict['mean_{}'.format(i)] = self.means[i]
        return state_dict
    
    
    def load_state_dict(self, state_dict, strict: bool=False):
        for i in range(len(self.n_classes)):
            self.means[i] = state_dict['mean_{}'.format(i)]
        super().load_state_dict(state_dict, strict)



class CDVAEDataParallel(nn.DataParallel):

    def __init__(self, vae_model):
        super().__init__(vae_model)

    def calc_loss(self, x, x_hat, mean, logvar, labels):
        return self.module.calc_loss(x, x_hat, mean, logvar, labels)

    def state_dict(self):
        return self.module.state_dict()

    def load_state_dict(self, state_dict, *args, **kwargs):
        self.module.load_state_dict(state_dict, *args, **kwargs)

    def generate_samples(self, num_samples, device):
        return self.module.generate_samples(num_samples, device)
    

def create_canonical_base(n_vectors, len_vectors):
    assert len_vectors>=n_vectors, \
        "number of vectors cannot be superior to the length of them"
    return torch.eye(len_vectors)[:n_vectors]


def create_orthonormal_base(n_vectors, len_vectors):
    if len_vectors >= n_vectors:
        mean_class = ortho_group.rvs(dim=len_vectors)[:n_vectors]
    else:
        mean_class = ortho_group.rvs(dim=n_vectors)
        mean_class = jl_transform(mean_class, len_vectors)
    return torch.Tensor(mean_class)


def create_zeros_base(n_vectors, len_vectors):
    return torch.zeros((n_vectors, len_vectors))


def create_onedim_base(n_vectors, len_vectors, max_val=10):
    base = torch.zeros((n_vectors, len_vectors))
    lin = torch.linspace(-max_val, max_val, n_vectors)
    for i in range(n_vectors):
        base[i, 0] = lin[i]
    return base