from math import sqrt

import torch
import torch.nn as nn

from transformers import LlamaConfig
from layers.Embed import PatchEmbedding
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
from huggingface_hub import hf_hub_download
import torch.nn.functional as F
from torch.autograd import grad

import transformers

from layers.StandardNorm import Normalize

transformers.logging.set_verbosity_error()

class FlattenHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.conv1d = nn.Conv1d(1, 1, kernel_size=32, stride=32, padding=0)
        self.bn = nn.BatchNorm1d(1)
        self.gelu = nn.GELU()
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):
        B, T, D, L = x.shape
        x = self.conv1d(x.reshape(-1, 1, 4096))
        x = self.bn(x)
        x = self.gelu(x)
        x = self.flatten(x.reshape(B, T, 128, L))
        x = self.linear(x)
        x = self.dropout(x)
        return x

class Model(nn.Module):

    def __init__(self, configs):
        super(Model, self).__init__()
        self.task_name = configs.task_name
        self.pred_len = configs.pred_len
        self.seq_len = configs.seq_len
        self.d_ff = configs.d_ff
        self.top_k = 5
        self.d_llm = 4096
        self.patch_len = configs.patch_len
        self.stride = configs.stride

        self.llama_config = LlamaConfig.from_pretrained('meta-llama/Meta-Llama-3-8B')
        self.llama_config.num_hidden_layers = configs.llm_layers
        self.llama_config.output_attentions = True
        self.llama_config.output_hidden_states = True

        self.llama_pipeline = transformers.pipeline("text-generation", model='meta-llama/Meta-Llama-3-8B', model_kwargs={"torch_dtype": torch.bfloat16})
        self.llama = self.llama_pipeline.model.model
        self.lm_head = self.llama_pipeline.model.lm_head
        
        self.tokenizer = self.llama_pipeline.tokenizer
        
        self.moirai = MoiraiForecast(
            module=MoiraiModule.from_pretrained(
                f"Salesforce/moirai-1.0-R-Large",
            ),
            prediction_length=configs.pred_len,
            context_length=configs.seq_len,
            patch_size=configs.patch_len,
            num_samples=100,
            target_dim=1,
            feat_dynamic_real_dim=0,
            past_feat_dynamic_real_dim=0,
        )

        if self.tokenizer.eos_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        else:
            pad_token = '[PAD]'
            self.tokenizer.add_special_tokens({'pad_token': pad_token})
            self.tokenizer.pad_token = pad_token

        for param in self.llama_pipeline.model.parameters():
            param.requires_grad = False
        
        for param in self.moirai.parameters():
            param.requires_grad = False

        self.patch_embedding = PatchEmbedding(
            1024, self.patch_len, self.stride, configs.dropout)
        
        self.words = []
        prompts = ['pulse', 'peak', 'emission', 'leading edge', 'trailing edge', 'increasing', 'decreasing', 'constant', 'noise', 'oscillating']
        for prompt in prompts:
            self.words += [self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)['input_ids'].squeeze(0)[1]]
        self.word_embeddings = self.llama.get_input_embeddings()(torch.tensor(self.words).to(self.llama.device)).to(torch.bfloat16).to(self.llama.device)
        
        self.cross_attention_layer = CrossAttentionLayer(1024, configs.n_heads, self.d_ff, self.d_llm)

        self.patch_nums = int((configs.seq_len - self.patch_len) / self.stride + 2)
        self.head_nf = self.d_ff * self.patch_nums

        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            self.output_projection = FlattenHead(configs.enc_in, self.head_nf, self.pred_len,
                                                 head_dropout=configs.dropout)
        else:
            raise NotImplementedError

        self.normalize_layers = Normalize(configs.enc_in, affine=False)

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            if not self.training:
                dec_out, carli_conf, saliency, entropy = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
                return dec_out[:, -self.pred_len:, :], carli_conf[:,-self.pred_len:], saliency, entropy
            else:
                dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
                return dec_out[:, -self.pred_len:, :]
        return None

    def forecast(self, x_enc, phase_plate, x_dec, target_size):

        x_enc = self.normalize_layers(x_enc, 'norm')

        B, T, N = x_enc.size()
        x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

        min_values = torch.min(x_enc, dim=1)[0]
        max_values = torch.max(x_enc, dim=1)[0]
        medians = torch.median(x_enc, dim=1).values
        lags = self.calcute_lags(x_enc)

        prompt = []
        prompt_head = []
        for b in range(x_enc.shape[0]):
            min_values_str = str(min_values[b].tolist()[0])
            max_values_str = str(max_values[b].tolist()[0])
            median_values_str = str(medians[b].tolist()[0])
            lags_values_str = str(lags[b].tolist())
            prompt_ = (
                f"<|begin_of_text|><|start_header_id|>Dataset description<|end_header_id|>\n"
                f"The Inertial Confinement Fusion (ICF) is an essential method in producing sustainable energy. ICF aims at achieving nuclear fusion by compressing fuel pellets to high densities and temperatures. However, the ICF process is disturbed by hot electrons, making it crucial to study the generation of hot electrons. The hot electrons produced by Two-Plasmon Decay (TPD) in direct drive ICF can thwart ignition by preheating the deuterium and tritium fuel. The number of hot electrons is mainly related to the following variables: phase plate (SG4, SG5, and SG5-650), target size (the radius of the target pellet in micrometers) and the intensity of the input laser.<|eot_id|><|start_header_id|>Task description<|end_header_id|>\n"
                f"For the laser intensity given for each time step ({str(self.seq_len)} steps in total), forecast the hard X-ray  (HXR) energy emitted by hot electrons for each time step ({str(self.pred_len)} steps in total). Also, take the phase plate and target size into account as they also affect the hot electron production. The phase plate and target size remain unchanged for all time steps.<|eot_id|><|start_header_id|>Input statistics<|end_header_id|>\n"
                f"min value: {min_values_str}, max value: {max_values_str}, median value: {median_values_str}, top 5 lags: {lags_values_str}, Phase Plate: {phase_plate[b]}, Target Size: {str(target_size[b])}<|eot_id|><|start_header_id|>Input Data<|end_header_id|>\n"
            )
            head_ = "<|eot_id|><|start_header_id|>Prediction<|end_header_id|>\n"

            prompt_head.append(head_)
            prompt.append(prompt_)

        x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()

        prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids
        heads = self.tokenizer(prompt_head, return_tensors="pt", padding=True, truncation=True).input_ids
        prompt_embeddings = self.llama.get_input_embeddings()(prompt.to(x_enc.device))
        heads_embeddings = self.llama.get_input_embeddings()(heads.to(x_enc.device))

        source_embeddings = self.word_embeddings.to(x_enc.device)
        
        x_enc = x_enc.permute(0, 2, 1).contiguous()
        enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16))
        enc_out = self.moirai.module.encoder(enc_out)
        enc_out = self.cross_attention_layer(enc_out, source_embeddings, source_embeddings)

        llama_enc_out = torch.cat([prompt_embeddings, enc_out, heads_embeddings], dim=1)
        dec_out = self.llama(inputs_embeds=llama_enc_out).last_hidden_state
        if not self.training:
            if self.llama.config.pretraining_tp > 1:
                lm_head_slices = self.lm_head.weight.split(self.llama.vocab_size // self.llama.config.pretraining_tp, dim=0)
                logits = [F.linear(dec_out, lm_head_slices[i]) for i in range(self.llama.config.pretraining_tp)]
                logits = torch.cat(logits, dim=-1)
            else:
                logits = self.lm_head(dec_out)
            logits = logits.detach().cpu().float()
            entropy_map = get_uncertainty_per_token(logits)

        proj_in = torch.reshape(
            dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1]))
        proj_in = proj_in.permute(0, 1, 3, 2).contiguous()  # [B, 1, d_ff, squence_len]
        proj_in_2 = proj_in[:, :, :, -self.patch_nums:]

        proj_out = self.output_projection(
            proj_in_2)  
        
        if not self.training:
            saliency = get_saliency_map(proj_in_2, proj_out)
            saliency_map = F.softmax(saliency/0.02, dim=1).detach().cpu().float()

        proj_out = proj_out.permute(0, 2, 1).contiguous()

        proj_out = self.normalize_layers(proj_out, 'denorm')

        if not self.training:
            entropy_bank = entropy_map[:, -self.patch_nums:]
            calibrated_confidence = []
            entropy_bank = entropy_bank.to(saliency_map.dtype)
            for batch in range(saliency_map.shape[0]):
                calibrated_confidence.append((saliency_map[batch, :, :] @ entropy_bank[batch, :].T).unsqueeze(0))# [1,400]
            calibrated_confidence = torch.cat(calibrated_confidence, dim=0)
            return proj_out, calibrated_confidence, saliency, entropy_bank
        else:
            return proj_out

    def calcute_lags(self, x_enc):
        q_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, dim=-1)
        mean_value = torch.mean(corr, dim=1)
        _, lags = torch.topk(mean_value, self.top_k, dim=-1)
        return lags

class CrossAttentionLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
        super(CrossAttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)

        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.out_projection = nn.Linear((d_keys * n_heads) + 1024, d_llm)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, target_embedding, source_embedding, value_embedding):
        B, L, _ = target_embedding.shape
        S, _ = source_embedding.shape
        H = self.n_heads

        target_embedding_proj = self.query_projection(target_embedding).view(B, L, H, -1)
        source_embedding = self.key_projection(source_embedding).view(S, H, -1)
        value_embedding = self.value_projection(value_embedding).view(S, H, -1)

        out = self.cross_attention(target_embedding_proj, source_embedding, value_embedding)

        out = out.reshape(B, L, -1)
        
        out = torch.cat([target_embedding, out], dim=-1)

        return self.out_projection(out)

    def cross_attention(self, target_embedding, source_embedding, value_embedding):
        B, L, H, E = target_embedding.shape

        scale = 1. / sqrt(E)

        scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        cross_attention_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)

        return cross_attention_embedding


def get_uncertainty_per_token(logits):
    probabilities = F.softmax(logits/0.02, dim=-1)

    small_value = 1e-10
    probabilities = torch.clamp(probabilities, min=small_value)

    entropy = -torch.sum(probabilities * torch.log(probabilities), dim=-1)

    return entropy

def get_saliency_map(input, output):
    gradient_all = []
    for b in range(output.shape[0]):  # batch
        inner_batch = []
        for t in range(output.shape[2]):  # timestep
            gradient = grad(output[b, :, t], input, retain_graph=True)[0]
            inner_batch.append(gradient[b, :, :, :])
        gradient_all.append(torch.cat(inner_batch, dim=0).unsqueeze(0))
    gradient = torch.cat(gradient_all, dim=0)

    saliency_map = torch.sum(gradient, dim=2)

    return saliency_map
