import os
os.environ["TOKENIZERS_PARALLELISM"] = "1"

import argparse
import random
from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    default_data_collator,
    get_linear_schedule_with_warmup,
    set_seed,
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, PeftModel, TaskType
from bitsandbytes.optim import AdamW8bit
import matplotlib.pyplot as plt

torch.set_float32_matmul_precision('high')
os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.backends.cudnn.benchmark = True

def plot_token_length_distribution(dataset, tokenizer):
    def compute_length(example):
        total_lengths = []
    
        s = example['sentence']
        a = example['option1']
        b = example['option2']
        ans = example['answer']
    
        idx = s.index("_")
        option_sentences = [
            s[:idx] + a + s[idx + 1:],  # Option 1
            s[:idx] + b + s[idx + 1:]   # Option 2
        ]
    
        for sentence in option_sentences:
            prompt = "Complete the sentence:\n"
            input_text = prompt + sentence
            tokenized = tokenizer(input_text, add_special_tokens=True, truncation=False)
            input_ids = tokenized["input_ids"] + [tokenizer.eos_token_id]
            total_lengths.append(len(input_ids))
    
        return {"total_length": total_lengths}

    # Map to get all token lengths
    tokenized = dataset.map(
        compute_length,
        batched=False,
        num_proc=4,
        desc="Computing token lengths (both-choice)"
    )

    all_lengths = tokenized["total_length"]
    print(all_lengths[0])
    all_lengths = [e[0] for e in all_lengths]

    # Print basic stats
    print(f"Mean token length: {sum(all_lengths)/len(all_lengths):.2f}")
    print(f"95th percentile token length: {sorted(all_lengths)[int(0.95*len(all_lengths))]}")
    print(f"Max token length: {max(all_lengths)}")

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--max_length", type=int, default=256, help="Maximum token length")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--num_workers", type=int, default=16, help="Dataloader workers")

    parser.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-v0.1", help="Base model")
    parser.add_argument("--rank", type=int, default=8)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.1)

    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--num_warmup_steps", type=int, default=30)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output_dir", type=str, default="lora_winogrande")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--sample_size", type=int, default=30000, help="Number of samples for training")

    return parser.parse_args()

def preprocess_winogrande(samples, tokenizer, max_length):
    all_input_texts = []
    all_labels = []

    for s, a, b, ans in zip(samples['sentence'], samples['option1'], samples['option2'], samples['answer']):
        idx = s.index("_")
        option_sentences = [
            s[:idx] + a + s[idx + 1:],
            s[:idx] + b + s[idx + 1:]
        ]

        correct_option = int(ans) - 1

        for i, sentence in enumerate(option_sentences):
            prompt = "Complete the sentence:\n"
            input_text = prompt + sentence
            all_input_texts.append(input_text)

            all_labels.append(i == correct_option)

    tokenized = tokenizer(all_input_texts, add_special_tokens=True, truncation=True, max_length=max_length)

    batch_input_ids, batch_labels, batch_attention_mask = [], [], []

    for input_ids, is_correct in zip(tokenized["input_ids"], all_labels):
        if is_correct:
            labels = input_ids.copy()
        else:
            labels = [-100] * len(input_ids)

        padding_length = max_length - len(input_ids)

        # LEFT padding
        input_ids = [tokenizer.pad_token_id] * padding_length + input_ids
        labels = [-100] * padding_length + labels
        attention_mask = [0] * padding_length + [1] * (len(input_ids) - padding_length)

        batch_input_ids.append(input_ids[-max_length:])
        batch_attention_mask.append(attention_mask[-max_length:])
        batch_labels.append(labels[-max_length:])

    return {
        "input_ids": torch.tensor(batch_input_ids),
        "attention_mask": torch.tensor(batch_attention_mask),
        "labels": torch.tensor(batch_labels),
    }



def main(args):
    set_seed(args.seed)

    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    accelerator.init_trackers(
        project_name=f"mistral7b-lora-r{args.rank}-winogrande-epochs{args.num_epochs}",
        config=vars(args)
    )

    dataset = load_dataset("winogrande", "winogrande_xl", split="train", trust_remote_code=True)

    random.seed(args.seed)
    dataset = dataset.shuffle(seed=args.seed)

    train_test_split = dataset.train_test_split(test_size=0.05, seed=args.seed)
    train_dataset = train_test_split["train"]
    val_dataset = train_test_split["test"]

    with accelerator.main_process_first():
        processed_train_ds = train_dataset.map(
            lambda samples: preprocess_winogrande(samples, tokenizer, args.max_length),
            batched=True,
            num_proc=4,
            remove_columns=train_dataset.column_names,
            load_from_cache_file=False,
            desc="Tokenizing Winogrande Train",
        )

        processed_val_ds = val_dataset.map(
            lambda samples: preprocess_winogrande(samples, tokenizer, args.max_length),
            batched=True,
            num_proc=4,
            remove_columns=val_dataset.column_names,
            load_from_cache_file=False,
            desc="Tokenizing Winogrande Validation",
        )

    train_dataloader = DataLoader(
        processed_train_ds,
        shuffle=True,
        collate_fn=default_data_collator,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    val_dataloader = DataLoader(
        processed_val_ds,
        shuffle=False,
        collate_fn=default_data_collator,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, attn_implementation="flash_attention_2")
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "v_proj", "up_proj", "down_proj", "gate_proj", "k_proj", "o_proj"],
        inference_mode=False,
        r=args.rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
    )
    model = get_peft_model(model, lora_config)
    model = model.to(dtype=torch.bfloat16, device="cuda")
    
    model.enable_input_require_grads()

    model.print_trainable_parameters()

    optimizer = AdamW8bit(model.parameters(), lr=args.lr, weight_decay=0.01)

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
    )

    model, train_dataloader, val_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        model, train_dataloader, val_dataloader, optimizer, lr_scheduler
    )

    torch.cuda.empty_cache()

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    min_loss = 1000
    for epoch in range(args.num_epochs):
        model.train()
        total_loss = 0

        for step, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch}")):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)

                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

                accelerator.log({"train/iter_loss": loss.detach().cpu().item()})
                total_loss += loss.detach().cpu().float()

        accelerator.log({"train/epoch_loss": total_loss})
        if accelerator.is_main_process:
            avg_loss = total_loss / len(train_dataloader)
            print(f"Epoch {epoch} Train Loss: {avg_loss:.4f}")

        model.eval()
        val_loss = 0
        for step, batch in enumerate(tqdm(val_dataloader, desc=f"Validation Epoch {epoch}")):
            with torch.no_grad():
                outputs = model(**batch)
            loss = outputs.loss
            val_loss += loss.detach().cpu().float()

        val_loss /= len(val_dataloader)
        accelerator.log({"val/epoch_loss": val_loss})
        if accelerator.is_main_process:
            print(f"Epoch {epoch} Validation Loss: {val_loss:.4f}")

        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            if val_loss < min_loss:
                min_loss = val_loss
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(args.output_dir)
        if epoch == 5:
            break

    accelerator.end_training()


if __name__ == "__main__":
    args = parse_args()
    main(args)


    

