import hydra

import torch
import torch.nn as nn
from typing import List

from src.models.components.mlp import MLP
from src.models.components.dec import ConvNet, ConsistencyDec
from src.models.components.fourier import FNO2d


class MLPConv(nn.Module):
    """
    Bundled MLPConv dynamic model
    """

    def __init__(self,
                 input_dim,
                 ctx_dim,
                 latent_dim,
                 bundling_k,
                 conv1_dim,
                 output_dim=None,
                 act='SiLU',
                 num_neurons: List[int] = [64, 32],
                 dt: float = 0.01):
        super().__init__()
        # mlp_out_dim = output_dim * latent_dim if output_dim is not None else input_dim * latent_dim
        self.mlp = MLP(input_dim * bundling_k + ctx_dim,
                       input_dim * latent_dim,
                       num_neurons=num_neurons,
                       hidden_act=act,
                       out_act=act)
        self.cnn = ConvNet(input_dim, latent_dim, conv1_dim, bundling_k, output_dim, act)
        self.dec = ConsistencyDec(dt)

        self.ctx_dim = ctx_dim
        self.latent_dim = latent_dim
        self.dt = dt

    def forward(self, ctx, us):
        bs = ctx.shape[0]
        xs = torch.cat([ctx, us.reshape(bs, -1)], dim=-1)
        xs = self.mlp(xs)

        xs = xs.reshape(bs, -1, self.latent_dim)
        ds = self.cnn(xs)  # [bs, K, dim]
        pred = self.dec(us[:, -1, :], ds)
        return pred


class MLPFourier(nn.Module):
    """
    Bundled MLPFourier dynamic model
    """

    def __init__(self,
                 mlp_config,
                 fno_config,
                 dt: float = 0.01,
                 **unused_kwargs):
        super().__init__()
        # mlp_out_dim = output_dim * latent_dim if output_dim is not None else input_dim * latent_dim
        self.mlp = MLP(**mlp_config)
        self.fno = FNO2d(**fno_config)
        self.dec = ConsistencyDec(dt)
        self.dt = dt

    def forward(self, ctx, us):
        us = us.squeeze(-3)  # squeeze the bundling dimension
        bs = ctx.shape[0]
        h, w = us.shape[-2:]
        # NOTE: this is a bit ugly, just flattening and concatenating, and then reshaping
        xs = torch.cat([ctx, us.flatten(start_dim=-2, end_dim=-1)], dim=-1)
        xs = self.mlp(xs).reshape(bs, -1, h, w)
        ds = self.fno(xs)  # [bs, K, c, h, w]
        pred = ds.squeeze(-3)
        # pred = self.dec(us[:, -1, ...], ds) # no need if we dont use bundling??
        return pred


class ContextFNO2d(nn.Module):
    """
    Bundled FNO model with context input
    Given a context vector it can be concatenated with the input as an additional channel
    Channel mixing in the first conv layer then deals with the context.
    NOTE: since the context is a fixed size vector, the model cannot handle
    """

    def __init__(self,
                 fno_config,
                 ctx_dim,
                 width=32,
                 height=32,
                 dt: float = 0.01,
                 residual: bool = False,
                 **unused_kwargs):
        super().__init__()
        fno_config['in_channels'] += 1  # context channel adds to input
        self.ctx_dec = nn.Linear(ctx_dim, width * height)

        # self.fno = hydra.utils.instantiate(fno_config, _recursive_=False)
        self.fno = FNO2d(**fno_config)

        self.dec = ConsistencyDec(dt)
        self.width = width
        self.height = height
        self.dt = dt
        self.residual = residual

    # Naming based on conventions in trainer
    def forward(self, ctx, us):
        x = us
        ctx = self.ctx_dec(ctx)
        ctx = ctx.reshape(-1, 1, self.width, self.height)
        x = torch.cat([x, ctx], dim=1)  # concatenate in channel dimension, could even be bundling dim
        out = self.fno(x)  # [bs, K, c, h, w]
        if self.residual:
            out = out + us
        # pred = self.dec(us[:, -1, ...], ds) # no need if we dont use bundling??
        return out


class HyperContextFNO2d(nn.Module):
    """
    Bundled FNO model with context input
    Given a context vector it can be concatenated with the input as an additional channel
    Channel mixing in the first conv layer then deals with the context.
    NOTE: since the context is a fixed size vector, the model cannot handle
    """

    def __init__(self,
                 fno_config,
                 ctx_dim,
                 width=32,
                 height=32,
                 dt: float = 0.01,
                 residual: bool = False,
                 **unused_kwargs):
        super().__init__()
        # fno_config['in_channels'] += 1  # context channel adds to input
        self.fno = hydra.utils.instantiate(fno_config, _recursive_=False)
        self.dec = ConsistencyDec(dt)
        self.width = width
        self.height = height
        self.dt = dt
        self.residual = residual

    # Naming based on conventions in trainer
    def forward(self, ctx, us):
        x = us
        out = self.fno(x, ctx)  # [bs, K, c, h, w]
        if self.residual:
            out = out + x
        return out