
import torch
from torch import nn
import e3nn
from e3nn import o3

from .linearity import linearity
from .nonlinearity import nonlinearity
from .normalization import signal_norm
from .blocks import CGNetBlock, FFNN_block
from loss_functions import *

from torch import Tensor
from typing import Optional, List, Dict

import sys, os

# MAX_FLOAT32 = torch.tensor(3e30).type(torch.float32)

class ClebschGordanVAE_symmetric_simple(torch.nn.Module):
    def __init__(self,
                 irreps_in: o3.Irreps,
                 latent_dim: int,
                 n_cg_blocks: int,
                 irreps_cg_hidden: o3.Irreps,
                 w3j_matrices: Dict[int, Tensor],
                 device: str,
                 use_batch_norm: bool = True,
                 weights_initializer: Optional[str] = None,
                 nonlinearity_rule: str = 'full',
                 ch_nonlin_rule: str = 'full',
                 norm_type: str = 'layer', # None, layer, signal
                 normalization: str = 'component', # norm, component -> only considered if norm_type is not none
                 norm_affine: Optional[Union[str, bool]] = True, # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                 norm_nonlinearity: str = 'swish', # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                 norm_location: str = 'between', # first, between, last
                 linearity_first: bool = False, # currently only works with this being false
                 filter_symmetric: bool = True, # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                 x_rec_loss_fn: str = 'cosine', # mse, mse_normalized, cosine
                 do_final_signal_norm: bool = True,
                 learn_frame: bool = True):
        super().__init__()

        self.irreps_in = irreps_in
        self.device = device
        self.linearity_first = linearity_first
        self.use_batch_norm = use_batch_norm
        self.x_rec_loss_fn = x_rec_loss_fn
        self.latent_dim = latent_dim
        self.do_final_signal_norm = do_final_signal_norm

        initial_irreps = irreps_in

        # prepare irreps for both encoder and decoder
        prev_irreps = o3.Irreps([irr for irr in irreps_cg_hidden if irr.ir.l <= 1]) # start with lmax = 1, the multiplicities are arbitrary
        irreps_list = [prev_irreps]
        for i in range(1, n_cg_blocks):
            next_irreps = o3.Irreps([irr for irr in irreps_cg_hidden if irr.ir.l <= max(prev_irreps.ls)*2])
            irreps_list.append(next_irreps)
            prev_irreps = next_irreps
        irreps_list.append(initial_irreps)

        decoder_irreps_list = irreps_list[1:] # exclude first one since it's different
        encoder_irreps_list = irreps_list[:-1][::-1] # exclude data irreps and reverse

        ## encoder - cg
        prev_irreps = initial_irreps
        self.irreps_for_output = [initial_irreps]
        encoder_cg_blocks = []
        for i in range(n_cg_blocks):
            temp_irreps_hidden = encoder_irreps_list[i]
            encoder_cg_blocks.append(CGNetBlock(prev_irreps,
                                                temp_irreps_hidden,
                                                w3j_matrices,
                                                linearity_first=linearity_first,
                                                filter_symmetric=filter_symmetric,
                                                use_batch_norm=self.use_batch_norm,
                                                ls_nonlin_rule=nonlinearity_rule, # full, elementwise, efficient
                                                ch_nonlin_rule=ch_nonlin_rule, # full, elementwise
                                                norm_type=norm_type, # None, layer, signal
                                                normalization=normalization, # norm, component -> only if norm_type is not none
                                                norm_affine=norm_affine, # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=norm_nonlinearity, # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=norm_location, # first, between, last
                                                weights_initializer=weights_initializer,
                                                init_scale=1.0))


            prev_irreps = encoder_cg_blocks[-1].irreps_out
            print(prev_irreps)
            self.irreps_for_output.append(encoder_cg_blocks[-1].irreps_out)
        
        self.encoder_cg_blocks = torch.nn.ModuleList(encoder_cg_blocks)

        invariants_latent_space = [mul for (mul, _) in prev_irreps][0]
        self.encoder_mean = torch.nn.Linear(invariants_latent_space, latent_dim)
        self.encoder_log_var = torch.nn.Linear(invariants_latent_space, latent_dim)

        # component that learns the frame
        self.learn_frame = learn_frame
        if learn_frame:
            # take l=1 vectors (extract multiplicities) of last block and learn two l=1 vectors (x and pseudo-y direction)
            frame_learner_l1_mul = torch.sum(torch.tensor(self.encoder_cg_blocks[-1].irreps_out.ls) == 1).item()
            frame_learner_irreps_in = o3.Irreps('%dx1e' % frame_learner_l1_mul)
            frame_learner_irreps_out = o3.Irreps('2x1e')
            self.frame_learner = torch.nn.Sequential(linearity(frame_learner_irreps_in, frame_learner_irreps_out))
            self.frame_upsample = torch.nn.Sequential(linearity(o3.Irreps('3x1e'), frame_learner_irreps_in))

        ## decoder - cg
        decoder_cg_blocks = []
        prev_irreps = o3.Irreps([irr for irr in irreps_cg_hidden if irr.ir.l <= 1])
        print(prev_irreps)
        for i in range(n_cg_blocks):

            ## NOTE: linearity_first must be False
            temp_decoder_irreps_cg_hidden = decoder_irreps_list[i]

            decoder_cg_blocks.append(CGNetBlock(prev_irreps,
                                                temp_decoder_irreps_cg_hidden, # change to temp_decoder_irreps_cg_hidden for the pretrained model for CSE543 report
                                                w3j_matrices,
                                                linearity_first=linearity_first,
                                                filter_symmetric=filter_symmetric,
                                                use_batch_norm=self.use_batch_norm,
                                                ls_nonlin_rule=nonlinearity_rule, # full, elementwise, efficient
                                                ch_nonlin_rule=ch_nonlin_rule, # full, elementwise
                                                norm_type=norm_type, # None, layer, signal
                                                normalization=normalization, # norm, component -> only if norm_type is not none
                                                norm_affine=norm_affine, # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=norm_nonlinearity, # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=norm_location, # first, between, last
                                                weights_initializer=weights_initializer,
                                                init_scale=1.0))

            # prev_irreps = decoder_cg_blocks[-1].irreps_out
            prev_irreps = temp_decoder_irreps_cg_hidden
            print(prev_irreps)

        self.decoder_cg_blocks = torch.nn.ModuleList(decoder_cg_blocks)

        if self.do_final_signal_norm:
            self.final_signal_norm = torch.nn.Sequential(signal_norm(irreps_in, normalization='component', affine=None))

        ## setup reconstruction loss functions
        self.signal_rec_loss = eval(NAME_TO_LOSS_FN[x_rec_loss_fn])(irreps_in, self.device)
    
    # @profile
    def encode(self, x: Tensor):
        # print('---------------------- In encoder ----------------------', file=sys.stderr)

        h = x
        for i, block in enumerate(self.encoder_cg_blocks):
            h = block(h)
            
            if self.learn_frame and i == len(self.encoder_cg_blocks) - 1:
                last_l1_values = {1: h[1]}

        z_mean = self.encoder_mean(h[0].squeeze())
        z_log_var = self.encoder_log_var(h[0].squeeze())

        if self.learn_frame:
            learned_frame = self.orthonormalize_frame(self.frame_learner(last_l1_values)[1])
        else:
            learned_frame = None

        # print('---------------------- Out of encoder ----------------------', file=sys.stderr)
        return (z_mean, z_log_var), None, learned_frame
    
    # @profile
    def decode(self, z: Tensor, frame: Tensor):
        # print('---------------------- In decoder ----------------------', file=sys.stderr)
        
        h = self.frame_upsample({1: frame})
        h[0] = z.unsqueeze(-1)
        for i, block in enumerate(self.decoder_cg_blocks):
            h = block(h)

        x_reconst = h

        if self.do_final_signal_norm:
            x_reconst = self.final_signal_norm(x_reconst)
        # print('---------------------- Out of decoder ----------------------', file=sys.stderr)

        return x_reconst, None

    # @profile
    def forward(self, x: Dict[int, Tensor], x_vec: Optional[Tensor] = None, frame: Optional[Tensor] = None):
        '''
        Note: this function is independent of the choice of probability distribution for the latent space,
              and of the choice of encoder and decoder. Only the inputs and outputs must be respected
        '''

        distribution_params, _, learned_frame = self.encode(x)
        z = self.reparameterization(*distribution_params)

        if self.learn_frame:
            frame = learned_frame

        x_reconst, _ = self.decode(z, frame)

        def make_vector(x: Dict[int, Tensor]):
            x_vec = []
            for l in sorted(list(x.keys())):
                x_vec.append(x[l].reshape((x[l].shape[0], -1)))
            return torch.cat(x_vec, dim=-1)

        # gather loss values
        x_reconst_vec = make_vector(x_reconst)
        if x_vec is None:
            x_vec = make_vector(x) # NOTE: doing this is sub-optimal!
        
        x_reconst_loss = self.signal_rec_loss(x_reconst_vec, x_vec)

        kl_divergence = self.kl_divergence(*distribution_params) / self.latent_dim  # KLD is summed over each latent variable, so it's better to divide it by the latent dim
                                                                                    # to get a value that is independent (or less dependent) of the latent dim size

        return None, x_reconst_loss, kl_divergence, None, x_reconst, (distribution_params, None, None)
    
    def reparameterization(self, mean: Tensor, log_var: Tensor):

        # isotropic gaussian latent space
        stddev = torch.exp(0.5 * log_var) # takes exponential function (log var -> stddev)
        # stddev = torch.where(torch.isposinf(stddev), MAX_FLOAT32.to(self.device), stddev)
        # stddev = torch.where(torch.isneginf(stddev), (-MAX_FLOAT32).to(self.device), stddev)
        epsilon = torch.randn_like(stddev).to(self.device)        # sampling epsilon        
        z = mean + stddev*epsilon                          # reparameterization trick

        return z
    
    def kl_divergence(self, z_mean: Tensor, z_log_var: Tensor):
        # isotropic normal prior on the latent space
        return torch.mean(- 0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp(), dim=-1))
    
    def orthonormalize_frame(self, x_psy_N6):
        '''
        Gram-Schmidt process
        
        y = psy - (<x, psy> / <x, x>) x
        z = x \cross y

        x = x / ||x||
        y = y / ||y||
        z = z / ||z||
        '''
        
        x, psy = x_psy_N6[:, 0, :], x_psy_N6[:, 1, :]
        
        x_dot_psy = torch.sum(torch.mul(x, psy), dim=1).view(-1, 1)
        x_dot_x = torch.sum(torch.mul(x, x), dim=1).view(-1, 1)

        y = psy - (x_dot_psy/x_dot_x) * x
        
        z = torch.cross(x, y, dim=1)
        
        x = x / torch.sqrt(torch.sum(torch.mul(x, x), dim=1).view(-1, 1))
        y = y / torch.sqrt(torch.sum(torch.mul(y, y), dim=1).view(-1, 1))
        z = z / torch.sqrt(torch.sum(torch.mul(z, z), dim=1).view(-1, 1))
        
        xyz = torch.stack([x, y, z], dim=1)
        
        return xyz
