import torch
import torch.nn as nn
from transformers import LlamaPreTrainedModel, LlamaModel
from transformers.modeling_outputs import ModelOutput
from typing import Any, Dict, List, Optional, Union, Tuple
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
import numpy as np

from base_trainer import RewardTrainer
from transformers import Trainer, TrainingArguments
from transformers.trainer_utils import PredictionOutput
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForSequenceClassification

from dataclasses import dataclass, field


from accelerate.utils import (
    DistributedType,
)

def get_last_token_hidden_state(hidden_states, attention_mask):
    # hidden_states: [B,T,H]; attention_mask: [B,T]
    lengths = attention_mask.long().sum(dim=1) - 1
    lengths = lengths.clamp(min=0)
    batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
    
    return hidden_states[batch_idx, lengths]

def post_process_RBv1(rewards_j, rewards_k, metric_key_prefix):
    acc = (rewards_j > rewards_k).astype(np.float16).mean()
    reward_diff = (rewards_j - rewards_k).astype(np.float16).mean()
    metrics = {
        f"{metric_key_prefix}_RBv1_accuracy": acc.item(),
        f"{metric_key_prefix}_RBv1_avg_rewards_chosen": rewards_j.mean().item(),
        f"{metric_key_prefix}_RBv1_avg_rewards_rejected": rewards_k.mean().item(),
        f"{metric_key_prefix}_RBv1_avg_margin": reward_diff.item(),
    }
    return metrics

def post_process_RMBench(
    expanded_domains: List[str], 
    domains_to_analyze: List[str], 
    scores: np.ndarray,
    metric_key_prefix: str
) -> Dict[str, float]:

    
    def compute_h_n_e_accuracy(scores_subset: np.ndarray) -> Dict[str, float]:
        MATRIX_SIZE = 3
        num_rows = scores_subset.shape[0]

        assert num_rows % MATRIX_SIZE == 0
        
        num_samples = num_rows // MATRIX_SIZE
        scores_reshaped = scores_subset.reshape(num_samples, MATRIX_SIZE, 2)
        chosen_scores = scores_reshaped[:, :, 0]
        rejected_scores = scores_reshaped[:, :, 1]
        
        chosen_exp = np.expand_dims(chosen_scores, axis=2)
        rejected_exp = np.expand_dims(rejected_scores, axis=1)
        
        victory_matrices = chosen_exp > rejected_exp
        acc_matrix = victory_matrices.astype(float).mean(axis=0)

        upper_right_count = MATRIX_SIZE * (MATRIX_SIZE - 1) / 2
        hard_acc = np.sum(np.triu(acc_matrix, 1)) / upper_right_count
        
        normal_acc = np.mean(np.diag(acc_matrix))
        
        lower_left_count = MATRIX_SIZE * (MATRIX_SIZE - 1) / 2
        easy_acc = np.sum(np.tril(acc_matrix, -1)) / lower_left_count
        
        return {"hard_acc": hard_acc, "normal_acc": normal_acc, "easy_acc": easy_acc}

    expanded_domains = np.array(expanded_domains)
    
    domain_results = {}
    for domain in domains_to_analyze:
        mask = np.char.startswith(expanded_domains, domain)
        domain_results[domain] = compute_h_n_e_accuracy(scores[mask])

    domain_avg_results = {
        domain: np.mean(list(metrics.values()))
        for domain, metrics in domain_results.items()
        if not np.isnan(list(metrics.values())[0])
    }
    
    all_hard = [m['hard_acc'] for m in domain_results.values() if not np.isnan(m['hard_acc'])]
    all_normal = [m['normal_acc'] for m in domain_results.values() if not np.isnan(m['normal_acc'])]
    all_easy = [m['easy_acc'] for m in domain_results.values() if not np.isnan(m['easy_acc'])]
    
    domain_h_n_e_acc = {
        "hard_acc": np.mean(all_hard) if all_hard else np.nan,
        "normal_acc": np.mean(all_normal) if all_normal else np.nan,
        "easy_acc": np.mean(all_easy) if all_easy else np.nan,
    }

    all_avg = list(domain_avg_results.values())
    total_avg_acc = np.mean(all_avg) if all_avg else np.nan

    final_results = {}
    final_results.update(domain_avg_results)
    final_results.update(domain_h_n_e_acc)
    final_results["total_avg_acc"] = total_avg_acc
    
    metrics = {
        f"{metric_key_prefix}_RMBench_Chat": final_results['chat'],
        f"{metric_key_prefix}_RMBench_Math": final_results['math'],
        f"{metric_key_prefix}_RMBench_Code": final_results['code'],
        f"{metric_key_prefix}_RMBench_Safety": final_results['safety'],
        f"{metric_key_prefix}_RMBench_Hard": final_results['hard_acc'],
        f"{metric_key_prefix}_RMBench_Normal": final_results['normal_acc'],
        f"{metric_key_prefix}_RMBench_Easy": final_results['easy_acc'],
        f"{metric_key_prefix}_RMBench_total": final_results['total_avg_acc'],
    }

    return metrics

# 2. 定义皮尔逊相关系数的辅助函数
def pearson_correlation(x, y, eps=19):
    """计算皮尔逊相关系数 (在 batch 维度上)"""
    mean_x = torch.mean(x)
    mean_y = torch.mean(y)
    xm = x - mean_x
    ym = y - mean_y
    r_num = torch.sum(xm * ym)
    r_den = torch.sqrt(torch.sum(xm**2) * torch.sum(ym**2))
    return r_num / (r_den + eps)


class ALBMModel(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        
        self.model = LlamaModel(config)
        
        self.score = nn.Linear(config.hidden_size, 1, bias=False)
        self.length_head = nn.Linear(config.hidden_size, 1, bias=False)
        self.prompt_analyzer_head = nn.Linear(config.hidden_size, 1, bias=False)

        initializer_range = 0.02
        self.score.weight.data.normal_(mean=0.0, std=initializer_range)
        self.length_head.weight.data.normal_(mean=0.0, std=initializer_range)
        self.prompt_analyzer_head.weight.data.normal_(mean=0.0, std=initializer_range)

        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )


class ALBMRewardTrainer(RewardTrainer):
    def __init__(self, **kwargs):           
        super().__init__(**kwargs)

        self.lambda_r = 0.33
        self.lambda_dr = 0.33
        self.lambda_il = 0.33
        self.lambda_el = 0.33

        # if torch.distributed.get_rank() == 0:
        #     import debugpy
        #     print("Rank 0: Waiting for debugger to attach on port 56788...", flush=True)
        #     debugpy.listen(56788)
        #     debugpy.wait_for_client()
        #     print("Debugger attached!", flush=True)

        # # 确保所有进程都等待调试器连接成功
        # torch.distributed.barrier()

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[list[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
    # 对于RewardBench-V2，可以在这里重写
        output = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)

        return output

    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str = "Evaluating",
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[list[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        output = super().evaluation_loop(
            dataloader,
            description,
            prediction_loss_only,
            ignore_keys,
            metric_key_prefix,
        )

        rewards = output.predictions
        rewards_j = output.predictions[:,0]  # 已经是gather好的
        rewards_k = output.predictions[:,1]  
        
        if self.accelerator.is_main_process:
            _dataset_name = list(set(self.eval_dataset['dataset']))[0]
            if _dataset_name == 'RBv1':
                metrics = post_process_RBv1(rewards_j, rewards_k, metric_key_prefix)
            elif _dataset_name == 'RMBench':
                domains = ["chat", "math", "code", "safety"]
                metrics = post_process_RMBench(self.eval_dataset['domain'], domains, rewards, metric_key_prefix)
            elif _dataset_name == 'RBv2':
                metrics = post_process_RBv1(rewards_j, rewards_k, metric_key_prefix)
        else:
            metrics = {}

        output.metrics.update(metrics)

        return output

    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], output_hidden_states=True)

        last_hidden_states = outputs.last_hidden_state
        attention_masks = inputs["attention_mask"]

        last_hidden_state_chosen = last_hidden_states[::2]
        attention_mask_chosen = attention_masks[::2]
        output_chosen = get_last_token_hidden_state(last_hidden_state_chosen, attention_mask_chosen)
        
        last_hidden_state_rejected = last_hidden_states[1::2]
        attention_mask_rejected = attention_masks[1::2]
        output_rejected = get_last_token_hidden_state(last_hidden_state_rejected, attention_mask_rejected)

        quality_reward_chosen = model.score(output_chosen)
        quality_reward_rejected = model.score(output_rejected)
        
        length_reward_chosen = model.length_head(output_chosen)
        length_reward_rejected = model.length_head(output_rejected)
        
        prompt_weight = model.prompt_analyzer_head(output_chosen)
        
        # L_R (BT损失)
        final_reward_chosen = quality_reward_chosen + prompt_weight * length_reward_chosen
        final_reward_rejected = quality_reward_rejected + prompt_weight * length_reward_rejected
        loss_r = -torch.nn.functional.logsigmoid(final_reward_chosen - final_reward_rejected).mean()

        # L_DR (解耦奖励损失)
        loss_dr_quality = -torch.nn.functional.logsigmoid(quality_reward_chosen - quality_reward_rejected).mean()
        loss_dr_length = -torch.nn.functional.logsigmoid(length_reward_chosen - length_reward_rejected).mean()
        loss_dr = loss_dr_quality + loss_dr_length

        # L_IL (隐式正交损失)
        W_q = model.score.weight
        W_l = model.length_head.weight
        loss_il = torch.abs(torch.mm(W_q, W_l.t())).mean()
        
        # L_EL (显式长度损失)
        # length_chosen = (attention_mask_chosen.sum(dim=1) - 1).to(model.device)
        # length_rejected = (attention_mask_rejected.sum(dim=1) - 1).to(model.device)
        
        # batch_quality_rewards = torch.cat([quality_reward_chosen, quality_reward_rejected], dim=0).squeeze()
        # batch_length_rewards = torch.cat([length_reward_chosen, length_reward_rejected], dim=0).squeeze()
        # batch_lengths = torch.cat([length_chosen, length_rejected], dim=0).squeeze().to(batch_length_rewards.dtype)

        # corr_quality_len = pearson_correlation(batch_quality_rewards, batch_lengths)
        # corr_length_len = pearson_correlation(batch_length_rewards, batch_lengths)
        # loss_el = torch.abs(corr_quality_len) - corr_length_len

        final_loss = self.lambda_r * loss_r + self.lambda_dr * loss_dr + self.lambda_il * loss_il 
        # + 0.0 * loss_el 
        
        # bad loss, easily lead to NaN gradient
        # + self.lambda_el * loss_el 

        
        
        if return_outputs:
            return final_loss, {"rewards_j": quality_reward_chosen, "rewards_k": quality_reward_rejected}
        
        return final_loss