from dataclasses import replace
from functools import reduce
from operator import mul
from typing import Callable, List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from scOT.model import ScOT, ScOTConfig
from the_well.data.datasets import BoundaryCondition

from crps_retrofitting.models.shared_utils.flexi_utils import (
    choose_kernel_size_deterministic,
    choose_kernel_size_random,
)
from crps_retrofitting.models.shared_utils.normalization import RMSGroupNorm
from crps_retrofitting.models.shared_utils.patch_jitterers import (
    FixedPatchJittererBoundaryPad,
    PatchJittererBoundaryPad,
)


def dim_pad(x, max_d):
    """
    Assume T B C are first channels, then see how many spatial dims we need to append/
    """
    squeeze = 0
    if x.ndim - 3 < max_d:
        x = x.unsqueeze(-1)
        squeeze += 1
    if x.ndim - 3 < max_d:
        x = x.unsqueeze(-1)
        squeeze += 1
    return x, squeeze


class ScOTWrapper(nn.Module):
    def __init__(
        self,
        pretrained_name,
        image_size,
        num_channels=5,
        num_out_channels=5,
        in_timesteps=1,
        out_timesteps=1,
        n_states=1,
        delta_t=0.05,
        from_pretrained=True,
    ):
        super().__init__()
        self.num_channels = num_channels
        self.num_out_channels = num_out_channels
        self.delta_t = delta_t
        self.image_size = image_size
        print(f"Delta t set to {self.delta_t}")
        self.causal_in_time = False
        self.pretrained_name = pretrained_name
        if from_pretrained:
            config = ScOTConfig.from_pretrained(self.pretrained_name)
            pretrained_num_channels = config.num_channels
            pretrained_num_out_channels = config.num_out_channels

            config.num_channels = self.num_channels
            config.num_out_channels = self.num_out_channels

            self.sc_ot = ScOT.from_pretrained(
                self.pretrained_name, config=config, ignore_mismatched_sizes=True
            )

    def forward(
        self,
        x,
        state_labels,
        bcs,
        metadata,
        proj_axes=None,
        return_att=False,
        train=True,
    ):
        x = x[-1]

        # RUN MODEL
        preds = self.sc_ot(
            pixel_values=x, time=torch.tensor([self.delta_t], device=x.device)
        ).output

        preds = preds.unsqueeze(0)
        return preds
