# %%
import math
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Literal, Union, Optional

from peft import LoraConfig
import randomname

from core import BASE_PATH
from core.llm import LLM
from core.messages import Message, Role
from training.train_student import train
from training.utils import augment_filenames, set_seed

# %%
def main(
    devices: int = 1,
    base: str = "llama3-8b-instruct",
    project_name: str = "huggingface",
    group_name: str = None,
    eval_interval: int = -1,
    save_interval: float = float("inf"),
    log_interval: int = 1,
    generation_interval: int = -1,
    # Hyperparameters
    learning_rate: float = 1e-5,
    batch_size: int = 4,
    micro_batch_size: int = 4,
    n_epochs: int = 10,
    train_temperature: float = 2.0,
    kl_div_loss: bool = True,
    token_loss_weight: float = 0.0,
    logit_loss_weight: float = 1.0,
    teacher: Union[Literal["student", "student_base"], str] = "student_base",
    lora_r: int = 1024,
    lora_type: str = "full",
    weight_decay: float = 0.1,
    warmup_steps: int = None,
    warmup_ratio: float = 0.1,
    run_name: str = None,
    dataset: str = "nyt_default",
    lesson_temp: float = 1.5,
    exam_temp: float = 0.25,
    lesson_num_choices: int = 1,
    exam_num_choices: int = 1,
    use_wandb: bool = False,
    validate: bool = True,
    dataset_llama3_8b: bool = True,
    dataset_llama3_8b_test: bool = True,
    dataset_llama3_70b: bool = False,
    dataset_qwen25_3b: bool = False,
    dataset_qwen25_14b: bool = False,
    dataset_qwen25_72b: bool = False,
    seed: int = 0,
    mixed_precision: str = "bf16",
    max_grad_norm: float = 1.0,
    decay: bool = True,
    save: bool = True,
    save_during_training: bool = False,
    checkpoint_interval: int = 0,
    checkpoint_interval_seconds: int = 0,
    datapath: Path = Path("data"),
    bonito_model: str = "llama3-8b-instruct",
    bonito_questions: int = 30,
    bonito_temperature: float = 1.5,
    bonito_max_items_train: int = 1000,
    bonito_max_items_test: int = 1000,
    dataset_vllm: bool = True,
    max_length: int = 0,
    max_total_length: int = 0,
    closed_book_token_loss: bool = True,
    distractor_dataset: str = "",
    deepspeed_path: str = "",
    deepspeed_path_teacher: str = "",
    reverse_kl: bool = False,
    partition_idx: Optional[int] = None,
    partition_type: Optional[str] = None,
    tulu: bool = False,
    tulu_batch_size: int = 2,
):
    if seed:
        set_seed(seed)
    project_path = BASE_PATH / "checkpoints" / project_name

    if dataset == "nyt_cot":
        train_files = [
            f"nyt_cot_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml",
        ]
        val_files = [
            f"nyt_cot_{bonito_max_items_test}_test.xml",
        ]
    elif dataset == "nyt_default":
        train_files = [
            f"nyt_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml",
        ]
        val_files = [
            f"nyt_default_{bonito_max_items_test}_test.xml",
        ]
    elif dataset == "amazon_cot":
        train_files = [
            f"amazon_cot_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml",
        ]
        val_files = [
            f"amazon_cot_{bonito_max_items_test}_test.xml",
        ]
    elif dataset == "amazon_default":
        train_files = [
            f"amazon_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml",
        ]
        val_files = [
            f"amazon_default_{bonito_max_items_test}_test.xml",
        ]
    elif dataset == "reddit_cot":
        train_files = [
            f"reddit_cot_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml",
        ]
        val_files = [
            f"reddit_cot_{bonito_max_items_test}_test.xml",
        ]
    elif dataset == "reddit_default":
        train_files = [
            f"reddit_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml",
        ]
        val_files = [
            f"reddit_default_{bonito_max_items_test}_test.xml",
        ]
    elif dataset == "new_wiki_cot":
        train_files = [
            f"new_wiki_cot_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml",
        ]
        val_files = [
            f"new_wiki_cot_{bonito_max_items_test}_test.xml",
        ]
    elif dataset == "new_wiki_default":
        train_files = [
            f"new_wiki_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items_train}_train.xml",
        ]
        val_files = [
            f"new_wiki_default_{bonito_max_items_test}_test.xml",
        ]
    else:
        raise RuntimeError(f"Dataset {dataset} undefined.")

    train_files = augment_filenames(train_files, lesson_temp, lesson_num_choices, dataset_vllm,
                                    llama3_8b=dataset_llama3_8b, llama3_70b=dataset_llama3_70b,
                                    qwen25_3b=dataset_qwen25_3b, qwen25_14b=dataset_qwen25_14b, qwen25_72b=dataset_qwen25_72b,
                                    partition_idx=partition_idx, partition_type=partition_type)
    val_files = augment_filenames(
        files=val_files,
        temperature=exam_temp,
        n_choices=exam_num_choices,
        vllm=dataset_vllm,
        llama3_8b=dataset_llama3_8b_test,
    )
    data = [train_files, val_files]

    if "llama3" in base:
        assert teacher in ["student", "student_base"] or "llama" in teacher
        opening_message = Message(
            Role.SYSTEM,
            "You are a knowledgeable assistant trained to provide accurate and helpful information. Please respond to the user's queries promptly."
        )
    else:
        opening_message = None
    base_llm = LLM(base, opening_message=opening_message)
    
    if not run_name:
        run_name = randomname.get_name()
    if not group_name:
        group_name = "initial_runs"

    if 'llama' in base:
        target_modules = [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
            "lm_head",
        ]
    elif 'qwen' in base:
        target_modules = [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
            "lm_head",
        ]
    else:
        raise RuntimeError(f"Lora target modules undefined for base {base}")

    peft_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_r*2,
        target_modules=target_modules,
        lora_dropout=0.05,
    )

    # Parameters important for balancing losses
    logit_loss_micro_batch_size = micro_batch_size
    token_loss_micro_batch_size = micro_batch_size
    n_logit_micro_batches_per_batch = math.ceil(batch_size / devices / micro_batch_size)
    n_token_micro_batches_per_batch = math.ceil(batch_size / devices / micro_batch_size)

    hparams = {
        k: v for k, v in locals().items()
        if (isinstance(v, (int, float, str, bool, dict, LoraConfig, Path, Message)) or v is None)
        and not k.startswith("_")
        and not k.isupper()
    }

    hparams = SimpleNamespace(**hparams)

    train(
        project_path=project_path,
        base_llm=base_llm,
        data=data,
        hparams=hparams,
    )

# %%
if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)
