import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from data_process.data_factory import data_provider
from utils.tools import EarlyStopping, adjust_learning_rate, visual
from utils.metrics import metric
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
from model import get_model_class
import random
warnings.filterwarnings('ignore')
import json
from math import sqrt
from data_process.data_factory import data_dict
import matplotlib.pyplot as plt

global_predict_len = 96

llm_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-1.5B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
word_embeds = llm_model.get_input_embeddings().weight.to(llm_model.dtype) 

class ConvHead(nn.Module):
    def __init__(self, nf, target_window, num_layers, head_dropout=0):
        super().__init__()
        self.conv_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.conv_layers.append(nn.Conv1d(nf, nf, kernel_size=3, padding=1))
            self.conv_layers.append(nn.ReLU())
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf*24, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):
        x = x.permute(0, 2, 1) 
        for layer in self.conv_layers:
            x = layer(x)
        x = self.flatten(x.permute(0, 2, 1)) 
        x = self.linear(x)
        x = self.dropout(x)
        return x


class AdaptiveFactorizedBilinearFusion(nn.Module):
    def __init__(self, d_model, d_factor):
        super(AdaptiveFactorizedBilinearFusion, self).__init__()
        self.U = nn.Linear(d_model, d_factor)
        self.V = nn.Linear(d_model, d_factor)
        self.P = nn.Linear(d_factor, d_model)
        W_init = torch.randn(24, d_factor)
        # Register the parameter explicitly
        self.register_parameter('W', nn.Parameter(W_init))
        #self.W.register_hook(hook_fn)
    def forward(self, X, Y):
        U_X = self.U(X)
        V_Y = self.V(Y)
        # print(U_X.shape)
        Z =   V_Y * self.W * U_X
        Z = self.P(Z)
        return Z

class TITSP(torch.nn.Module):
    def __init__(self, args):
        super(TS2VobModel, self).__init__()
        self.word_embeddings = llm_model.get_input_embeddings().weight.to(llm_model.dtype)   # torch.Size([151936, 2048])
        self.vocab_size = self.word_embeddings.shape[0]
        self.num_tokens = 512  # 待修改, 这几个都是float16
        self.bfusion = AdaptiveFactorizedBilinearFusion(1536, 1536).to(llm_model.dtype)
        self.output_projection = ConvHead(1536,96,3,0.15).to(llm_model.dtype)
        self.attn = nn.MultiheadAttention(embed_dim=1536, num_heads=8).to(llm_model.dtype)
        self.top_k = 5
    def forward(self, prompt_embeds, data_embeds, ret_attn=False):
        data_embeds = torch.permute(data_embeds,(0,2,1))   
        data_embeds = torch.permute(data_embeds,(1,0,2))   
        #prompt_embeds = torch.permute(prompt_embeds,(0,2,1))   
        prompt_embeds = torch.permute(prompt_embeds,(1,0,2))   
        
        attn_output, attn_weights = self.attn(data_embeds, prompt_embeds, prompt_embeds)
        data_embeds = torch.permute(data_embeds,(1,0,2))
        attn_output = torch.permute(attn_output,(1,0,2))
        enc_out = self.bfusion(data_embeds, attn_output)
        dec_out = llm_model(inputs_embeds=enc_out,mode='infer-reprogram')   # torch.Size([BS, seq_len, dim])
        dec_out = self.output_projection(dec_out)
        if not ret_attn:
            return dec_out
        else:
            return dec_out, attn_weights
