import torch
import torch.nn as nn
from math import sqrt
import numpy as np
from typing import Tuple
from functools import partial
from einops import rearrange, repeat
from icecream import ic

from layers.STEmbedding import SNIP, AttentionPrompt
import torch.nn.functional as F

class MultiHeadsAttention(nn.Module):
    '''
    The Attention operation
    '''

    def __init__(self, scale=None, attention_dropout=0.1, returnA=False):
        super(MultiHeadsAttention, self).__init__()
        self.scale = scale
        self.returnA = returnA
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1./sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.returnA:
            return V.contiguous(), A.contiguous()
        else:
            return V.contiguous(), None

class AttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, d_keys=None, d_values=None, mix=True, dropout=0.1, returnA=False, 
                 att_type='full'):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (hid_dim//n_heads)
        d_values = d_values or (hid_dim//n_heads)

        if att_type == 'full' or att_type == 'proxy':
            self.inner_attention = MultiHeadsAttention(
                scale=None, attention_dropout=dropout, returnA=returnA)
        self.query_projection = nn.Linear(hid_dim, d_keys * n_heads)
        self.key_projection = nn.Linear(hid_dim, d_keys * n_heads)
        self.value_projection = nn.Linear(hid_dim, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, hid_dim)
        self.n_heads = n_heads
        self.returnA = returnA
        self.mix = mix
        


    def forward(self, queries, keys, values):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads


        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, A = self.inner_attention(
            queries,
            keys,
            values,
        )
        if self.mix:
            out = out.transpose(2, 1).contiguous()
        out = out.view(B, L, -1)
        out = self.out_projection(out)
        if self.returnA:
            return out, A
        else:
            return out, None




class STEncoderLayer(nn.Module):
    def __init__(self, factor, d_model, n_heads, input_length, time_factor, use_meta_emb_as_proxy=False,
                 d_ff=None, dropout=0.1, att_dropout=0.1, return_att=False):
        super().__init__()
        d_ff = d_ff or 4*d_model
        self.return_att = return_att
        self.use_meta_emb_as_proxy  = use_meta_emb_as_proxy 
        
        if use_meta_emb_as_proxy == False:
            self.proxy_token = nn.Parameter(torch.zeros(1, factor, d_model)) # proxy is learnable parameters

        self.time_factor = time_factor
        self.time_readout = nn.Sequential(nn.Conv2d(input_length, time_factor, kernel_size=(1,1)), 
                                          nn.GELU(),
                                          nn.Conv2d(time_factor, time_factor, kernel_size=(1,1)))
        self.time_recover = nn.Sequential(nn.Conv2d(time_factor, time_factor, kernel_size=(1,1)), 
                                          nn.GELU(),
                                          nn.Conv2d(time_factor, input_length, kernel_size=(1,1)))
        
        self.node2proxy = AttentionLayer(d_model, n_heads, dropout=att_dropout, returnA=return_att)
        self.proxy2node = AttentionLayer(d_model, n_heads, dropout=att_dropout, returnA=return_att)

        self.dropout = nn.Dropout(dropout)

        self.FFN = nn.Sequential(nn.Linear(d_model, d_ff),
                                  nn.GELU(),
                                  nn.Linear(d_ff, d_model))
        self.d_head = d_model // n_heads

    def forward(self, data, proxy= None):  # data:BTNC now:bnc
        batch = data.shape[0]
        T = data.shape[1]
        
        if self.use_meta_emb_as_proxy and proxy is not None:
            z_proxy = repeat(proxy,'m d -> (b) m d', b = batch*self.time_factor)
        else:
            z_proxy = repeat(self.proxy_token,'o m d -> (o b) m d', b = batch*self.time_factor)
        
        kv_data = self.time_readout(data)
        kv_data = rearrange(kv_data, 'b t n c-> (b t) n c')

        
        proxy_feature, A1 = self.node2proxy(z_proxy, kv_data, kv_data)
        node_feature, A2 = self.proxy2node(kv_data, proxy_feature, proxy_feature)
        enc_feature = kv_data + self.dropout(node_feature)
        enc_feature = enc_feature + self.dropout(self.FFN(enc_feature))

        final_out = rearrange(enc_feature, '(b T) N d -> b T N d', b=batch)
        final_out = self.time_recover(final_out)

        if self.return_att:
            A1 = rearrange(A1, '(b t) h l s -> b t h l s', b=batch)
            A2 = rearrange(A2, '(b t) h l s -> b t h l s', b=batch)
            return final_out, [A1, A2]
        else:
            return final_out, None

class getTimeEmbedding(nn.Module):
    def __init__(self, hid_dim, slice_size_per_day,emb_dropout=0.1):
        super().__init__()
        self.time_in_day_embedding = nn.Embedding(slice_size_per_day, hid_dim)
        self.day_in_week_embedding = nn.Embedding(7, hid_dim)
        self.dropout = nn.Dropout(emb_dropout)

    def forward(self, time_data):
        t_hour = time_data[...,0:1]
        t_day = time_data[...,1:2]
        time_in_day_emb = self.time_in_day_embedding(t_hour)
        day_in_week_emb = self.day_in_week_embedding(t_day)
        time_emb = time_in_day_emb + day_in_week_emb
        time_emb = self.dropout(time_emb)        
        return time_emb
    
class getSpatialEmbedding(nn.Module):
    '''
    not used
    '''
    def __init__(self, hid_dim, num_nodes, emb_dropout=0.1):
        super().__init__()
        self.spatial_embedding = nn.Embedding(num_nodes, hid_dim)
        self.dropout = nn.Dropout(emb_dropout)

    def forward(self, x, spatial_indexs=None):
        if spatial_indexs is None:
            batch, _,  num_nodes, _ = x.shape
            spatial_indexs = torch.LongTensor(torch.arange(num_nodes)).to(x.device)  # (N,)
        spatial_emb = self.spatial_embedding(spatial_indexs).unsqueeze(0).unsqueeze(1)
        spatial_emb = self.dropout(spatial_emb)
        return spatial_emb


class DataEncoding(nn.Module):
    def __init__(self, in_dim, hid_dim, hasCross=True, activation='relu'):
        super().__init__()
        assert activation in ['gelu', 'relu']
        in_units = in_dim*2 if hasCross else in_dim
        self.hasCross = hasCross
        self.linear1 = nn.Linear(in_units, hid_dim)
        self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
        self.linear2 = nn.Linear(hid_dim, hid_dim)

    def forward(self, x, latestX):
        if self.hasCross:
            data = torch.cat([x, latestX], dim=-1)
        else:
            data = x
        data = self.linear1(data)
        data = self.activation(data)
        data = self.linear2(data)
        return data


class Model(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.input_length = args.input_length
        self.predict_length = args.predict_length
        if self.predict_length ==0:
            self.predict_length = self.input_length
        self.in_dim = args.in_dim
        self.pre_dim = args.pre_dim
        self.num_nodes = args.num_nodes # not used when args.hasRawSemb == False
        self.tau = args.tau
        self.useTCN = self.tau > 0
        self.hid_dim = args.hid_dim
        self.n_heads = args.n_heads
        self.num_layers = args.num_layers
        self.norm_flag = 'none'
        self.att_type = 'proxy'
        self.M = args.M
        self.addLatestX = True
        self.hasCross = True

        self.hasTemb = bool(args.hasTemb)
        self.hasRawSemb = bool(args.hasRawSemb)
        self.prompt_type = args.prompt_type
        assert self.prompt_type in ['SNIP', 'AttPrompt', 'none']
        self.hasSTencoder= bool(args.hasSTencoder)

        self.return_att = False
        self.activation_data = args.activation_data
        self.activation_enc = args.activation_enc
        self.activation_dec = args.activation_dec
        self.emb_dropout = args.emb_dropout
        self.te_emb_dropout = args.te_emb_dropout
        self.se_emb_dropout = args.se_emb_dropout
        self.enc_dropout = args.enc_dropout
        self.att_dropout = args.att_dropout
        self.adj_dropout = args.adj_dropout

        self.finetune_drop = bool(args.finetune_drop)

        self.slice_size_per_day = args.slice_size_per_day
        self.revin = bool(args.revin)
        self.revin_type = args.revin_type

        ### for SNIP calculation
        if self.prompt_type == 'SNIP':
            self.static_feats_dim_list = args.static_feats_dim_list
            self.static_func_type = args.static_func_type
            self.dynamic_func_type = args.dynamic_func_type
            self.support_len = args.support_len
            self.use_meta_emb_as_proxy = bool(args.use_meta_emb_as_proxy)
            self.use_drop_data_emb_for_snip = bool(args.use_drop_data_emb_for_snip)
        else:            
            self.use_meta_emb_as_proxy = False

        ### Data encoding
        self.data_encoding = DataEncoding(
            self.in_dim, self.hid_dim, self.hasCross, activation=self.activation_data)
        self.data_emb_dropout_layer = nn.Dropout(p=self.emb_dropout)
        
        ### Topology dropout
        if self.adj_dropout > 0:
            self.adj_dropout_layer = nn.Dropout(self.adj_dropout)

        ### Temporal embedding
        if self.hasTemb:
            self.getTemb = getTimeEmbedding(self.hid_dim, self.slice_size_per_day, self.te_emb_dropout)

        ### Spatial-Temporal embedding
        if self.hasRawSemb:
            self.getSemb = getSpatialEmbedding(self.hid_dim, self.num_nodes, self.se_emb_dropout)

        if self.prompt_type=='SNIP':
            self.getSNIP = SNIP(self.static_feats_dim_list, self.hid_dim, support_len=self.support_len, 
                                static_func_type = self.static_func_type, dynamic_func_type= self.dynamic_func_type,
                                dropout_rate=self.se_emb_dropout)
        
        elif self.prompt_type=='AttPrompt':
            self.getAttPrompt = AttentionPrompt(args.M, self.hid_dim)
            
        ### TCN encoding
        if self.useTCN:
            assert self.tau in [2,3]
            tcn_pad_l, tcn_pad_r = self.tau-2, 1 
            self.padding = nn.ReplicationPad2d((tcn_pad_l, tcn_pad_r, 0, 0))  # x must be like (B,C,N,T)
            self.time_conv = nn.Conv2d(self.hid_dim, self.hid_dim, (1, self.tau))
        
        ### Spatial-Temporal Feature Extraction 
        if self.hasSTencoder:
            if self.att_type=='proxy':
                self.spatial_agg_list = nn.ModuleList([
                    STEncoderLayer(self.M, self.hid_dim, n_heads=self.n_heads, input_length=self.input_length, time_factor =3, 
                                   use_meta_emb_as_proxy=self.use_meta_emb_as_proxy) for _ in range(self.num_layers)])

        ### Prediction Layer
        self.output_proj = nn.Linear(
            self.input_length * self.hid_dim, self.predict_length * self.pre_dim
        )

    def forward(self, x, mode='val'):
        x, x_time, y_time, aux_data = x  # x: (B,T,N,C) x_time:(B,T,2) aux_data:{...}
        B, T, N, _ = x.shape
        assert x.shape[-1] == self.in_dim
        adj_mx = aux_data['adj_mx']
        period_feats = aux_data['period_feats']
        stci_feats = aux_data['stci_feats']
        structure_feats = aux_data['structure_feats']

        if self.revin:
            if self.revin_type == 'ST':
                means = x.mean(dim=(1,2), keepdim=True).detach()
                stdev = torch.sqrt(torch.var(x, dim=(1,2), keepdim=True, unbiased=False) + 1e-5)
                x = x - means
                x /= stdev
            elif self.revin_type == 'S':
                means = x.mean(dim=1, keepdim=True).detach()
                stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
                x = x - means
                x /= stdev
            elif self.revin_type == 'T':
                means = x.mean(dim=2, keepdim=True).detach()
                stdev = torch.sqrt(torch.var(x, dim=2, keepdim=True, unbiased=False) + 1e-5)
                x = x - means
                x /= stdev

        if self.finetune_drop:
            if self.training:
                if self.adj_dropout>0:
                    adj_mx[0] = self.adj_dropout_layer(adj_mx[0])
                    if len(adj_mx)>1:
                        adj_mx[1] = self.adj_dropout_layer(adj_mx[1])
        else:
            if self.training and mode!='finetune':
                if self.adj_dropout>0:
                    adj_mx[0] = self.adj_dropout_layer(adj_mx[0])
                    if len(adj_mx)>1:
                        adj_mx[1] = self.adj_dropout_layer(adj_mx[1])
        
        latestX = x[:, -1:, :, :].repeat([1, self.input_length, 1, 1])
        data_emb = self.data_encoding(x, latestX)
        data = self.data_emb_dropout_layer(data_emb)
        
        if self.hasTemb:
            x_time_emb = self.getTemb(x_time)
            data = data + x_time_emb

        if self.hasRawSemb:
            x_spa_emb = self.getSemb(data)
            data = data + x_spa_emb
        
        proxy = None
        if self.prompt_type=='SNIP':
            input_static_feat_list = [period_feats,stci_feats,structure_feats]

            if self.use_drop_data_emb_for_snip:
                prompt_emb = self.getSNIP(data, input_static_feat_list, adj_mx)
            else:
                prompt_emb = self.getSNIP(data_emb, input_static_feat_list, adj_mx)

            data = data + prompt_emb
            
            if self.use_meta_emb_as_proxy:
                proxy = self.getSNIP.get_meta_emb()

        elif self.prompt_type == 'AttPrompt':
            prompt_emb = self.getAttPrompt(data)
            data = data + prompt_emb
        elif self.prompt_type.lower() == 'none':
            prompt_emb = None
        
        if self.useTCN:
            data = data.transpose(1, 3)  # (B,T,N,C)->(B,C,N,T)
            if self.tau > 1:
                data = self.padding(data)
            data = self.time_conv(data)
            data = data.transpose(1, 3)  # (B,C,N,T)->(B,T,N,C)
            assert data.shape[1] == self.input_length

        if self.hasSTencoder:
            skip = data
            A_list = []
            for i in range(self.num_layers):
                data, A = self.spatial_agg_list[i](data, proxy)
                A_list.append(A)
            data += skip


        data = rearrange(data, 'b t n d -> b n (t d)')
        main_output = self.output_proj(data)
        main_output = rearrange(main_output, 'b n (t c) -> b t n c', t=self.predict_length, c=self.pre_dim)

        if self.addLatestX:
            if self.input_length == self.predict_length:
                main_output += latestX
            else:
                main_output += latestX[:, 0:self.predict_length, :, :]

        if self.revin:
            main_output = main_output * stdev
            main_output = main_output + means

        return main_output, prompt_emb