import os
os.environ["TOKENIZERS_PARALLELISM"] = "1"

import argparse
from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    default_data_collator,
    get_linear_schedule_with_warmup,
    set_seed,
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, PeftModel, TaskType
from bitsandbytes.optim import AdamW8bit
from torch.optim.lr_scheduler import LambdaLR

TARGET_MAPPING = {
    0: "no",
    1: "yes",
}

torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

def plot_token_length_distribution(dataset, tokenizer):
    def compute_length(samples):
        inputs = [f"{p}\nQuestion: {q}?\nAnswer:" for p, q in zip(samples['passage'], samples['question'])]
        targets = [TARGET_MAPPING[label] for label in samples['label']]

        model_inputs = tokenizer(inputs, add_special_tokens=True, truncation=False)
        labels = tokenizer(targets, add_special_tokens=False, truncation=False)

        total_lengths = []
        for input_ids, label_ids in zip(model_inputs["input_ids"], labels["input_ids"]):
            label_ids = label_ids + [tokenizer.eos_token_id]
            total = len(input_ids) + len(label_ids)
            total_lengths.append(total)

        return {"total_length": total_lengths}

    # Map to get all token lengths
    tokenized = dataset.map(
        compute_length,
        batched=True,
        num_proc=4,
        desc="Computing token lengths for plotting",
    )

    all_lengths = tokenized["total_length"]

    plt.figure(figsize=(10, 6))
    plt.hist(all_lengths, bins=50, color='blue', edgecolor='black')
    plt.title("Token Length Distribution (input + output)")
    plt.xlabel("Total tokens")
    plt.ylabel("Number of examples")
    plt.grid(True)
    plt.savefig("hist.png")

    print(f"Mean token length: {sum(all_lengths)/len(all_lengths):.2f}")
    print(f"95th percentile token length: {sorted(all_lengths)[int(0.95*len(all_lengths))]}")
    print(f"Max token length: {max(all_lengths)}")


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--max_length", type=int, default=256, help="Maximum token length")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--num_workers", type=int, default=32, help="Number of dataloader workers")

    parser.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-v0.1", help="Model to finetune")
    parser.add_argument("--rank", type=int, default=8)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.1)

    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--num_warmup_steps", type=int, default=30)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--run_dir", type=str, default=None, help="LoRA weights path if resuming")
    parser.add_argument("--output_dir", type=str, default="lora_boolq")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)

    return parser.parse_args()


def preprocess_boolq(samples, tokenizer, max_length):
    inputs = [f"{p}\nQuestion: {q}?\nAnswer:" for p, q in zip(samples['passage'], samples['question'])]
    targets = [TARGET_MAPPING[label] for label in samples['label']]

    model_inputs = tokenizer(inputs, add_special_tokens=True, truncation=True, max_length=max_length)
    labels = tokenizer(targets, add_special_tokens=False)

    batch_input_ids = []
    batch_attention_mask = []
    batch_labels = []

    for input_ids, label_ids in zip(model_inputs["input_ids"], labels["input_ids"]):
        label_ids = label_ids + [tokenizer.eos_token_id]

        combined_input = input_ids + label_ids
        combined_attention_mask = [1] * len(combined_input)

        combined_labels = [-100] * len(input_ids) + label_ids

        padding_length = max_length - len(combined_input)

        padded_input_ids = [tokenizer.pad_token_id] * padding_length + combined_input
        padded_attention_mask = [0] * padding_length + combined_attention_mask
        padded_labels = [-100] * padding_length + combined_labels

        batch_input_ids.append(padded_input_ids[:max_length])
        batch_attention_mask.append(padded_attention_mask[:max_length])
        batch_labels.append(padded_labels[:max_length])

    return {
        "input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(batch_attention_mask, dtype=torch.long),
        "labels": torch.tensor(batch_labels, dtype=torch.long),
    }


def main(args):
    set_seed(args.seed)

    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    accelerator.init_trackers(
        project_name=f"mistral7b-lora-r{args.rank}-boolq-epochs{args.num_epochs}",
        config=vars(args)
    )

    dataset = load_dataset("super_glue", "boolq", trust_remote_code=True)

    train_dataset = dataset["train"]
    val_dataset = dataset["validation"]

    with accelerator.main_process_first():
        processed_train_ds = train_dataset.map(
            lambda samples: preprocess_boolq(samples, tokenizer, args.max_length),
            batched=True,
            num_proc=4,
            remove_columns=train_dataset.column_names,
            load_from_cache_file=False,
            desc="Tokenizing BoolQ Train",
        )

        processed_val_ds = val_dataset.map(
            lambda samples: preprocess_boolq(samples, tokenizer, args.max_length),
            batched=True,
            num_proc=4,
            remove_columns=val_dataset.column_names,
            load_from_cache_file=False,
            desc="Tokenizing BoolQ Validation",
        )

    train_dataloader = DataLoader(
        processed_train_ds,
        shuffle=True,
        collate_fn=default_data_collator,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
    )

    val_dataloader = DataLoader(
        processed_val_ds,
        shuffle=False,
        collate_fn=default_data_collator,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
    )

    if args.run_dir is None:
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, attn_implementation="flash_attention_2")

        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["q_proj", "v_proj", "up_proj", "down_proj", "gate_proj", "k_proj", "o_proj"],
            inference_mode=False,
            r=args.rank,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
        )
        model = get_peft_model(model, lora_config)
        model = model.to(dtype=torch.bfloat16, device="cuda")

    model.enable_input_require_grads()

    model.print_trainable_parameters()

    optimizer = AdamW8bit(model.parameters(), lr=args.lr, weight_decay=0.01)

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
    )
    

    model, train_dataloader, val_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        model, train_dataloader, val_dataloader, optimizer, lr_scheduler
    )

    torch.cuda.empty_cache()

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    min_loss = 1000
    patience = 0
    for epoch in range(args.num_epochs):
        accelerator.wait_for_everyone()
        model.train()
        total_loss = 0

        for step, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch}")):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

                accelerator.log({"train/iter_loss": loss.detach().cpu().item()})
                total_loss += loss.detach().cpu().float()

        accelerator.log({"train/epoch_loss": total_loss})
        if accelerator.is_main_process:
            average_loss = total_loss / len(train_dataloader)
            print(f"Epoch {epoch} Train Loss: {average_loss:.4f}")

        model.eval()
        val_loss = 0
        for step, batch in enumerate(tqdm(val_dataloader, desc=f"Validation Epoch {epoch}")):
            with torch.no_grad():
                outputs = model(**batch)
            loss = outputs.loss
            val_loss += loss.detach().cpu().float()

        val_loss /= len(val_dataloader)
        accelerator.log({"val/epoch_loss": val_loss})
        if accelerator.is_main_process:
            print(f"Epoch {epoch} Validation Loss: {val_loss:.4f}")

        if accelerator.is_main_process:
            if val_loss < min_loss:
                patience = 0
                min_loss = val_loss
                stop_training = False
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(args.output_dir)
                print("Saved Model")
            else:
                patience += 1
                if patience == 5:
                    stop_training = True
                    patience = 0
        else:
            stop_training = None  # placeholder
        
        stop_tensor = torch.tensor(int(stop_training) if stop_training is not None else 0, device=accelerator.device)
        stop_tensor = accelerator.gather_for_metrics(stop_tensor)
        stop_training = bool(stop_tensor[0].item())
            
        if stop_training:
            break

    accelerator.end_training()


if __name__ == "__main__":
    args = parse_args()
    main(args)


    

