import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
import wandb
import torch
from datasets import load_dataset
import datetime
import typing
import torch.nn
from torcheval.metrics import Perplexity
from transformers.tokenization_utils_fast import PaddingStrategy, TruncationStrategy


def compute_metrics(eval_pred, compute_result=False):
    # corpus perplexity = exp(mean(loss))
    logits, labels = eval_pred
    
    if not hasattr(compute_metrics, "perplexity_class"):
        compute_metrics.perplexity_class = Perplexity(ignore_index=-100, device=logits.device)
    compute_metrics.perplexity_class.update(logits, labels)
    
    if compute_result:
        result = compute_metrics.perplexity_class.compute()
        compute_metrics.perplexity_class = Perplexity(ignore_index=-100, device=logits.device)
        return {"perplexity": result.item()}
    
    return {"perplexity": None}


class PerplexitySFTTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    ce_loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        mode = "train" if self.model.training else "eval"
        (loss, outputs) = super().compute_loss(
            model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
        )
        
        # Compute perplexity of ground truth text on the current training model
        # Based on computation of mean_token_accuracy in SFTTrainer
        # mean_perplexity = mean(exp(loss))
        if "labels" in inputs and not self.args.use_liger_kernel:
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = inputs["labels"][..., 1:].contiguous()
            mask = shift_labels != -100

            with torch.no_grad():
                perplexity = torch.exp(
                    (self.ce_loss_fct(shift_logits.transpose(1, 2), shift_labels) * mask).sum(1)
                    / mask.sum(1)
                )
                self._metrics[mode]["mean_perplexity"].append(perplexity.mean().item())

        return (loss, outputs) if return_outputs else loss



def main(loss_type: typing.Literal["SFT", "DPO"], split_type: str, valid: bool, collection: typing.Literal["breadth", "depth"], epochs: int):
    model_name = "meta-llama/Llama-3.1-8B-Instruct"
    
    # Load datasets directly from JSONL files
    base_dir = f"/mnt/dv/wid/projects3/XXXX-3-XXXX-5-human-ai/mini-twitter-llm-agent-binary/data/{loss_type.lower()}_data_formatted"
    if loss_type == "DPO":
        ds = load_dataset("json", data_files=f"{base_dir}/{model_name.split('/')[1]}_{collection}_{split_type}_train.jsonl", split="train")
        ds_test = load_dataset("json", data_files=f"{base_dir}/{model_name.split('/')[1]}_{collection}_{split_type}_test.jsonl", split="train")
    else:  # SFT
        ds = load_dataset("json", data_files=f"{base_dir}/{collection}_{split_type}_train.jsonl", split="train")
        ds_test = load_dataset("json", data_files=f"{base_dir}/{collection}_{split_type}_test.jsonl", split="train")
    print(f"Loaded {len(ds)} train examples and {len(ds_test)} test examples")
    if not (len(ds) > 0 and len(ds_test) > 0):
        print(f"No examples found for {model_name}_{collection}_{split_type}")
        return
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.set_truncation_and_padding(
        padding_strategy=PaddingStrategy.LONGEST,
        pad_to_multiple_of=2,
        truncation_strategy=TruncationStrategy.DO_NOT_TRUNCATE,
        max_length=tokenizer.model_max_length,
        stride=0,
        padding_side=tokenizer.padding_side,
    )

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        # bnb_4bit_use_double_quant=True,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        quantization_config=bnb_config, 
        device_map="auto",
        trust_remote_code=True,
        attn_implementation="flash_attention_3",
        dtype=torch.bfloat16,
    )

    model = prepare_model_for_kbit_training(model, gradient_checkpointing_kwargs={'use_reentrant':True})

    peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.00,
        r=16,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules="all-linear",
    )

    date_str = datetime.datetime.now().strftime("%Y%m%d")
    # date_str = "20250919"
    output_model_name = f"ft:Llama-3.1-8B-Instruct-{loss_type}-{date_str}:{collection}-{split_type}-{epochs}epochs"
    wandb.init(project="llama_finetuning", name=output_model_name)

    # Apply chat template using map function
    def format_messages_sft(x):
        formatted = tokenizer.apply_chat_template(x['messages'], tokenize=False)
        # for completion-only loss, we need to split the prompt and completion
        # TODO: adapt other models, but it is now good with Llama
        start = "<|start_header_id|>assistant<|end_header_id|>"
        start_idx = formatted.find(start)
        if start_idx == -1:
            raise ValueError(f"Start token '{start}' not found in formatted messages: {formatted}")
        prompt = formatted[:start_idx + len(start)].strip()
        completion = formatted[start_idx + len(start):].strip()
        return {"prompt": prompt, "completion": completion}
    
    def format_messages_dpo(x):
        formatted_prompt = tokenizer.apply_chat_template(x['input']['messages'], tokenize=False)
        formatted_chosen = tokenizer.apply_chat_template(x['preferred_output'], tokenize=False)
        formatted_rejected = tokenizer.apply_chat_template(x['non_preferred_output'], tokenize=False)
        return {"prompt": formatted_prompt, "chosen": formatted_chosen, "rejected": formatted_rejected}

    format_messages = format_messages_sft if loss_type == "SFT" else format_messages_dpo

    ds = ds.map(format_messages)
    ds_test = ds_test.map(format_messages)

    # Remove the original 'messages' column to keep only 'prompt' and 'completion'
    ds_columns_to_keep = ['prompt', 'completion'] if loss_type == "SFT" else ['prompt', 'chosen', 'rejected']
    ds_columns_to_remove = [x for x in ds.column_names if x not in ds_columns_to_keep]
    ds = ds.remove_columns(ds_columns_to_remove)
    ds_test_columns_to_remove = [x for x in ds_test.column_names if x not in ds_columns_to_keep]
    ds_test = ds_test.remove_columns(ds_test_columns_to_remove)
    
    config_class = SFTConfig if loss_type == "SFT" else DPOConfig
    specific_training_args = {
        "SFT": {
            "completion_only_loss": True,
            "metric_for_best_model": "eval_loss",
            "learning_rate": 5e-5,
            "per_device_train_batch_size": 8,
            "per_device_eval_batch_size": 8,
            "gradient_accumulation_steps": 32,
            # "padding_free": True,
            "group_by_length": True,
        },
        "DPO": {
            "metric_for_best_model": "eval_rewards/margins",
            "greater_is_better": True,
            "learning_rate": 5e-5,
            "per_device_train_batch_size": 4,
            "per_device_eval_batch_size": 4,
            "gradient_accumulation_steps": 4,
        },
    }
    
    output_dir = f"../../finetuned_models/{output_model_name}"
    resume_from_checkpoint = False
    if os.path.exists(output_dir):
        dir_contents = os.listdir(output_dir)
        for item in dir_contents:
            if item.startswith("checkpoint-"):
                resume_from_checkpoint = True
                print(f"Resuming from checkpoint {item}")
                break

    training_arguments = config_class(
        output_dir=output_dir,
        eval_strategy="epoch",
        eval_steps=1,
        logging_strategy="steps",
        logging_steps=1,
        save_strategy="epoch",
        do_eval=True,
        optim="paged_adamw_8bit",
        eval_on_start=True,
        num_train_epochs=epochs,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        bf16=True,
        tf32=True,
        report_to="wandb",
        save_total_limit=2,
        load_best_model_at_end=True,
        max_length=None,
        # max_length = tokenizer.model_max_length,
        torch_compile=True,
        torch_compile_backend="inductor",
        torch_compile_mode="max-autotune",
        batch_eval_metrics=True,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={'use_reentrant':True},
        **specific_training_args[loss_type],
    )
    
    trainer_class = PerplexitySFTTrainer if loss_type == "SFT" else DPOTrainer
    specific_trainer_args = {
        "SFT": {
            "compute_metrics": compute_metrics,
        },
        "DPO": {},
    }

    trainer = trainer_class(
        model=model,
        train_dataset=ds,
        eval_dataset=ds_test,
        peft_config=peft_config,
        processing_class=tokenizer,
        args=training_arguments,
        **specific_trainer_args[loss_type],
    )
    
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    trainer.save_model()
    trainer.save_state()
    wandb.finish()


if __name__ == "__main__":
    for collection in ["depth", "breadth"]:
        for loss_type in ["SFT", "DPO"]:
            for split_type in ["round", "topic", "group"]:
                main(loss_type=loss_type, split_type=split_type, valid=True, collection=collection, epochs=5)
