import os
import argparse
import torch
import torch.nn as nn
import transformers
import wandb
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
import utils
import pickle as pkl
from loguru import logger
from utils import set_logger, train_loop, ClassificationModel, TokenizerCollator


def get_target_modules(model_name):
    if "Qwen2" in model_name or "Llama" in model_name:
        return ["q_proj", "k_proj", "v_proj"]
    elif "OpenELM" in model_name:
        return ["qkv_proj"]


def get_num_attention_layers(model_name, model_config):
    if "Qwen2" in model_name or "Llama" in model_name:
        return model_config.num_hidden_layers
    elif "OpenELM" in model_name:
        return model_config.num_transformer_layers


def main(args):
    accelerator = Accelerator(mixed_precision="bf16")

    identifier = f"{args.dataset_name.replace('/', '_')}_{args.model_name.replace('/', '_')}_layer{args.start_layer}_{args.token_selection}_token_LoRA"
    output_dir = os.path.join(args.result_path, identifier)

    if os.path.exists(output_dir):
        if accelerator.is_main_process:
            print("Output directory already exists. Exiting.")
        accelerator.end_training()
        return

    if accelerator.is_main_process:
        logfile_path = os.path.join(args.log_dir, f"{identifier}.log")
        set_logger(logfile_path=logfile_path)
        logger.info(f"Logging to {logfile_path}")

    # Load the dataset.
    dataset = utils.dataset_load(args.dataset_name)
    train_dataset, test_dataset = dataset["train"], dataset["test"]
    if accelerator.is_main_process:
        logger.info(f"Prepared dataset: {args.dataset_name}, train size: {len(train_dataset)}, test size: {len(test_dataset)}")

    # Initialize the tokenizer and model.
    tokenizer, model = utils.tokenizer_model_load(args.model_name, torch_dtype=torch.bfloat16)
    if accelerator.is_main_process:
        logger.info(f"Loaded tokenizer and model: {args.model_name}")

    if 'Qwen' in args.model_name or 'Llama' in args.model_name:
        model.lm_head = nn.Identity()

    collate_fn = TokenizerCollator(tokenizer)

    num_labels = len(set(sample["labels"] for sample in train_dataset))

    total_layers = get_num_attention_layers(args.model_name, model.config)
    if accelerator.is_main_process:
        logger.info(f"Total layers: {total_layers}")
    if args.start_layer >= total_layers:
        accelerator.end_training()
        raise ValueError(f"start_layer must be less than total_layers ({total_layers}).")
    if args.start_layer < 0:
        accelerator.end_training()
        raise ValueError("start_layer must be a non-negative integer.")
    layers = [i for i in range(args.start_layer, total_layers)]

    config = LoraConfig(
        r=args.lora_rank,
        lora_alpha=args.lora_rank,
        target_modules=get_target_modules(args.model_name),
        layers_to_transform=layers,
        bias="none",
        lora_dropout=0.05,
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, config)
    model.config.use_cache = False  # Avoid extra memory usage since we are not doing generation

    model = ClassificationModel(model, num_labels=num_labels, token_selection=args.token_selection)

    if accelerator.is_main_process:
        for name, param in model.named_parameters():
            if param.requires_grad:
                    logger.info(f"Trainable parameter: {name}")

    # Optimizer and Scheduler.
    optimizer = transformers.Adafactor(
        params=[p for p in model.parameters() if p.requires_grad],
        lr=args.learning_rate,
        eps=(1e-30, 1e-3),
        clip_threshold=1.0,
        decay_rate=-0.8,
        beta1=None,
        weight_decay=1e-5,
        relative_step=False,
        scale_parameter=False,
        warmup_init=False,
    )
    scheduler = transformers.get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.05 * args.n_steps),
        num_training_steps=args.n_steps,
    )

    # DataLoaders
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                  collate_fn=collate_fn, num_workers=args.num_workers)
    eval_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                                 collate_fn=collate_fn, num_workers=args.num_workers)

    model, optimizer, scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, scheduler, train_dataloader, eval_dataloader
    )

    if accelerator.is_main_process:
        wandb.init(
            name=identifier,
            config=vars(args),
        )

    # Run training.
    if accelerator.is_main_process:
        logger.info("Starting training...")
    train_curve, eval_curve = train_loop(model, train_dataloader, eval_dataloader, optimizer, scheduler, accelerator, args)

    # Save the final model checkpoint only on the main process.
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving the model checkpoint and train/val curves to {output_dir}")
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(output_dir, safe_serialization=True)
        with open(os.path.join(output_dir, "train_curve.pkl"), "wb") as f:
            pkl.dump(train_curve, f)
        with open(os.path.join(output_dir, "eval_curve.pkl"), "wb") as f:
            pkl.dump(eval_curve, f)
        wandb.finish()

    accelerator.end_training()


if __name__ == "__main__":
    @logger.catch()
    def run():
        parser = argparse.ArgumentParser(description="LoRA Fine-Tuning")

        # Command-line arguments.
        parser.add_argument("--dataset_name", type=str, default="rotten_tomatoes",
                            help="Name of the dataset to be used")
        parser.add_argument("--model_name", type=str, default="Qwen/Qwen2-0.5B",
                            help="Name of the pre-trained model")
        parser.add_argument("--batch_size", type=int, default=1,
                            help="Batch size for training and evaluation")
        parser.add_argument("--n_steps", type=int, default=5000,
                            help="Total number of training steps")
        parser.add_argument("--eval_steps", type=int, default=500,
                            help="Interval (in steps) for evaluation")
        parser.add_argument("--logging_steps", type=int, default=100,
                            help="Interval (in steps) for logging training loss")
        parser.add_argument("--learning_rate", type=float, default=1e-4,
                            help="Learning rate for the optimizer")
        parser.add_argument("--lora_rank", type=int, default=32,
                            help="LoRA rank setting")
        parser.add_argument("--num_workers", type=int, default=4,
                            help="Number of worker threads for data loading")
        parser.add_argument("--start_layer", type=int, default=0,
                            help="The starting layer index to apply LoRA to")
        parser.add_argument("--token_selection", type=str, default="mean",
                            help="Token selection method: 'last' or 'mean'")
        parser.add_argument("--result_path", type=str, default="./lora",
                            help="Path to save the best model checkpoint and train/val curves")
        parser.add_argument("--log_dir", type=str, default="./lora_logs",
                            help="Directory to save log files")
        parser.add_argument("--wandb_entity", type=str, required=True,
                            help="Wandb entity name")
        parser.add_argument("--wandb_project", type=str, default="NTPS",
                            help="Wandb project name")
        args = parser.parse_args()
        main(args)

    run()