import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from functools import partial
try:
    from xspatial_modules import hMLP_stem, hMLP_output, SubsampledLinear, UpsampledLinear
    from xmixed_modules import build_spacetime_block, SpaceTimeBlock
except:
    from .xspatial_modules import hMLP_stem, hMLP_output, SubsampledLinear, UpsampledLinear
    from .xmixed_modules import build_spacetime_block, SpaceTimeBlock

def build_avit(params):
    """ Builds model from parameter file. 

    General recipe is to build the spatial and temporal modules separately and then
    combine them in a model. Eventually the "stem" and "destem" should 
    also be parameterized. 
    """
    space_time_block = build_spacetime_block(params)
    model = XAViT(patch_size=params.patch_size,
                     embed_dim=params.embed_dim,
                     processor_blocks=params.processor_blocks,
                     n_states=params.n_states,
                     override_block=space_time_block,)
    return model

class XAViT(nn.Module):
    """
    Naive model that interweaves spatial and temporal attention blocks. Temporal attention 
    acts only on the time dimension. 

    Args:
        patch_size (tuple): Size of the input patch
        embed_dim (int): Dimension of the embedding
        processor_blocks (int): Number of blocks (consisting of spatial mixing - temporal attention)
        n_states (int): Number of input state variables.  
    """
    def __init__(self, patch_size=(16, 16), embed_dim=768, processor_blocks=8, n_states=6,
                 override_block=None, drop_path=.2):
        super().__init__()
        self.drop_path = drop_path
        self.dp = np.linspace(0, drop_path, processor_blocks)
        self.space_bag = SubsampledLinear(n_states, embed_dim//4)
        self.embed = hMLP_stem(patch_size=patch_size, in_chans=embed_dim//4, embed_dim=embed_dim)

        # Default to factored spacetime block with default settings (space/time axial attention)
        if override_block is not None:
            inner_block = override_block
        else:
            inner_block = partial(SpaceTimeBlock, hidden_dim=embed_dim)
        self.blocks = nn.ModuleList([inner_block(drop_path=self.dp[i])
                                     for i in range(processor_blocks)])
        self.debed = hMLP_output(patch_size=patch_size, embed_dim=embed_dim, out_chans=n_states)
        self.out_linear = UpsampledLinear(embed_dim//4, n_states)
        
        # for efficient caching
        # self.forward1d = torch.compile(self._forward, dynamic=True) 
        # self.forward2d = torch.compile(self._forward, dynamic=True) 
        # self.forward3d = torch.compile(self._forward, dynamic=True) 
        self.forward1d = self._forward
        self.forward2d = self._forward
        self.forward3d = self._forward

    def expand_projections(self, expansion_amount):
        """ Appends addition embeddings for finetuning on new data """
        with torch.no_grad():
            # Expand input projections
            temp_space_bag = SubsampledLinear(dim_in = self.space_bag.dim_in + expansion_amount, dim_out=self.space_bag.dim_out)
            temp_space_bag.weight[:, :self.space_bag.dim_in] = self.space_bag.weight
            temp_space_bag.bias[:] = self.space_bag.bias[:]
            self.space_bag = temp_space_bag
            # expand output projections
            out_head = nn.ConvTranspose2d(self.debed.embed_dim//4, self.debed.out_chans+expansion_amount, kernel_size=4, stride=4)
            temp_out_kernel = out_head.weight
            temp_out_bias = out_head.bias
            temp_out_kernel[:, :self.debed.out_chans, :, :] = self.debed.out_kernel
            temp_out_bias[:self.debed.out_chans] = self.debed.out_bias
            self.debed.out_kernel = nn.Parameter(temp_out_kernel)
            self.debed.out_bias = nn.Parameter(temp_out_bias)

    def freeze_middle(self):
        # First just turn grad off for everything
        for param in self.parameters():
            param.requires_grad = False
        # Activate for embed/debed layers
        for param in self.space_bag.parameters():
            param.requires_grad = True
        self.debed.out_kernel.requires_grad = True
        self.debed.out_bias.requires_grad = True
    
    def freeze_processor(self):
        # First just turn grad off for everything
        for param in self.parameters():
            param.requires_grad = False
        # Activate for embed/debed layers
        for param in self.space_bag.parameters():
            param.requires_grad = True
        for param in self.debed.parameters():
            param.requires_grad = True
        for param in self.embed.parameters():
            param.requires_grad = True

    def unfreeze(self):
        for param in self.parameters():
            param.requires_grad = True

    def forward(self, x, *args, **kwargs):
        T, B, C, *H = x.shape
        D = len(H)
        if D == 1:
            return self.forward1d(x, *args, **kwargs)
        elif D == 2:
            return self.forward2d(x, *args, **kwargs)
        elif D == 3:
            return self.forward3d(x, *args, **kwargs)
        else:
            raise ValueError(f"Input tensor has {D} spatial dimensions, but expected 1D, 2D, or 3D tensor.")

    def _forward(self, x, state_labels, bcs):
        T, B, C, *H = x.shape
        D = len(H)
        # Normalize (time + space per sample)
        with torch.no_grad():
            axes = tuple(-i for i in range(D,0,-1))
            data_std, data_mean = torch.std_mean(x, dim=(0, *axes), keepdims=True)
            data_std = data_std + 1e-7 # Orig 1e-7
        x = (x - data_mean) / (data_std)

        # Sparse proj
        x = rearrange(x, 't b c ... -> t b ... c')
        x = self.space_bag(x, state_labels) # dimension-free

        # Encode
        x = rearrange(x, 't b ... c -> (t b) c ...')
        x_list = [torch.swapaxes(x.clone(), i+2, -1) for i in range(D)]
        x_list = self.embed(x_list) # dimension-dependent, lifting
        x_list = [rearrange(x, '(t b) c ... -> t b c ...', t=T) for x in x_list]

        # Process
        for blk in self.blocks:
            x_list = blk(x_list, bcs)
        
        # Decode - It would probably be better to grab the last time here since we're only
        # predicting the last step, but leaving it like this for compatibility to causal masking
        x_list = [rearrange(x, 't b c ... -> (t b) c ...') for x in x_list]
        x_list = self.debed(x_list, state_labels[0])

        # aggregation
        x_list = [torch.swapaxes(x_list[i], -1, i+2) for i in range(D)]
        x = torch.stack(x_list, dim=0)
        x = torch.max(x, dim=0)[0]

        x = rearrange(x, '(t b) c ... -> t b c ...', t=T)

        # Denormalize 
        x = x * data_std + data_mean # All state labels in the batch should be identical
        return x[-1] # Just return last step - now just predict delta.



if __name__ == '__main__':
    print(torch.cuda.is_available())
    model = XAViT().cuda()
    # model.expand_projections(2)
    for n, p in model.debed.named_parameters():
        print(n, p.shape)
    model.expand_projections(2)
    for n, p in model.debed.named_parameters():
        print(n, p.shape)
    T = 10
    bs = 4
    nx = 128
    ny = 128
    x = torch.randn(T, bs, 2,  nx, ny).cuda()
    print('xshape', x.shape)
    labels = [0, 1]
    y = model(x, labels)
    print('yshape', y.shape)


