import os
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from accelerate import Accelerator
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    default_data_collator,
    get_linear_schedule_with_warmup,
    set_seed,
    BitsAndBytesConfig
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, PeftModel, TaskType
from bitsandbytes.optim import AdamW8bit

torch.set_float32_matmul_precision('high')
os.environ["TOKENIZERS_PARALLELISM"] = "1"

from datasets import load_dataset
import torch
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LambdaLR

def plot_token_length_distribution(dataset, tokenizer):
    def compute_length(samples):
        batch_size = len(samples['question'])
        inputs, targets = [], []

        for q, choices, answer_key in zip(samples['question'], samples['choices'], samples['answerKey']):
            # Correct way to parse choices
            choice_texts = [f"({label}) {text}" for label, text in zip(choices['label'], choices['text'])]
            formatted_choices = "\n".join(choice_texts)
            input_text = f"Question: {q}\n{formatted_choices}\n\nAnswer:"
            inputs.append(input_text)
            correct_choice_index = choices["label"].index(answer_key)
            correct_choice_text = choices["text"][correct_choice_index]
            targets.append(correct_choice_text)

        model_inputs = tokenizer(inputs, add_special_tokens=True, truncation=False)
        labels = tokenizer(targets, add_special_tokens=False, truncation=False)

        total_lengths = []
        for input_ids, label_ids in zip(model_inputs["input_ids"], labels["input_ids"]):
            label_ids = label_ids + [tokenizer.eos_token_id]
            total = len(input_ids) + len(label_ids)
            total_lengths.append(total)

        return {"total_length": total_lengths}

    tokenized = dataset.map(
        compute_length,
        batched=True,
        num_proc=4,
        desc="Computing token lengths for plotting",
    )

    all_lengths = tokenized["total_length"]

    plt.figure(figsize=(10, 6))
    plt.hist(all_lengths, bins=50, color='blue', edgecolor='black')
    plt.title("Token Length Distribution (input + output)")
    plt.xlabel("Total tokens")
    plt.ylabel("Number of examples")
    plt.grid(True)
    plt.savefig("hist.png")

    # Print basic stats
    print(f"Mean token length: {sum(all_lengths)/len(all_lengths):.2f}")
    print(f"95th percentile token length: {sorted(all_lengths)[int(0.95*len(all_lengths))]}")
    print(f"Max token length: {max(all_lengths)}")

def parse_args():
    parser = argparse.ArgumentParser()

    # Dataset
    parser.add_argument("--max_length", type=int, default=256)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--num_workers", type=int, default=8)

    # Model
    parser.add_argument("--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-step-50K-105b")
    parser.add_argument("--rank", type=int, default=8)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.1)

    # Training
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--num_warmup_steps", type=int, default=30)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--run_dir", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default="lora_arc_easy")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)

    return parser.parse_args()

def load_arc_easy_dataset(tokenizer, max_length):
    dataset = load_dataset("ai2_arc", "ARC-Easy")

    full_train = dataset["train"].train_test_split(test_size=0.1, seed=42)
    train_dataset = full_train["train"]
    val_dataset = full_train["test"]

    def format_arc(samples):
        batch_size = len(samples['question'])
        inputs, targets = [], []
    
        for q, choices, answer_key in zip(samples['question'], samples['choices'], samples['answerKey']):
            choice_texts = [f"({label}) {text}" for label, text in zip(choices['label'], choices['text'])]
            formatted_choices = "\n".join(choice_texts)
            input_text = f"Question: {q}\n{formatted_choices}\n\nAnswer:"
            inputs.append(input_text)
    
            if answer_key in choices["label"]:
                correct_choice_index = choices["label"].index(answer_key)
                correct_choice_text = choices["text"][correct_choice_index]
            else:
                correct_choice_text = ""
    
            targets.append(correct_choice_text)
    
        tokenizer.padding_side = "left"  # safe for causal LM
        model_inputs = tokenizer(inputs, add_special_tokens=True, truncation=True, max_length=max_length)
    
        labels = tokenizer(targets, add_special_tokens=False, truncation=True, max_length=max_length)
    
        batch_input_ids, batch_labels, batch_attention_mask = [], [], []
    
        for input_ids, label_ids in zip(model_inputs["input_ids"], labels["input_ids"]):
            if len(label_ids) == 0:
                label_ids = [tokenizer.eos_token_id]
            else:
                label_ids.append(tokenizer.eos_token_id)
    
            input_with_label = input_ids + label_ids
            attention_mask = [1] * len(input_with_label)
            labels_ids_padded = [-100] * len(input_ids) + label_ids
    
            padding_length = max_length - len(input_with_label)
    
            input_with_label = [tokenizer.pad_token_id] * padding_length + input_with_label
            attention_mask = [0] * padding_length + attention_mask
            labels_ids_padded = [-100] * padding_length + labels_ids_padded
    
            batch_input_ids.append(input_with_label[-max_length:])
            batch_attention_mask.append(attention_mask[-max_length:])
            batch_labels.append(labels_ids_padded[-max_length:])
    
        return {
            "input_ids": torch.tensor(batch_input_ids),
            "attention_mask": torch.tensor(batch_attention_mask),
            "labels": torch.tensor(batch_labels),
        }

    return train_dataset, val_dataset, format_arc


def main(args):
    set_seed(args.seed)

    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="bf16")
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    accelerator.init_trackers(
        project_name=f"lora_arc_easy-r{args.rank}-epochs{args.num_epochs}",
        config=vars(args)
    )

    train_dataset, val_dataset, preprocess_func = load_arc_easy_dataset(tokenizer, args.max_length)

    with accelerator.main_process_first():
        processed_train_ds = train_dataset.map(
            preprocess_func,
            batched=True,
            num_proc=4,
            remove_columns=train_dataset.column_names,
            load_from_cache_file=False,
            desc="Tokenizing ARC-Easy train",
        )
    
        processed_val_ds = val_dataset.map(
            preprocess_func,
            batched=True,
            num_proc=4,
            remove_columns=val_dataset.column_names,
            load_from_cache_file=False,
            desc="Tokenizing ARC-Easy validation",
        )

        train_dataloader = DataLoader(
            processed_train_ds,
            shuffle=True,
            collate_fn=default_data_collator,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=True,
        )
        
        val_dataloader = DataLoader(
            processed_val_ds,
            shuffle=False,
            collate_fn=default_data_collator,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=True,
        )


    if args.run_dir is None:
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, attn_implementation="flash_attention_2")
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["q_proj", "v_proj", "up_proj", "down_proj", "gate_proj", "k_proj", "o_proj"],
            inference_mode=False,
            r=args.rank,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
        )
        model = get_peft_model(model, lora_config)
    else:
        base_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        model = PeftModel.from_pretrained(base_model, os.path.join(args.output_dir, args.run_dir, "checkpoints"), is_trainable=True)

    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()

    model.print_trainable_parameters()

    optimizer = AdamW8bit(model.parameters(), lr=args.lr, weight_decay=0.01)

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
    )

    model, train_dataloader, val_dataloader, optimizer, lr_scheduler = accelerator.prepare(model, train_dataloader, val_dataloader, optimizer, lr_scheduler)

    torch.cuda.empty_cache()

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)
    min_loss = 1000
    patience = 0
    for epoch in range(args.num_epochs):
        accelerator.wait_for_everyone()  # Sync all processes
        model.train()
        total_loss = 0

        for step, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch}")):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

                accelerator.log({"iter_loss": loss.detach().cpu().item()})
                total_loss += loss.detach().cpu().float()

        if accelerator.is_main_process:
            average_loss = total_loss / len(train_dataloader)
            print(f"Epoch {epoch}: Total Loss = {average_loss:.4f}")

        model.eval()
        val_loss = 0
        
        for step, batch in enumerate(tqdm(val_dataloader, desc=f"Validation Epoch {epoch}")):
            with torch.no_grad():
                outputs = model(**batch)
            loss = outputs.loss
            val_loss += loss.detach().cpu().float()
        
        val_loss /= len(val_dataloader)
        accelerator.log({"validation_loss": val_loss})
        
        if accelerator.is_main_process:
            print(f"Validation Loss after epoch {epoch}: {val_loss:.4f}")

        if accelerator.is_main_process:
            if val_loss < min_loss:
                min_loss = val_loss
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(args.output_dir)
    
    accelerator.end_training()


if __name__ == "__main__":
    args = parse_args()
    main(args)