from math import sqrt

import torch
import torch.nn as nn

from transformers import LlamaConfig, LlamaModel, LlamaTokenizer, GPT2Config, GPT2Model, GPT2Tokenizer, BertConfig, \
    BertModel, BertTokenizer
from layers.Embed import PatchEmbedding
import transformers
from layers.StandardNorm import Normalize

transformers.logging.set_verbosity_error()


class FlattenHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x


class MMVQVAE_VER2(nn.Module):

    def __init__(self, config):
        super(MMVQVAE_VER2, self).__init__()
 
        self.pred_len = config["pred_len"]
        self.seq_len = config["seq_len"]
        self.d_ff = config["d_ff"]
        self.top_k = 5
        self.d_llm = config["llm_dim"]
        self.patch_len = config["patch_len"]
        self.stride = config["stride"]
        self.llm_model_type = config["llm_model"]
        self.llm_layers = config["llm_layers"]
        if self.llm_model_type == 'llama':
            self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b')
            self.llama_config.num_hidden_layers = config.llm_layers
            self.llama_config.output_attentions = True
            self.llama_config.output_hidden_states = True
            try:
                self.llm_model = LlamaModel.from_pretrained(
                    # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
                    'huggyllama/llama-7b',
                    trust_remote_code=True,
                    local_files_only=True,
                    config=self.llama_config,
                    # load_in_4bit=True
                )
            except EnvironmentError:  # downloads model from HF is not already done
                print("Local model files not found. Attempting to download...")
                self.llm_model = LlamaModel.from_pretrained(
                    # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
                    'huggyllama/llama-7b',
                    trust_remote_code=True,
                    local_files_only=False,
                    config=self.llama_config,
                    # load_in_4bit=True
                )
            try:
                self.tokenizer = LlamaTokenizer.from_pretrained(
                    # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model",
                    'huggyllama/llama-7b',
                    trust_remote_code=True,
                    local_files_only=True
                )
            except EnvironmentError:  # downloads the tokenizer from HF if not already done
                print("Local tokenizer files not found. Atempting to download them..")
                self.tokenizer = LlamaTokenizer.from_pretrained(
                    # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model",
                    'huggyllama/llama-7b',
                    trust_remote_code=True,
                    local_files_only=False
                )
        elif self.llm_model_type == 'GPT2':
            local_model_path = "/root/MMTSF/NewsForecasting/pretrain_models/GPT2"
            self.gpt2_config = GPT2Config.from_pretrained(local_model_path)

            self.gpt2_config.num_hidden_layers = self.llm_layers 
            self.gpt2_config.output_attentions = True
            self.gpt2_config.output_hidden_states = True
            self.llm_model = GPT2Model.from_pretrained(
                    local_model_path,
                    trust_remote_code=True,
                    local_files_only=False,
                    config=self.gpt2_config,
                )
            self.tokenizer = GPT2Tokenizer.from_pretrained(
                    local_model_path,
                    trust_remote_code=True,
                    local_files_only=False
                )

        elif self.llm_model_type == 'BERT':
            self.bert_config = BertConfig.from_pretrained('google-bert/bert-base-uncased')

            self.bert_config.num_hidden_layers = self.llm_layers
            self.bert_config.output_attentions = True
            self.bert_config.output_hidden_states = True
            try:
                self.llm_model = BertModel.from_pretrained(
                    'google-bert/bert-base-uncased',
                    trust_remote_code=True,
                    local_files_only=True,
                    config=self.bert_config,
                )
            except EnvironmentError:  # downloads model from HF is not already done
                print("Local model files not found. Attempting to download...")
                self.llm_model = BertModel.from_pretrained(
                    'google-bert/bert-base-uncased',
                    trust_remote_code=True,
                    local_files_only=False,
                    config=self.bert_config,
                )

            try:
                self.tokenizer = BertTokenizer.from_pretrained(
                    'google-bert/bert-base-uncased',
                    trust_remote_code=True,
                    local_files_only=True
                )
            except EnvironmentError:  # downloads the tokenizer from HF if not already done
                print("Local tokenizer files not found. Atempting to download them..")
                self.tokenizer = BertTokenizer.from_pretrained(
                    'google-bert/bert-base-uncased',
                    trust_remote_code=True,
                    local_files_only=False
                )
        else:
            raise Exception('LLM model is not defined')

        if self.tokenizer.eos_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        else:
            pad_token = '[PAD]'
            self.tokenizer.add_special_tokens({'pad_token': pad_token})
            self.tokenizer.pad_token = pad_token

        for param in self.llm_model.parameters():
            param.requires_grad = False
        self.prompt_domain = config["prompt_domain"]
        if self.prompt_domain:
            self.description = config["content"]
        else:
            self.description = 'The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment.'

        self.dropout = nn.Dropout(config["dropout"])

        self.patch_embedding = PatchEmbedding(
            config["d_model"], self.patch_len, self.stride, config["dropout"])

        self.word_embeddings = self.llm_model.get_input_embeddings().weight
        self.vocab_size = self.word_embeddings.shape[0]
        self.num_tokens = 1000
        self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)

        self.reprogramming_layer = ReprogrammingLayer(config["d_model"], config["n_heads"], self.d_ff, self.d_llm)

        self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
        self.head_nf = self.d_ff * self.patch_nums


        self.output_projection = FlattenHead(1, self.head_nf, self.pred_len,
                                                head_dropout=config["dropout"])
  

        self.normalize_layers = Normalize(1, affine=False)
        self.codebook_dim = 512
        self.n_embeddings = 64
        self.beta = 0.25
        self.vector_quantization = VectorQuantizer(
            self.n_embeddings, self.codebook_dim, self.beta, self.d_llm* self.patch_nums)

        self.vq_weight = config['vqvae_weight']
        self.decoder = Decoder(in_dim=self.codebook_dim,out_dim=self.d_ff,patch_num=self.patch_nums)

    def forward(self, x_enc,x_text):

        dec_out = self.forecast(x_enc,x_text)
        return dec_out


    def forecast(self, x_enc,x_text):


        x_enc = x_enc.unsqueeze(-1)
         
        x_enc = self.normalize_layers(x_enc, 'norm')

        B, T, N = x_enc.size()
        x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

        min_values = torch.min(x_enc, dim=1)[0]
        max_values = torch.max(x_enc, dim=1)[0]
        medians = torch.median(x_enc, dim=1).values
        lags = self.calcute_lags(x_enc)
        trends = x_enc.diff(dim=1).sum(dim=1)
    
        prompt = []
        
        for b in range(x_enc.shape[0]):
            min_values_str = str(min_values[b].tolist()[0])
            max_values_str = str(max_values[b].tolist()[0])
            median_values_str = str(medians[b].tolist()[0])
            lags_values_str = str(lags[b].tolist())
            prompt_ = (
                f"<|start_prompt|>Dataset description: {self.description}"
                f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; "
                "Input statistics: "
                f"min value {min_values_str}, "
                f"max value {max_values_str}, "
                f"median value {median_values_str}, "
                f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, "
                f"top 5 lags are : {lags_values_str}<|<end_prompt>|>"
            )
            prompt.append(prompt_)

        
        x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()
        prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1000).input_ids
        prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device))  # (batch, prompt_token, dim)
        source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
        
        x_enc = x_enc.permute(0, 2, 1).contiguous()
        enc_out, n_vars = self.patch_embedding(x_enc)
        x_former = enc_out
        
        enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
 
        z_e = enc_out
   
        self.loss_vae, z_q, perplexity, _, _ = self.vector_quantization(z_e)
        x_hat = self.decoder(z_q)
        loss_recon = torch.mean((x_hat - x_former) ** 2)
        self.loss_vae += loss_recon

        llama_enc_out = enc_out
        dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state
        dec_out = dec_out[:, :, :self.d_ff]

        dec_out = torch.reshape(
            dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1]))
        dec_out = dec_out.permute(0, 1, 3, 2).contiguous()

        dec_out = self.output_projection(dec_out[:, :, :, -self.patch_nums:])
        dec_out = dec_out.permute(0, 2, 1).contiguous()

        dec_out = self.normalize_layers(dec_out, 'denorm')
        dec_out = dec_out.squeeze()
        return dec_out
    def calculate_loss(self):
        return self.vq_weight * self.loss_vae
    def calcute_lags(self, x_enc):
        q_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, dim=-1)
        mean_value = torch.mean(corr, dim=1)
        _, lags = torch.topk(mean_value, self.top_k, dim=-1)
        return lags
  


class ReprogrammingLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
        super(ReprogrammingLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, target_embedding, source_embedding, value_embedding):
        ## Source Embedding = WTE + Linear 
        ## Target Embedding = x patch embedding 
        B, L, _ = target_embedding.shape
        S, _ = source_embedding.shape
        H = self.n_heads

        ## 再过一遍线性层 映射到融合的隐空间 
        target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
        source_embedding = self.key_projection(source_embedding).view(S, H, -1)
        value_embedding = self.value_projection(value_embedding).view(S, H, -1)
        ## 两者相乘 算注意力 
        out = self.reprogramming(target_embedding, source_embedding, value_embedding)

        out = out.reshape(B, L, -1)

        return self.out_projection(out)

    def reprogramming(self, target_embedding, source_embedding, value_embedding):
        B, L, H, E = target_embedding.shape
        ## 两者相乘 
        scale = 1. / sqrt(E)

        scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)

        return reprogramming_embedding
import torch
import torch.nn as nn
import torch.nn.functional as F

class VectorQuantizer(nn.Module):
    def __init__(self, n_e, e_dim, beta, input_dim):
        super(VectorQuantizer, self).__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta
        
        # Project input to embedding dimension
        self.linear = nn.Linear(input_dim, e_dim)
        # Project back to original input dimension
        self.linear2 = nn.Linear(e_dim, input_dim)
        
        # Embedding layer
        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

    def forward(self, z):
        B, P, D = z.shape
        
        z_merged = z.view(B,-1)
        z_flattened = self.linear(z_merged) # b,p*d -> b e_dim
        d = torch.sum(z_flattened.unsqueeze(1)**2, dim=-1) + \
            torch.sum(self.embedding.weight**2, dim=-1) - \
            2 * torch.matmul(z_flattened, self.embedding.weight.t())
        min_encoding_indices = torch.argmin(d, dim=-1)  # [B]
        min_encodings = torch.zeros(B, self.n_e).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices.unsqueeze(1), 1)
        z_q = torch.matmul(min_encodings, self.embedding.weight)
        loss = torch.mean((z_q.detach() - z_flattened)**2) + \
           self.beta * torch.mean((z_q - z_flattened.detach())**2)
        z_q = z_flattened + (z_q - z_flattened).detach()
    
        # 计算困惑度
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
       
        return loss, z_q, perplexity, min_encodings, min_encoding_indices
class Decoder(nn.Module):
    def __init__(self, in_dim, out_dim,patch_num):
        super(Decoder, self).__init__()
        self.P = patch_num
        self.D = out_dim
        self.net = nn.Sequential(
            nn.Linear(in_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, self.P * self.D * 2),
            nn.ReLU(),
            nn.Linear(self.P * self.D * 2, self.P * self.D),
        )
        self.patch_num = patch_num

    def forward(self, x):
        B = x.size(0)
        out = self.net(x)

        ## Reshape
        return out.view(B, self.P, self.D)
