# https://huggingface.co/docs/transformers/en/training
import os
import json
import pickle
import argparse
from pathlib import Path
from functools import partial
from collections import defaultdict

import requests
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    ProcessorMixin,
    LlamaConfig,
)
from accelerate import Accelerator

from models import LlamaForCausalLM, LlamaDraftForCausalLM, AutoModelForCausalLM
from utils import Timer
from criterion import DraftCausalLMDistillLoss, ForDraftCausalLMMetric, ForDraftCausalLMRLMetric, DraftCausalLMChainedLoss, DraftCausalLMChainedRLLoss, DraftCausalLMTreeLoss, DraftCausalLMGRPOLoss, DraftCausalLMQLoss
from dataloader import DataCollator, Dataset
from preprocess import get_tokenizer
from trainer import CustomGRPOTrainer, CustomSLTrainer


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--data_dir", type=Path, default="preprocessed_data")
    parser.add_argument("--model_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
    parser.add_argument("--pretrained_model", type=str, default=None)
    parser.add_argument("--output_dir", type=Path, default="output-llama-finetune-repro")
    parser.add_argument("--overwrite_output_dir", action="store_true")
    parser.add_argument("--num_train_epochs", type=int, default=10)
    parser.add_argument("--per_device_train_batch_size", type=int, default=4)
    parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--logging_steps", type=int, default=50)
    parser.add_argument("--eval_epochs", type=float, default=1)
    parser.add_argument("--save_epochs", type=float, default=1)
    parser.add_argument("--warmup_epochs", type=float, default=1)
    parser.add_argument("--learning_rate", type=float, default=3e-4)
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--dataloader_num_workers", type=int, default=16)
    parser.add_argument("--report_to", type=str, default="tensorboard")
    parser.add_argument("--gradient_checkpointing", action="store_true")
    parser.add_argument("--use_liger_kernel", action="store_true")
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--loss_type", choices=["cross_entropy", "distill_cross_entropy", "chained", "chained-rl", "grpo", "rl-q", "tree", "double-chained", "double-chained-rl"], default="distill_cross_entropy")
    parser.add_argument("--auto_regressive", action="store_true", help="Whether to use auto regressive decoding")
    parser.add_argument("--adam_beta1", type=float, default=0.9)
    parser.add_argument("--adam_beta2", type=float, default=0.95)
    parser.add_argument("--grad_clip", type=float, default=0.5)
    parser.add_argument("--p_weight", type=float, default=0.1)
    parser.add_argument("--eval_on_start", action="store_true", help="Whether to evaluate on start")
    parser.add_argument("--max_length", type=int, default=2048)
    parser.add_argument("--depth", type=int, default=3)
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    with Timer("Loading tokenizer..."):
        tokenizer = get_tokenizer(args.model_name, max_length=args.max_length)

    train_dataset = Dataset(
        args.data_dir,
        split="train",
        debug=args.debug,
        add_feature_noise=True,
        tokenizer=tokenizer,
        max_length=args.max_length,
    )
    valid_dataset = Dataset(
        args.data_dir,
        split="valid",
        debug=args.debug,
        add_feature_noise=False,
        tokenizer=tokenizer,
        max_length=args.max_length,
    )

    # with Timer("Dataset info..."):
    #     sampling_rate = 100
    #     total_train_size = len(train_dataset)
    #     total_train_token_size = sum([
    #         len(train_dataset[i]["input_ids"])
    #         for i in tqdm(range(0, len(train_dataset), sampling_rate), dynamic_ncols=True)
    #     ]) * sampling_rate
    #     total_valid_size = len(valid_dataset)
    #     total_valid_token_size = sum([
    #         len(valid_dataset[i]["input_ids"])
    #         for i in tqdm(range(0, len(valid_dataset), sampling_rate), dynamic_ncols=True)
    #     ]) * sampling_rate

    #     print(f"Total train size: {total_train_size}")
    #     print(f"Total train token size: {total_train_token_size / 1e6:.1f}M")
    #     print(f"Total valid size: {total_valid_size}")
    #     print(f"Total valid token size: {total_valid_token_size / 1e6:.1f}M")

    data_collator = DataCollator(tokenizer=tokenizer, pad_to_multiple_of=8)
    data_collator([train_dataset[0], train_dataset[1]])

    with Timer("Loading config..."):
        config = LlamaConfig.from_pretrained(
            args.model_name,
            trust_remote_code=False,
            device_map="cuda",
            use_cache=False,
        )
        config = config.to_draft_config()

    accelerate = Accelerator()

    with Timer("Loading base model..."):
        base_model = LlamaForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.bfloat16,
            device_map="cpu",
            use_cache=False,
        )

    with Timer("Loading model..."):
        with torch.device("cuda"):
            if args.pretrained_model:
                model = LlamaDraftForCausalLM.from_pretrained(
                    args.pretrained_model,
                    torch_dtype=torch.float32,
                    device_map="cuda",
                )
                model.set_output_embeddings(base_model.get_output_embeddings())
            else:
                model = AutoModelForCausalLM.from_config(
                    config=config,
                    torch_dtype=torch.float32,
                )
                model.set_first_layer(base_model)
                model.set_input_embeddings(base_model.get_input_embeddings())
                model.set_output_embeddings(base_model.get_output_embeddings())

        print(f"base mode dtype: {base_model.model.dtype}")
        print(f"model dtype: {model.model.dtype}")

    num_processes = accelerate.num_processes

    train_length = len(train_dataset) // num_processes
    train_steps = train_length * args.num_train_epochs // (args.per_device_train_batch_size * args.gradient_accumulation_steps)
    print(f"Train length per device: {train_length}")
    print(f"Train steps: {train_steps}")

    eval_steps = int(train_steps // args.num_train_epochs * args.eval_epochs)
    save_steps = int(train_steps // args.num_train_epochs * args.save_epochs)
    warmup_steps = int(train_steps // args.num_train_epochs * args.warmup_epochs)

    print(f"Eval steps: {eval_steps}")
    print(f"Save steps: {save_steps}")
    print(f"Warmup steps: {warmup_steps}")

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        overwrite_output_dir=args.overwrite_output_dir,
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        eval_strategy="steps",
        save_strategy="steps",
        logging_steps=args.logging_steps,
        eval_steps=eval_steps,
        save_steps=save_steps,
        warmup_steps=warmup_steps,
        learning_rate=args.learning_rate,
        bf16=args.bf16,
        fp16=args.fp16,
        bf16_full_eval=args.bf16,
        fp16_full_eval=args.fp16,
        dataloader_num_workers=args.dataloader_num_workers,
        report_to=args.report_to,
        gradient_checkpointing=args.gradient_checkpointing,  # gradient checkpointing is buggy when using multi-gpu
        use_liger_kernel=args.use_liger_kernel,
        label_names=[
            "labels",
            "loss_masks",
            "hidden_states",
            "base_hidden_states",
            "output_topk_ids",
        ],
        batch_eval_metrics=True,
        remove_unused_columns=False,
        max_grad_norm=args.grad_clip,
        # lr_scheduler_type="constant_with_warmup",
        lr_scheduler_type="cosine_with_restarts",
        torch_empty_cache_steps=1,
        dataloader_drop_last=True,
        eval_on_start=args.eval_on_start,
        seed=args.seed,
    )

    teacher_optimizer = None
    if args.loss_type == "distill_cross_entropy":
        loss_func = DraftCausalLMDistillLoss(
            model=model,
            p_weight=args.p_weight,
            top_k=args.top_k,
            mult=1 / args.gradient_accumulation_steps,
            num_seqs=args.per_device_train_batch_size,
        )
    elif args.loss_type == "chained":
        loss_func = DraftCausalLMChainedLoss(
            model=model,
            p_weight=args.p_weight,
            top_k=args.top_k,
            mult=1 / args.gradient_accumulation_steps,
            num_seqs=args.per_device_train_batch_size,
            auto_regressive=args.auto_regressive,
            depth=args.depth,
        )
    elif args.loss_type == "chained-rl":
        # assert args.per_device_train_batch_size == 1, "RL loss only supports batch size 1"
        # assert args.per_device_eval_batch_size == 1, "RL loss only supports batch size 1"
        loss_func = DraftCausalLMChainedRLLoss(
            model=model,
            p_weight=args.p_weight,
            top_k=args.top_k,
            mult=1 / args.gradient_accumulation_steps,
            num_seqs=args.per_device_train_batch_size,
            depth=args.depth,
        )
    elif args.loss_type == "tree":
        # assert args.per_device_train_batch_size == 1, "RL loss only supports batch size 1"
        # assert args.per_device_eval_batch_size == 1, "RL loss only supports batch size 1"
        loss_func = DraftCausalLMTreeLoss(
            model=model,
            p_weight=args.p_weight,
            top_draft=args.top_k,
            mult=1 / args.gradient_accumulation_steps,
            num_seqs=args.per_device_train_batch_size,
            depth=args.depth,
        )
    else:
        raise ValueError(f"Invalid loss type: {args.loss_type}")

    trainer = CustomSLTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        data_collator=data_collator,
        processing_class=tokenizer,
        compute_loss_func=loss_func,
        compute_metrics=ForDraftCausalLMRLMetric(model, num_seqs=args.per_device_eval_batch_size * num_processes),
        optimizer_cls_and_kwargs=(
            torch.optim.AdamW,
            {
                "params": model.get_trainable_parameters(),
                "lr": args.learning_rate,
                "betas": (args.adam_beta1, args.adam_beta2),
                "weight_decay": 0.0,
            },
        ),
    )

    trainer.train()

    trainer.save_model(Path(args.output_dir) / "final")
    tokenizer.save_pretrained(Path(args.output_dir) / "final")
