from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,Qwen2ForCausalLM,LlamaForCausalLM
import torch
import torch.nn as nn
from peft import PeftModel
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training, PeftConfig


class JointPRMModel(nn.Module):
    
    def __init__(self, prm_model, bias_expert_model, candidate_tokens):
        super().__init__()
        self.prm_model = prm_model
        self.bias_expert_model = bias_expert_model
        self.candidate_tokens = candidate_tokens
        
 
    
    def save_pretrained(self, save_directory, **kwargs):

        import os
        import json
        
       
        os.makedirs(save_directory, exist_ok=True)
        
        
        prm_save_path = os.path.join(save_directory, "prm_model")
        self.prm_model.save_pretrained(prm_save_path)
        
      
        bias_save_path = os.path.join(save_directory, "bias_expert_model")
        self.bias_expert_model.save_pretrained(bias_save_path)
        
        config = {
            "candidate_tokens": self.candidate_tokens,
            "model_type": "JointPRMModel"
        }
        config_path = os.path.join(save_directory, "joint_config.json")
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
        
    
    @classmethod
    def from_pretrained(cls, model_directory, prm_base_model_path, bias_expert_base_model_path, USE_8bit=False, device_map='auto'):
        
        import os
        import json
        
        
        config_path = os.path.join(model_directory, "joint_config.json")
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        candidate_tokens = config["candidate_tokens"]
        
        
        prm_save_path = os.path.join(model_directory, "prm_model")
        prm_model = get_prm_model(prm_base_model_path, USE_8bit, adapter_path=prm_save_path, device_map=device_map)
        
       
        bias_save_path = os.path.join(model_directory, "bias_expert_model")
        bias_expert_model = AutoModelForCausalLM.from_pretrained(
            bias_expert_base_model_path,
            device_map=device_map,
            torch_dtype=torch.bfloat16,
        )
        

        bias_expert_model = PeftModel.from_pretrained(bias_expert_model, bias_save_path)
        
        
        joint_model = cls(prm_model, bias_expert_model, candidate_tokens)
        
        
        return joint_model

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        
        prm_outputs = self.prm_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )
        
       
        bias_outputs = self.bias_expert_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        if labels is not None:
            
            return self._compute_joint_loss(prm_outputs, bias_outputs, labels)
        else:
           
            return self._compute_joint_prediction(prm_outputs, bias_outputs)
    
    def _compute_joint_loss(self, prm_outputs, bias_outputs, labels):
        # labels: [batch_size, seq_len]
        labels_index = torch.argwhere(torch.bitwise_or(
            labels == self.candidate_tokens[0], 
            labels == self.candidate_tokens[1]
        )) # [num_valid_labels, 2]
        
        if labels_index.size(0) == 0:
            
            return {
                'loss': torch.tensor(0.0, device=prm_outputs.logits.device, requires_grad=True),
                'logits': prm_outputs.logits,
                'prm_loss': torch.tensor(0.0, device=prm_outputs.logits.device, requires_grad=True),
                'bias_loss': torch.tensor(0.0, device=prm_outputs.logits.device, requires_grad=True)
            }
        
        
        gold = torch.where(
            labels[labels_index[:, 0], labels_index[:, 1]] == self.candidate_tokens[1], 
            0, 1
        ) # [num_valid_labels]
        
        
        pred_index = labels_index.clone()
        pred_index[:, 1] = pred_index[:, 1] - 1
        
        
        #prm_outputs.logits: [batch_size, seq_len, vocab_size]
        #bias_outputs.logits: [batch_size, seq_len, vocab_size]
        prm_logits = prm_outputs.logits[pred_index[:, 0], pred_index[:, 1]][:, self.candidate_tokens] 
        bias_logits = bias_outputs.logits[pred_index[:, 0], pred_index[:, 1]][:, self.candidate_tokens] 
      
        
        
        prm_rewards=prm_logits.softmax(dim=-1)[:, 0]  
        bias_rewards=bias_logits.softmax(dim=-1)[:, 0] 
        
        
        joint_reward_prob=prm_rewards*torch.sigmoid(bias_rewards)
        

        eps = 1e-8 
        joint_reward_prob_clamped = torch.clamp(joint_reward_prob, eps, 1-eps)
        joint_logits = torch.stack([
            torch.log(1 - joint_reward_prob_clamped),  
            torch.log(joint_reward_prob_clamped)       
        ], dim=1)
        
        loss_fct = nn.CrossEntropyLoss()
        prm_gold_loss = loss_fct(prm_logits, gold)
        joint_loss = loss_fct(joint_logits, gold) 
        
        lambda_g=0.3
        

        prm_ce_loss = (1-lambda_g)*joint_loss+lambda_g*prm_gold_loss
        bias_ce_loss = joint_loss

        batch_indices = labels_index[:, 0]
        lengths = []
        for batch_idx in torch.unique(batch_indices):
           
            mask = (batch_indices == batch_idx)
            sample_positions = labels_index[mask, 1]
            length = torch.max(sample_positions).item() + 1
            lengths.extend([length] * mask.sum().item())
        
        lengths = torch.tensor(lengths, dtype=torch.float32, device=prm_logits.device)
        
      
        lambda_r = 0.1
        lambda_b = 0.7
        
        if len(prm_rewards) > 1:  
            prm_corr = self._compute_correlation(prm_rewards, lengths)
            prm_corr_loss = prm_corr ** 2 
            

            bias_corr = self._compute_correlation(bias_rewards, lengths)
            bias_corr_loss = bias_corr ** 2  
            
            prm_loss = prm_ce_loss + lambda_r * prm_corr_loss
         
            bias_loss = bias_ce_loss - lambda_b * bias_corr_loss
        else:
            
            prm_loss = prm_ce_loss
            bias_loss = bias_ce_loss
        

        return {
            'loss': joint_loss,  
            'logits': {
                'logits': prm_outputs.logits,  
                'joint_probs': joint_reward_prob,   
                'gold': gold                  
            },
            'prm_loss': prm_loss,
            'bias_loss': bias_loss,
            'separate_training': True 
        }
    
    def _compute_joint_prediction(self, prm_outputs, bias_outputs):
        prm_logits = prm_outputs.logits
        bias_logits = bias_outputs.logits
        

        prm_probs = torch.softmax(prm_logits[:, :, self.candidate_tokens], dim=-1)
        bias_probs = torch.softmax(bias_logits[:, :, self.candidate_tokens], dim=-1)

        log_prm_probs = torch.log(prm_probs + 1e-8)
        log_bias_probs = torch.log(bias_probs + 1e-8)
        joint_logits = log_prm_probs + log_bias_probs

        joint_logits_full = prm_outputs.logits.clone()
        joint_logits_full[:, :, self.candidate_tokens] = joint_logits
        
        return {
            'logits': joint_logits_full
        }

    @property
    def device(self):
        return next(self.prm_model.parameters()).device

    def _compute_correlation(self, x, y):


        x = x.float()
        y = y.float()
        

        x_mean = torch.mean(x)
        y_mean = torch.mean(y)
        

        covariance = torch.mean((x - x_mean) * (y - y_mean))
        

        x_std = torch.sqrt(torch.mean((x - x_mean) ** 2) + 1e-8)  # 加小常数避免除零
        y_std = torch.sqrt(torch.mean((y - y_mean) ** 2) + 1e-8)
        
        correlation = covariance / (x_std * y_std + 1e-8)
        
        return correlation


def get_prm_model(model_path,USE_8bit,adapter_path=None,device_map='auto'):
    
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        # load_in_8bit=USE_8bit,   # Enables 8-bit quantization
        device_map=device_map,   # Automatically assigns the model to available GPUs/CPUs
        # torch_dtype=torch.float16,  # Mixed precision for faster inference
        torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
    )

    # for name,param in model.named_parameters():
    #     print(name)
    print(model)

    # if USE_8bit is True:
    #     model = prepare_model_for_kbit_training(model)

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,  # LoRA for causal language modeling task
        r=8,  # Rank of LoRA
        lora_alpha=32,  # Alpha scaling factor for LoRA
        lora_dropout=0.1,  # Dropout rate for LoRA layers
        target_modules=["q_proj", "v_proj"],  # Apply LoRA to specific layers
    )
    if adapter_path is None:
        model = get_peft_model(model, lora_config)
    else:
        adapter_config = PeftConfig.from_pretrained(adapter_path)
        model = PeftModel.from_pretrained(model, adapter_path)

    # model.to('cuda:0')
    print(model.device)
    return model


def get_joint_prm_model(prm_model_path, bias_expert_path, candidate_tokens, USE_8bit=False, adapter_path=None, device_map='auto'):
    prm_model = get_prm_model(prm_model_path, USE_8bit, adapter_path, device_map)
    

    bias_expert_model = AutoModelForCausalLM.from_pretrained(
        bias_expert_path,
        device_map=device_map,
        torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
    )
    

    bias_lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=4,  
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj"],
    )
    bias_expert_model = get_peft_model(bias_expert_model, bias_lora_config)

    joint_model = JointPRMModel(prm_model, bias_expert_model, candidate_tokens)
    
    print(f"Joint model device: {joint_model.device}")
    return joint_model

def load_saved_joint_prm_model(saved_model_path, prm_base_model_path, bias_expert_base_model_path, USE_8bit=False, device_map='auto'):
    return JointPRMModel.from_pretrained(
        saved_model_path, 
        prm_base_model_path, 
        bias_expert_base_model_path, 
        USE_8bit=USE_8bit, 
        device_map=device_map
    )
