
import torch
from torch import nn
import e3nn
from e3nn import o3

from .linearity import linearity
from .nonlinearity import nonlinearity
from .normalization import signal_norm
from .blocks import CGNetBlock, FFNN_block
from loss_functions import *

from torch import Tensor
from typing import Optional, List, Dict

import sys, os

# MAX_FLOAT32 = torch.tensor(3e30).type(torch.float32)

class ClebschGordanVAE_symmetric(torch.nn.Module):
    def __init__(self,
                 irreps_in: o3.Irreps,
                 latent_dim: int,
                 n_cg_blocks: int,
                 irreps_cg_hidden: o3.Irreps,
                 w3j_matrices: Dict[int, Tensor],
                 device: str,
                 n_reconstruction_layers: int = 0,
                 bottleneck_hidden_dims: Optional[List[int]] = None, # in order for encoder, get reversed for decoder
                 dropout_rate: float = 0.0,
                 use_batch_norm: bool = True,
                 weights_initializer: Optional[str] = None,
                 nonlinearity_rule: str = 'full',
                 ch_nonlin_rule:str = 'full',
                 norm_type: str = 'layer', # None, layer, signal
                 normalization: str = 'component', # norm, component -> only considered if norm_type is not none
                 norm_affine: Optional[Union[str, bool]] = True, # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                 norm_nonlinearity: str = 'swish', # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                 norm_location: str = 'between', # first, between, last
                 linearity_first: bool = False, # currently only works with this being false
                 filter_symmetric: bool = True, # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                 sf_rec_loss_fn: str = 'mse', # mse, cosine
                 softmax_before_sf_mse: bool = False, # unused
                 x_rec_loss_fn: str = 'cosine', # mse, mse_normalized, cosine
                 do_final_signal_norm: bool = True,
                 teacher_forcing: bool = False,
                 learn_frame: bool = True):
        super().__init__()

        self.irreps_in = irreps_in
        self.device = device
        self.linearity_first = linearity_first
        self.use_batch_norm = use_batch_norm
        self.sf_rec_loss_fn = sf_rec_loss_fn
        self.softmax_before_sf_mse = softmax_before_sf_mse
        self.x_rec_loss_fn = x_rec_loss_fn
        self.latent_dim = latent_dim
        self.n_reconstruction_layers = n_reconstruction_layers
        self.teacher_forcing = teacher_forcing
        self.do_final_signal_norm = do_final_signal_norm

        initial_irreps = irreps_in

        # prepare irreps for both encoder and decoder
        self.l_0_hidden_dim = (torch.Tensor(irreps_cg_hidden.ls) == 0).sum().item()
        prev_irreps = o3.Irreps([irr for irr in irreps_cg_hidden if irr.ir.l <= 1]) # start with lmax = 1, the multiplicities are arbitrary
        irreps_list = [prev_irreps]
        for i in range(1, n_cg_blocks):
            next_irreps = o3.Irreps([irr for irr in irreps_cg_hidden if irr.ir.l <= max(prev_irreps.ls)*2])
            irreps_list.append(next_irreps)
            prev_irreps = next_irreps
        irreps_list.append(initial_irreps)

        decoder_irreps_list = irreps_list[1:] # exclude first one since it's different
        encoder_irreps_list = irreps_list[:-1][::-1] # exclude data irreps and reverse

        ## encoder - cg
        prev_irreps = initial_irreps
        self.irreps_for_output = [initial_irreps]
        encoder_cg_blocks = []
        for i in range(n_cg_blocks):
            temp_irreps_hidden = encoder_irreps_list[i]
            encoder_cg_blocks.append(CGNetBlock(prev_irreps,
                                                temp_irreps_hidden,
                                                w3j_matrices,
                                                linearity_first=linearity_first,
                                                filter_symmetric=filter_symmetric,
                                                use_batch_norm=self.use_batch_norm,
                                                ls_nonlin_rule=nonlinearity_rule, # full, elementwise, efficient
                                                ch_nonlin_rule=ch_nonlin_rule, # full, elementwise
                                                norm_type=norm_type, # None, layer, signal
                                                normalization=normalization, # norm, component -> only if norm_type is not none
                                                norm_affine=norm_affine, # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=norm_nonlinearity, # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=norm_location, # first, between, last
                                                weights_initializer=weights_initializer,
                                                init_scale=1.0))


            prev_irreps = encoder_cg_blocks[-1].irreps_out
            print(prev_irreps)
            self.irreps_for_output.append(encoder_cg_blocks[-1].irreps_out)
        
        self.encoder_cg_blocks = torch.nn.ModuleList(encoder_cg_blocks)
        
        # compute ls_indices for each of the irreps_for_output, and compute dimensionality of scalar features as well
        self.ls_indices_per_irreps_for_output = []
        scalar_features_dim = 0
        for irreps in self.irreps_for_output:
            self.ls_indices_per_irreps_for_output.append(torch.cat([torch.tensor(irreps.ls)[torch.tensor(irreps.ls) == l].repeat(2*l+1) for l in sorted(list(set(irreps.ls)))]))
            scalar_features_dim += sum([irr.mul for irr in irreps if irr.ir.l <= 0]) # get size of scalar features

        ## encoder - bottleneck
        self.scalar_features_dim = scalar_features_dim
        self.encoder_bottleneck = FFNN_block(self.scalar_features_dim, bottleneck_hidden_dims,
                                             nonlinearity='relu', use_batch_norm=False, dropout_rate=dropout_rate)

        self.encoder_mean = torch.nn.Linear(self.encoder_bottleneck.output_dim, latent_dim)
        self.encoder_log_var = torch.nn.Linear(self.encoder_bottleneck.output_dim, latent_dim)

        ## decoder - bottleneck
        self.decoder_bottleneck = FFNN_block(latent_dim, bottleneck_hidden_dims[::-1], output_dim=self.scalar_features_dim,
                                             nonlinearity='relu', use_batch_norm=False, dropout_rate=dropout_rate)

        # component that learns the frame
        self.learn_frame = learn_frame
        if learn_frame:
            # take l=1 vectors (extract multiplicities) of last block and learn two l=1 vectors (x and pseudo-y direction)
            frame_learner_l1_mul = torch.sum(torch.tensor(self.encoder_cg_blocks[-1].irreps_out.ls) == 1).item()
            frame_learner_irreps_in = o3.Irreps('%dx1e' % frame_learner_l1_mul)
            frame_learner_irreps_out = o3.Irreps('2x1e')
            self.frame_learner = torch.nn.Sequential(linearity(frame_learner_irreps_in, frame_learner_irreps_out))
            self.frame_upsample = torch.nn.Sequential(linearity(o3.Irreps('3x1e'), frame_learner_irreps_in))

        ## decoder - cg
        decoder_cg_blocks = []
        prev_irreps = o3.Irreps([irr for irr in irreps_cg_hidden if irr.ir.l <= 1])
        print(prev_irreps)
        for i in range(n_cg_blocks):

            ## NOTE: linearity_first must be False
            temp_decoder_irreps_cg_hidden = decoder_irreps_list[i]
            temp_decoder_irreps_cg_hidden_l_gr_0 = o3.Irreps([irr for irr in temp_decoder_irreps_cg_hidden if irr.ir.l > 0])

            decoder_cg_blocks.append(CGNetBlock(prev_irreps,
                                                temp_decoder_irreps_cg_hidden_l_gr_0, # change to temp_decoder_irreps_cg_hidden for the pretrained model for CSE543 report
                                                w3j_matrices,
                                                linearity_first=linearity_first,
                                                filter_symmetric=filter_symmetric,
                                                use_batch_norm=self.use_batch_norm,
                                                ls_nonlin_rule=nonlinearity_rule, # full, elementwise, efficient
                                                ch_nonlin_rule=ch_nonlin_rule, # full, elementwise
                                                norm_type=norm_type, # None, layer, signal
                                                normalization=normalization, # norm, component -> only if norm_type is not none
                                                norm_affine=norm_affine, # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=norm_nonlinearity, # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=norm_location, # first, between, last
                                                weights_initializer=weights_initializer,
                                                init_scale=1.0))

            # prev_irreps = decoder_cg_blocks[-1].irreps_out
            prev_irreps = temp_decoder_irreps_cg_hidden
            print(prev_irreps)

        self.decoder_cg_blocks = torch.nn.ModuleList(decoder_cg_blocks)

        if self.do_final_signal_norm:
            self.final_signal_norm = torch.nn.Sequential(signal_norm(irreps_in, normalization='component', affine=None))

        ## setup reconstruction loss functions
        scalar_irreps = (self.scalar_features_dim * o3.Irreps('0e')).simplify()
        self.scalar_rec_loss = eval(NAME_TO_LOSS_FN[sf_rec_loss_fn])(scalar_irreps, self.device)
        self.signal_rec_loss = eval(NAME_TO_LOSS_FN[x_rec_loss_fn])(irreps_in, self.device)
        
        ## setup scalar features regularization loss function
        self.sf_regularization_fn = lambda x: torch.mean(-torch.log(torch.abs(x)))
    
    # @profile
    def encode(self, x: Tensor):
        # print('---------------------- In encoder ----------------------', file=sys.stderr)
        scalar_features = []
        scalar_features.append(x[0].squeeze(-1)) # get scalar values

        h = x
        for i, block in enumerate(self.encoder_cg_blocks):
            h = block(h)

            scalar_features.append(h[0].squeeze(-1)) # get scalar values
            
            if self.learn_frame and i == len(self.encoder_cg_blocks) - 1:
                last_l1_values = {1: h[1]}

            
        scalar_features_raw = torch.cat(scalar_features, dim=-1)

        # makes each value of direc_scalar_features ~average~ 1.0 --> I don't use this, and it changes the entropy of the vector anyway
        if 'mse' in self.sf_rec_loss_fn and self.softmax_before_sf_mse:
            scalar_features = scalar_features_raw * torch.nn.functional.softmax(scalar_features_raw, dim=-1)
        else:
            scalar_features = scalar_features_raw
        
        h = self.encoder_bottleneck(scalar_features)

        z_mean = self.encoder_mean(h)
        z_log_var = self.encoder_log_var(h)

        if self.learn_frame:
            learned_frame = self.orthonormalize_frame(self.frame_learner(last_l1_values)[1])
        else:
            learned_frame = None

        # print('---------------------- Out of encoder ----------------------', file=sys.stderr)
        return (z_mean, z_log_var), scalar_features, learned_frame
    
    # @profile
    def decode(self, z: Tensor, frame: Tensor, scalar_features_from_encoder: Optional[Tensor] = None):
        # print('---------------------- In decoder ----------------------', file=sys.stderr)
        scalar_features_reconst = self.decoder_bottleneck(z)

        if self.training and self.teacher_forcing:
            scalar_features_for_signal_reconst = scalar_features_from_encoder
        else:
            scalar_features_for_signal_reconst = scalar_features_reconst
        
        num_l_0_in_input = (self.ls_indices_per_irreps_for_output[0] == 0).sum()
        scalar_x_reconst = scalar_features_for_signal_reconst[:, : num_l_0_in_input] # save the first one
        
        l_gr_0 = self.frame_upsample({1: frame})
        for i, block in enumerate(self.decoder_cg_blocks):
            
            if i == 0:
                sf = scalar_features_for_signal_reconst[:, -self.l_0_hidden_dim :]
            else:
                sf = scalar_features_for_signal_reconst[:, -(i+1) * self.l_0_hidden_dim : -i * self.l_0_hidden_dim]
           
            h = l_gr_0
            h[0] = sf.unsqueeze(-1) # need to add an extra dimension for the moments

            l_gr_0 = block(h)

        x_reconst = l_gr_0
        x_reconst[0] = scalar_x_reconst.unsqueeze(-1) # need to add an extra dimension for the moments

        if self.do_final_signal_norm:
            x_reconst = self.final_signal_norm(x_reconst)
        # print('---------------------- Out of decoder ----------------------', file=sys.stderr)

        return x_reconst, scalar_features_reconst

    # @profile
    def forward(self, x: Dict[int, Tensor], x_vec: Optional[Tensor] = None, frame: Optional[Tensor] = None):
        '''
        Note: this function is independent of the choice of probability distribution for the latent space,
              and of the choice of encoder and decoder. Only the inputs and outputs must be respected
        '''

        distribution_params, scalar_features, learned_frame = self.encode(x)
        z = self.reparameterization(*distribution_params)

        if self.learn_frame:
            frame = learned_frame

        x_reconst, scalar_features_reconst = self.decode(z, frame, scalar_features)

        def make_vector(x: Dict[int, Tensor]):
            x_vec = []
            for l in sorted(list(x.keys())):
                x_vec.append(x[l].reshape((x[l].shape[0], -1)))
            return torch.cat(x_vec, dim=-1)

        # gather loss values
        scalar_features_reconst_loss = self.scalar_rec_loss(scalar_features_reconst.to(self.device), scalar_features.to(self.device))
        
        x_reconst_vec = make_vector(x_reconst)
        if x_vec is None:
            x_vec = make_vector(x) # NOTE: doing this is sub-optimal!
        
        x_reconst_loss = self.signal_rec_loss(x_reconst_vec, x_vec)

        kl_divergence = self.kl_divergence(*distribution_params) / self.latent_dim  # KLD is summed over each latent variable, so it's better to divide it by the latent dim
                                                                                    # to get a value that is independent (or less dependent) of the latent dim size
        sf_reg = self.sf_regularization_fn(scalar_features)

        return scalar_features_reconst_loss, x_reconst_loss, kl_divergence, sf_reg, x_reconst, (distribution_params, scalar_features, scalar_features_reconst)
    
    def reparameterization(self, mean: Tensor, log_var: Tensor):

        # isotropic gaussian latent space
        stddev = torch.exp(0.5 * log_var) # takes exponential function (log var -> stddev)
        # stddev = torch.where(torch.isposinf(stddev), MAX_FLOAT32.to(self.device), stddev)
        # stddev = torch.where(torch.isneginf(stddev), (-MAX_FLOAT32).to(self.device), stddev)
        epsilon = torch.randn_like(stddev).to(self.device)        # sampling epsilon        
        z = mean + stddev*epsilon                          # reparameterization trick

        return z
    
    def kl_divergence(self, z_mean: Tensor, z_log_var: Tensor):
        # isotropic normal prior on the latent space
        return torch.mean(- 0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp(), dim=-1))
    
    def orthonormalize_frame(self, x_psy_N6):
        '''
        Gram-Schmidt process
        
        y = psy - (<x, psy> / <x, x>) x
        z = x \cross y

        x = x / ||x||
        y = y / ||y||
        z = z / ||z||
        '''
        
        x, psy = x_psy_N6[:, 0, :], x_psy_N6[:, 1, :]
        
        x_dot_psy = torch.sum(torch.mul(x, psy), dim=1).view(-1, 1)
        x_dot_x = torch.sum(torch.mul(x, x), dim=1).view(-1, 1)

        y = psy - (x_dot_psy/x_dot_x) * x
        
        z = torch.cross(x, y, dim=1)
        
        x = x / torch.sqrt(torch.sum(torch.mul(x, x), dim=1).view(-1, 1))
        y = y / torch.sqrt(torch.sum(torch.mul(y, y), dim=1).view(-1, 1))
        z = z / torch.sqrt(torch.sum(torch.mul(z, z), dim=1).view(-1, 1))
        
        xyz = torch.stack([x, y, z], dim=1)
        
        return xyz
