import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from layers.mlp import MLP
from .TSCC import AlignFusionModel
from transformers import AutoModel,AutoModelForCausalLM, AutoTokenizer

class TimeLayer(nn.Module):
    def __init__(self, original_layer, rank=16):
        super().__init__()
        self.original_layer = original_layer
        self.lora_A = nn.Linear(896, rank)
        self.lora_C = nn.LSTM(rank, 896,batch_first=True)
        self.lora_D = nn.LSTM(896, rank,batch_first=True)
        self.lora_B = nn.Linear(rank, 128)
        nn.init.normal_(self.lora_A.weight, std=0.02)
        nn.init.zeros_(self.lora_B.weight)

    def forward(self, x):
        original_output = self.original_layer(x)
        A = self.lora_A(x)
        C,(_,_) = self.lora_C(A)
        D,(_,_) = self.lora_D(C)
        B = self.lora_B(D)

        return original_output + B

def add_time_adapter(model, rank=8):
    for layer in model.layers:

        layer.self_attn.k_proj = TimeLayer(layer.self_attn.k_proj, rank)
        layer.self_attn.v_proj = TimeLayer(layer.self_attn.v_proj, rank)
    return model

class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.token_len = configs.token_len
        if configs.use_multi_gpu:
            self.device = f"cuda:{configs.local_rank}"
        else:
            self.device = f"cuda:{configs.gpu}"
        print(self.device)

        self.gpt2 = AutoModel.from_pretrained(
            "Qwen/Qwen-0.5B-GRPO").to(self.device)

        self.hidden_dim_of_gpt2 = 896
        
        for name, param in self.gpt2.named_parameters():
            param.requires_grad = False


        self.encoder = MLP(self.token_len, self.hidden_dim_of_gpt2,
                        configs.mlp_hidden_dim, configs.mlp_hidden_layers,
                        configs.dropout, configs.mlp_activation)
        self.decoder = MLP(self.hidden_dim_of_gpt2, self.token_len,
                        configs.mlp_hidden_dim, configs.mlp_hidden_layers,
                        configs.dropout, configs.mlp_activation)

        self.word_embeddings = self.gpt2.get_input_embeddings().weight
        self.vocab_size = self.word_embeddings.shape[0]
        self.text_prototype_linear = nn.Linear(self.vocab_size, 1500)
        self.fusion = AlignFusionModel()
        self.gpt2 = add_time_adapter(self.gpt2)


    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        # x_enc = self.revin_layer(x_enc,'norm')
        means = x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc - means
        stdev = torch.sqrt(
            torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc /= stdev

        
        bs, _, n_vars = x_enc.shape

        x_enc = x_enc.permute(0, 2, 1)

        x_enc = x_enc.reshape(x_enc.shape[0] * x_enc.shape[1], -1)

        fold_out = x_enc.unfold(dimension=-1, size=self.token_len, step=self.token_len)
        token_num = fold_out.shape[1]

        times_embeds = self.encoder(fold_out)

        prompt_key = self.text_prototype_linear(self.gpt2.get_input_embeddings().weight.transpose(0, 1)).transpose(0, 1)

        times_embeds = self.fusion(times_embeds, prompt_key)

        outputs = self.gpt2(
            inputs_embeds=times_embeds).last_hidden_state

        dec_out = self.decoder(outputs)
        dec_out = dec_out.reshape(bs, n_vars, -1)

        dec_out = dec_out.permute(0, 2, 1)

        dec_out = dec_out * \
            (stdev[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1))
        dec_out = dec_out + \
            (means[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1))

        
        return dec_out
    
    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)