from trl import SFTTrainer, SFTConfig, GRPOConfig
from ..utils.seed import set_seed
from unsloth import FastLanguageModel
import pandas as pd
import numpy as np
from datasets import load_dataset, Dataset
import argparse
import torch
import gc
import yaml
from .get_trainer import get_trainer
from vllm import SamplingParams
from ..utils.rewards import math_verify_reward



reasoning_start = "<start_working_out>" 
reasoning_end   = "<end_working_out>"   
solution_start  = "<SOLUTION>"
solution_end    = "</SOLUTION>"

system_prompt = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""

chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "{{ messages[0]['content'] + eos_token }}"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "{{ '{system_prompt}' + eos_token }}"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ message['content'] }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ message['content'] + eos_token }}"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"\
    "{% endif %}"

def format_dataset(x):
    expected_answer = x["expected_answer"]
    problem = x["problem"]

    thoughts = x["generated_solution"]
    thoughts = thoughts.replace("<think>", "").replace("</think>", "")

    thoughts = thoughts.strip()
    final_prompt = \
        reasoning_start + thoughts + reasoning_end + \
        solution_start + expected_answer + solution_end
    return [
        {"role" : "system",    "content" : system_prompt},
        {"role" : "user",      "content" : problem},
        {"role" : "assistant", "content" : final_prompt},
    ]

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()
    with open(args.config, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)



    set_seed(cfg["seed"])
    model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = cfg["model_name"],
    max_seq_length = cfg["max_seq_length"],
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = cfg["lora_rank"],
    gpu_memory_utilization = 0.9# Reduce if out of memory
    )

    baseline_save_dir = "/outputs/models/" + cfg["model_name"].split("/")[1] + "/baseline"
    algorithm_name = (cfg["loss_type"] + 
                      ("_no_kl" if not cfg["use_kl"] else "") + 
                      ("_random" if cfg["use_random"] else "") + 
                      ("_token" if cfg["use_token"] else "") +
                      ("_branch" if cfg["use_branch"] else "")+
                      ("_fkl" if cfg["use_fkl"] else ""))
    model_save_dir = "/outputs/models/" + cfg["model_name"].split("/")[1] + algorithm_name
    

    model = FastLanguageModel.get_peft_model(
    model,
    r = cfg["lora_rank"], # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = cfg["lora_rank"]*2, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = cfg["random_state"]
    )
    
    if not cfg["instruct"]:
        # Do SFT if model is not instruction tuned
        print("Starting SFT")
        # Replace with out specific template:
        chat_template = chat_template\
            .replace("'{system_prompt}'",   f"'{system_prompt}'")\
            .replace("'{reasoning_start}'", f"'{reasoning_start}'")
        tokenizer.chat_template = chat_template

        dataset = load_dataset("unsloth/OpenMathReasoning-mini", split = "cot")
        dataset = dataset.to_pandas()[
            ["expected_answer", "problem", "generated_solution"]
        ]

        # Try converting to number - if not, replace with NaN
        is_number = pd.to_numeric(pd.Series(dataset["expected_answer"]), errors = "coerce").notnull()
        # Select only numbers
        dataset = dataset.iloc[np.where(is_number)[0]]
        dataset["Messages"] = dataset.apply(format_dataset, axis = 1)
        dataset["N"] = dataset["Messages"].apply(lambda x: len(tokenizer.apply_chat_template(x)))
        dataset = dataset.loc[dataset["N"] <= cfg["max_seq_length"]/2].copy()
        dataset["text"] = tokenizer.apply_chat_template(dataset["Messages"].values.tolist(), tokenize = False)
        dataset = Dataset.from_pandas(dataset)
        trainer = SFTTrainer(
            model = model,
            tokenizer = tokenizer,
            train_dataset = dataset,
            args = SFTConfig(
            dataset_text_field = "text",
            per_device_train_batch_size = cfg["sft"]["per_device_train_batch_size"],
            gradient_accumulation_steps = cfg["sft"]["gradient_accumulation_steps"], # Use GA to mimic batch size!
            warmup_steps = cfg["sft"]["warmup_steps"],
            num_train_epochs = cfg["sft"]["num_epochs"], # Set this for 1 full training run.
            learning_rate = cfg["sft"]["learning_rate"], # Reduce to 2e-5 for long training runs
            logging_steps = cfg["sft"]["logging_steps"],
            optim = cfg["sft"]["optim"],
            weight_decay = cfg["sft"]["weight_decay"],
            lr_scheduler_type = cfg["sft"]["lr_scheduler_type"],
            seed = cfg["seed"],
            report_to = "none"
            )
            )
        trainer.train()
        del dataset
        torch.cuda.empty_cache()
        gc.collect()
        model.save_lora(baseline_save_dir)

        
        print("SFT DONE!")
    
    print("Starting RL")
    
    dataset = load_dataset("open-r1/DAPO-Math-17k-Processed", "en", split = "train")
    dataset = dataset.map(lambda x: {
        "prompt" : [
            {"role": "system", "content": system_prompt},
            {"role": "user",   "content": x["prompt"]},
        ],
        "answer": x["solution"],
    })
    tokenized = dataset.map(
        lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
        batched = True,
    )
    tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])})

    maximum_length = int(np.quantile(tokenized["L"], 0.9))


    # Filter only samples smaller than 90% max length
    dataset = dataset.select(np.where(np.array(tokenized["L"]) <= maximum_length)[0])
    dataset = dataset.select(range(cfg["rl"]["N"]))
    del tokenized
    
    max_prompt_length = maximum_length + 1 # + 1 just in case!
    max_completion_length = cfg["max_seq_length"] - max_prompt_length


    vllm_sampling_params = SamplingParams(
        min_p = cfg["vllm"]["min_p"],
        top_p = cfg["vllm"]["top_p"],
        top_k = cfg["vllm"]["top_k"],
        seed = cfg["random_state"],
        stop = [tokenizer.eos_token],
        include_stop_str_in_output = True,
    )
    training_args = GRPOConfig(
        vllm_sampling_params = vllm_sampling_params,
        temperature = cfg["rl"]["temperature"],
        learning_rate = cfg["rl"]["learning_rate"],
        weight_decay = cfg["rl"]["weight_decay"],
        warmup_ratio = cfg["rl"]["warmup_ratio"],
        lr_scheduler_type = cfg["rl"]["lr_scheduler_type"],
        loss_type = cfg["loss_type"],
        beta= cfg["rl"]["kl_coef"] if cfg["use_kl"] else 0, # kl penalty term
        epsilon = cfg["rl"]["epsilon"],
        epsilon_high = cfg["rl"]["epsilon_high"] if cfg["loss_type"] == 'dapo' else cfg["rl"]["epsilon"],
        optim = cfg["rl"]["optim"],
        logging_steps = cfg["rl"]["logging_steps"],
        per_device_train_batch_size = cfg["rl"]["per_device_train_batch_size"],
        gradient_accumulation_steps = cfg["rl"]["gradient_accumulation_steps"], # Increase to 4 for smoother training
        num_generations = cfg["rl"]["num_generations"], # Decrease if out of memory
        max_prompt_length = max_prompt_length,
        max_completion_length = max_completion_length,
        #num_train_epochs = 1, # Set to 1 for a full training run
        max_steps = cfg["rl"]["max_steps"],
        save_steps = cfg["rl"]["save_steps"],
        report_to = ['wandb'], # Can use Weights & Biases
        output_dir = model_save_dir+"/ckpts"

        # For optional training + evaluation
        # fp16_full_eval = True,
        # per_device_eval_batch_size = 4,
        # eval_accumulation_steps = 1,
        # eval_strategy = "steps",
        # eval_steps = 1,
    )

    trainer = get_trainer(model= model,
                          tokenizer= tokenizer,
                          training_args= training_args,
                          dataset= dataset,
                          reward= math_verify_reward,
                          cfg= cfg
                          )
    trainer.train()
    model.save_lora(model_save_dir)
    return

if __name__ == "__main__":
    main()