import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import models_vit
from .utils import safe_resize
from .RevIN import RevIN
import einops
from PIL import Image
from .swin_transformer import swin_t
from torchvision.models import Swin_T_Weights

ViT_MODELS = {
    'vit_base': [models_vit.vit_base_patch16, "vit_base.pth"],
    'vit_large': [models_vit.vit_large_patch16, "vit_large.pth"],
    'vit_huge': [models_vit.vit_huge_patch14, "vit_huge.pth"],
    'swin_tiny': [swin_t, Swin_T_Weights.DEFAULT],
}


class Flatten_Head(nn.Module):
    def __init__(
        self,
        n_vars,
        head_dim,
        pred_len,
        head_dropout=0,
        vm_dim=768,
        patch_num=197
    ):
        super().__init__()

        self.n_vars = n_vars

        self.dimension_reduction = nn.Linear(vm_dim, head_dim)

        self.linears = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        self.flattens = nn.ModuleList()
        for i in range(self.n_vars):
            self.flattens.append(nn.Flatten(start_dim=-2))
            self.linears.append(nn.Linear(head_dim * patch_num, pred_len))
            self.dropouts.append(nn.Dropout(head_dropout))

    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]
        x = self.dimension_reduction(x)  
        x_out = []
        x = einops.rearrange(x, 'b n d p -> n b d p')
        for i in range(self.n_vars):
            z = x[i]
            z = self.flattens[i](z)  # z: [bs x d_model * patch_num]
            z = self.linears[i](z)  # z: [bs x target_window]
            z = self.dropouts[i](z)
            x_out.append(z)
        x = torch.stack(x_out, dim=1)  # x: [bs x nvars x target_window]
        x = einops.rearrange(x, 'b n l -> b l n')
        return x

class UTSRegression(nn.Module):
    def __init__(self, config):
        super(UTSRegression, self).__init__()
        self.ckpt = config.ckpt
        self.load_ckpt = config.load_ckpt
        self.vm = ViT_MODELS[config.backbone][0]()
        self.rev_in = RevIN(config.num_features)

        if self.load_ckpt:
            if config.backbone != 'swin_tiny': # vit case
                ckpt_path = os.path.join(self.ckpt, ViT_MODELS[config.backbone][1])
                if not os.path.isfile(ckpt_path):
                    assert False, f"Checkpoint file {ckpt_path} not found. "
                try:
                    checkpoint = torch.load(ckpt_path, map_location='cpu')
                    self.vm.load_state_dict(checkpoint['model'], strict=False)
                except:
                    print(f"Bad checkpoint file. Please delete {ckpt_path} and redownload!")
            else: # swin case
                weight = ViT_MODELS[config.backbone][1]
                self.vm.load_state_dict(weight.get_state_dict(progress=True), strict=False)
                self.vm_transform = weight.transforms()
        
        self.pad_left = 0
        if config.context_len % config.periodicity != 0:
            self.pad_left = config.periodicity - config.context_len % config.periodicity
        self.periodicity = config.periodicity

        self.interpolation = {
            "bilinear": Image.BILINEAR,
            "nearest": Image.NEAREST,
            "bicubic": Image.BICUBIC,
        }[config.interpolation]

        self.resize = safe_resize(self.vm.patch_embed.img_size, interpolation=self.interpolation)

        self.image_size = self.vm.patch_embed.img_size[0]
        self.patch_size = self.vm.patch_embed.patch_size[0]
        self.num_patch = self.image_size // self.patch_size


        if config.backbone == 'swin_tiny':
            self.head = Flatten_Head(
                n_vars=config.num_features,
                head_dim=config.head_dim,
                pred_len=config.pred_len,
                head_dropout=config.head_dropout,
                vm_dim=config.vm_dim,
                patch_num=self.vm.window_size[0] * self.vm.window_size[1]
            )
        else:
            self.head = Flatten_Head(
                n_vars=config.num_features,
                head_dim=config.head_dim,
                pred_len=config.pred_len,
                head_dropout=config.head_dropout,
                vm_dim=config.vm_dim,
                patch_num=self.num_patch * self.num_patch + 1
            )

    def _vm_forward(self, x):
        # x of size (B, C, H, W)
        x = self.vm.forward_features(x)
        return x
    
    def _time2img(self, x):
        x = einops.rearrange(x, "b l n -> b n l")
        x_pad = F.pad(x, (self.pad_left, 0), mode='replicate') # [b n l']
        x_2d = einops.rearrange(x_pad, 'b n (p f) -> (b n) 1 f p', f=self.periodicity)
        x_resize = self.resize(x_2d)
        x_input = einops.repeat(x_resize, 'b 1 h w -> b c h w', c=3)
        return x_input
        

    
    def forward(self, x): # x of size (b, l, n)
        B = x.shape[0]
        # first nomalize
        x = self.rev_in(x, mode="norm")
        # then resize and interpolate x to 224x224, copy 3 channels
        x = self._time2img(x) # (bn, c, h, w)
        # then forward x to the vision model
        x = self._vm_forward(x) # (bn, image_patch, vm_dim)
        x = einops.rearrange(x, "(b n) p d -> b n p d", b=B)
        # finally use a linear layer to regress the output
        x = self.head(x)
        # denormalize x (b, l, n)
        x = self.rev_in(x, mode="denorm")
        return x