from typing import Tuple, List
from abc import ABC, abstractmethod
import itertools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torchdiffeq import odeint_adjoint

from src.models.hypernetwork import Hypernetwork
from src.utils.database import standardize


PDEBENCH_SPECIFICS = {
    # c_in, spatial_ndims, padding_mode, ksize (for OML), scheme_type
    'swe': [1, 2, 'zeros', 5, 'VI'],
    'incompNS': [3, 2, 'zeros', 5, 'VII'],
    'compNS128': [4, 2, 'circular', 11, 'VI'],
    'compNS512': [4, 2, 'circular', 11, 'VI'],
    'compNS': [4, 2, 'circular', 11, 'VI'],
    'diffre2d': [2, 2, 'zeros', 5, 'VII'],
    'burgers': [1, 1, 'circular', 5, 'VI'],
}


def get_hyperparams_shapes(model_type, c_in, c_hidden, ksize, d): 
    """ obtain the shapes of the weights of the integrated network f_\theta """

    k_shape = (ksize,) * d
    ones = (1,) * d
    if model_type == "VI":
        hp_shapes = [
            # linear filtering
            (c_hidden//2, c_in, *k_shape),
            # adapter in
            (c_hidden, c_in+c_hidden//2, *ones),
            # mlps
            (2*c_hidden, c_hidden, *ones),
            (c_hidden, 2*c_hidden, *ones),
            (2*c_hidden, c_hidden, *ones),
            (c_hidden, 2*c_hidden, *ones),
            # adapter in out
            (c_in, c_hidden, *ones),
        ]
    elif model_type == "VII":
        hp_shapes = [
            # linear filtering
            (c_hidden//2, c_in+1, *k_shape),
            # adapter in
            (c_hidden, c_in+c_hidden//2, *ones),
            # mlps
            (2*c_hidden, c_hidden, *ones),
            (c_hidden, 2*c_hidden, *ones),
            (2*c_hidden, c_hidden, *ones),
            (c_hidden, 2*c_hidden, *ones),
            # adapter out
            (c_in, c_hidden, *ones),
        ]
    else:
        raise ValueError(f"model_type {model_type} not recognized")
    return hp_shapes


def create_frontier_mask(shape, dtype, device):
    # Create a tensor of zeros
    mask = torch.zeros(shape, dtype=dtype, device=device)
    
    # Set the frontier (boundary) to ones
    mask[...,0, :] = 1
    mask[...,-1, :] = 1
    mask[...,:, 0] = 1
    mask[...,:, -1] = 1
    mask -= mask.mean((-1,-2), keepdims=True)
    mask /= mask.std((-1,-2), keepdims=True)
    
    return mask


def conv(v: torch.Tensor, w: torch.Tensor, padding_mode: str, fc: bool = True):
    """ Convolution by kernel w for a 1D or 2D signal v """
    b = v.shape[0]
    spatial_ndims = v.ndim - 2
    p = (w.shape[-1]-1) // 2

    v = rearrange(v, 'b cin ... -> (b cin) ...')
    if padding_mode is not None:
        v = F.pad(v, (p,)*2*spatial_ndims, mode=padding_mode)
    groups = b if fc else w.shape[0]
    v = {1: F.conv1d, 2: F.conv2d}[spatial_ndims](v, w, padding=0, stride=1, groups=groups)
    v = rearrange(v, '(b cout) ... -> b cout ...', b=b)
    return v


def create_integrand(hyperparams, scheme_type, param_idx, param_shapes, padding_mode, spatial_ndims):
    """ Create the integrated network f_\theta """

    if padding_mode == "zeros": 
        padding_mode = "constant"

    # decompose the weights needed at different layers
    w = [
        hyperparams[:,i:j].reshape(-1, *sh[1:])   # mix batch and output channel dimension
        for (i,j), sh in zip(zip(param_idx[:-1], param_idx[1:]), param_shapes)
    ]

    if scheme_type == "VI":
        def integrand(t, v):
            """ a simple CNN with a first linear filtering followed by 1x1 convolutions """ 
            groups = 8

            # linear filtering 
            residual = v.clone()
            v = conv(v, w[0], padding_mode)
            v = torch.cat([residual, v], dim=1)
            # adapter in
            v = conv(v, w[1], padding_mode)
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            # mlp
            residual = v.clone()
            v = conv(v, w[2], padding_mode)
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            v = conv(v, w[3], padding_mode)
            v += residual
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            residual = v.clone()
            v = conv(v, w[4], padding_mode)
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            v = conv(v, w[5], padding_mode)
            v += residual
            # adapter out
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            v = conv(v, w[6], padding_mode)
            return v
    elif scheme_type == "VII":
        def integrand(t, v):
            """ a simple CNN with a first linear filtering followed by 1x1 convolutions
             which add a mask with the boundaries as input channel """ 
            b = v.shape[0]
            h = v.shape[2]
            groups = 8

            # linear filtering 
            residual = v.clone()
            mask = create_frontier_mask((b, 1, h, h), v.dtype, v.device)
            v = torch.cat([v, mask], dim=1)
            v = conv(v, w[0], padding_mode)
            v = torch.cat([residual, v], dim=1)
            # adapter in
            v = conv(v, w[1], padding_mode)
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            # mlp
            residual = v.clone()
            v = conv(v, w[2], padding_mode)
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            v = conv(v, w[3], padding_mode)
            v += residual
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            residual = v.clone()
            v = conv(v, w[4], padding_mode)
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            v = conv(v, w[5], padding_mode)
            v += residual
            # adapter out
            v = F.group_norm(v, groups, None, None, 1e-5)
            v = F.gelu(v)
            v = conv(v, w[6], padding_mode)
            return v
    else:
        raise ValueError(f"scheme_type {scheme_type} not recognized")
    return integrand


class Operator(nn.Module, ABC):
    """ Operator F(t,u,du,ddu) that defines a PDE
    \partial_t u = F(t,u,du,ddu). 
    """ 

    @abstractmethod
    def forward(self, 
        t: float, 
        u: torch.Tensor, 
        du: torch.Tensor, 
        Du: torch.Tensor,
        W: nn.Parameter,
    ): 
        pass


class ICneuralPDE(nn.Module):
    """ Meta-operator that combines 
    - a parameter network \theta = \psi(u{t+1},...,u{t+T})
    - an operator F_\theta(t, u, du, ddu)
    - an integrator
    """

    def __init__(self, 
        in_channels: int | None,
        spatial_ndims: int | None,
        embed_dim: int,
        num_heads: int,
        processor_blocks: int,
        bias_type: str,
        drop_path: float,
        scheme_type: str | None,
        timesteps_integrator: int,
        padding_mode: str | None,
        hidden_dim: int,
        ksize: int | None,
        finetune: bool = False,
    ):
        super().__init__()

        # the operator F_\theta
        self.timesteps_integrator = timesteps_integrator

        # the parameter network
        self.customize = in_channels is None
        if self.customize:
            # adapt the number of parameters to the dataset
            dataset_specifics = PDEBENCH_SPECIFICS.copy()
            if finetune:
                dataset_specifics['shearflow'] = [4, 2, 'circular', 5, 'VI']
                dataset_specifics['euler_multi_quadrants_periodicBC'] = [5, 2, 'circular', 5, 'VI']

            self.param_shapes = {
                k: get_hyperparams_shapes(stype, c_in, c_hidden=hidden_dim, ksize=ksize, d=d)
                for k, (c_in,d,_,ksize,stype) in dataset_specifics.items()
            }
            self.param_idx = {
                k: list(itertools.accumulate([0]+[np.prod(sh) for sh in pshape])) for k, pshape in self.param_shapes.items()
            }
            self.padding_mode = {
                k: pmode for k, (_,_,pmode,_,_) in dataset_specifics.items()
            }
            self.out_chans = {
                k: pidx[-1] for k, pidx in self.param_idx.items()
            }
            spatial_ndims = {
                k: d for k, (_,d,_,_,_) in dataset_specifics.items()
            }
            self.scheme_type = {
                k: stype for k, (_,_,_,_,stype) in dataset_specifics.items()
            }
        else:
            self.scheme_type = scheme_type
            self.param_shapes = get_hyperparams_shapes(scheme_type, in_channels, c_hidden=hidden_dim, ksize=ksize, d=spatial_ndims)
            self.param_idx = list(itertools.accumulate([0]+[np.prod(sh) for sh in self.param_shapes]))
            self.padding_mode = padding_mode
            self.out_chans = self.param_idx[-1]

        self.hypernetwork = Hypernetwork(
            in_chans=in_channels, out_chans=self.out_chans,
            embed_dim=embed_dim, spatial_ndims=spatial_ndims, padding_mode=self.padding_mode, groups=12, 
            processor_blocks=processor_blocks, drop_path=drop_path, 
            num_heads=num_heads, bias_type=bias_type, finetune=finetune
        )

    def forward_single_step(self, 
        x: torch.Tensor, 
        predict_normed: bool = False,
        state_labels: torch.Tensor | None = None,
        dset_name: str | None = None
    ) -> Tuple:
        """ x is B T C H W """
        spatial_dims = tuple(range(3,x.squeeze(-1,-2).ndim))

        # preprocess the context
        x, mean, std = standardize(x, dims=(1,*spatial_dims), return_stats=True)
        metadata = {'mean': mean, 'std': std}  # b t c h w
        
        # operator parameters (: numerical scheme parameters)
        hyperparams = self.hypernetwork(x, state_labels, dset_name)
        x = x.squeeze(-1,-2)
        
        if self.customize:
            integrand = create_integrand(hyperparams, self.scheme_type[dset_name], self.param_idx[dset_name], self.param_shapes[dset_name], self.padding_mode[dset_name], x.ndim-3)
        else:
            integrand = create_integrand(hyperparams, self.scheme_type, self.param_idx, self.param_shapes, self.padding_mode, x.ndim-3)
    
        x = x[:,-1,...]
        spatial_dims = tuple(range(2,x.ndim))

        # preprocess last step
        x, mean_t, std_t = standardize(x, dims=spatial_dims, return_stats=True)  # b c h (w)

        # integrate the integrand f_\theta with a fast and memory-efficient backward through an adjoint method
        t = torch.linspace(0, 1, self.timesteps_integrator, device=x.device)
        if t.shape[0] == 2:  # in that case, perform a simple forward, odeint_adjoint may have unstable gradients
            x = integrand(.0, x)
        else:
            x = odeint_adjoint(integrand, x, t=t, method="rk4", adjoint_params=(hyperparams,))[-1,:,:]
        x = x * std_t + mean_t
        x = x.unsqueeze(1)

        if x.ndim == 4:  # x is 1d
            x = x.unsqueeze(-1)

        if predict_normed:
            x = x * metadata['std'] + metadata['mean']

        return x, metadata

    def forward(self,
        x: torch.Tensor,
        predict_normed: bool = False,
        n_future_steps: int = 1,
        state_labels: torch.Tensor | None = None,
        dset_name: str | None = None
    ):
        """ x is B T C H W """
        # first iteration: 
        out, metadata = self.forward_single_step(x, predict_normed=False, state_labels=state_labels, dset_name=dset_name)
        if n_future_steps == 1:
            if predict_normed:
                out = out * metadata['std'] + metadata['mean']
            return out, metadata
        # more iterations: rollout
        context = x.clone()
        spatial_dims = tuple(range(3,x.squeeze(-1,-2).ndim))
        context = standardize(context, dims=(1,*spatial_dims), return_stats=False)
        outputs = [out]
        for _ in range(n_future_steps-1):
            context = torch.cat([context[:,1:,...], out], dim=1)
            out, _ = self.forward_single_step(context, predict_normed=False, state_labels=state_labels, dset_name=dset_name)
            outputs.append(out)
        out = torch.cat(outputs, dim=1)
        if predict_normed:
            out = out * metadata['std'] + metadata['mean']
        return out, metadata