import os
import yaml
import torch
import mlflow
import argparse
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from unsloth import FastLanguageModel
from unsloth.chat_templates import standardize_sharegpt
from datasets import load_dataset, DatasetDict, Dataset, load_from_disk
import pandas as pd
from trl import SFTTrainer, SFTConfig

RANDOM_SEED = 42

def generate_conversation(examples):
    problems  = examples["problem"]
    solutions = examples["generated_solution"]
    conversations = []
    for problem, solution in zip(problems, solutions):
        conversations.append([
            {"role" : "user",      "content" : problem},
            {"role" : "assistant", "content" : solution},
        ])
    return { "conversations": conversations, }


def rename_messages_to_conversations(example):
    example["conversations"] = example["messages"]
    return example


@hydra.main(config_path=None, config_name=None, version_base=None)
def main(cfg: DictConfig):    
    # Create output directory
    hydra_cfg = HydraConfig.get()
    run_name = hydra_cfg.job.config_name.split('/')[-1].split('.')[0]
    output_dir = os.path.join(cfg.trainer.output_dir, run_name)
    os.makedirs(output_dir, exist_ok=True)

    # Set up MLflow tracking
    mlflow.set_tracking_uri(cfg.mlflow.tracking_uri)
    mlflow.set_experiment(cfg.mlflow.experiment_name)

    # Create model and tokenizer
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = cfg.model.model_name,
        max_seq_length = cfg.model.max_seq_length,
        load_in_4bit = cfg.model.load_in_4bit,
        load_in_8bit = cfg.model.load_in_8bit,
        full_finetuning = cfg.model.full_finetuning,
        # token = "hf_...",      # use one if using gated models
    )

    # Create PEFT model
    model = FastLanguageModel.get_peft_model(
        model,
        r = cfg.peft.lora_rank,
        target_modules = cfg.peft.target_modules,
        lora_alpha = cfg.peft.lora_alpha,
        lora_dropout = cfg.peft.lora_dropout,
        bias = cfg.peft.bias,
        use_gradient_checkpointing = cfg.peft.use_gradient_checkpointing,
        random_state = RANDOM_SEED,
        use_rslora = cfg.peft.use_rslora,
        loftq_config = cfg.peft.loftq_config,
    )

    print(model.dtype)

    reasoning_dataset = load_dataset("unsloth/OpenMathReasoning-mini", split = "cot")
    non_reasoning_dataset =  load_dataset("json", data_files={
                                "train": "/path/to/train.jsonl",
                                "val": "/path/to/val.jsonl",
                                "test": "/path/to/test.jsonl"
                            })

    reasoning_conversations = tokenizer.apply_chat_template(
        reasoning_dataset.map(generate_conversation, batched = True)["conversations"],
        tokenize = False,
    )

    dataset = non_reasoning_dataset.map(rename_messages_to_conversations)

    # Apply Unsloth's standardizer to each split separately
    train_dataset = standardize_sharegpt(dataset["train"])
    val_dataset   = standardize_sharegpt(dataset["val"])
    test_dataset  = standardize_sharegpt(dataset["test"])

    # Recombine into a DatasetDict if needed
    dataset = DatasetDict({
        "train": train_dataset,
        "val": val_dataset,
        "test": test_dataset,
    })

    # Convert into conversations
    non_reasoning_conversations_train = tokenizer.apply_chat_template(
        dataset["train"]["conversations"],
        tokenize = False,
    )
    non_reasoning_conversations_val = tokenizer.apply_chat_template(
        dataset["val"]["conversations"],
        tokenize = False,
    )
    non_reasoning_conversations_test = tokenizer.apply_chat_template(
        dataset["test"]["conversations"],
        tokenize = False,
    )

    # Combine non-reasoning & reasoning datasets
    non_reasoning_subset = pd.Series(non_reasoning_conversations_train)
    non_reasoning_subset_eval = pd.Series(non_reasoning_conversations_val)
    reasoning_conversations = []  # no reasoning for now
    data_series = []
    if len(reasoning_conversations) > 0:
        data_series.append(pd.Series(reasoning_conversations))
    if len(non_reasoning_subset) > 0:
        data_series.append(pd.Series(non_reasoning_subset))

    # Training data
    data = pd.concat(data_series)
    data.name = "text"
    combined_dataset = Dataset.from_pandas(pd.DataFrame(data))
    combined_dataset = combined_dataset.shuffle(seed = RANDOM_SEED)

    # Eval data
    data_eval = pd.concat([pd.Series(non_reasoning_subset_eval)])
    data_eval.name = "text"
    dataset_eval = Dataset.from_pandas(pd.DataFrame(data_eval))
    dataset_eval = dataset_eval.shuffle(seed = RANDOM_SEED)

    with mlflow.start_run(run_name=run_name):
        # Log all simple params except tags and MLflow specific keys
        for key, value in OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True, structured_config_mode=False).items():
            try:
                mlflow.log_param(key, value)
            except Exception:
                pass  # skip unloggable types like lists or dicts

        # Build the SFTConfig object
        sft_args = SFTConfig(
            dataset_text_field = cfg.training.dataset_text_field,
            per_device_train_batch_size = cfg.training.per_device_train_batch_size,
            gradient_accumulation_steps = cfg.training.gradient_accumulation_steps,
            warmup_ratio = cfg.training.warmup_ratio,
            max_steps = cfg.training.max_steps,
            learning_rate = float(cfg.training.learning_rate),
            logging_steps = cfg.training.logging_steps,
            optim = cfg.training.optim,
            weight_decay = cfg.training.weight_decay,
            lr_scheduler_type = cfg.training.lr_scheduler_type,
            seed = RANDOM_SEED,
            report_to = "mlflow",
        )

        # Initialize the trainer
        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            train_dataset=combined_dataset,
            eval_dataset=dataset_eval,
            eval_strategy=cfg.trainer.eval_strategy,
            save_strategy=cfg.trainer.save_strategy,
            output_dir=output_dir,
            args=sft_args,
            dataset_kwargs = {"skip_prepare_dataset": False}  # True only for pre-tokenized datasets
        )

        # Show current memory stats
        gpu_stats = torch.cuda.get_device_properties(0)
        start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
        max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
        print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
        print(f"{start_gpu_memory} GB of memory reserved.")

        # Launch training
        trainer_stats = trainer.train()

    # Show final memory and time stats
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
    used_percentage = round(used_memory / max_memory * 100, 3)
    lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
    print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
    print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
    print(f"Peak reserved memory = {used_memory} GB.")
    print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
    print(f"Peak reserved memory % of max memory = {used_percentage} %.")
    print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

if __name__ == '__main__':
    main()