""" AViT model used in MPP for generating next frame 
Mostly inspired from https://github.com/PolymathicAI/multiple_physics_pretraining."""
from typing import Tuple
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange

from src.models.tokenizer import CNN
from src.models.attention import SpaceTimeBlock
from src.utils.database import standardize


class AViT(nn.Module):
    """

    Args:
        patch_size (tuple): Size of the input patch
        embed_dim (int): Dimension of the embedding
        processor_blocks (int): Number of blocks (consisting of spatial mixing - temporal attention)
        n_states (int): Number of input state variables.  
    """
    def __init__(self, 
        tokenizer_type: str, 
        padding_mode: str,
        in_channels: int,
        spatial_ndims: int,
        patch_size: Tuple, 
        embed_dim: int,  
        num_heads: int,
        processor_blocks: int, 
        bias_type: str,
        drop_path: float,
        no_decoder: bool = False,
        mpp_norm: bool = False,
        finetune: bool = False,
    ):
        super().__init__()
        self.drop_path = drop_path
        self.dp = np.linspace(0, drop_path, processor_blocks)

        # tokenizer
        if tokenizer_type == "fold_unfold":
            self.tokenizer = FoldUnFold(patch_size[0])  # for 1D data
        elif tokenizer_type == "CNN":
            self.tokenizer = CNN(
                in_channels, embed_dim, spatial_ndims=spatial_ndims, 
                groups=12, padding_mode=padding_mode, customize=spatial_ndims is None, 
                finetune=finetune
            )
        else: 
            raise ValueError(f"Unknown tokenizer {tokenizer_type}")

        # additional linear layers that are learned per-dataset in MPP
        if mpp_norm:
            n_states = 12  # see MPP implementation
            scale = (n_states / in_channels)**.5
            self.tokenizer.encoder[0].weight.data = scale * self.tokenizer.encoder[0].weight.data
            self.tokenizer.encoder[0].bias.data = scale * self.tokenizer.encoder[0].bias.data
            
        self.no_decoder = no_decoder  # deprecated

        # attention layers
        self.blocks = nn.ModuleList([
            SpaceTimeBlock(dim=embed_dim, num_heads=num_heads, bias_type=bias_type, drop_path=dp)
            for dp in self.dp
        ])

    def freeze_transformer(self):
        for param in self.blocks.parameters():
            param.requires_grad = False

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

    def unfreeze(self):
        for param in self.parameters():
            param.requires_grad = True

    def forward_single_step(self, 
        x, 
        predict_normed=False,
        state_labels=None,
        dset_name: str | None = None
    ):
        # dimensions
        B = x.shape[0]
        spatial_dims = tuple(range(3,x.squeeze(-1,-2).ndim))

        # preprocess
        x, mean, std = standardize(x, dims=(1,*spatial_dims), return_stats=True)
        metadata = {'mean': mean, 'std': std}

        # encode
        x = rearrange(x, 'b t c h w -> (b t) c h w')
        x = self.tokenizer.encode(x, state_labels)
        x = rearrange(x, '(b t) c h w -> b t c h w', b=B)

        # attention layers
        for blk in self.blocks:
            x, _ = blk(x, return_att=False)

        x = x[:,-1:,...]

        # decode
        x = rearrange(x, 'b t c h w -> (b t) c h w')
        x = self.tokenizer.decode(x, state_labels)
        x = rearrange(x, '(b t) c h w -> b t c h w', b=B)

        if predict_normed:
            x = x * std + mean
        
        return x, metadata
    
    def forward(self,
        x: torch.Tensor,
        predict_normed: bool = False,
        n_future_steps: int = 1,
        state_labels: torch.Tensor | None = None,
        dset_name: str | None = None
    ):
        """ x is B T C H W """
        # first iteration: 
        out, metadata = self.forward_single_step(x, predict_normed=False, state_labels=state_labels, dset_name=dset_name)
        if n_future_steps == 1:
            if predict_normed:
                out = out * metadata['std'] + metadata['mean']
            return out, metadata
        # more iterations: rollout
        context = x.clone()
        spatial_dims = tuple(range(3,x.squeeze(-1,-2).ndim))
        context = standardize(context, dims=(1,*spatial_dims), return_stats=False)
        outputs = [out]
        for _ in range(n_future_steps-1):
            context = torch.cat([context[:,1:,...], out], dim=1)
            out, _ = self.forward_single_step(context, predict_normed=False, state_labels=state_labels, dset_name=dset_name)
            outputs.append(out)
        out = torch.cat(outputs, dim=1)
        if predict_normed:
            out = out * metadata['std'] + metadata['mean']
        return out, metadata
