from dotenv import load_dotenv
load_dotenv()

import argparse
import json
import os
import random
import time
from datetime import datetime
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from trl import DPOConfig, DPOTrainer

import sys
sys.path.append(".")
from examples.datasets import load_dataset_and_prompt_handlers
from src.dataset import MutationDataset
from src.utils.training import init_distributed_mode, setup_wandb, get_peft_config, get_ds_config, initialize_models
from src.utils.logger_utils import setup_logger

logger = setup_logger()


def parse_args():
    parser = argparse.ArgumentParser(description="Adversarial training between attacker and solver")
    
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--model_peft_checkpoint", type=str, help="Path to PEFT checkpoint to load")
    parser.add_argument("--splits_file", type=str, required=True, help="Path to the splits file.")
    parser.add_argument("--dataset", type=str, required=True, choices=["bigcodebench-complete", "kodcode-complete"])
    parser.add_argument("--method", type=str, required=True, choices=["ours", "codegen"])
    parser.add_argument("--input_column", type=str, required=True, help="Column to use for responses ('response' or 'solutions')")
    parser.add_argument("--iteration", type=int, required=True)
    parser.add_argument("--save_dir", type=str, required=True)
    parser.add_argument("--hf_repo_name", type=str, help="HuggingFace dataset to load")
    parser.add_argument("--max_model_len", type=int, default=1024, help="Maximum number of input tokens for the model.")
    parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum number of new generated tokens.")
    parser.add_argument("--consistency", action="store_true", help="Whether to add consistency score to reward")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling")
    parser.add_argument("--static_solver", type=str, default=None)
    parser.add_argument("--solver_is_nonstatic", action="store_true", help="Whether to use a static solver")
    parser.add_argument("--static_eval_size", type=int, default=10)
    
    # Training parameters
    parser.add_argument("--num_train_examples", type=int, required=True)
    parser.add_argument("--num_val_examples", type=int, required=True)
    parser.add_argument("--margin_threshold", type=float, required=True)
    parser.add_argument("--eval_ratio", type=float, default=0.2)
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument("--per_device_train_batch_size", type=int, default=4)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--eval_steps", type=int, default=50)
    parser.add_argument("--save_steps", type=int, default=None)
    parser.add_argument("--save_total_limit", type=int, default=3)
    parser.add_argument("--beta", type=float, default=0.1)
    parser.add_argument("--warmup_ratio", type=float, default=0.1)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--gradient_checkpointing", action="store_true")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--optim", type=str, default="adamw_torch")
    parser.add_argument("--loss_type", type=str, default="sigmoid")
    parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
    
    # LoRA parameters
    parser.add_argument("--lora_r", type=int, default=16)
    parser.add_argument("--lora_alpha", type=int, default=32)
    parser.add_argument("--lora_dropout", type=float, default=0.05)
    
    # Wandb config
    parser.add_argument("--wandb_project", type=str, default=None)
    parser.add_argument("--wandb_entity", type=str, default=None)
    
    # DeepSpeed config
    parser.add_argument("--deepspeed", action="store_true")
    parser.add_argument("--zero_stage", type=int, default=2)
    
    # Other necessary parameters
    parser.add_argument("--seed", type=int, default=42)

    return parser.parse_args()


def get_run_name(args):
    date_str = datetime.now().strftime("%Y-%m-%d-%H-%M")
    model_name = args.model.split("/")[-1]
    num_gpus = torch.cuda.device_count()
    global_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * num_gpus
    if args.method == "ours" and args.input_column == "response":
        role = f"att_iter{args.iteration}"
    elif args.method == "ours" and args.input_column == "solutions":
        role = f"sol_iter{args.iteration}"
    elif args.method == "codegen":
        role = "bon"
    model_name = f"{role}_{model_name}_ex{args.num_train_examples}_lr{args.learning_rate}_bs{global_batch_size}_ep{args.num_train_epochs}_marg{args.margin_threshold}_{date_str}"
    run_name = f"{args.hf_repo_name}_{model_name}"
    return run_name


def train_dpo(
    args,
    model,
    tokenizer,
    ds_config,
    train_dataset: Dataset,
    eval_dataset: Optional[Dataset],
    output_dir: str,
    callbacks=None,
    wandb_run=None,
    run_name=None
) -> str:
    """Train model using DPO on preferred vs non-preferred examples.
    Returns path to saved model."""
    logger.info(f"\nStarting training...")
    peft_config = get_peft_config(args)
    
    if args.static_solver:
        metric_for_best_model = f"eval_{args.static_solver}_score"
    elif args.solver_is_nonstatic:
        metric_for_best_model = f"eval_{args.model}_score"
    else:
        metric_for_best_model = "eval_loss"
    print(f"Metric for best model: {metric_for_best_model}")
    training_args = DPOConfig(
        # Core DPO parameters
        beta=args.beta,
        loss_type=args.loss_type,
        # Training parameters
        learning_rate=args.learning_rate,
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_length=args.max_tokens,
        max_prompt_length=args.max_model_len,
        # Optimizer settings
        optim=args.optim,
        lr_scheduler_type=args.lr_scheduler_type,
        warmup_ratio=args.warmup_ratio,
        weight_decay=args.weight_decay,
        max_grad_norm=1.0,
        # Mixed precision and performance
        fp16=args.fp16,
        gradient_checkpointing=args.gradient_checkpointing,
        gradient_checkpointing_kwargs={'use_reentrant': False},
        # Logging and evaluation
        logging_steps=1,
        eval_steps=args.eval_steps,
        save_strategy='steps',
        save_steps=args.save_steps if args.save_steps else args.eval_steps,
        eval_strategy="steps",
        save_total_limit=args.save_total_limit,
        metric_for_best_model=metric_for_best_model,
        greater_is_better=False,
        report_to="wandb" if wandb_run else None,
        do_eval=True,
        # Dataset handling
        remove_unused_columns=True,
        # Paths and config
        output_dir=output_dir,
        run_name=run_name,
        deepspeed=ds_config,
        # load_best_model_at_end=True
    )

    def compute_metrics(eval_preds):
        # Dummy logic here - you'll replace this with actual logic later
        metrics = {}
        # Optionally pull values from some shared location (e.g., global dict updated by callback)
        if hasattr(train_dpo, "latest_metrics"):
            metrics.update(train_dpo.latest_metrics)
        return metrics


    dpo_trainer = DPOTrainer(
        model=model,
        ref_model=None,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset, 
        # tokenizer=tokenizer, # for trl==0.14.0
        processing_class=tokenizer,# for trl==0.16.0
        peft_config=peft_config,
        callbacks=callbacks,
        compute_metrics=compute_metrics,
    )

    # Train the model
    start_time = time.time()
    train_result = dpo_trainer.train()
    end_time = time.time()
    train_time = end_time - start_time
    
    best_model_dir = os.path.join(output_dir, "best")
    dpo_trainer.save_model(best_model_dir)
    tokenizer.save_pretrained(best_model_dir)
    logger.info(f"Saved best model to {best_model_dir}")
    
    if wandb_run:
        wandb_run.log({
            'train/final_loss': train_result.training_loss,
            'train/total_steps': train_result.global_step,
            'train/total_time': train_time,
        })


def run_dpo():
    # Set LOCAL_RANK for DeepSpeed if not already set
    if "LOCAL_RANK" not in os.environ:
        os.environ["LOCAL_RANK"] = "0"  # Default to single GPU
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "29500"
    
    args = parse_args()
    print(args)
    init_distributed_mode()
    
    # Set random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    # Initialize deepspeed
    ds_config = None
    if args.deepspeed:
        import deepspeed
        deepspeed.init_distributed()
        ds_config = get_ds_config(args)
    
    # Initialize wandb
    run_name = get_run_name(args)
    logger.info(f"Run name: {run_name}")
    wandb_run = setup_wandb(args, run_name=run_name)
    
    # Set output dir
    output_dir = Path(args.save_dir) / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    if os.environ.get('LOCAL_RANK', '0') == '0':
        with open(output_dir / 'args.json', 'w') as f:
            json.dump(vars(args), f, indent=2)
    
    # Load model and tokenizer
    model_path = args.model_peft_checkpoint if args.model_peft_checkpoint else args.model
    model, tokenizer = initialize_models(args, model_path)
    
    # Load dataset
    if args.dataset == 'bigcodebench-complete':
        from examples.datasets.bigcodebench_complete import dataset_engine
    elif args.dataset == 'kodcode-complete':
        from examples.datasets.kodcode_complete import dataset_engine
    else:
        raise ValueError(f"Invalid dataset: {args.dataset}")
    trainset, _ = dataset_engine(splits_file=args.splits_file, return_val=False)

    hf_entity = os.getenv("HF_ENTITY", None)
    if not "dpo" in args.hf_repo_name:
        dpo_dataset_name = f"{hf_entity}/{args.hf_repo_name}_dpo_{args.num_train_examples}"
        print(f"Making DPO dataset {dpo_dataset_name}")
        mutation_dataset = MutationDataset(
            hf_repo_or_local_dir=f"{hf_entity}/{args.hf_repo_name}",
            tokenizer=tokenizer,
            trainset=trainset
        )
        # if args.input_column == "response":
        #     from examples.mutator.codegen import MUTATOR_SYSTEM_PROMPT
        #     system_prompt = MUTATOR_SYSTEM_PROMPT
        # elif args.input_column == "solutions":
        #     if args.method == "ours":
        #         from examples.solver.codegen import SOLVER_SYSTEM_PROMPT
        #         system_prompt = SOLVER_SYSTEM_PROMPT
        #     elif args.method == "codegen":
        #         from examples.solver.codegen import CODEGEN_SYSTEM_PROMPT
        #         system_prompt = CODEGEN_SYSTEM_PROMPT
        # print(f"Using system prompt: {system_prompt}")
        dataset = mutation_dataset.to(
            format="dpo",
            eval_ratio=args.eval_ratio,
            num_train_examples=args.num_train_examples,
            num_val_examples=args.num_val_examples,
            tokenize=False,
            apply_chat_template=False,
            margin_threshold=args.margin_threshold,
            input_column=args.input_column,
            hf_repo_or_local_dir=dpo_dataset_name,
            consistency=args.consistency,
            # system_prompt=system_prompt
        )
    else:
        dataset = load_dataset(f"{hf_entity}/{args.hf_repo_name}")
    train_dataset = dataset['train']
    val_dataset = dataset['test']
    logger.info(f"Loaded {len(train_dataset)} train and {len(val_dataset)} validation examples")

    if args.static_solver:
        from src.callback import StaticSolverEvalCallback
        from examples.verdict.codegen import verdict_engine

        dataset_engine, mutator_prompt_handler, solver_prompt_handler, codegen_prompt_handler, incorrect_codegen_prompt_handler = load_dataset_and_prompt_handlers(args.dataset)
        start_time = time.time()
        _, _, raw_dataset = dataset_engine(splits_file=args.splits_file, return_val=True)
        print(f"Loaded static solver eval dataset with {len(raw_dataset)} examples in {time.time() - start_time} seconds")
        verdict = verdict_engine(args.dataset)
        def filter_func(result):
            passed, info = verdict(problem=result['problem'], completion=result['mutation'])
            return result['success'] and not passed, info

        generation_kwargs = {
            "max_tokens": args.max_model_len,
            "max_new_tokens": args.max_tokens,
            "temperature": args.temperature
        }
        static_solver_callback = StaticSolverEvalCallback(
            dataset_name=args.dataset,
            attacker_model_name=args.model,
            eval_dataset=raw_dataset[:args.static_eval_size],
            mutator_prompt_handler=mutator_prompt_handler,
            solver_prompt_handler=solver_prompt_handler,
            solver_model_name=args.model,
            static_solver=args.static_solver, 
            solver_is_static=not args.solver_is_nonstatic,
            filter_func=filter_func,
            model=model,
            tokenizer=tokenizer,
            **generation_kwargs
        )
        callbacks = [static_solver_callback]
    else:
        callbacks = None
    
    # Train with DPO
    train_dpo(
        args=args,
        model=model,
        tokenizer=tokenizer,
        ds_config=ds_config,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        output_dir=output_dir,
        wandb_run=wandb_run,
        callbacks=callbacks,
        run_name=run_name
    )
    
    # # Push trained model to HuggingFace Hub
    # if os.environ.get('LOCAL_RANK', '0') == '0':
    #     model_name = f"{hf_entity}/{run_name}"
    #     model_name.replace("-", "_")
    #     model.push_to_hub(model_name)
    #     tokenizer.push_to_hub(model_name)
    #     logger.info(f"Model and tokenizer pushed to HuggingFace Hub: {model_name}")
    if wandb_run:
        wandb_run.finish()


if __name__ == "__main__":
    run_dpo()

