import os
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from accelerate import Accelerator
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    default_data_collator,
    get_linear_schedule_with_warmup,
    set_seed,
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, PeftModel, TaskType
from bitsandbytes.optim import AdamW8bit

torch.set_float32_matmul_precision('high')
os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.backends.cudnn.benchmark = True

import matplotlib.pyplot as plt

def plot_token_length_distribution(dataset, tokenizer):
    def compute_length(samples):
        inputs = [f"Question: {q}\nAnswer:" for q in samples['question']]
        targets = [answer for answer in samples['answer']]

        model_inputs = tokenizer(inputs, add_special_tokens=True, truncation=False)
        labels = tokenizer(targets, add_special_tokens=False, truncation=False)

        total_lengths = []
        for input_ids, label_ids in zip(model_inputs["input_ids"], labels["input_ids"]):
            label_ids = label_ids + [tokenizer.eos_token_id]
            total = len(input_ids) + len(label_ids)
            total_lengths.append(total)

        return {"total_length": total_lengths}

    tokenized = dataset.map(
        compute_length,
        batched=True,
        num_proc=4,
        desc="Computing token lengths for plotting",
    )

    all_lengths = tokenized["total_length"]

    plt.figure(figsize=(10, 6))
    plt.hist(all_lengths, bins=50, color='blue', edgecolor='black')
    plt.title("Token Length Distribution (input + output)")
    plt.xlabel("Total tokens")
    plt.ylabel("Number of examples")
    plt.grid(True)
    plt.savefig("hist.png")

    print(f"Mean token length: {sum(all_lengths)/len(all_lengths):.2f}")
    print(f"95th percentile token length: {sorted(all_lengths)[int(0.95*len(all_lengths))]}")
    print(f"Max token length: {max(all_lengths)}")

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--max_length", type=int, default=256)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--num_workers", type=int, default=8)

    parser.add_argument("--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-step-50K-105b")
    parser.add_argument("--rank", type=int, default=8)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.1)

    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--num_warmup_steps", type=int, default=30)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--run_dir", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default="lora_test")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)

    return parser.parse_args()

def preprocess_func(samples, tokenizer, max_length):
    inputs = [f"Question: {q}\nAnswer:" for q in samples['question']]
    targets = [answer for answer in samples['answer']]

    model_inputs = tokenizer(inputs, add_special_tokens=True, truncation=False)
    labels = tokenizer(targets, add_special_tokens=False, truncation=False)

    batch_input_ids, batch_labels, batch_attention_mask = [], [], []

    for input_ids, label_ids in zip(model_inputs["input_ids"], labels["input_ids"]):

        total_length = len(input_ids) + len(label_ids) + 1  # +1 for eos_token_id

        if total_length > max_length:
            continue

        label_ids.append(tokenizer.eos_token_id)

        input_with_label = input_ids + label_ids
        attention_mask = [1] * len(input_with_label)
        labels_ids_padded = [-100] * len(input_ids) + label_ids

        padding_length = max_length - len(input_with_label)

        input_with_label = [tokenizer.pad_token_id] * padding_length + input_with_label
        attention_mask = [0] * padding_length + attention_mask
        labels_ids_padded = [-100] * padding_length + labels_ids_padded

        batch_input_ids.append(input_with_label)
        batch_attention_mask.append(attention_mask)
        batch_labels.append(labels_ids_padded)

    return {
        "input_ids": torch.tensor(batch_input_ids),
        "attention_mask": torch.tensor(batch_attention_mask),
        "labels": torch.tensor(batch_labels),
    }




def main(args):
    set_seed(args.seed)

    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"


    accelerator.init_trackers(
        project_name=f"mistral7b-lora-r{args.rank}-math_genie-epochs{args.num_epochs}-adapter",
        config=vars(args)
    )

    dataset = load_dataset("parquet", data_files={
        'train': 'math_genie/patch/train/math_genie_patch.parquet',
        'validation': 'math_genie/patch/val/math_genie_patch_valid.parquet'
    })

    with accelerator.main_process_first():
        processed_train_ds = dataset["train"].map(
            lambda samples: preprocess_func(samples, tokenizer, args.max_length),
            batched=True,
            num_proc=8,
            remove_columns=dataset["train"].column_names,
            load_from_cache_file=False,
            desc="Tokenizing dataset",
        )

    train_dataloader = DataLoader(
        processed_train_ds,
        shuffle=True,
        collate_fn=default_data_collator,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    if args.run_dir is None:
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, attn_implementation="flash_attention_2")
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["q_proj", "v_proj", "up_proj", "down_proj", "gate_proj", "k_proj", "o_proj"],
            inference_mode=False,
            r=args.rank,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
        )
        model = get_peft_model(model, lora_config)
    else:
        base_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        model = PeftModel.from_pretrained(base_model, os.path.join(args.output_dir, args.run_dir, "checkpoints"), is_trainable=True)

    model = model.to(dtype=torch.bfloat16, device="cuda")
    model.enable_input_require_grads()

    model.print_trainable_parameters()

    optimizer = AdamW8bit(model.parameters(), lr=args.lr, weight_decay=0.01)

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
    )

    model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(model, train_dataloader, optimizer, lr_scheduler)

    torch.cuda.empty_cache()

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    for epoch in range(args.num_epochs):
        model.train()
        total_loss = 0

        for step, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch}")):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

                accelerator.log({"iter_loss": loss.detach().cpu().item()})
                total_loss += loss.detach().cpu().float()

        if accelerator.is_main_process:
            average_loss = total_loss / len(train_dataloader)
            print(f"Epoch {epoch}: Total Loss = {average_loss}")

        accelerator.log({"epoch_loss": total_loss})

        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(args.output_dir)
        if epoch == 3:
            break

    accelerator.end_training()


if __name__ == "__main__":
    args = parse_args()
    main(args)

    

