from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,Qwen2ForCausalLM,LlamaForCausalLM
import torch
import argparse
import os
import numpy as np
from functools import partial
from peft import PeftModel
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
# Ensure bitsandbytes is available for 8-bit quantization
# import bitsandbytes as bnb
from sklearn.metrics import roc_auc_score, log_loss, accuracy_score
from tqdm import tqdm
from torch.nn import BCEWithLogitsLoss
from transformers import DataCollatorWithPadding,DataCollatorForSeq2Seq,set_seed
from transformers import EarlyStoppingCallback, TrainerCallback
from datasets import concatenate_datasets
import json
from dataset import get_prm_dataset,PRMDataCollator
from modeling_custom_qwen import get_prm_model, get_joint_prm_model
import random
from datasets import disable_caching

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="/Qwen2.5-Math-7B-Instruct")
parser.add_argument("--data_path", type=str, default="data")
parser.add_argument("--per_device_train_batch_size", type=int, default=2)
parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
parser.add_argument("--total_batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--bias_expert_lr_multiplier", type=float, default=3.0, help="Bias expert learning rate multiplier")
parser.add_argument("--enable_joint_training", action='store_true', help="Enable joint training with bias expert")
parser.add_argument("--datasets", type=str, default='prm800k_v2')
parser.add_argument("--server", type=str, default='slurm')
parser.add_argument("--adapter_path", type=str, default=None)
parser.add_argument("--train_data_path", type=str, 
default="/phase2_train_new.json")
parser.add_argument("--test_data_path", type=str,
default="/data/phase2_test.new.json")
parser.add_argument("--custom_attention_flag", action='store_true')


bias_expert_path = "/qwen0.5B"

args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'

good_token = '+'
bad_token = '-'
step_tag = '\n\n\n\n\n' #ки
prompt_tag = '\n\n\n\n'
step_tag2 = '\n\n'

model_path = args.model_path

tokenizer = AutoTokenizer.from_pretrained(
    model_path, 
    add_eos_token=False, 
)

print(tokenizer.eos_token_id)

tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token
tokenizer.padding_side = "left"  # Allow batched inference



candidate_tokens = tokenizer.encode(f" {good_token} {bad_token}") # [488, 481]
print(candidate_tokens)
step_tag_id = tokenizer.encode(f" {step_tag}")[-1] # 76325
prompt_tag_id = tokenizer.encode(f" {prompt_tag}")[-1] #22701
print('step_tag_id:',tokenizer.encode(f" {step_tag}"))
print('prompt_tag_id:',tokenizer.encode(f" {prompt_tag}"))

USE_8bit = False
#Qwen2ForCausalLM

set_seed(42)


DATA_PATH = {
    "train": args.train_data_path,
    "test": args.test_data_path,
    
}
data_config = {"step_tag_id":step_tag_id, "candidate_tokens":candidate_tokens, "prompt_tag_id":prompt_tag_id}


disable_caching()

tokenized_datasets = get_prm_dataset(DATA_PATH=DATA_PATH, tokenizer=tokenizer, data_config=data_config,mode='train')



# Data collator for padding inputs dynamically
# data_collator = DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding='longest')
custom_attention_flag=args.custom_attention_flag 
data_collator = PRMDataCollator(tokenizer,data_config,custom_attention_flag)

BATCH_SIZE = args.total_batch_size
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // args.per_device_train_batch_size

device_map = "auto"
print(os.environ.get("WORLD_SIZE", 1))
world_size = int(os.environ.get("WORLD_SIZE", 1))
# world_size=8
ddp = world_size != 1
if ddp:
    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
    GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size

print(world_size)
print(ddp)


if args.enable_joint_training:
    model = get_joint_prm_model(
        prm_model_path=model_path,
        bias_expert_path=bias_expert_path,
        candidate_tokens=candidate_tokens,
        USE_8bit=USE_8bit,
        adapter_path=args.adapter_path,
        device_map=device_map
    )

    class JointTrainer(Trainer):
        def create_optimizer(self):
            if self.optimizer is None:
                prm_params = []
                bias_expert_params = []
                prm_lora_params = []
                bias_expert_lora_params = []
                
               
                for name, param in self.model.named_parameters():
                    if param.requires_grad:
                        is_lora = 'lora_' in name.lower()
                        
                        if 'bias_expert_model' in name:
                            if is_lora:
                                bias_expert_lora_params.append(param)
                            else:
                                bias_expert_params.append(param)
                        else:
                            if is_lora:
                                prm_lora_params.append(param)
                            else:
                                prm_params.append(param)

                all_prm_params = prm_params + prm_lora_params
                all_bias_expert_params = bias_expert_params + bias_expert_lora_params

                optimizer_grouped_parameters = [
                    {
                        "params": all_prm_params,
                        "lr": args.learning_rate,
                        "weight_decay": 0.01,
                    },
                    {
                        "params": all_bias_expert_params,
                        "lr": args.learning_rate * args.bias_expert_lr_multiplier,
                        "weight_decay": 0.01,
                    }
                ]
                
            
                optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
                optimizer_kwargs['lr'] = args.learning_rate  
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
                
            return self.optimizer
        
        def _save(self, output_dir=None, state_dict=None):
           
            output_dir = output_dir if output_dir is not None else self.args.output_dir
            os.makedirs(output_dir, exist_ok=True)
            

            if hasattr(self.model, 'save_pretrained'):
                self.model.save_pretrained(output_dir)
            else:
                super()._save(output_dir, state_dict)
                return
            

            if hasattr(self, 'tokenizer') and self.tokenizer is not None:
                self.tokenizer.save_pretrained(output_dir)
                print(f"Tokenizer已保存到: {output_dir}")
            
            
       
                super()._save(output_dir, state_dict)

               

        def training_step(self, model, inputs, num_items_in_batch=None):

            model.train()
            inputs = self._prepare_inputs(inputs)

            with self.compute_loss_context_manager():
                loss_outputs = self.compute_loss(model, inputs)
            
            if isinstance(loss_outputs, dict) and loss_outputs.get('separate_training', False):
               
                prm_loss = loss_outputs['prm_loss']
                bias_loss = loss_outputs['bias_loss']
                
         
                self.optimizer.zero_grad()
                
      
                prm_loss.backward(retain_graph=True)
                
        
                prm_grads = {}
                for name, param in model.named_parameters():
                    if param.grad is not None and 'bias_expert_model' not in name:
                        prm_grads[name] = param.grad.clone()
                    if 'bias_expert_model' in name and param.grad is not None:
                        param.grad.zero_() 
                
             
                bias_loss.backward(retain_graph=True)
                
               
                for name, param in model.named_parameters():
                    if name in prm_grads:
                        param.grad = prm_grads[name]
                
                
                self.optimizer.step()
               
                total_loss = prm_loss + bias_loss
                return total_loss.detach() / self.args.gradient_accumulation_steps
            else:
                
                return super().training_step(model, inputs)
    
    TrainerClass = JointTrainer
else:
    model = get_prm_model(model_path,USE_8bit,adapter_path=args.adapter_path,device_map=device_map)
    TrainerClass = Trainer



fp = f'bs_{args.total_batch_size}_lr_{args.learning_rate}_datasets_{args.datasets}'
if args.enable_joint_training:
    fp += f'_joint_bias{args.bias_expert_lr_multiplier}_lambda_r{0.1}_lambda_b{0.7}'
output_path = f'./trained_models/{fp}'

eval_results_path=f'./eval_result/{args.datasets}.txt'

# Training arguments
training_args = TrainingArguments(
    output_dir=output_path,
    learning_rate=args.learning_rate,
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    save_safetensors=False,  
    bf16=True,
    report_to="none",  
    dataloader_num_workers=32,
    deepspeed=None,
    ddp_find_unused_parameters=False,
)

# Define a custom metric function (e.g., accuracy for binary classification)
def compute_metrics(eval_pred,args,eval_results_path):
    # pass
    # print(eval_pred)
    print('bb')
    pre, labels = eval_pred
    auc = roc_auc_score(pre[1], pre[0])
    ll = log_loss(pre[1], pre[0])
    acc = accuracy_score(pre[1], pre[0] > 0.5)
    result ={
        'auc': auc, 
        'll': ll, 
        'acc': acc,
        'bs':args.total_batch_size,
        'lr':args.learning_rate,
        'datasets':args.datasets 
    } 
    print(result)
    
    with open(eval_results_path, 'a') as f:
        f.write(str(result)+'\n')
    return result

partial_compute_metrics = partial(compute_metrics,args=args,eval_results_path=eval_results_path)



def preprocess_logits_for_metrics(logits,labels):
    
    if isinstance(logits, dict) and 'joint_probs' in logits and 'gold' in logits:
        
        joint_probs = logits['joint_probs']
        gold = logits['gold']
        return joint_probs[:, 1], gold 
    else:

        if isinstance(logits, dict):
            logits = logits['logits']
            
        labels_index = torch.argwhere(torch.bitwise_or(labels == candidate_tokens[0], labels == candidate_tokens[1]))
        
        if labels_index.size(0) == 0:
            
            return torch.tensor([]), torch.tensor([])
        
        gold = torch.where(labels[labels_index[:, 0], labels_index[:, 1]] == candidate_tokens[1], 0, 1)
        labels_index[: , 1] = labels_index[: , 1] - 1
        
        logits = logits[labels_index[:, 0], labels_index[:, 1]][:, [candidate_tokens[1], candidate_tokens[0]]]
        prob = torch.softmax(logits, dim=-1)
        return prob[:, 1], gold
    
class StopTrainingCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        # stop on the first epoch
        control.should_training_stop = True


class JointLossLoggingCallback(TrainerCallback):
    
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and hasattr(kwargs.get('model'), 'prm_model'):
            pass
    
    def on_step_end(self, args, state, control, **kwargs):
        pass


set_seed(42)
# Initialize the Trainer
trainer = TrainerClass(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],  # Replace with a validation set if available
    data_collator=data_collator,
    tokenizer=tokenizer,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    compute_metrics=partial_compute_metrics,
    callbacks = [StopTrainingCallback(), JointLossLoggingCallback()],
)


set_seed(42)
trainer.train()
