import shelve
import sys
from pathlib import Path
import hashlib
import datetime

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from data_provider.describe_generator import describe
from transformers import AutoModelForCausalLM, AutoTokenizer
from layers.SelfAttention_Family import AttentionLayer, FullAttention


class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        ).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return self.pe[:, : x.size(1)]


class TokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__ >= "1.5.0" else 2
        self.tokenConv = nn.Conv1d(
            in_channels=c_in,
            out_channels=d_model,
            kernel_size=3,
            padding=padding,
            padding_mode="circular",
            bias=False,
        )
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_in", nonlinearity="leaky_relu"
                )

    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x


class FixedEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(FixedEmbedding, self).__init__()

        w = torch.zeros(c_in, d_model).float()
        w.require_grad = False

        position = torch.arange(0, c_in).float().unsqueeze(1)
        div_term = (
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        ).exp()

        w[:, 0::2] = torch.sin(position * div_term)
        w[:, 1::2] = torch.cos(position * div_term)

        self.emb = nn.Embedding(c_in, d_model)
        self.emb.weight = nn.Parameter(w, requires_grad=False)

    def forward(self, x):
        return self.emb(x).detach()


class TemporalEmbedding(nn.Module):
    def __init__(self, d_model, embed_type="fixed", freq="h"):
        super(TemporalEmbedding, self).__init__()

        minute_size = 4
        hour_size = 24
        weekday_size = 7
        day_size = 32
        month_size = 13

        Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
        if freq == "t":
            self.minute_embed = Embed(minute_size, d_model)
        self.hour_embed = Embed(hour_size, d_model)
        self.weekday_embed = Embed(weekday_size, d_model)
        self.day_embed = Embed(day_size, d_model)
        self.month_embed = Embed(month_size, d_model)

    def forward(self, x):
        x = x.long()
        minute_x = (
            self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0
        )
        hour_x = self.hour_embed(x[:, :, 3])
        weekday_x = self.weekday_embed(x[:, :, 2])
        day_x = self.day_embed(x[:, :, 1])
        month_x = self.month_embed(x[:, :, 0])

        return hour_x + weekday_x + day_x + month_x + minute_x


class TimeFeatureEmbedding(nn.Module):
    def __init__(self, d_model, embed_type="timeF", freq="h"):
        super(TimeFeatureEmbedding, self).__init__()

        freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
        d_inp = freq_map[freq]
        self.embed = nn.Linear(d_inp, d_model, bias=False)

    def forward(self, x):
        return self.embed(x)


class DataEmbedding(nn.Module):
    def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
        super(DataEmbedding, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        self.temporal_embedding = (
            TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
            if embed_type != "timeF"
            else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
        )
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):
        if x_mark is None:
            x = self.value_embedding(x) + self.position_embedding(x)
        else:
            x = (
                self.value_embedding(x)
                + self.temporal_embedding(x_mark)
                + self.position_embedding(x)
            )
        return self.dropout(x)


class DataEmbedding_inverted(nn.Module):
    def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
        super(DataEmbedding_inverted, self).__init__()
        self.value_embedding = nn.Linear(c_in, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):
        x = x.permute(0, 2, 1)
        # x: [Batch Variate Time]
        if x_mark is None:
            x = self.value_embedding(x)
        else:
            x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
        # x: [Batch Variate d_model]
        return self.dropout(x)


class DataEmbedding_wo_pos(nn.Module):
    def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
        super(DataEmbedding_wo_pos, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        self.temporal_embedding = (
            TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
            if embed_type != "timeF"
            else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
        )
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):
        if x_mark is None:
            x = self.value_embedding(x)
        else:
            x = self.value_embedding(x) + self.temporal_embedding(x_mark)
        return self.dropout(x)


class NumEmbedding(nn.Module):
    def __init__(self, d_model, patch_len, stride, padding, dropout):
        super(NumEmbedding, self).__init__()
        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch_layer = nn.ReplicationPad1d((0, padding))

        # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)

        # Positional embedding
        self.position_embedding = PositionalEmbedding(d_model)

        # Residual dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # do patching
        n_vars = x.shape[1]
        x = self.padding_patch_layer(x)  # (B, N, L + padding)
        x = x.unfold(
            dimension=-1, size=self.patch_len, step=self.stride
        )  # shape now: (B, N, num_patches, patch_len)
        # num_patches = ((L + padding - patch_len) // stride) + 1
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
        # shape: (B*N, num_patches, patch_len)
        # Input encoding
        # value embedding: (B*N, num_patches, d_model)
        # position embedding: (B*N, num_patches, d_model)
        x = self.value_embedding(x) + self.position_embedding(x)
        return self.dropout(x), n_vars


class TextEmbedding(nn.Module):
    """
    Modified TextEmbedding module to support LoRA fine-tuning.
    """

    def __init__(self, d_model, data_name, seq_len):
        super(TextEmbedding, self).__init__()
        # LLM
        self.data_name = data_name
        cache_name = data_name + "_" + str(seq_len)
        model_name = "meta-llama/Llama-3.2-1B-Instruct"
        self.llm = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="cuda",
        )
        self.llm.config.use_cache = False

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"

        cache_dir = f"D:\\db_inst\\{cache_name}"
        Path(cache_dir).mkdir(parents=True, exist_ok=True)
        cache_file = os.path.join(cache_dir, "embedding_cache.db")
        self.cache = shelve.open(cache_file, flag="c")
        print(f"Using cache file at: {os.path.abspath(cache_file)}")

    def forward(self, x):
        desps = []
        for i in range(x.shape[0]):
            x_enc_batch = x[i]
            patch_desp = []
            for j in range(x.shape[1]):
                x_enc_batch_seq = x_enc_batch[j].flatten()
                des = describe(
                    np.array(x_enc_batch_seq.cpu()),
                    self.data_name,
                    self.tokenizer,
                )
                patch_desp.append(des)

            # --- Faster Caching Logic with Shelve ---
            final_pooled_tensors = [None] * len(patch_desp)
            uncached_desps = []
            uncached_indices = []

            for idx, des in enumerate(patch_desp):
                key = hashlib.sha256(des.encode()).hexdigest()

                if key in self.cache:
                    try:
                        cached_tensor = self.cache[key]
                        if cached_tensor.shape[0] != 2048:
                            print("resave")
                            uncached_desps.append(des)
                            uncached_indices.append(idx)
                        final_pooled_tensors[idx] = cached_tensor.to("cuda")
                    except Exception as e:
                        print(
                            f"Warning: Could not load cache for key {key}. Regenerating. Error: {e}"
                        )
                        uncached_desps.append(des)
                        uncached_indices.append(idx)
                else:
                    uncached_desps.append(des)
                    uncached_indices.append(idx)

            if uncached_desps:
                inputs = self.tokenizer(
                    uncached_desps,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                ).to("cuda")

                with torch.no_grad():
                    outputs = self.llm(**inputs, output_hidden_states=True)

                last_hidden_state = outputs.hidden_states[-1]
                attn_mask = inputs["attention_mask"].unsqueeze(-1)
                pooled_batch = (last_hidden_state * attn_mask).sum(1) / attn_mask.sum(
                    1
                ).clamp(min=1)

                for k, original_index in enumerate(uncached_indices):
                    pooled_tensor = pooled_batch[k]
                    des_to_cache = uncached_desps[k]

                    key = hashlib.sha256(des_to_cache.encode()).hexdigest()
                    self.cache[key] = pooled_tensor.cpu()

                    final_pooled_tensors[original_index] = pooled_tensor

            pooled = torch.stack(final_pooled_tensors)
            desps.append(pooled)

        text_embeddings = torch.stack(desps)
        return text_embeddings


class GatingNetwork(nn.Module):
    """
    A simple network to generate expert weights.
    It takes a summary of the input and outputs a weight for each modality.
    """

    def __init__(self, input_dim, num_experts=2, hidden_dim_ratio=0.5):
        super().__init__()
        hidden_dim = int(input_dim * hidden_dim_ratio)
        self.gate = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_experts),
            nn.Softmax(dim=-1),
        )

    def forward(self, x):
        # x is the summary tensor for the patch
        return self.gate(x)

class TextNumEmbedding(nn.Module):
    """
    A patch embedding module that uses a gating network (MoE) to dynamically
    weight and fuse numerical and textual embeddings.
    """

    def __init__(
        self, d_model, patch_len, stride, padding, dropout, data_name, seq_len
    ):
        super(TextNumEmbedding, self).__init__()
        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch_layer = nn.ReplicationPad1d((0, padding))

        # --- Numerical Expert Path ---
        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)

        # --- Textual Expert Path ---
        llm_dim = 2048  # This is the output dim of your TextEmbedding model
        self.text_embedding = TextEmbedding(d_model, data_name, seq_len)
        self.text_project = nn.Linear(llm_dim, d_model)

        # --- Gating Network (The "Mixture of Experts" part) ---
        # The gate will look at the raw patch vector to decide on weights
        self.gating_network = GatingNetwork(input_dim=patch_len, num_experts=2)

        # --- Common Components ---
        self.position_embedding = PositionalEmbedding(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 1. Perform Patching
        n_vars = x.shape[1]
        x_patched = self.padding_patch_layer(x)
        x_patched = x_patched.unfold(
            dimension=-1, size=self.patch_len, step=self.stride
        )
        x_patched = torch.reshape(
            x_patched, (x.shape[0] * x.shape[1], x_patched.shape[2], x_patched.shape[3])
        )
        # x_patched shape: (B*N, num_patches, patch_len)

        # 2. Get Embeddings from Both "Experts"
        numerical_emb = self.value_embedding(x_patched)
        textual_emb = self.text_project(self.text_embedding(x_patched))
        # Both have shape: (B*N, num_patches, d_model)

        # 3. Use Gating Network to get dynamic weights
        # Create a summary of each patch for the gate to evaluate
        gate_input = x_patched.mean(dim=1)  # Average across patches -> (B*N, patch_len)
        weights = self.gating_network(gate_input)  # (B*N, 2)

        # Expand weights for broadcasting across the patch and model dimensions
        w_numerical = weights[:, 0].unsqueeze(-1).unsqueeze(-1)  # -> (B*N, 1, 1)
        w_textual = weights[:, 1].unsqueeze(-1).unsqueeze(-1)  # -> (B*N, 1, 1)

        # 4. Apply weights to fuse the embeddings
        # This is the weighted sum of the expert outputs
        fused_emb = (numerical_emb * w_numerical) + (textual_emb * w_textual)

        # 5. Add Positional Encoding
        final_emb = fused_emb + self.position_embedding(x_patched)

        return self.dropout(final_emb), n_vars

class TextOnlyEmbedding(nn.Module):
    """
    A patch embedding module that uses a gating network (MoE) to dynamically
    weight and fuse numerical and textual embeddings.
    """

    def __init__(
        self, d_model, patch_len, stride, padding, dropout, data_name, seq_len
    ):
        super(TextOnlyEmbedding, self).__init__()
        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch_layer = nn.ReplicationPad1d((0, padding))

        # --- Textual Expert Path ---
        llm_dim = 2048  # Assuming this is the output dim of your TextEmbedding model
        self.text_embedding = TextEmbedding(d_model, data_name, seq_len)
        self.text_project = nn.Linear(llm_dim, d_model)

        # --- Common Components ---
        self.position_embedding = PositionalEmbedding(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 1. Perform Patching
        n_vars = x.shape[1]
        x_patched = self.padding_patch_layer(x)
        x_patched = x_patched.unfold(
            dimension=-1, size=self.patch_len, step=self.stride
        )
        x_patched = torch.reshape(
            x_patched, (x.shape[0] * x.shape[1], x_patched.shape[2], x_patched.shape[3])
        )
        # x_patched shape: (B*N, num_patches, patch_len)

        # 2. Get Embeddings from Both "Experts"
        textual_emb = self.text_project(self.text_embedding(x_patched))

        # 5. Add Positional Encoding
        final_emb = textual_emb + self.position_embedding(x_patched)

        return self.dropout(final_emb), n_vars
