import os
import math
import numpy as np

import torch
from torch import nn, optim
import torch.nn.functional as F

from utils import *
from filters import filters
from funcs import *
import timm
from timm.models.vision_transformer import Block
from models_tsn import *

class PositionalEncoding(nn.Module):
    """ position encoding """
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe = torch.einsum('nbc->bnc', pe)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:,:x.size(1),:]
        return self.dropout(x)
    
class fMRIAutoEncoder(nn.Module):
    """ fMRI autoencoder """
    def __init__(self, cfg):
        super().__init__()
        
        self.embed_dim = cfg.mdl_embed_dim
        self.num_segments = cfg.fmri_num_segs
        self.segment_len = cfg.fmri_seg_size
        self.num_rois = cfg.num_rois
        
        TSN_variants = {'TSN_Vanilla': TransientStateNetVanilla,
                        'TSN_EmbConcat': TransientStateNetEmbedConcat,
                        'TSN_EmbAttn': TransientStateNetEmbedAttn,
                        'MaskedTSN': MaskedTransientStateNet,
                        'DynamicMaskedTSN': DynamicMaskedTransientStateNet}
        # --------------------------------------------------------------------------
        # Encoder specifics
        self.tsn = TSN_variants[cfg.TSN_variants](cfg)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.pos_encoder = PositionalEncoding(d_model=self.embed_dim, dropout=0, 
                                              max_len=100)
        self.blocks = nn.ModuleList([
            Block(self.embed_dim, cfg.AE_num_heads, cfg.AE_mlp_ratio, 
                  qkv_bias=True, norm_layer=cfg.AE_norm_layer)
            for i in range(cfg.AE_depth)])
        self.norm = cfg.AE_norm_layer(self.embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # Decoder specifics
        self.decoder_embed = nn.Linear(self.embed_dim, cfg.AE_dec_embed_dim, bias=True)
        self.decoder_pos_encoder = PositionalEncoding(d_model=cfg.AE_dec_embed_dim, dropout=0, 
                                              max_len=100)
        self.decoder_blocks = nn.ModuleList([
            Block(cfg.AE_dec_embed_dim, cfg.AE_num_heads, cfg.AE_mlp_ratio, 
                  qkv_bias=True, norm_layer=cfg.AE_norm_layer)
            for i in range(cfg.AE_dec_depth)])

        self.decoder_norm = cfg.AE_norm_layer(cfg.AE_dec_embed_dim)
        self.decoder_pred = nn.Linear(cfg.AE_dec_embed_dim, 
                                      self.segment_len * self.num_rois, bias=True) # decoder to patch
        # --------------------------------------------------------------------------

        self.initialize_weights()

    def initialize_weights(self):
        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
    
    def forward_encoder(self, x, mask_ratio):
        """
        x: (B, num_segs, seg_len, num_rois)
        """
        L = x.shape[1]
        pos_embed = self.pos_encoder.pe
    
        x, mask, ids_restore = self.tsn(x, mask_ratio)
        # add pos embed w/o cls token
        x = x + pos_embed[:, 1:(L+1), :]
        
        # append cls token
        cls_token = self.cls_token + pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x, mask, ids_restore
    
    def forward_decoder(self, x):
        """
        x: (B, num_masked_segs, D)
        """
        decoder_pos_embed = self.decoder_pos_encoder.pe
        
        x = self.decoder_embed(x)
        
        # add pos embed
        B, L, _ = x.shape
        x = x + decoder_pos_embed[:, :L, :]

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)
        x = x.reshape(B, L, self.segment_len, self.num_rois)
        # remove cls token
        x = x[:, 1:, :, :]
        return x
        
    def forward(self, x, mask_ratio=0.5, mask_loss_ratio=0.65):
        """
        x: (B, num_segs, seg_len, num_rois)
        """
        latent, mask, ids_restore = self.forward_encoder(x, mask_ratio)
        pred = self.forward_decoder(latent)  # [N, L, p*p*3]
        loss, (loss_masked, loss_unmasked) = self.forward_loss(x, pred, mask, mask_loss_ratio)
        return loss, pred, mask, latent, (loss_masked, loss_unmasked)
    
    def forward_loss(self, tgt, pred, mask, mask_loss_ratio):
        """
        tgt: (B, num_segs, seg_len, num_rois)
        pred: [B, num_segs, seg_len, num_rois]
        mask: [B * num_segs, num_rois], 0 is keep, 1 is remove, 
        """
        B, S, L, N = tgt.shape
        tgt = torch.einsum('bln->bnl', tgt.reshape(B*S, L, N))
        pred = torch.einsum('bln->bnl', pred.reshape(B*S, L, N))
        
        loss = (pred - tgt) ** 2
        loss = loss.mean(dim=-1) # [B * num_segs, num_rois]
            
        N_masked_samples = mask.sum()
        N_unmasked_samples = (1 - mask).sum()
        if N_masked_samples == 0:
            loss_masked_samples = torch.tensor(torch.nan, device=pred.device)
            loss_unmasked_samples = (loss * (1 - mask)).sum() / N_unmasked_samples
            loss = loss_unmasked_samples
            
        elif N_unmasked_samples == 0:
            loss_masked_samples = (loss * mask).sum() / N_masked_samples
            loss_unmasked_samples = torch.tensor(torch.nan, device=pred.device)
            loss = loss_masked_samples
            
        else:
            loss_masked_samples = (loss * mask).sum() / N_masked_samples  # mean loss on removed patches
            loss_unmasked_samples = (loss * (1 - mask)).sum() / N_unmasked_samples  # mean loss on not removed patches
            loss = mask_loss_ratio * loss_masked_samples + (1 - mask_loss_ratio) * loss_unmasked_samples

        return loss, (loss_masked_samples, loss_unmasked_samples)
    
class fMRIStateTransferModel(nn.Module):
    """ fMRI autoencoder """
    def __init__(self, cfg):
        super().__init__()
        
        self.embed_dim = cfg.mdl_embed_dim
        self.num_segments = cfg.fmri_num_segs
        self.segment_len = cfg.fmri_seg_size
        self.num_rois = cfg.num_rois
        
        TSN_variants = {'TSN_Vanilla': TransientStateNetVanilla,
                        'TSN_EmbConcat': TransientStateNetEmbedConcat,
                        'TSN_EmbAttn': TransientStateNetEmbedAttn,
                        'MaskedTSN': MaskedTransientStateNet,
                        'DynamicMaskedTSN': DynamicMaskedTransientStateNet}
        # --------------------------------------------------------------------------
        # Encoder specifics (should exactly same to fMRIAutoEncoder)
        self.tsn = TSN_variants[cfg.TSN_variants](cfg)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.pos_encoder = PositionalEncoding(d_model=self.embed_dim, dropout=0, 
                                              max_len=100)
        self.blocks = nn.ModuleList([
            Block(self.embed_dim, cfg.AE_num_heads, cfg.AE_mlp_ratio, 
                  qkv_bias=True, norm_layer=cfg.AE_norm_layer)
            for i in range(cfg.AE_depth)])
        self.norm = cfg.AE_norm_layer(self.embed_dim)
        # --------------------------------------------------------------------------
        
        # --------------------------------------------------------------------------
        # Projection head specifics
        self.global_pool = cfg.global_pool
        self.fc_norm = cfg.head_norm_layer(self.embed_dim) if cfg.global_pool else nn.Identity()
        self.head_drop = nn.Dropout(cfg.head_drop_rate)
        self.head = nn.Linear(self.embed_dim, cfg.head_output_dim)

        # Transient head
        self.has_reg_head = hasattr(cfg, 'has_reg_head') and cfg.has_reg_head
        if self.has_reg_head:
            self.reg_head_drop = nn.Dropout(cfg.reg_head_drop_rate)
            self.reg_head = nn.Linear(self.embed_dim, cfg.reg_head_output_dim)
        
    def forward_features(self, x, mask_ratio=0.0):
        """
        x: (B, num_segs, seg_len, num_rois)
        feats: (B, num_segs, D)
        """
        L = x.shape[1]
        pos_embed = self.pos_encoder.pe

        x, _, _ = self.tsn(x, mask_ratio)
        # add pos embed w/o cls token
        x = x + pos_embed[:, 1:(L+1), :]

        # append cls token
        cls_token = self.cls_token + pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x
    
    def forward_head(self, x):
        """
        x: (B, num_masked_segs, D)
        """
        x = x[:, 1:, :].mean(dim=1) if self.global_pool else x[:, 0, :]
        x = self.fc_norm(x)
        x = self.head_drop(x)
        return self.head(x)
    
    def forward_transient_head(self, x):
        """
        x: (B, num_masked_segs, D)
        """
        x = x[:, 1:, :]
        x = self.reg_head_drop(x)
        return self.reg_head(x)
    
    def forward(self, x, mask_ratio=0.0):
        x = self.forward_features(x, mask_ratio=mask_ratio)
        if self.has_reg_head:
            reg = self.forward_transient_head(x)
            x = self.forward_head(x)
            return x, reg
        else:
            x = self.forward_head(x)
            return x
        
    def get_last_selfattention(self, x, mask_ratio=0.0):
        """
        x: (B, num_segs, seg_len, num_rois)
        feats: (B, num_segs, D)
        """
        L = x.shape[1]
        pos_embed = self.pos_encoder.pe

        x, _, _ = self.tsn(x, mask_ratio)
        # add pos embed w/o cls token
        x = x + pos_embed[:, 1:(L+1), :]

        # append cls token
        cls_token = self.cls_token + pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for i, blk in enumerate(self.blocks):
            if i < len(self.blocks) - 1:
                x = blk(x)
            else:
                # return attention of the last block
                return blk(x, return_attention=True)