# -*- coding: utf-8 -*-
# distill.py
import logging
import os
from io import StringIO
from types import SimpleNamespace

import datasets
import hydra
import torch
import yaml
from accelerate import Accelerator
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from peft import LoraConfig, get_peft_model
from rich.console import Console
from rich.panel import Panel
from rich.syntax import Syntax
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import logging as hf_logging
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

import wandb
from utils import SYSTEM_PROMPT, init


accelerator = Accelerator()
log = logging.getLogger(__name__)
if not accelerator.is_main_process:
    hf_logging.set_verbosity_error()
    hf_logging.disable_progress_bar()
    datasets.disable_progress_bar()
    tqdm = lambda x, *args, **kwargs: x

def log_color(content, title=""):
    console = Console()
    console.print(Panel(content, title=title, border_style="cyan", title_align="left"))

    # Log the message as plain text for log files
    string_io = StringIO()
    plain_console = Console(file=string_io, highlight=False)
    plain_console.print(Panel(content, title=title, border_style="none", title_align="left"))
    log.info("\n" + string_io.getvalue())

@hydra.main(config_path=".", config_name="train_config", version_base="1.3")
def main(cfg: DictConfig):
    assert cfg.train_traces, "Please provide the training traces"
    if cfg.do_eval:
        assert cfg.holdout_traces, "Please provide the holdout traces for evaluation"
    cfg.tokenizer = cfg.tokenizer or cfg.student
    with open(cfg.train_traces+".yaml", 'r') as f:
        trace_config = yaml.safe_load(f)
        trace_cfg = SimpleNamespace(**trace_config)

    init(os.getenv("USER"), cfg.seed)

    if accelerator.is_main_process:
        content = Syntax(OmegaConf.to_yaml(cfg, resolve=True), 'yaml', theme="monokai")
        log_color(content, title="Model Config")
        content = Syntax(yaml.dump(trace_config), 'yaml', theme="monokai")
        log_color(content, title="Trace Config")

    if not cfg.wandb:
        os.environ["WANDB_DISABLED"] = "true"

    tokenizer = AutoTokenizer.from_pretrained(
        cfg.tokenizer,
        use_fast=True,
        fast_tokenizer=True,
        trust_remote_code=True,
        padding_side="left"
    )
    if "llama" in cfg.tokenizer.lower():
        eot_token_id = 128009
        eos_token_id = 128001
        tokenizer.pad_token_id = 128004
        tokenizer.eos_token_id = eos_token_id
        tokenizer.add_eos_token = False
    else:
        raise ValueError(f"Unknown tokenizer {cfg.tokenizer}")

    student = AutoModelForCausalLM.from_pretrained(
        cfg.student,
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        use_cache=True,
    )
    student.generation_config.pad_token_id = tokenizer.pad_token_id
    student.generation_config.add_eos_token = False

    if cfg.lora:
        peft_config = LoraConfig(
            r=cfg.lora_r,
            lora_alpha=cfg.lora_alpha,
            target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],
            lora_dropout=cfg.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
        )
        student = get_peft_model(student, peft_config)

    def preprocess_function(examples):
        # trace_colname = trace_cfg.trace_colname+"_af" if trace_cfg.answer_force else trace_cfg.trace_colname
        trace_colname = trace_cfg.trace_colname
        responses = []
        for response in examples[trace_colname]:
            fixed_response = response.split("<｜Assistant｜>")[1]
            fixed_response = fixed_response.replace("<｜end▁of▁sentence｜>", tokenizer.eos_token)
            responses.append(fixed_response)
        messages = [[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": problem.strip()},
            {"role": "assistant", "content": response.strip()}]
            for problem, response in zip(examples["problem"], responses)]
        tokens = [toks for toks in tokenizer.apply_chat_template(messages, add_generation_prompt=False) if len(toks) <= cfg.max_length]
        tokens = [toks[:-1] for toks in tokens]  # For llama which always emits eot_token
        tok_lengths = [len(toks) for toks in tokens]
        return {"input_ids": tokens, "token_lengths": tok_lengths}

    if accelerator.is_main_process:
        train_traces = datasets.load_from_disk(cfg.train_traces)
        train_traces = train_traces.map(
            preprocess_function,
            batched=True,
            batch_size=16384,
            num_proc=96,
            remove_columns=list(train_traces.column_names),
            desc="Preprocessing train dataset",
            load_from_cache_file=True
        )
        log_color(tokenizer.decode(train_traces[0]['input_ids']), title="Example Input")
        train_token_length_stats = train_traces.to_pandas()["token_lengths"].describe()
        log_color(str(train_token_length_stats.round(2)), title="Train Trace Token Lengths")
        train_traces = train_traces.remove_columns("token_lengths")
        train_traces.save_to_disk("/tmp/cached_train_traces")

        holdout_traces = datasets.load_from_disk(cfg.holdout_traces)
        holdout_traces = holdout_traces.map(
            preprocess_function,
            batched=True,
            batch_size=16384,
            num_proc=96,
            remove_columns=list(holdout_traces.column_names),
            desc="Preprocessing holdout dataset",
            load_from_cache_file=True
        )
        log_color(str(holdout_traces.to_pandas()["token_lengths"].describe().round(2)), title="Holdout Trace Token Lengths")
        holdout_traces = holdout_traces.remove_columns("token_lengths")
        holdout_traces.save_to_disk("/tmp/cached_holdout_traces")
    accelerator.wait_for_everyone()
    train_traces = datasets.load_from_disk("/tmp/cached_train_traces")
    holdout_traces = datasets.load_from_disk("/tmp/cached_holdout_traces")

    if 'llama' in cfg.student.lower():
        response_string = "<|start_header_id|>assistant<|end_header_id|>\n\n"
    elif "qwen" in cfg.student.lower():
        response_string = "<|im_start|>assistant\n"
    else:
        raise ValueError(f"Unknown model {cfg.student}")
    collator = DataCollatorForCompletionOnlyLM(
        response_template=tokenizer.encode(response_string, add_special_tokens=False),
        tokenizer=tokenizer,
        mlm=False
    )
    steps_per_epoch = len(train_traces) // cfg.batch_size
    eval_steps = int(steps_per_epoch * cfg.eval_epochs) if cfg.do_eval else 0
    trainer = SFTTrainer(
        model=student,
        train_dataset=train_traces,
        eval_dataset=holdout_traces,
        processing_class=tokenizer,
        data_collator=collator,
        args=SFTConfig(
            bf16=student.config.use_bfloat16,
            do_eval=cfg.do_eval,
            max_length=cfg.max_length,
            eval_strategy="steps" if cfg.do_eval else "no",
            eval_steps=eval_steps,
            eval_on_start=True if cfg.do_eval else False,
            gradient_accumulation_steps=cfg.batch_size // cfg.per_device_batch_size // accelerator.num_processes,
            max_grad_norm=cfg.max_grad_norm,
            gradient_checkpointing=False,
            gradient_checkpointing_kwargs={"use_reentrant": False},
            learning_rate=cfg.lr,
            weight_decay=cfg.weight_decay,
            log_level="info",
            logging_steps=10,
            logging_strategy="steps",
            lr_scheduler_type=cfg.lr_scheduler_type,
            optim='adamw_torch_fused',
            num_train_epochs=cfg.num_epochs,
            output_dir=cfg.model_path,
            overwrite_output_dir=True,
            per_device_train_batch_size=cfg.per_device_batch_size,
            per_device_eval_batch_size=cfg.per_device_batch_size*2,
            report_to="wandb" if cfg.wandb else None,
            save_strategy="steps",
            save_steps=500,
            save_total_limit=3,
            seed=cfg.seed,
            warmup_ratio=cfg.warmup,
            remove_unused_columns=False,
            # dataset_text_field="text",
            # dataloader_num_workers=2,
            # packing=True,
            label_names=["labels"],
            ddp_find_unused_parameters=False,
            save_safetensors=False,
        ),
    )

    if accelerator.is_main_process and cfg.wandb:
        wandb_run = wandb.init(
            project="antidistillation",
            name=f"{cfg.exp_dir}/{cfg.model_name}",
            config={**cfg, "trace_config": trace_config},
        )
        wandb.log({
            "train_trace_raw_accuracy": trace_cfg.stats["raw_accuracy"],
            "train_trace_af_accuracy": trace_cfg.stats["af_accuracy"],
            "trace_token_length/stats": {k: float(v) for k,v in train_token_length_stats.items()}
        })

        full_cfg = OmegaConf.to_container(cfg, resolve=True)
        hydra_cfg = HydraConfig.get()
        full_cfg["hydra"] = {
            "run_dir": hydra_cfg.run.dir,
            "job_name": hydra_cfg.job.name,
            "cwd": hydra_cfg.runtime.cwd,
        }
        full_cfg["wandb_run_id"] = wandb_run.id
        yaml_path = cfg.model_path + ".yaml"
        with open(yaml_path, "w") as f:
            OmegaConf.save(full_cfg, f)
    train_result = trainer.train(resume_from_checkpoint=cfg.checkpoint)

    if cfg.do_eval:
        metrics = trainer.evaluate()
        metrics["eval_samples"] = len(trainer.eval_dataset)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    if accelerator.is_main_process:
        trainer.model.config.use_cache = True
        if cfg.lora:
            trainer.model = trainer.model.merge_and_unload()
        torch.cuda.empty_cache()
        final_output_dir = os.path.join(cfg.model_path, "final")
        trainer.save_model(final_output_dir)
        trainer.tokenizer.save_pretrained(final_output_dir)

        if cfg.wandb:
            wandb.finish()

    accelerator.end_training()

if __name__ == "__main__":
    main()