import torch
import torch.nn as nn
import numpy as np
from math import sqrt
from transformers import AutoModelForCausalLM, AutoTokenizer

# from utils.masking import TimeCausalMask, DiagMask
import math
from torch.nn.parameter import Parameter

import torch

INVALID_LABEL = -100


def expand_future_labels(
    labels: torch.Tensor, future_steps: int, invalid_value: int = INVALID_LABEL
):
    """
    Expand label sequence of length T to [B, T-f+1, f], where each position contains "future f steps" labels.
    Positions that cannot be aligned are filled with invalid_value (default -100).

    Args:
        labels: [B, T], generally Long, value range [0, V-1], invalid positions can be -100.
        future_steps: f >= 1
        invalid_value: Used to represent invalid label positions (not participating in gradient/loss)

    Returns:
        expanded_labels: [B, T-f+1, f] (Long)
    """
    if future_steps < 1:
        raise ValueError("future_steps must be >= 1")

    B, T = labels.shape
    if T < future_steps:
        raise ValueError(
            f"seq_len ({T}) must be >= future_steps ({future_steps})"
        )

    device = labels.device
    T_eff = T - future_steps + 1

    # Initialize to invalid_value
    expanded = torch.full(
        (B, T_eff, future_steps),
        fill_value=invalid_value,
        device=device,
        dtype=labels.dtype,
    )

    # Column-wise aligned copy (labels usually Long/no grad needed, in-place copy is fine)
    for i in range(future_steps):
        # Align to [B, T_eff]
        expanded[:, :, i] = labels[:, i : i + T_eff]

    return expanded


def get_future_logits(
    logits: torch.Tensor,
    expanded_labels: torch.Tensor,
    future_steps: int,
    invalid_value: int = INVALID_LABEL,
):
    """
    Collect logits corresponding to "future labels" at each step based on expanded_labels, output [B, T-f+1, f].
    - Ensure gradients can backprop from output to logits (no in-place aggregation, use cat/stack computation graph).
    - Automatically set invalid label positions (== invalid_value) to 0 and prevent gradient contribution.

    Args:
        logits: [B, T, V], model output (float), giving V-class logits for each time step T.
        expanded_labels: [B, T-f+1, f], generated by expand_future_labels (Long).
        future_steps: f
        invalid_value: Consistent with expand_future_labels usage (default -100)

    Returns:
        gathered_logits: [B, T-f+1, f] (float), invalid positions set to 0.
        valid_mask:      [B, T-f+1, f] (bool), True indicates valid labels.
    """
    if future_steps < 1:
        raise ValueError("future_steps must be >= 1")

    B, T, V = logits.shape
    T_eff = T - future_steps + 1
    if T_eff < 1:
        raise ValueError(
            f"seq_len ({T}) must be >= future_steps ({future_steps})"
        )

    # Align time dimension to [B, T-f+1, V]
    shifted_logits = logits[:, :T_eff, :]

    # mask: valid labels (not equal to invalid_value)
    valid_mask = expanded_labels != invalid_value  # [B, T_eff, f]

    outs = []
    for i in range(future_steps):
        indices = expanded_labels[
            :, :, i
        ]  # [B, T_eff] (Long / or need to convert to Long)
        # To prevent gather out of bounds, clamp invalid positions to valid range; will be masked to zero later anyway
        safe_idx = (
            indices.clamp(0, V - 1).long().unsqueeze(-1)
        )  # [B, T_eff, 1]
        # Gather from [B, T_eff, V] along last dimension to [B, T_eff, 1] -> squeeze to [B, T_eff]
        gathered_i = torch.gather(
            shifted_logits, dim=2, index=safe_idx
        ).squeeze(-1)
        outs.append(gathered_i.unsqueeze(-1))  # [B, T_eff, 1]

    # Concatenate to get [B, T_eff, f], preserving complete computation graph
    gathered_logits = torch.cat(outs, dim=2)

    # Set invalid positions to 0 (neither participate in loss nor backprop gradients)
    gathered_logits = gathered_logits.masked_fill(~valid_mask, 0.0)

    return gathered_logits


def Generate_TemporalMap(device, Batch_size, leng_Q, leng_S, sigma=2):
    Time_1 = (
        torch.arange(0, leng_S, 1)
        .unsqueeze(1)
        .expand([leng_S, leng_S])
        .to(device)
    )
    Time_2 = (
        torch.arange(0, leng_S, 1)
        .unsqueeze(0)
        .expand([leng_S, leng_S])
        .to(device)
    )
    TemporalMap = torch.exp(
        -((Time_1 - Time_2) * (Time_1 - Time_2)) / (2 * sigma * sigma)
    ) / (math.sqrt(2 * math.pi) * sigma)
    # If future_steps is greater than 10, truncate the long tail of TemporalMap
    # Because for sequences of length 10,

    mask = TimeCausalMask(leng_S, device, width=leng_S)
    TemporalMap.masked_fill_(mask.mask, -np.inf)
    TemporalMap = TemporalMap.unsqueeze(0).expand(Batch_size, leng_S, leng_S)
    TemporalMap = TemporalMap.unsqueeze(1).expand(
        Batch_size, leng_Q, leng_S, leng_S
    )

    return torch.softmax(TemporalMap, dim=-2)


def Generate_StartPoint(
    device, Batch_size, leng_Q, leng_S, mode="one_hot", decay=0.9
):
    """
    Generate initial distribution e for random walk, shape [B, L, S, 1]
    - mode="one_hot": Maintain original behavior, set to 1 at t=0
    - mode="time_decay": Time decay distribution, w_t = decay ** t, strictly normalized on S dimension
    """
    dtype = torch.float32
    if mode == "one_hot":
        e = torch.zeros([leng_S, 1], dtype=dtype, device=device)
        e[0, 0] = 0.8
        e[1, 0] = 0.2
    elif mode == "time_decay":
        # t=0 is the earliest time step; exponential/geometric decay to later time steps
        t = torch.arange(leng_S, device=device, dtype=dtype)
        # Ensure decay is in (0,1) range to avoid singular cases
        decay = float(decay)
        if not (0.0 < decay < 1.0):
            raise ValueError("decay must be in (0, 1).")
        w = decay**t  # [S]
        w = w / (w.sum())  # Normalize, numerically stable
        e = w.view(leng_S, 1)  # [S,1]
    else:
        raise ValueError("Invalid mode. Choose 'one_hot' or 'time_decay'.")

    # Expand to [B, L, S, 1]
    e = e.unsqueeze(0).expand(Batch_size, leng_S, 1)
    e = e.unsqueeze(1).expand(Batch_size, leng_Q, leng_S, 1)
    return e


def random_walk(prob, e, transition, pre_d):
    # Ensure all tensors use the same data type
    device = e.device
    dtype = e.dtype

    # Convert to consistent data type
    prob = torch.tensor(prob, device=device, dtype=dtype)
    transition = transition.to(dtype=dtype)
    pre_d = pre_d.to(dtype=dtype)
    e = e.to(dtype=dtype)

    d = prob * torch.matmul(transition, pre_d) + (1 - prob) * e
    return d


def combine_time_spatial_map(
    model, labels, TimeMap, conbined_type, hidden_states=None
):
    # Get future_steps embeddings [b, valid_len, future_steps, D]
    # Get embeddings corresponding to labels
    with torch.no_grad():
        try:
            if hasattr(model, "module"):
                # Distributed training case
                embeddings = model.module.model.model.embed_tokens
            else:
                # Single GPU training case
                embeddings = model.model.model.embed_tokens
        except AttributeError:
            embeddings = model.model.embed_tokens
        # Create a copy of labels to avoid modifying original data
        valid_labels = labels.clone()

        # Replace -100 with valid token ID (usually 0, i.e., PAD token)
        mask = valid_labels == -100
        if mask.any():
            valid_labels[mask] = 0

        # Get embeddings using valid labels
        future_steps_embed = embeddings(
            valid_labels
        )  # [b, vaild_len, future_steps, D]

    # del embeddings
    norm = torch.norm(future_steps_embed, p=2, dim=-1, keepdim=True)
    # Normalize embedding vectors
    normalized_embed = future_steps_embed / (
        norm + 1e-12
    )  # Add small value to avoid division by zero

    # Use batch matrix multiplication to compute cosine similarity
    # Dot product of normalized vectors is cosine similarity
    cosine_similarity = torch.matmul(
        normalized_embed, normalized_embed.transpose(-1, -2)
    )  # [b, vaild_len, future_steps]
    content_similarity = torch.softmax(cosine_similarity, dim=-2)
    TimeMap = TimeMap.to(
        content_similarity.dtype
    )  # Ensure both matrices have consistent data types
    if conbined_type == "addition":
        combined_map = (
            TimeMap + content_similarity
        ) / 2  # Calculate addition of temporal graph and content similarity
    elif conbined_type == "time_left_multiplication":
        combined_map = torch.matmul(
            TimeMap, content_similarity
        )  # Calculate product of temporal graph and content similarity
    elif conbined_type == "content_left_multiplication":
        combined_map = torch.matmul(content_similarity, TimeMap)
    else:
        raise ValueError(
            "Invalid combination type. Choose from 'addition', 'time_left_multiplication', or 'spatial_left_multiplication'."
        )
    # del future_steps_embed, cosine_similarity, content_similarity
    return combined_map  # Generate spatio-temporal joint transition matrix


def random_walk_analytic(prob, e, transition):
    # Get shape of transition
    batch_size, leng_Q, leng_S, _ = transition.shape

    # Create identity matrix with same shape as transition
    identity = torch.eye(leng_S, device=transition.device)
    identity = (
        identity.unsqueeze(0)
        .unsqueeze(0)
        .expand(batch_size, leng_Q, leng_S, leng_S)
    )

    # Calculate (I - prob*A)^(-1)
    rev = torch.linalg.inv(identity - prob * transition)

    # Calculate (1-prob) * rev * e
    d = (1 - prob) * torch.matmul(rev, e)

    return d


def Generate_Decay_Prob(
    model, labels, _iter, prob, combined_type, type="iter", hidden_states=None
):
    # extended_labels
    batch_size, seq_len, future_steps = labels.shape
    e = Generate_StartPoint(
        device=labels.device,
        Batch_size=batch_size,
        leng_Q=seq_len,
        leng_S=future_steps,
    )  # [b, L, S, 1]
    d = Generate_StartPoint(
        device=labels.device,
        Batch_size=batch_size,
        leng_Q=seq_len,
        leng_S=future_steps,
    )
    TimeMap = Generate_TemporalMap(
        device=labels.device,
        Batch_size=batch_size,
        leng_Q=seq_len,
        leng_S=future_steps,
    )
    A = combine_time_spatial_map(
        model,
        labels,
        TimeMap,
        conbined_type=combined_type,
        hidden_states=hidden_states,
    )  # Calculate spatio-temporal joint transition matrix
    # Calculate how many steps of random walk
    if type == "iter":
        for i in range(_iter):
            d = random_walk(prob=prob, e=e, transition=A, pre_d=d)
    else:
        d = random_walk_analytic(prob=prob, e=e, transition=A)

    d = d.squeeze(-1)

    return d


class TimeCausalMask:
    def __init__(self, L, device="cpu", width=9):
        mask_shape = [L, L]
        kernel = 2
        with torch.no_grad():
            self._mask = torch.ones(mask_shape, dtype=torch.int).to(device)
            if width != 1:
                self._mask = (
                    self._mask
                    - torch.triu(self._mask, diagonal=-kernel)
                    + torch.triu(self._mask, diagonal=kernel + 1)
                )
            else:
                self._mask = (
                    self._mask
                    - torch.triu(self._mask)
                    + torch.triu(self._mask, diagonal=1)
                )

    @property
    def mask(self):
        return self._mask.type(torch.bool)


if __name__ == "__main__":
    model_path = "/path/to/model"
    torch.cuda.manual_seed(0)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path, device_map="auto", torch_dtype=torch.float16
    )
    batch_size = 2
    seq_len = 10
    vocab_size = 10
    future_steps = 5
    # labels = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
    labels = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
    expand_labels = expand_future_labels(labels, future_steps)
    print(
        expand_labels.shape
    )  # [batch_size, seq_len - future_steps + 1, future_steps]
    d = Generate_Decay_Prob(
        model, expand_labels, type="acc", prob=0.6, combined_type="addition"
    )
