__all__ = ['PatchTST_backbone']

# Cell
from typing import Callable, Optional
import torch, time
from torch import nn
from torch import Tensor
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
#from collections import OrderedDict
from models.utils.PatchTST_layers import *
from models.utils.RevIN import RevIN
import math
 
global_bias = False

class Model(nn.Module):
    def __init__(self, configs):
        super().__init__()
        
        c_in = configs.enc_in
        context_window = configs.seq_len
        target_window = configs.pred_len
        patch_len = configs.patch_len
        stride = configs.stride
        padding_patch = "end"
        
        n_layers = configs.e_layers
        n_heads = configs.n_heads
        d_model = configs.d_model
        d_ff = configs.d_ff
        dropout = configs.dropout
        self.adaptive_dilated_atten = configs.adaptive_dilated_atten
        self.no_dilated_atten = configs.no_dilated_atten
        self.multi_query = configs.multi_query
        self.comp_dim = configs.comp_dim
        self.rep_num = configs.rep_num

        
        self.revin = True
        self.split_num = configs.split_num
        
        

        # RevIn
        if self.revin: self.revin_layer = RevIN(c_in)
        
        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch = padding_patch #= "nedn"
        patch_num = int((context_window - patch_len)/stride + 1)
        
        if self.adaptive_dilated_atten:
            stan = (np.log2(patch_num) // 2)  # only available for 2 ** n
            if n_layers % 2 == 1:
                start_block_num     = int(2 ** (stan - (n_layers // 2)))
            else:
                start_block_num     = int(2 ** (stan - ((n_layers - 1) // 2)))
            
            dilated_attn  = [(start_block_num * (2 ** i), (patch_num // start_block_num) // (2 ** i))  for i in range(n_layers)]
        else:
            block_num   = int(2 ** (np.log2(patch_num) // 2))  # only available for 2 ** n
            block_size  = int(patch_num // block_num)
            dilated_attn = [(block_num, block_size) for i in range(n_layers)]
        if self.no_dilated_atten:
            dilated_attn = [None] * n_layers


        print(dilated_attn)
        # Backbone 
        self.backbone = TSTiEncoder(c_in, dilated_attn, patch_num=patch_num, patch_len=patch_len, n_layers=n_layers, d_model=d_model,
                                     n_heads=n_heads, d_ff=d_ff, dropout=dropout, attn_dropout = dropout, split_num = self.split_num, act = 'gelu', multi_query = self.multi_query)
        
        # if self.adaptive_dilated_atten:
            
                                
                                

        # Head
        self.n_vars = c_in        
        if self.comp_dim == d_model:
            self.comp_layer = None
            self.head_nf = d_model * patch_num
            self.head = nn.Linear(self.head_nf, target_window, bias = global_bias)#Flatten_Head(False, self.n_vars, self.head_nf, target_window, head_dropout=0)
        else:
            self.comp_layer = nn.Linear(d_model, self.comp_dim, bias = global_bias)
            self.head_nf = self.comp_dim * patch_num
            self.head = nn.Linear(self.head_nf, target_window, bias = global_bias)



        
        
        
    def forward(self, z, idx = None, rep_num = None):
        if self.training:
            return self.inner_forward(z,idx)
        else:
            batch_size = z.shape[0]
            z_lst, time_lst = [], []
            rep_num = self.rep_num if rep_num is None else rep_num
            for i in range(rep_num):
                main_idx2, real_idx2, addi_num = self.feature_block_sampling(batch_size)
                main_idx1 = torch.arange(batch_size).unsqueeze(-1).repeat(1,main_idx2.shape[-1])
                real_idx1 = torch.arange(batch_size).unsqueeze(-1).repeat(1,real_idx2.shape[-1])
                newz = z.transpose(1,2)[main_idx1, main_idx2].transpose(1,2)
                ret_z, time = self.inner_forward(newz, main_idx2)
                ret_z = ret_z[...,:self.n_vars].transpose(1,2)[real_idx1, real_idx2.sort(-1)[1]].transpose(1,2)
                z_lst.append(ret_z)
                time_lst.append(time)
            return sum(z_lst) / rep_num, sum(time_lst) / rep_num

            
    def feature_block_sampling(self, batch_size):
        rest_num = self.n_vars % self.split_num
        addi_num = self.split_num- self.n_vars % self.split_num
        rand_idx1 = torch.stack([torch.randperm(self.n_vars) for i in range(batch_size)], dim = 0)
        if addi_num != 0:
            rand_idx2 = torch.stack([torch.randperm(self.n_vars - rest_num)[:addi_num] for i in range(batch_size)], dim = 0)
            main_idx = torch.cat([rand_idx1, rand_idx2], dim=-1)
        return main_idx, rand_idx1, addi_num


    def inner_forward(self, z, idx = None):                                                                   # z: [bs x nvars x seq_len]
        
        
        if self.revin: 
            z = self.revin_layer(z, 'norm')
        z = z.permute(0,2,1)
        
            
        # do patching
        z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride)                   # z: [bs x nvars x patch_num x patch_len]
        
        # model
        start_time = time.time()
        z = self.backbone(z, idx, self.split_num)                                                                # z: [bs x nvars x d_model x patch_num]
        end_time = time.time()

        if self.comp_layer is not None:
            z = self.comp_layer(z) 
        z = self.head(z.flatten(-2))                                                                    # z: [bs x nvars x target_window] 
        # denorm
        z = z.permute(0,2,1)
        if self.revin: 
            z = self.revin_layer(z, 'denorm')
        return z, torch.tensor([end_time - start_time]).to(z)
    
class TSTiEncoder(nn.Module):  #i means channel-independent
    def __init__(self, c_in, dilated_attn, patch_num, patch_len, n_layers=3, d_model=128, n_heads=16, 
                 d_k=None, d_v=None, d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", 
                 store_attn=False, res_attention=True, pre_norm=False, pe='zeros', learn_pe=True, split_num = 0, multi_query = True):
        
        super().__init__()
        
        self.patch_num = patch_num
        self.patch_len = patch_len
        
        # Input encoding
        q_len = patch_num
        self.W_P = nn.Linear(patch_len, d_model,bias = global_bias)        # Eq 1: projection of feature vectors onto a d-dim vector space
        self.seq_len = q_len
 
        # Positional encoding
        self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model)
        self.C_pos = positional_encoding(pe, learn_pe, c_in, d_model)
        self.dropout = nn.Dropout(dropout)

        # Encoder
        self.encoder = TSTEncoder(dilated_attn, q_len, c_in, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
                                   pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn, split_num = split_num, multi_query = multi_query)

        
    def forward(self, x, idx = None, split_num = None) -> Tensor:                                              # x: [bs x nvars x patch_num x patch_len]
        C_pos = self.C_pos[idx].unsqueeze(-2)
        x = self.W_P(x)                                                          # x: [bs x nvars x patch_num x d_model]
        u = self.dropout(x + self.W_pos + C_pos)
        z = self.encoder(u, split_num, idx)                                                      # z: [bs * nvars x patch_num x d_model]
        return z    
            
            
    
# Cell
class TSTEncoder(nn.Module):
    def __init__(self, dilated_attn, q_len, c_in, d_model, n_heads, d_k=None, d_v=None, d_ff=None, 
                        norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
                        res_attention=False, n_layers=1, pre_norm=False, store_attn=False, split_num = 0, multi_query = True):
        super().__init__()

        self.layers1 = nn.ModuleList([CrossTimeFeatLayer(c_in, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
                                                      attn_dropout=attn_dropout, dropout=dropout,
                                                      activation=activation, dilated_attn = dilated_attn[i], multi_query = multi_query) for i in range(n_layers)])
        
        self.n_layers = n_layers
        self.split_num = split_num

    def forward(self, src:Tensor, split_num = None, idx=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
        output = src
        attn_rank = None
        for i in range(self.n_layers):
            output, attn_rank = self.layers1[i](output, self.split_num, idx, attn_rank)
        return output


class CrossTimeFeatLayer(nn.Module):
    def __init__(self, c_in, d_model, n_heads, d_k=None, d_v=None, d_ff=256, attn_dropout=0, dropout=0., bias=True, activation="gelu", dilated_attn = None, multi_query = True):
        super().__init__()
        assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
        d_k = d_model // n_heads if d_k is None else d_k
        d_v = d_model // n_heads if d_v is None else d_v
        
        self.multi_query = multi_query
        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v

        if self.multi_query:
            self.W_Q = nn.Linear(d_model, d_k * n_heads, bias = global_bias)
            self.W_Q_feat = nn.Linear(d_model, d_k * n_heads, bias = global_bias)
            self.W_K = nn.Linear(d_model, d_k, bias = global_bias)
            self.W_V = nn.Linear(d_model, d_v, bias = global_bias)
        else:
            self.W_Q = nn.Linear(d_model, d_k * n_heads, bias = global_bias)
            self.W_Q_feat = nn.Linear(d_model, d_k * n_heads, bias = global_bias)
            self.W_K = nn.Linear(d_model, d_k * n_heads, bias = global_bias)
            self.W_V = nn.Linear(d_model, d_v * n_heads, bias = global_bias)

        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model, bias = global_bias), nn.Dropout(dropout))
        self.scale1 = nn.Parameter(torch.tensor(d_k ** -0.5), requires_grad=True)
        self.scale2 = nn.Parameter(torch.tensor(d_k ** -0.5), requires_grad=True)
        self.scale3 = nn.Parameter(torch.tensor(d_k ** -0.5), requires_grad=True)
        # Add & Norm
        self.is_electricity = (c_in == 321)
        if self.is_electricity:
            attn_dropout = 0
        self.dropout_attn = nn.Dropout(dropout)
        self.attn_dropout1 = nn.Dropout(attn_dropout)
        self.attn_dropout2 = nn.Dropout(attn_dropout)
        self.attn_dropout3 = nn.Dropout(attn_dropout)
        
        if not self.is_electricity:
            tmp = torch.zeros(c_in, n_heads, d_model // n_heads)
            self.feat_center = nn.Parameter(tmp)
            torch.nn.init.kaiming_uniform_(self.feat_center, a=math.sqrt(5))

        self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias = global_bias),
                                get_activation_fn(activation),
                                nn.Dropout(dropout),
                                nn.Linear(d_ff, d_model, bias = global_bias))
        
        self.dropout_ffn = nn.Dropout(dropout)

        self.dilated_attn = dilated_attn

    def forward(self, src:Tensor, split_num, idx, attn_rank):

        batch_size, c_size, seq_size, d_model = src.size()
        
        # Linear (+ split in multiple heads)
        if self.multi_query:
            q_s = self.W_Q(src).view(batch_size, c_size, seq_size, self.n_heads, self.d_k)
            k_s = self.W_K(src).view(batch_size, c_size, seq_size, self.d_k).unsqueeze(-2)#.repeat(1,1,1,self.n_heads,1)
            v_s = self.W_K(src).view(batch_size, c_size, seq_size, self.d_v).unsqueeze(-2)#.repeat(1,1,1,self.n_heads,1)
        else:
            q_s = self.W_Q(src).view(batch_size, c_size, seq_size, self.n_heads, self.d_k)
            k_s = self.W_K(src).view(batch_size, c_size, seq_size, self.n_heads, self.d_k)
            v_s = self.W_K(src).view(batch_size, c_size, seq_size, self.n_heads, self.d_v)

        if self.dilated_attn is None:
            tmp_q = rearrange(q_s, "b c s h d -> b c h s d")
            tmp_k = rearrange(k_s, "b c s h d -> b c h d s")
            v_s   = rearrange(v_s, "b c s h d -> b c h s d")
            attn_scores = torch.matmul(tmp_q, tmp_k) * self.scale1
            attn_weights = F.softmax(attn_scores, dim=-1)               
            attn_weights = self.attn_dropout1(attn_weights)
            v_s_n = torch.matmul(attn_weights, v_s)
            if split_num != 1:
                if not self.is_electricity:
                    v_s = v_s + self.dropout_attn(v_s_n)
                else:
                    v_s = v_s + v_s_n
            else:
                v_s = v_s_n
        else:

            tmp_q = rearrange(q_s, "b c (block_num block_size) h d -> b c h block_num block_size d", block_num = self.dilated_attn[0] )
            tmp_k = rearrange(k_s, "b c (block_num block_size) h d -> b c h block_num d block_size", block_num = self.dilated_attn[0] )
            v_s   = rearrange(v_s, "b c (block_num block_size) h d -> b c h block_num block_size d", block_num = self.dilated_attn[0] )
            attn_scores = torch.matmul(tmp_q, tmp_k) * self.scale1
            attn_weights = F.softmax(attn_scores, dim=-1)               
            attn_weights = self.attn_dropout1(attn_weights)
            v_s_n = torch.matmul(attn_weights, v_s)
            if not self.is_electricity:
                v_s = v_s + self.dropout_attn(v_s_n)
            else:
                v_s = v_s + v_s_n

            tmp_q = rearrange(q_s, "b c (block_num block_size) h d -> b c h block_size block_num d", block_num = self.dilated_attn[0] )
            tmp_k = rearrange(k_s, "b c (block_num block_size) h d -> b c h block_size d block_num", block_num = self.dilated_attn[0] )
            v_s   = rearrange(v_s, "b c h block_num block_size d -> b c h block_size block_num d", block_num = self.dilated_attn[0] )
            attn_scores = torch.matmul(tmp_q, tmp_k) * self.scale2
            attn_weights = F.softmax(attn_scores, dim=-1)               
            attn_weights = self.attn_dropout2(attn_weights)
            v_s_n = torch.matmul(attn_weights, v_s)
            if split_num != 1:
                if not self.is_electricity:
                    v_s = v_s + self.dropout_attn(v_s_n)
                else:
                    v_s = v_s + v_s_n
            else:
                v_s = v_s_n
            
            v_s   = rearrange(v_s, "b c h block_size block_num d -> b c h (block_num block_size) d", block_num = self.dilated_attn[0] )
        
        if split_num == 1:
            q_s = None
            v_s   = rearrange(v_s, "b c h s d -> b c s (h d)")
            tmp_q = None
            tmp_k = None
            attn_scores = None
            attn_weights = None
        else:
            q_s = self.W_Q_feat(src).view(batch_size, c_size, seq_size, self.n_heads, self.d_k)
            if not self.is_electricity:
                q_s = q_s - self.feat_center[idx].unsqueeze(2)
            tmp_q = rearrange(q_s, "b (block_num block_size) s h d -> b block_num s h block_size d", block_num = c_size // split_num )
            tmp_k = rearrange(k_s, "b (block_num block_size) s h d -> b block_num s h d block_size", block_num = c_size // split_num )
            v_s   = rearrange(v_s, "b (block_num block_size) h s d -> b block_num s h block_size d", block_num = c_size // split_num )
            attn_scores = torch.matmul(tmp_q, tmp_k) * self.scale3
            attn_weights = F.softmax(attn_scores, dim=-1)   
            attn_weights = self.attn_dropout3(attn_weights)
            v_s = torch.matmul(attn_weights, v_s)
            v_s   = rearrange(v_s, "b block_num s h block_size d -> b (block_num block_size) s (h d)")
        out = self.to_out(v_s)

        src = src + self.dropout_attn(out)
        src2 = self.ff(src)
        if not self.is_electricity:
            src = src + self.dropout_ffn(src2)
        else:
            src = src + src2
        return src, attn_rank
        

 