import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaForCausalLM
from peft import PeftModelForCausalLM, LoraConfig, TaskType


def contrastive_loss(q_current, rewards, epsilon=1e-10):
    # Separate positive and negative samples based on rewards
    positive_mask = rewards == 1  # Mask for positive examples
    negative_mask = rewards == 0  # Mask for negative examples

    # Apply exponential to logits (q_current)
    positive_logits = q_current[positive_mask]  # Logits for positive examples
    negative_logits = q_current[negative_mask]  # Logits for negative examples

    # Exponential transformation
    exp_positive = torch.exp(positive_logits)  # exp(q_current) for positive examples
    exp_negative = torch.exp(negative_logits)  # exp(q_current) for negative examples

    # Sum of all negative exponentials
    negative_sum = torch.sum(exp_negative) + epsilon  # Adding epsilon for numerical stability

    # Contrastive loss
    contrastive_term = exp_positive / (exp_positive + negative_sum)  # Fraction for contrastive loss
    loss = -torch.mean(torch.log(contrastive_term + epsilon))  # Mean over all positive examples

    return loss


class QValueEncoder(nn.Module):
    def __init__(self, llama_model_path, lora_path=None, lora_rank=16, lora_alpha=32, tau=1.0, pooling_method="mean", eos_token_id=128009, is_trainable=False):
        super(QValueEncoder, self).__init__()

        self.tau = tau
        self.pooling_method = pooling_method
        self.eos_token_id = eos_token_id

        # Load the pre-trained LLaMA model (base model)
        self.llama = LlamaForCausalLM.from_pretrained(llama_model_path)
        # self.llama_hidden_size = self.llama.config.hidden_size

        # Apply LoRA configuration (will add LoRA layers to the model)
        if lora_path is None:
            lora_config = LoraConfig(
                r=lora_rank,
                lora_alpha=lora_alpha,
                lora_dropout=0.1,
                target_modules=["q_proj", "v_proj"],  # Target attention layers (e.g., query and value projections)
                task_type=TaskType.CAUSAL_LM  # Task type: causal language modeling
            )

            # Apply LoRA to the LLaMA model (will add the low-rank matrices)
            self.llama = PeftModelForCausalLM(self.llama, lora_config)
        else:
            self.llama = PeftModelForCausalLM.from_pretrained(self.llama, lora_path, is_trainable=is_trainable)
        self.llama.print_trainable_parameters()

        # # Projection layers for state and action embeddings
        # self.projection_state = nn.Linear(self.llama_hidden_size, hidden_size)
        # self.projection_action = nn.Linear(self.llama_hidden_size, hidden_size)

    def encode(self, token_ids):
        if self.pooling_method == "mean":
            attention_mask = torch.tensor(token_ids != self.eos_token_id, dtype=torch.long, device=token_ids.device)
            outputs = self.llama.model(input_ids=token_ids, attention_mask=attention_mask, output_hidden_states=True)
            # print(len(outputs.hidden_states))
            # print(type(outputs))

            last_hidden_state = outputs.hidden_states[-1]

            attention_mask_expanded = attention_mask.unsqueeze(-1)
            masked_embeddings = last_hidden_state * attention_mask_expanded

            valid_tokens_count = attention_mask_expanded.sum(dim=1, keepdim=True)
            valid_tokens_count = torch.maximum(valid_tokens_count, torch.ones_like(valid_tokens_count))

            avg_embeddings = masked_embeddings.sum(dim=1) / valid_tokens_count.squeeze(-1)  # Shape: [batch_size, dim]
            return F.normalize(avg_embeddings, p=2, dim=-1)
        elif self.pooling_method == "lasttoken":
            # attention_mask = torch.tensor(token_ids != self.eos_token_id, dtype=torch.long, device=token_ids.device)
            # outputs = self.llama.model(input_ids=token_ids, attention_mask=attention_mask, output_hidden_states=True)

            outputs = self.llama.model(input_ids=token_ids, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1]

            device = token_ids.device

            # Create a mask where eos_token_id is found
            eos_mask = (token_ids == self.eos_token_id)  # (batch_size, seq_len)

            # Convert eos_mask to int32 for argmax (necessary for CUDA)
            eos_mask_int = eos_mask.to(torch.long)  # (batch_size, seq_len)

            # Get the first occurrence of eos_token_id for each sequence in the batch
            first_eos_positions = eos_mask_int.argmax(dim=1)  # (batch_size,)

            # Handle the case when no EOS token is found in the sequence (argmax returns 0 when no EOS is found)
            eos_not_found = eos_mask_int.sum(dim=1) == 0  # (batch_size,) True if no EOS token found

            first_eos_positions[eos_not_found] = token_ids.size(1) - 1  # Default fallback: last token (seq_len - 1)

            # Create an index tensor for the batch (batch_size,)
            batch_indices = torch.arange(token_ids.size(0), device=device)  # (batch_size,)

            # Gather the embeddings at the first EOS token positions (or fallback position)
            eos_embeddings = last_hidden_state[batch_indices, first_eos_positions]  # (batch_size, dim)

            return F.normalize(eos_embeddings, p=2, dim=-1)
        else:
            raise NotImplementedError

    def get_state_embedding(self, states):
        return self.encode(states)

    def get_action_embedding(self, actions):
        return self.encode(actions)

    def forward(self, states=None, actions=None, next_states=None, candidate_actions=None, rewards=None, ratio=1.0, mode="loss"):
        if mode == "loss":
            # print("the size of states is {}".format(states.size()))
            # print("the size of actions is {}".format(actions.size()))
            # print("the size of next_states is {}".format(next_states.size()))
            # print("the size of candidate_actions is {}".format(candidate_actions.size()))
            # print("the size of rewards is {}".format(rewards.size()))

            # Compute Q-value for current state-action pair Q_{\theta}(s, a)
            state_embeds = self.get_state_embedding(states)
            action_embeds = self.get_action_embedding(actions)

            q_current = torch.sum(state_embeds * action_embeds, dim=-1) / 0.1  # B

            # Compute Q-values for candidate actions in the next state Q_{\theta}(s', a')
            next_state_embeds = self.get_state_embedding(next_states)

            # candidate_action_embeds = self.get_action_embedding(candidate_actions)
            candidate_actions_reshaped = candidate_actions.view(-1, candidate_actions.size(-1))
            candidate_action_embeds_reshaped = self.get_action_embedding(candidate_actions_reshaped)
            candidate_action_embeds = candidate_action_embeds_reshaped.view(
                candidate_actions.size(0), -1, candidate_action_embeds_reshaped.size(-1))  # (batch_size, K, dim)

            # Calculate Q-values for each candidate action in the next state
            q_next = torch.sum(next_state_embeds.unsqueeze(1) * candidate_action_embeds, dim=-1) / 0.1  # B x K

            # Apply log -> sum -> exp (for log-sum-exp)
            q_next = q_next / self.tau
            max_q_next = torch.max(q_next, dim=-1)[0]
            log_sum_exp = max_q_next + torch.log(
                torch.sum(torch.exp(q_next - max_q_next.unsqueeze(-1)), dim=-1))
            log_sum_exp = log_sum_exp * self.tau

            # The target for the MSE loss
            target = rewards + log_sum_exp

            # MSE loss between Q_{\theta}(s, a) and the target
            mse_loss = F.mse_loss(q_current, target)

            # Contrastive loss components (calculated inside the mean function)
            # cls_loss = -torch.mean(
            #     rewards * torch.log(torch.sigmoid(q_current) + 1e-10) +  # Positive pairs (reward = 1)
            #     (1 - rewards) * torch.log(1 - torch.sigmoid(q_current) + 1e-10)  # Negative pairs (reward = 0)
            # )
            cont_loss = contrastive_loss(q_current, rewards)

            loss = (1 - ratio) * mse_loss + ratio * cont_loss
            return loss
        elif mode == "action_embedding":
            return self.get_action_embedding(actions)
        elif mode == "state_embedding":
            return self.get_state_embedding(states)


def case1():
    # Training
    llama_model_path = ".../hf_models/Llama-3.2-1B-Instruct"  # Example LLaMA model from Hugging Face

    # Create the QValueEncoder model with LoRA
    model = QValueEncoder(llama_model_path, is_trainable=True)

    print("\n### Trainable Parameters in Training Mode:")
    total_trainable_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: {param.shape}")
            total_trainable_params += param.numel()
    print(f"Total Trainable Parameters: {total_trainable_params}")

    lora_save_path = "cache2/saved_lora_model"
    model.llama.save_pretrained(lora_save_path)
    print(f"Saved LoRA parameters to: {lora_save_path}")


def case2():
    # load from a checkpoint and continue training

    llama_model_path = ".../hf_models/Llama-3.2-1B-Instruct"  # Example LLaMA model from Hugging Face
    lora_path = "cache2/saved_lora_model"
    new_lora_path = "cache2/saved_lora_model2"
    model = QValueEncoder(llama_model_path, lora_path, is_trainable=True)
    print("\n### Trainable Parameters in Training Mode:")
    total_trainable_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: {param.shape}")
            total_trainable_params += param.numel()
    print(f"Total Trainable Parameters: {total_trainable_params}")

    model.llama.save_pretrained(new_lora_path)
    print(f"Saved LoRA parameters to: {new_lora_path}")


def case3():
    # inference

    llama_model_path = ".../hf_models/Llama-3.2-1B-Instruct"  # Example LLaMA model from Hugging Face
    lora_path = "cache2/saved_lora_model2"
    model = QValueEncoder(llama_model_path, lora_path, is_trainable=False)
    print("\n### Trainable Parameters in Training Mode:")
    total_trainable_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: {param.shape}")
            total_trainable_params += param.numel()
    print(f"Total Trainable Parameters: {total_trainable_params}")


if __name__ == "__main__":
    # case1()
    # case2()
    case3()
