import torch
import torch.nn.functional as F
from trl.trainer.reward_trainer import RewardTrainer
import os
from datasets import load_from_disk, load_dataset, concatenate_datasets
from trl import RewardTrainer, RewardConfig, setup_chat_format
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification
from peft import PeftModel, LoraConfig, get_peft_model
import argparse

import os
os.environ["DEEPSPEED_USE_MPI"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"


class LastTokenRewardTrainer(RewardTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        out_c = model(inputs["input_ids_chosen"],
                      attention_mask=inputs["attention_mask_chosen"])
        out_r = model(inputs["input_ids_rejected"],
                      attention_mask=inputs["attention_mask_rejected"])

        def last_token_reward(logits, mask):
            last_idx = (mask.sum(dim=1) - 1).unsqueeze(1).unsqueeze(2)
            last_idx = last_idx.expand(-1, -1, logits.shape[-1])
            last_logits = logits.gather(1, last_idx).squeeze(1)
            # return last_logits[:, 1] - last_logits[:, 0]
            probabilities = F.softmax(last_logits, dim=-1) 
            return probabilities[:, 1]

        rewards_c = last_token_reward(out_c.logits, inputs["attention_mask_chosen"])
        rewards_r = last_token_reward(out_r.logits, inputs["attention_mask_rejected"])

        loss = -F.logsigmoid(rewards_c - rewards_r).mean()

        if return_outputs:
            return loss, {
                "rewards_chosen": rewards_c,
                "rewards_rejected": rewards_r,
            }
        return loss

def load_model_and_tokenizer(model_name, cache_dir, lora_checkpoint=None, is_lora=False):
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, trust_remote_code=True)
    
    print("Loading model...")
    model = AutoModel.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        cache_dir=cache_dir,
    )

    if is_lora:
        print("Loading LoRA checkpoint...")
        model = PeftModel.from_pretrained(
            model,
            lora_checkpoint,
            is_trainable=True,    # keep adapter params trainable
        )

        # sanity-check
        total = sum(p.numel() for p in model.parameters())
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"> Loaded LoRA adapters: {trainable}/{total} params trainable")
    
    # Align padding tokens
    model.config.pad_token_id = tokenizer.pad_token_id
    
    # Setup chat format if needed
    if tokenizer.chat_template is None:
        model, tokenizer = setup_chat_format(model, tokenizer)
        
    return model, tokenizer


def get_lora_config():
    return LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=["q_proj","k_proj","v_proj","o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )


def load_and_prepare_dataset(dataset_path):
    print("Loading dataset...")
    dataset = load_from_disk(dataset_path)
    # dataset2 = load_from_disk("./curriculum_learning/split_2_CL0_samestep_aug/")
    # dataset = concatenate_datasets([dataset1, dataset2])
    # print(f"Dataset loaded with {len(dataset)} examples by combinng {len(dataset1)} and {len(dataset2)}")

    dataset = dataset.shuffle(seed=42)
    print("Dataset shuffled")

    return dataset


def get_training_config(output_dir):
    return RewardConfig(
        output_dir=output_dir,
        logging_steps=10,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        fp16=True,
        save_total_limit=40,
        num_train_epochs=4,
        save_steps=500,
        resume_from_checkpoint=True,
        ddp_backend=None,
    )


def train_model(model, tokenizer, train_dataset, training_args, lora_config, output_dir):
    print("Initializing trainer...")
    trainer = LastTokenRewardTrainer(
        model=model,
        args=training_args,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        peft_config=lora_config
    )
    
    print("Starting training...")
    trainer.train()

    model.save_pretrained(output_dir)


def main():
    parser = argparse.ArgumentParser(description="Train a preference reward model")
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-Math-PRM-7B",
                        help="Name of the base model to train")
    parser.add_argument("--cache_dir", type=str, default="/cmlscratch/agrawal5/cache",
                        help="Directory for model cache")
    parser.add_argument("--dataset_path", type=str, default="./curriculum_learning/split_1_CL1_dynamic/",
                        help="Path to the preference dataset")
    parser.add_argument("--output_dir", type=str, default="curriculum_learning/Qwen2.5-Math-PRM-7B-pref_0.5_to_1_CL0_wo_0rating",
                        help="Directory to save the trained model")
    parser.add_argument("--lora_checkpoint", type=str, default=None,
                        help="Path to the LoRA checkpoint")
    parser.add_argument("--is_lora", action="store_true",
                        help="Whether to use LoRA")
    args = parser.parse_args()

    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(args.model_name, args.cache_dir, args.lora_checkpoint, args.is_lora)

    # Get LoRA configuration
    lora_config = get_lora_config()
    if args.is_lora:
        lora_config = None

    # Load and prepare dataset
    train_dataset = load_and_prepare_dataset(args.dataset_path)

    # Get training configuration
    training_args = get_training_config(args.output_dir)

    # Train the model
    train_model(model, tokenizer, train_dataset, training_args, lora_config, args.output_dir)


if __name__ == "__main__":
    main()
