import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath
from .utils import get_activation
from efficient_kan import KAN , KANLinear , ReLUKANLayer
# from fastkan import FastKAN as KAN
import timm
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from kan.MultKAN import MultKAN,KANLayer,KAN
import itertools
from random import randint


class Kan_Block(nn.Module):
    def __init__(
        self,
        in_features: int,
        hid_features: int,
        out_features: int,
        KAN_layers_num: int,
        activ="gelu",
        drop: float = 0.0,
    ):
        super().__init__()

        self.KAN_layers_num = KAN_layers_num
        self.in_features = in_features
        self.hid_features = hid_features
        self.out_features = out_features

        self.KAN_Layers = nn.ModuleList()
        input_dim = in_features
        output_dim = hid_features

        for _ in range(self.KAN_layers_num):
            if _ == self.KAN_layers_num - 1:
                output_dim = out_features
            else:
                output_dim = hid_features
            self.KAN_Layers.append(KANLayer(in_dim = input_dim, out_dim = output_dim, sp_trainable = True, sb_trainable = True))
            input_dim = output_dim

        self.jump_net_1 = nn.Linear(in_features, self.hid_features)  
        self.jump_net_2 = nn.Linear(self.hid_features, out_features)
        self.act = get_activation(activ)
        self.drop = DropPath(drop)
        self.norm = nn.LayerNorm(self.hid_features)
        self.norm_out = nn.LayerNorm(self.out_features)

    def forward(self, x):
        sp = x.shape
        if len(sp) == 3:
            d1 , d2 , d3 = sp
            x = x.reshape(d1 * d2 , -1)
        if len(sp) == 4:
            d1 , d2 , d3 , d4 = sp
            x = x.reshape(d1 * d2 * d3 , -1)

        for _ in range(self.KAN_layers_num):
            if _ == 0:
                x = self.norm(self.KAN_Layers[_](x)[0]) + self.jump_net_1(x)
            elif _ == self.KAN_layers_num - 1:
                x = self.norm_out(self.KAN_Layers[_](x)[0]) + self.jump_net_2(x)
            else:
                x = self.norm(self.KAN_Layers[_](x)[0]) + x
                x = self.drop(x)

        if len(sp) == 3:
            x = x.reshape(d1 ,  d2 , -1)
        if len(sp) == 4:
            x = x.reshape(d1 , d2 ,  d3 , -1)

        return x


class MLPBlock(nn.Module):

    def __init__(
        self,
        dim,
        in_features: int,
        hid_features: int,
        out_features: int,
        activ="gelu",
        drop: float = 0.0,
        jump_conn='trunc',
        use_kan = 0 , 
        KAN_layers_num = 0 , 
    ):
        super().__init__()


        self.dim = dim
        self.in_features = in_features
        self.hid_features = hid_features
        self.out_features = out_features
            
        if use_kan == 1:
            self.net = Kan_Block(in_features=in_features , hid_features=hid_features , out_features=out_features , KAN_layers_num = KAN_layers_num , activ = 'gelu' , drop = 0.1)
        else:
            self.net = nn.Sequential(
            nn.Linear(in_features, hid_features),
            get_activation(activ),
            nn.Linear(hid_features, out_features),
            DropPath(drop))

        if jump_conn == "trunc":
            self.jump_net = nn.Identity()
        elif jump_conn == 'proj':
            self.jump_net = nn.Linear(in_features, out_features)
        else:
            raise ValueError(f"jump_conn:{jump_conn}")

    def forward(self, x):
        
        if self.hid_features == 0:
            return x
        x = torch.transpose(x, self.dim, -1)
        x = self.jump_net(x)[..., :self.out_features] + self.net(x)
        x = torch.transpose(x, self.dim, -1)
        return x


class PatchEncoder(nn.Module):

    def __init__(
        self,
        in_len: int,
        hid_len: int,
        in_chn: int,
        hid_chn: int,
        out_chn,
        patch_size: int,
        hid_pch: int,
        KAN_layers_num: int , 
        norm=None,
        activ="gelu",
        drop: float = 0.0,
    ) -> None:
        super().__init__()
        self.net = nn.Sequential()
        self.patch_size = patch_size
        channel_wise_mlp = MLPBlock(1, in_chn, hid_chn, out_chn, activ, drop , use_kan = 1 , KAN_layers_num = KAN_layers_num)
        inter_patch_mlp = MLPBlock(2, in_len // patch_size, hid_len, in_len // patch_size , activ,
                         drop , use_kan = 1 , KAN_layers_num = KAN_layers_num)
        if norm == 'bn':
            norm_class = nn.BatchNorm2d
        elif norm == 'in':
            norm_class = nn.InstanceNorm2d
        else:
            norm_class = nn.Identity
        self.linear = KANLayer(in_dim = patch_size, out_dim = 1 , sp_trainable = True , sb_trainable = True)
        intra_patch_mlp = MLPBlock(3, patch_size, hid_pch, patch_size, activ, drop , use_kan = 1 , KAN_layers_num = KAN_layers_num)
        self.net.append(Rearrange("b c (l1 l2) -> b c l1 l2", l2=patch_size))
        self.net.append(norm_class(in_chn))
        self.net.append(channel_wise_mlp)
        self.net.append(norm_class(out_chn))
        self.net.append(inter_patch_mlp)
        self.net.append(norm_class(out_chn))
        self.net.append(intra_patch_mlp)

    def forward(self, x):
        y = self.net(x)
        dim1 , dim2 , dim3 , dim4 = y.shape
        y = y.reshape(-1 , self.patch_size)
        y = self.linear(y)[0]
        y = y.reshape(dim1 , dim2 , -1)
        return y


class PatchDecoder(nn.Module):

    def __init__(
        self,
        in_len: int,
        hid_len: int,
        in_chn: int,
        hid_chn: int,
        out_chn,
        patch_size: int,
        hid_pch: int,
        KAN_layers_num: int , 
        norm=None,
        activ="gelu",
        drop: float = 0.0,
    ) -> None:
        super().__init__()
        self.net = nn.Sequential()
        self.patch_size = patch_size
        inter_patch_mlp = MLPBlock(2, in_len // patch_size, hid_len, in_len // patch_size, activ,
                         drop , use_kan = 1 , KAN_layers_num = KAN_layers_num)
        channel_wise_mlp = MLPBlock(1, in_chn, hid_chn, out_chn, activ, drop , use_kan = 1 , KAN_layers_num = KAN_layers_num)
        if norm == 'bn':
            norm_class = nn.BatchNorm2d
        elif norm == 'in':
            norm_class = nn.InstanceNorm2d
        else:
            norm_class = nn.Identity

        self.linear = KANLayer(in_dim = 1, out_dim = patch_size , sp_trainable = True , sb_trainable = True)

        intra_patch_mlp = MLPBlock(3, patch_size, hid_pch, patch_size, activ, drop , use_kan = 1 , KAN_layers_num = KAN_layers_num)
        self.net.append(norm_class(in_chn))
        self.net.append(intra_patch_mlp)
        self.net.append(norm_class(in_chn))
        self.net.append(inter_patch_mlp)
        self.net.append(norm_class(in_chn))
        self.net.append(channel_wise_mlp)

    def forward(self, x):
        dim1 , dim2 , dim3 = x.shape
        x = x.reshape((dim1 * dim2 * dim3) , 1)
        x = self.linear(x)[0]

        x = x.reshape(dim1 , dim2 , dim3 , -1)
        y = self.net(x)
        y = y.reshape(dim1 , dim2 , -1)
        return y


class PredictionHead(nn.Module):

    def __init__(self,
                 in_len,
                 out_len,
                 hid_len,
                 in_chn,
                 out_chn,
                 hid_chn,
                 activ,
                 drop=0.0) -> None:
        super().__init__()
        self.net = nn.Sequential()
        if in_chn != out_chn:
            c_jump_conn = "proj"
        else:
            c_jump_conn = "trunc"
        self.net.append(
            MLPBlock(1,
                in_chn,
                hid_len,
                out_chn,
                activ=activ,
                drop=drop,
                jump_conn=c_jump_conn))
        self.net.append(
            MLPBlock(2,
                in_len,
                hid_len,
                out_len,
                activ=activ,
                drop=drop,
                jump_conn='proj'))

    def forward(self, x):
        return self.net(x)


class TimeKAN(nn.Module):

    def __init__(self,
                 in_len,
                 out_len,
                 in_chn,
                 ex_chn,
                 out_chn,
                 patch_sizes,
                 hid_len,
                 hid_chn,
                 hid_pch,
                 hid_pred,
                 norm,
                 last_norm,
                 activ,
                 drop,
                 Patch_num,
                 MST_num,
                 KAN_layers_num , 
                 reduction="sum" , 
                 ) -> None:
        super().__init__()
        self.in_len = in_len
        self.out_len = out_len
        self.in_chn = in_chn
        self.out_chn = out_chn
        self.last_norm = last_norm
        self.reduction = reduction
        self.patch_encoders = nn.ModuleList()
        self.patch_decoders = nn.ModuleList()
        self.pred_heads = nn.ModuleList()
        self.Patch_num = Patch_num

        self.MST_num = MST_num
        self.MST_patch_sizes = []
        self.MST_paddings = []
        self.MST_encoder = nn.ModuleList()
        self.MST_decoder = nn.ModuleList()
        self.MST_Forecast_len = []
        self.MST_Forecast = nn.ModuleList()

        self.patch_sizes = patch_sizes
        all_permutations = list(itertools.permutations(self.patch_sizes))
        self.MST_patch_sizes.append(self.patch_sizes[:Patch_num])
        # print(self.patch_sizes[:Patch_num])
        for _ in range(self.MST_num - 1): self.MST_patch_sizes.append(all_permutations[randint(0 , len(all_permutations))][:Patch_num])
        for _ in range(self.MST_num): self.MST_encoder.append(nn.ModuleList())
        for _ in range(self.MST_num): self.MST_decoder.append(nn.ModuleList())
        for _ in range(self.MST_num): self.MST_paddings.append([])
        for _ in range(self.Patch_num): self.MST_Forecast_len.append(0)
        all_chn = in_chn + ex_chn

        for _ in range(self.Patch_num):
            for __ in range(self.MST_num):
                patch_size = self.MST_patch_sizes[__][_]
                res = in_len % patch_size
                padding = (patch_size - res) % patch_size

                self.MST_paddings[__].append(padding)
                padded_len = in_len + padding

                self.MST_encoder[__].append(
                    PatchEncoder(padded_len, hid_len, all_chn, hid_chn,
                            in_chn, patch_size, hid_pch , KAN_layers_num, norm, activ, drop))
                self.MST_decoder[__].append(
                    PatchDecoder(padded_len, hid_len, in_chn, hid_chn, in_chn,
                            patch_size, hid_pch , KAN_layers_num, norm, activ, drop))
                self.MST_Forecast_len[_] = self.MST_Forecast_len[_] + padded_len // patch_size

            if out_len != 0 and out_chn != 0:
                self.MST_Forecast.append(
                    PredictionHead(self.MST_Forecast_len[_], out_len, hid_pred,
                                    in_chn, out_chn, hid_chn, activ, drop))
            else:
                self.MST_Forecast.append(nn.Identity())


    def forward(self, x, x_mark=None, x_mask=None):
        x = rearrange(x, "b l c -> b c l")
        if x_mark is not None:
            x_mark = rearrange(x_mark, "b l c -> b c l")
        if x_mask is not None:
            x_mask = rearrange(x_mask, "b l c -> b c l")
        if self.last_norm:
            x_last = x[:, :, [-1]].detach()
            x = x - x_last
            if x_mark is not None:
                x_mark_last = x_mark[:, :, [-1]].detach()
                x_mark = x_mark - x_mark_last
        y_pred = []
        for i in range(self.Patch_num):
            
            Emb = []
            Comp = 0
            for __ in range(self.MST_num):
                x_in = x
                if x_mark is not None:
                    x_in = torch.cat((x, x_mark), 1)
                
                x_in = F.pad(x_in, (self.MST_paddings[__][i], 0), "constant", 0)

                emb = self.MST_encoder[__][i](x_in)
                comp = self.MST_decoder[__][i](emb)[:, :, self.MST_paddings[__][i]:]
                Emb.append(emb)
                Comp += comp

            Eb = torch.cat(Emb, dim=2)
            # print(Eb.shape)
            pred = self.MST_Forecast[i](Eb)

            if x_mask is not None:
                Comp = Comp * x_mask
            x = x - Comp
            if self.out_len != 0 and self.out_chn != 0:
                y_pred.append(pred)


        if self.out_len != 0 and self.out_chn != 0:

            y_pred = reduce(torch.stack(y_pred, 0), "h b c l -> b c l",
                self.reduction)
            if self.last_norm and self.out_chn == self.in_chn:
                y_pred += x_last
            y_pred = rearrange(y_pred, "b c l -> b l c")
            return y_pred, x
        else:
            return None, x
