import os
import math
import numpy as np

import torch
from torch import nn, optim
import torch.nn.functional as F

from utils import *
from filters import filters
from funcs import *
import timm
from timm.layers import Mlp, DropPath
from timm.models.vision_transformer import Block
from model_utils import MultiheadAttention

class fMRISegEmbed(nn.Module):
    def __init__(self, embed_dim, fmri_seg_size):
        super().__init__()
        
        self.proj = nn.Conv1d(1, embed_dim, kernel_size=fmri_seg_size)
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        """
        x: (B, seg_len)
        """
        x = self.proj(x.unsqueeze(dim=1)).squeeze()
        x = self.norm(x)
        return x
    
class TransientStateNetVanilla(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        
        assert cfg.fmri_embed_dim == cfg.mdl_embed_dim
        self.fmriEmbed = fMRISegEmbed(embed_dim=cfg.fmri_embed_dim,
                                      fmri_seg_size=cfg.fmri_seg_size)
        
        # self-attention
        self.blocks = nn.ModuleList([
            Block(cfg.mdl_embed_dim, cfg.TSN_num_heads, cfg.TSN_mlp_ratio, 
                  qkv_bias=True, norm_layer=cfg.TSN_norm_layer)
            for i in range(cfg.TSN_depth)])
        self.norm1 = cfg.TSN_norm_layer(cfg.mdl_embed_dim)
        
        self.outProj = nn.Conv1d(cfg.mdl_embed_dim, cfg.mdl_embed_dim, kernel_size=cfg.num_rois)
        self.norm2 = cfg.TSN_norm_layer(cfg.mdl_embed_dim)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, cfg.fmri_embed_dim))
        self.roi_fixed_mask = False # generate random mask with fixed roi set 
        
    def forward(self, x, mask_ratio):
        """
        input: (B, num_segs, seg_len, num_rois)
        out: (B, num_segs, mdl_embed_dim)
        """
        B, S, L, N = x.shape
        x = torch.einsum('bsln->bsnl', x)
        x = x.reshape(-1,L)
        x = self.fmriEmbed(x)
        x = x.reshape(B*S,N,-1)
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm1(x)
        
        # output projection
        x = torch.einsum('bnc->bcn', x)
        x = self.outProj(x).squeeze().reshape(B,S,-1)
        x = self.norm2(x)
        return x, mask, ids_restore
    
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        len_mask = L - len_keep
        
        if self.roi_fixed_mask:
            noise = torch.rand(N, 1, device=x.device).repeat(1, L)  # noise in [0, 1]
        else:
            noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_keep = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        mask_tokens = self.mask_token.repeat(N, len_mask, 1)

        x_masked = torch.cat([x_keep, mask_tokens], dim=1)
        x_masked = torch.gather(x_masked, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) 

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore
    
class TransientStateNetEmbedConcat(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        assert cfg.fmri_embed_dim + cfg.roi_embed_dim == cfg.mdl_embed_dim
        
        # fmri embedding
        self.fmriEmbed = fMRISegEmbed(embed_dim=cfg.fmri_embed_dim,
                                      fmri_seg_size=cfg.fmri_seg_size)
        
        # roi embedding
        self.roiEmbed = nn.Embedding(cfg.num_rois, cfg.roi_embed_dim)
        self.roiIndex = torch.LongTensor(np.arange(cfg.num_rois))
        
        # self-attention
        self.blocks = nn.ModuleList([
            Block(cfg.mdl_embed_dim, cfg.TSN_num_heads, cfg.TSN_mlp_ratio, 
                  qkv_bias=True, norm_layer=cfg.TSN_norm_layer)
            for i in range(cfg.TSN_depth)])
        self.norm1 = cfg.TSN_norm_layer(cfg.mdl_embed_dim)
        
        # out projection (aggregate results)
        self.outProj = nn.Conv1d(cfg.mdl_embed_dim, cfg.mdl_embed_dim, kernel_size=cfg.num_rois)
        self.norm2 = cfg.TSN_norm_layer(cfg.mdl_embed_dim)
        
    def forward(self, x):
        """
        input: (B, num_segs, seg_len, num_rois)
        out: (B, num_segs, mdl_embed_dim)
        """
        B, S, L, N = x.shape
        x = torch.einsum('bsln->bsnl', x)
        x = x.reshape(-1, L)
        x = self.fmriEmbed(x)
        x = x.reshape(B*S, N, -1)
        
        # concatenate embedding
        e = self.roiEmbed(self.roiIndex.expand(B*S, -1).to(x.device))
        x = torch.cat((x, e), 2)
        
        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm1(x)
        
        # output projection
        x = torch.einsum('bnc->bcn', x)
        x = self.outProj(x).squeeze().reshape(B,S,-1)
        x = self.norm2(x)
        return x
    
class TransientStateNetEmbedAttn(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        assert cfg.fmri_embed_dim == cfg.mdl_embed_dim
        assert cfg.roi_embed_dim == cfg.mdl_embed_dim
        assert (cfg.TSN_attn_type == 'embedding') or (cfg.TSN_attn_type == 'fc')
        
        self.attn_type = cfg.TSN_attn_type
        
        # fmri embedding
        self.fmriEmbed = fMRISegEmbed(embed_dim=cfg.fmri_embed_dim,
                                      fmri_seg_size=cfg.fmri_seg_size)
        
        # roi embedding
        self.roiEmbed = nn.Embedding(cfg.num_rois, cfg.roi_embed_dim)
        self.roiIndex = torch.LongTensor(np.arange(cfg.num_rois))
        
        # self-attention
        self.blocks = nn.ModuleList([
            FCAttnBlock(cfg.mdl_embed_dim, cfg.TSN_num_heads, cfg.TSN_mlp_ratio, 
                        qkv_bias=True, norm_layer=cfg.TSN_norm_layer, qk_proj=cfg.TSN_qk_proj,
                        self_inclusive=cfg.TSN_self_inclusive)
            for i in range(cfg.TSN_depth)])
                
        # out projection (aggregate results)
        self.outProj = nn.Conv1d(cfg.mdl_embed_dim, cfg.mdl_embed_dim, kernel_size=cfg.num_rois)
        self.norm = cfg.TSN_norm_layer(cfg.mdl_embed_dim)
        
    def forward(self, x):
        """
        input: (B, num_segs, seg_len, num_rois)
        out: (B, num_segs, mdl_embed_dim)
        """
        B, S, L, N = x.shape
        x = torch.einsum('bsln->bsnl', x)
        x = x.reshape(-1, L)
        x = self.fmriEmbed(x)
        x = x.reshape(B*S, N, -1)
        
        if self.attn_type == 'embedding':
            # roi embedding as key and query
            kq = self.roiEmbed(self.roiIndex.expand(B*S, -1).to(x.device))
        else:
            kq = x
            x = self.roiEmbed(self.roiIndex.expand(B*S, -1).to(x.device))
        
        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(kq, kq, x)
        
        # output projection
        x = torch.einsum('bnc->bcn', x)
        x = self.outProj(x).squeeze().reshape(B,S,-1)
        x = self.norm(x)
        return x
    
class MaskedTransientStateNet(nn.Module):
    """ Masked transient state net with embedding/fc as attention """
    def __init__(self, cfg):
        super().__init__()
        assert cfg.fmri_embed_dim == cfg.mdl_embed_dim
        assert cfg.roi_embed_dim == cfg.mdl_embed_dim
        assert (cfg.TSN_attn_type == 'embedding') or (cfg.TSN_attn_type == 'fc')
        
        self.attn_type = cfg.TSN_attn_type
        
        # fmri embedding
        self.fmriEmbed = fMRISegEmbed(embed_dim=cfg.fmri_embed_dim,
                                      fmri_seg_size=cfg.fmri_seg_size)
        
        # roi embedding
        self.roiEmbed = nn.Embedding(cfg.num_rois, cfg.roi_embed_dim)
        self.roiIndex = torch.LongTensor(np.arange(cfg.num_rois))
        
        self.mask_token = nn.Parameter(torch.zeros(1, 1, cfg.fmri_embed_dim))
        self.roi_fixed_mask = False # generate random mask with fixed roi set 
        
        # self-attention
        self.blocks = nn.ModuleList([
            FCAttnBlock(cfg.mdl_embed_dim, cfg.TSN_num_heads, cfg.TSN_mlp_ratio, 
                        qkv_bias=True, norm_layer=cfg.TSN_norm_layer, qk_proj=cfg.TSN_qk_proj,
                        self_inclusive=cfg.TSN_self_inclusive)
            for i in range(cfg.TSN_depth)])
                
        # out projection (aggregate results)
        self.outProj = nn.Conv1d(cfg.mdl_embed_dim, cfg.mdl_embed_dim, kernel_size=cfg.num_rois)
        self.norm = cfg.TSN_norm_layer(cfg.mdl_embed_dim)
        
        self.initialize_weights()
        
    def initialize_weights(self):
        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.mask_token, std=.02)
        
    def forward(self, x, mask_ratio):
        """
        input: (B, num_segs, seg_len, num_rois)
        out: (B, num_segs, mdl_embed_dim)
        """
        
        B, S, L, N = x.shape
        x = torch.einsum('bsln->bsnl', x)
        x = x.reshape(-1, L)
        x = self.fmriEmbed(x)
        x = x.reshape(B*S, N, -1)
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
        if self.attn_type == 'embedding':
            # roi embedding as key and query
            kq = self.roiEmbed(self.roiIndex.expand(B*S, -1).to(x.device))
        else:
            kq = x
            x = self.roiEmbed(self.roiIndex.expand(B*S, -1).to(x.device))
        
        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(kq, kq, x)
        
        # output projection
        x = torch.einsum('bnc->bcn', x)
        x = self.outProj(x).squeeze().reshape(B,S,-1)
        x = self.norm(x)
        return x, mask, ids_restore
        
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        len_mask = L - len_keep
        
        if self.roi_fixed_mask:
            noise = torch.rand(N, 1, device=x.device).repeat(1, L)  # noise in [0, 1]
        else:
            noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_keep = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        mask_tokens = self.mask_token.repeat(N, len_mask, 1)

        x_masked = torch.cat([x_keep, mask_tokens], dim=1)
        x_masked = torch.gather(x_masked, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) 

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore
    
class DynamicMaskedTransientStateNet(nn.Module):
    """ Masked transient state net with embedding/fc as attention """
    def __init__(self, cfg):
        super().__init__()
        assert cfg.fmri_embed_dim == cfg.mdl_embed_dim
        assert cfg.roi_embed_dim == cfg.mdl_embed_dim
        assert cfg.TSN_attn_type == 'embedding'
        
        self.attn_type = cfg.TSN_attn_type

        # fmri embedding
        self.fmriEmbed = fMRISegEmbed(embed_dim=cfg.fmri_embed_dim,
                                      fmri_seg_size=cfg.fmri_seg_size)
        
        # roi embedding
        self.roiEmbed = nn.Embedding(cfg.num_rois, cfg.roi_embed_dim)
        self.roiIndex = torch.LongTensor(np.arange(cfg.num_rois))
        
        self.mask_token = nn.Parameter(torch.zeros(1, 1, cfg.fmri_embed_dim))
        self.roi_fixed_mask = False # generate random mask with fixed roi set 
        
        # self-attention
        blocks = [FCAttnBlock(cfg.mdl_embed_dim, cfg.TSN_num_heads, cfg.TSN_mlp_ratio, 
                              qkv_bias=True, norm_layer=cfg.TSN_norm_layer, qk_proj=cfg.TSN_qk_proj,
                              self_inclusive=cfg.TSN_self_inclusive, normalize_outputs=False)]
        blocks += [Block(cfg.mdl_embed_dim, cfg.TSN_num_heads, cfg.TSN_mlp_ratio, 
                         qkv_bias=True, norm_layer=cfg.TSN_norm_layer)
                    for i in range(cfg.TSN_depth-1)]
        self.blocks = nn.ModuleList(blocks)
        
        self.norm1 = cfg.TSN_norm_layer(cfg.mdl_embed_dim)
        
        # out projection (aggregate results)
        self.outProj = nn.Conv1d(cfg.mdl_embed_dim, cfg.mdl_embed_dim, kernel_size=cfg.num_rois)
        self.norm2 = cfg.TSN_norm_layer(cfg.mdl_embed_dim)
        
    def initialize_weights(self):
        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.mask_token, std=.02)
        
    def forward(self, x, mask_ratio):
        """
        input: (B, num_segs, seg_len, num_rois)
        out: (B, num_segs, mdl_embed_dim)
        """
        
        B, S, L, N = x.shape
        x = torch.einsum('bsln->bsnl', x)
        x = x.reshape(-1, L)
        x = self.fmriEmbed(x)
        x = x.reshape(B*S, N, -1)
        x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
        if self.attn_type == 'embedding':
            # roi embedding as key and query
            kq = self.roiEmbed(self.roiIndex.expand(B*S, -1).to(x.device))
        else:
            kq = x
            x = self.roiEmbed(self.roiIndex.expand(B*S, -1).to(x.device))
        
        # apply Transformer blocks
        for blk in self.blocks[:1]:
            x = blk(kq, kq, x)

        for blk in self.blocks[1:]:
            x = blk(x)
        x = self.norm1(x)
        
        # output projection
        x = torch.einsum('bnc->bcn', x)
        x = self.outProj(x).squeeze().reshape(B,S,-1)
        x = self.norm2(x)
        return x, mask, ids_restore
        
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        len_mask = L - len_keep
        
        if self.roi_fixed_mask:
            noise = torch.rand(N, 1, device=x.device).repeat(1, L)  # noise in [0, 1]
        else:
            noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_keep = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        mask_tokens = self.mask_token.repeat(N, len_mask, 1)

        x_masked = torch.cat([x_keep, mask_tokens], dim=1)
        x_masked = torch.gather(x_masked, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) 

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        
        return x_masked, mask, ids_restore
    
class FCAttnBlock(nn.Module):
    def __init__(
            self,
            embed_dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            qk_norm=False,
            qk_proj=True,
            proj_drop=0.,
            attn_drop=0.05,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            mlp_layer=Mlp,
            self_inclusive=False,
            normalize_outputs=True
    ):
        super().__init__()
        self.norm1 = norm_layer(embed_dim)
        self.attn = MultiheadAttention(embed_dim, num_heads, attn_drop=attn_drop, 
                                       proj_drop=proj_drop, bias=qkv_bias, qk_proj=qk_proj)
        
        self.ls1 = LayerScale(embed_dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(embed_dim) if normalize_outputs else nn.Identity()
        self.mlp = mlp_layer(
            in_features=embed_dim,
            hidden_features=int(embed_dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(embed_dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.self_inclusive = self_inclusive
        
    def forward(self, q, k ,v):
        N = q.shape[1]
        if self.self_inclusive:
            attn_mask = (torch.ones(N, N)).to(q.device)
        else:
            attn_mask = (1 - torch.diag(torch.ones(N))).to(q.device)
        
        v_, m = self.attn(q, k, v, attn_mask=attn_mask)
        v_ = self.ls1(self.drop_path1(v_))
        x = v + v_
        
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm1(x))))
        x = self.norm2(x)
        return x
    
    