import os
import random
from pathlib import Path

import hydra
import numpy as np
import torch
import wandb
from datasets import load_dataset, load_from_disk, DownloadMode
from dotenv import load_dotenv
from omegaconf import DictConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig

from rewards import get_reward_function
from trainer import get_advantage_function
from trainer import get_completion_processing_function
from trainer.bon_grpo import BonGRPOTrainer
from trainer.bon_onpolicy_grpo import BonOnPolicyGRPOTrainer
from trainer.custom_grpo import CustomGRPOTrainer


def is_bfloat16_supported():
    """Check if bfloat16 is supported on the current device."""
    return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8


def fix_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def setup_distributed_training():
    """Setup distributed training environment."""
    # Get distributed training info
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # Set device
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)

    return rank, world_size, local_rank


def load_and_prepare_dataset(cfg: DictConfig, rank=0, world_size=1):
    """Load and prepare dataset based on configuration."""
    # Load dataset from appropriate source
    if cfg.dataset.source == "disk":
        dataset = load_from_disk(cfg.dataset.name)
    elif cfg.dataset.source == "hf":
        dataset = load_dataset(
            cfg.dataset.name, download_mode=DownloadMode.FORCE_REDOWNLOAD
        )
    else:
        raise ValueError(f"Unknown dataset source: {cfg.dataset.source}")

    # Rename columns according to configuration
    column_mapping = {
        cfg.dataset.prompt_column: "prompt",
        cfg.dataset.ground_truth_column: "ground_truth",
        cfg.dataset.get("test_column", "test_list"): "test_list",
    }

    train_data = dataset[cfg.dataset.train_split].rename_columns(column_mapping)
    if cfg.dataset.eval_split:
        eval_data = dataset[cfg.dataset.eval_split].rename_columns(column_mapping)
    else:
        eval_data = None

    return train_data, eval_data


def prepare_model(cfg: DictConfig, rank=0, world_size=1):
    """Initialize and prepare model with optional PEFT."""
    # Initialize model with specified dtype
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model.name,
        torch_dtype=getattr(torch, cfg.model.torch_dtype),
        trust_remote_code=True,  # Add trust_remote_code for custom models
    )

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        cfg.model.name,
        trust_remote_code=True,  # Add trust_remote_code for custom models
    )

    # Apply PEFT if configured, use it carefully
    if cfg.model.use_peft:
        if cfg.model.peft.method == "lora":
            # Prepare model for k-bit training if using 8-bit or 4-bit
            model = prepare_model_for_kbit_training(model)

            # Configure LoRA
            lora_config = LoraConfig(
                r=cfg.model.peft.r,
                lora_alpha=cfg.model.peft.lora_alpha,
                target_modules=cfg.model.peft.target_modules,
                lora_dropout=cfg.model.peft.lora_dropout,
                bias=cfg.model.peft.bias,
                task_type=cfg.model.peft.task_type,
            )

            # Get PEFT model
            model = get_peft_model(model, lora_config)
            if rank == 0:  # Only print on main process
                model.print_trainable_parameters()  # Print trainable parameters info
        else:
            raise ValueError(f"Unsupported PEFT method: {cfg.model.peft.method}")

    return model, tokenizer


@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig):
    # Set cuda visible devices BEFORE any CUDA operations
    os.environ["CUDA_VISIBLE_DEVICES"] = cfg.cuda.visible_devices
    print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")

    # Setup distributed training AFTER setting CUDA_VISIBLE_DEVICES
    rank, world_size, local_rank = setup_distributed_training()

    # load some tokens
    load_dotenv()

    # fix seed for reproducibility
    fix_seed(cfg.seed)

    # Debug information
    if rank == 0:
        print(f"Training setup:")
        print(f"  Rank: {rank}, World Size: {world_size}, Local Rank: {local_rank}")
        print(
            f"  CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}"
        )
        print(f"  Available GPUs: {torch.cuda.device_count()}")
        if torch.cuda.is_available():
            print(f"  Current GPU: {torch.cuda.current_device()}")
            print(f"  GPU Name: {torch.cuda.get_device_name()}")

    # initialize wandb only on main process
    if rank == 0:
        wandb_key = os.environ.get("WANDB_KEY")
        wandb_host = os.environ.get("WANDB_HOST")
        if (not wandb_key or not wandb_host) and cfg.wandb.use_wandb:
            raise ValueError("Wandb credentials must be provided in .env file")
        elif cfg.wandb.use_wandb:
            wandb.login(key=wandb_key, host=wandb_host)
            if "experiment_name" in cfg.dataset:
                run_name = f"MBPP-GRPO-{cfg.dataset.experiment_name}"
            else:
                run_name = None
            wandb.init(
                project=cfg.wandb.project,
                name=run_name,
                notes=cfg.wandb.note if cfg.wandb.note else None,
            )

    # Load and prepare dataset
    train_data, eval_data = load_and_prepare_dataset(cfg, rank, world_size)

    # Initialize and prepare model
    model, tokenizer = prepare_model(cfg, rank, world_size)

    # Prepare training arguments
    if cfg.training.vllm_max_model_len is None:
        vllm_max_model_len = (
                cfg.training.max_prompt_length + cfg.training.max_completion_length
        )
    else:
        vllm_max_model_len = cfg.training.vllm_max_model_len

    if not hasattr(cfg.training, "per_device_eval_batch_size"):
        _per_device_eval_batch_size = cfg.training.num_generations
    else:
        _per_device_eval_batch_size = cfg.training.per_device_eval_batch_size

    # TODO: change it to be more automatic
    training_args = GRPOConfig(
        # Generation parameters
        max_prompt_length=cfg.training.max_prompt_length,
        num_generations=cfg.training.num_generations,
        temperature=cfg.training.temperature,
        max_completion_length=cfg.training.max_completion_length,
        # vLLM parameters
        use_vllm=cfg.training.use_vllm,
        vllm_mode=cfg.training.vllm_mode,
        vllm_server_base_url=cfg.training.vllm_server_base_url,
        # Basic training parameters
        num_train_epochs=cfg.training.num_train_epochs,
        per_device_train_batch_size=cfg.training.per_device_train_batch_size,
        gradient_accumulation_steps=cfg.training.gradient_accumulation_steps,
        max_grad_norm=cfg.training.max_grad_norm,
        learning_rate=cfg.training.learning_rate,
        gradient_checkpointing=cfg.training.gradient_checkpointing,
        # GRPO specific parameters
        beta=cfg.training.beta,
        reward_weights=cfg.training.reward_weights,
        epsilon=cfg.training.epsilon,
        epsilon_high=cfg.training.epsilon_high,
        num_iterations=cfg.training.num_iterations,
        loss_type=cfg.training.loss_type,
        # Reference model parameters
        sync_ref_model=cfg.training.sync_ref_model,
        ref_model_mixup_alpha=cfg.training.ref_model_mixup_alpha,
        ref_model_sync_steps=cfg.training.ref_model_sync_steps,
        # Optimizer parameters
        adam_beta2=cfg.training.adam_beta2,
        weight_decay=cfg.training.weight_decay,
        warmup_ratio=cfg.training.warmup_ratio,
        warmup_steps=cfg.training.warmup_steps,
        # TODO: warmup_steps will always override warmup_ratio
        lr_scheduler_type=cfg.training.lr_scheduler_type,
        # Logging and monitoring
        log_completions=cfg.training.log_completions,
        logging_steps=cfg.training.logging_steps,
        report_to=(
            ["wandb"] if rank == 0 else []
        ),  # Only report to wandb on main process
        run_name=cfg.training.run_name,
        output_dir=cfg.training.output_dir,
        save_steps=cfg.training.save_steps,
        # Evaluation parameters
        eval_strategy=cfg.training.eval_strategy,
        eval_steps=cfg.training.eval_steps,
        per_device_eval_batch_size=_per_device_eval_batch_size,
        bf16=cfg.training.bf16,
        fp16=cfg.training.fp16,
    )

    # Debug: Print training arguments for debugging
    if rank == 0:
        print(f"\nTraining Arguments:")
        print(f"  bf16: {training_args.bf16}")
        print(f"  fp16: {training_args.fp16}")
        print(
            f"  mixed_precision: {getattr(training_args, 'mixed_precision', 'Not set')}"
        )
        print(
            f"  per_device_train_batch_size: {training_args.per_device_train_batch_size}"
        )
        print(
            f"  gradient_accumulation_steps: {training_args.gradient_accumulation_steps}"
        )
        print(f"  num_generations: {training_args.num_generations}")
        print(
            f"  Effective batch size: {training_args.per_device_train_batch_size * world_size * training_args.gradient_accumulation_steps}"
        )
        print(
            f"  Is divisible by num_generations: {(training_args.per_device_train_batch_size * world_size * training_args.gradient_accumulation_steps) % training_args.num_generations == 0}"
        )

    if hasattr(cfg, "reward_function") and hasattr(cfg.reward_function, "name"):
        reward_functions = [
            get_reward_function(cfg.reward_function.name, **cfg.reward_function.kwargs)
        ]
    else:
        raise ValueError(
            f"Config file must have either `cfg.dataset.experiment_name` or `cfg.reward_function.name`"
            f"to get reward functions"
        )

    completion_processing_function = get_completion_processing_function(
        cfg.completion_processing_function
    )

    # if advantage function is not in the config, use the default one
    if (
            hasattr(cfg.training, "advantage_function")
            and cfg.training.advantage_function is not None
    ):
        advantage_function = get_advantage_function(
            cfg.training.advantage_function, cfg.training.advantage_function_kwargs
        )
    else:
        advantage_function = None

    # Initialize trainer
    if getattr(cfg.training, "trainer", "base") == "base":
        trainer = CustomGRPOTrainer(
            model=model,
            processing_class=tokenizer,
            reward_funcs=reward_functions,
            advantage_function=advantage_function,
            completion_processing_function=completion_processing_function,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            entropy_coef=getattr(cfg.training, "entropy_coef", 0.0),
        )
    elif getattr(cfg.training, "trainer", "base") == "bon_grpo":
        trainer = BonGRPOTrainer(
            model=model,
            processing_class=tokenizer,
            reward_funcs=reward_functions,
            advantage_function=advantage_function,
            completion_processing_function=completion_processing_function,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            entropy_coef=getattr(cfg.training, "entropy_coef", 0.0),
            best_k=getattr(cfg.training, "best_k", 8),
            var_redaction=getattr(cfg.training, "var_redaction", None),
            clamp_delta=getattr(cfg.training, "clamp_delta", 0.2),
        )
    elif getattr(cfg.training, 'trainer', 'base') == 'bon_onpolicy_grpo':
        trainer = BonOnPolicyGRPOTrainer(
            model=model,
            processing_class=tokenizer,
            reward_funcs=reward_functions,
            advantage_function=advantage_function,
            completion_processing_function=completion_processing_function,
            args=training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            entropy_coef=getattr(cfg.training, 'entropy_coef', 0.0),
            best_k=getattr(cfg.training, 'best_k', 8),
            var_redaction=getattr(cfg.training, 'var_redaction', None)
        )
    else:
        raise ValueError(f"Unknown trainer: {getattr(cfg.training, 'trainer', 'base')}")

    trainer.train()
    # save model at the end of training
    trainer.save_model(
        str(Path(cfg.training.output_dir) / f"final_step_{trainer.state.global_step}")
    )


if __name__ == "__main__":
    main()
