import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .subNets.BertTextEncoder import BertTextEncoder
from .subNets.transformers_encoder.transformer import TransformerEncoder
from .scoremodel import ScoreNet, loss_fn, Euler_Maruyama_sampler
import functools
from .rcan import Group
from random import sample

__all__ = ['IMDER']


class MSE(nn.Module):
    def __init__(self):
        super(MSE, self).__init__()

    def forward(self, pred, real):
        diffs = torch.add(real, -pred)
        n = torch.numel(diffs.data)
        mse = torch.sum(diffs.pow(2)) / n

        return mse

# Set up the SDE (SDE is used to define Diffusion Process)
device = 'cuda'
def marginal_prob_std(t, sigma):
    """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

    Args:
      t: A vector of time steps.
      sigma: The $\sigma$ in our SDE.

    Returns:
      The standard deviation.
    """
    t = torch.as_tensor(t, device=device)
    return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
    """Compute the diffusion coefficient of our SDE.

    Args:
      t: A vector of time steps.
      sigma: The $\sigma$ in our SDE.

    Returns:
      The vector of diffusion coefficients.
    """
    return torch.as_tensor(sigma ** t, device=device)

# Set up IMDer
class IMDER(nn.Module):
    def __init__(self, args):
        super(IMDER, self).__init__()
        self.args = args
        self.embed_dim = args.dst_feature_dim_nheads[0]
        
        if args.use_bert:
            self.text_model = BertTextEncoder(use_finetune=args.use_finetune, transformers=args.transformers, pretrained=args.pretrained)
        self.use_bert = args.use_bert
        self.orig_d_l, self.orig_d_a, self.orig_d_v = args.feature_dims
        self.d_l = self.d_a = self.d_v = self.embed_dim
        self.num_heads = args.dst_feature_dim_nheads[1]
        self.layers = args.nlevels
        self.attn_dropout = args.attn_dropout
        self.attn_dropout_modalities = args.attn_dropout_modalities
        self.relu_dropout = args.relu_dropout
        self.embed_dropout = args.embed_dropout
        self.res_dropout = args.res_dropout
        self.output_dropout = args.output_dropout
        self.text_dropout = args.text_dropout
        self.attn_mask = args.attn_mask
        self.MSE = MSE()

        combined_dim = 2 * (self.d_l + self.d_a + self.d_v)

        output_dim = args.num_classes if args.train_mode == "classification" else 1

        sigma = 25.0
        self.marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
        self.diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)  # used for sample
        self.score_l = ScoreNet(marginal_prob_std=self.marginal_prob_std_fn)
        self.score_v = ScoreNet(marginal_prob_std=self.marginal_prob_std_fn)
        self.score_a = ScoreNet(marginal_prob_std=self.marginal_prob_std_fn)

        self.cat_lv = nn.Conv1d(self.d_l * 2, self.d_l, kernel_size=1, padding=0)
        self.cat_la = nn.Conv1d(self.d_l * 2, self.d_l, kernel_size=1, padding=0)
        self.cat_va = nn.Conv1d(self.d_l * 2, self.d_l, kernel_size=1, padding=0)

        # 1. Temporal convolutional layers
        self.proj = [nn.Conv1d(args.feature_dims[i], self.embed_dim, kernel_size=1,
                              padding=0, bias=False) for i in range(len(args.feature_dims))]
        self.proj = nn.ModuleList(self.proj)

        # 2. Crossmodal Attentions
        self.trans = [nn.ModuleList([self.get_network(i, j) for j in range(len(args.feature_dims)) if i != j]) 
                     for i in range(len(args.feature_dims))]
        self.trans = nn.ModuleList(self.trans)
        
        # 3. Self Attentions
        self.trans_mems = [self.get_network(
            i, i, mem=True, layers=3) for i in range(len(args.feature_dims))]
        self.trans_mems = nn.ModuleList(self.trans_mems)

        # Projection layers
        self.proj1 = nn.Linear(combined_dim, combined_dim)
        self.proj2 = nn.Linear(combined_dim, combined_dim)
        self.out_layer = nn.Linear(combined_dim, output_dim)

    def get_network(self, mod1, mod2, mem=False, layers=-1):
        """Matches MULTModel's get_network function"""
        if not mem:
            embed_dim = self.embed_dim
            attn_dropout = self.attn_dropout_modalities[mod2]
        else:
            embed_dim = 2 * self.embed_dim
            attn_dropout = self.attn_dropout

        return TransformerEncoder(embed_dim=embed_dim,
                                num_heads=self.num_heads,
                                layers=max(self.layers, layers),
                                attn_dropout=attn_dropout,
                                relu_dropout=self.relu_dropout,
                                res_dropout=self.res_dropout,
                                embed_dropout=self.embed_dropout,
                                attn_mask=self.attn_mask)

    def generate_missing_modalities(self, proj_x_l, proj_x_a, proj_x_v, num_modal=None, ava_modal_idx=None):
        loss_score_l, loss_score_a, loss_score_v= 0.0, 0.0, 0.0

        if num_modal == 1:  # one modality is available
            if ava_modal_idx[0] == 0:  # has text
                conditions = proj_x_l
                loss_score_a = loss_fn(self.score_a, proj_x_a, 
                                    self.marginal_prob_std_fn, condition=conditions)
                loss_score_v = loss_fn(self.score_v, proj_x_v, 
                                    self.marginal_prob_std_fn, condition=conditions)
                loss_score_l = torch.tensor(0)
                # Generate samples from score-based models with the Euler_Maruyama_sampler
                proj_x_a = Euler_Maruyama_sampler(self.score_a, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                  device='cuda', condition=conditions)
                proj_x_v = Euler_Maruyama_sampler(self.score_v, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                  device='cuda', condition=conditions)
            elif ava_modal_idx[0] == 1: # has audio
                conditions = proj_x_a
                loss_score_l = loss_fn(self.score_l, proj_x_l, self.marginal_prob_std_fn, condition=conditions)
                loss_score_v = loss_fn(self.score_v, proj_x_v, self.marginal_prob_std_fn, condition=conditions)
                loss_score_a = torch.tensor(0)
                # Generate samples from score-based models with the Euler_Maruyama_sampler
                proj_x_l = Euler_Maruyama_sampler(self.score_l, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                    device='cuda', condition=conditions)
                proj_x_v= Euler_Maruyama_sampler(self.score_v, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                    device='cuda', condition=conditions)
            else:  # has video
                conditions = proj_x_v
                loss_score_l = loss_fn(self.score_l, proj_x_l, self.marginal_prob_std_fn, condition=conditions)
                loss_score_a = loss_fn(self.score_a, proj_x_a, self.marginal_prob_std_fn, condition=conditions)
                loss_score_v = torch.tensor(0)
                # Generate samples from score-based models with the Euler_Maruyama_sampler
                proj_x_l = Euler_Maruyama_sampler(self.score_l, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                  device='cuda', condition=conditions)
                proj_x_a = Euler_Maruyama_sampler(self.score_a, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                  device='cuda', condition=conditions)
        if num_modal == 2:  # two modalities are available
            if 0 not in ava_modal_idx:   # L is missing (V,A available)
                conditions = self.cat_va(torch.cat([proj_x_v, proj_x_a], dim=1))  # cat two avail modalities as conditions
                loss_score_l = loss_fn(self.score_l, proj_x_l, self.marginal_prob_std_fn, condition=conditions)
                loss_score_v, loss_score_a = torch.tensor(0), torch.tensor(0)
                # Generate samples from score-based models with the Euler_Maruyama_sampler
                proj_x_l = Euler_Maruyama_sampler(self.score_l, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                  device='cuda', condition=conditions)
            if 1 not in ava_modal_idx: # A is missing (L,V available)
                conditions = self.cat_lv(torch.cat([proj_x_l, proj_x_v], dim=1))  # cat two avail modalities as conditions
                loss_score_a = loss_fn(self.score_a, proj_x_a, self.marginal_prob_std_fn, condition=conditions)
                loss_score_l, loss_score_v = torch.tensor(0), torch.tensor(0)
                # Generate samples from score-based models with the Euler_Maruyama_sampler
                proj_x_a  = Euler_Maruyama_sampler(self.score_a, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                  device='cuda', condition=conditions)
            if 2 not in ava_modal_idx:  # V is missing (L,A available)
                conditions = self.cat_la(torch.cat([proj_x_l, proj_x_a], dim=1))  # cat two avail modalities as conditions
                loss_score_v = loss_fn(self.score_v, proj_x_v, self.marginal_prob_std_fn, condition=conditions)
                loss_score_l, loss_score_a = torch.tensor(0), torch.tensor(0)
                # Generate samples from score-based models with the Euler_Maruyama_sampler
                proj_x_v = Euler_Maruyama_sampler(self.score_v, self.marginal_prob_std_fn, self.diffusion_coeff_fn, proj_x_l.size(0),
                                                  device='cuda', condition=conditions)
        if num_modal == 3:  # no missing
            loss_score_l, loss_score_v, loss_score_a = torch.tensor(0), torch.tensor(0), torch.tensor(0)

        generated_features = [proj_x_l, proj_x_a, proj_x_v]
        generation_loss = [loss_score_l, loss_score_a, loss_score_v] 

        return generated_features, generation_loss

    def forward(self, x, num_modal=None, ava_modal_idx=None):
        proj_x = []
    
        for i, v in enumerate(x):
            if self.use_bert and i == 0: 
                with torch.no_grad():
                    v = self.text_model(v)
                    v = F.dropout(v.transpose(1, 2), p=self.text_dropout, training=self.training)
                    if v.size(-1) != self.embed_dim:
                        v = self.proj[i](v)
            else:  
                v = v.permute(0, 2, 1)
                if v.size(-1) != self.embed_dim:
                    v = self.proj[i](v)
            proj_x.append(v)

        proj_x_l, proj_x_a, proj_x_v = proj_x[0], proj_x[1], proj_x[2]

        generated_features, generation_loss = self.generate_missing_modalities(proj_x_l, proj_x_a, proj_x_v, 
                                         num_modal=num_modal, ava_modal_idx=ava_modal_idx)

        proj_x_l, proj_x_a, proj_x_v = generated_features
        loss_score_l, loss_score_a, loss_score_v = generation_loss

        proj_x_a = proj_x_a.permute(2, 0, 1)
        proj_x_v = proj_x_v.permute(2, 0, 1)
        proj_x_l = proj_x_l.permute(2, 0, 1)

        # Text modality (i=0)
        h_l_with_as = self.trans[0][0](proj_x_l, proj_x_a, proj_x_a)
        h_l_with_vs = self.trans[0][1](proj_x_l, proj_x_v, proj_x_v) 
        h_ls = torch.cat([h_l_with_as, h_l_with_vs], dim=2)
        h_ls = self.trans_mems[0](h_ls)
        if type(h_ls) == tuple:
            h_ls = h_ls[0]
        last_h_l = h_ls[-1]
        last_h_l = last_h_l.detach().clone().requires_grad_(True)
       
        # Audio modality (i=1)
        h_a_with_ls = self.trans[1][0](proj_x_a, proj_x_l, proj_x_l)
        h_a_with_vs = self.trans[1][1](proj_x_a, proj_x_v, proj_x_v)
        h_as = torch.cat([h_a_with_ls, h_a_with_vs], dim=2)
        h_as = self.trans_mems[1](h_as)
        if type(h_as) == tuple:
            h_as = h_as[0]
        last_h_a = h_as[-1]
        last_h_a = last_h_a.detach().clone().requires_grad_(True)

        # Video modality (i=2)
        h_v_with_ls = self.trans[2][0](proj_x_v, proj_x_l, proj_x_l)
        h_v_with_as = self.trans[2][1](proj_x_v, proj_x_a, proj_x_a)
        h_vs = torch.cat([h_v_with_ls, h_v_with_as], dim=2)
        h_vs = self.trans_mems[1](h_vs)
        if type(h_vs) == tuple:
            h_vs = h_vs[0]
        last_h_v = h_vs[-1]
        last_h_v = last_h_v.detach().clone().requires_grad_(True)

        last_hs = torch.cat([last_h_l, last_h_a, last_h_v], dim=1)
        # A residual block
        last_hs_proj = self.proj2(
            F.dropout(F.relu(self.proj1(last_hs), inplace=True), p=self.output_dropout, training=self.training))
        last_hs_proj += last_hs

        output = self.out_layer(last_hs_proj)

        res = {
            'Feature_t': last_h_l,
            'Feature_a': last_h_a,
            'Feature_v': last_h_v,
            'Feature_f': last_hs,
            'loss_score_l': loss_score_l,
            'loss_score_a': loss_score_a,
            'loss_score_v': loss_score_v,
            'Generated':
                {
                    'text': generated_features[0],
                    'audio': generated_features[1],
                    'video': generated_features[2]
                },
            'M': output
        }
        return res
