import logging
import os

os.environ["TORCH_USE_CUDA_DSA"] = "True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import sys
import time

import torch
from accelerate import dispatch_model, infer_auto_device_map
from datasets import Dataset

# Hugging Face PEFT imports
from peft import IA3Config, LoraConfig, TaskType, get_peft_model
from setproctitle import setproctitle
from sscompiler.compiler import (
    AbstractTransformer,
    PortableLoRAAdapter,
    mark_adapters_as_trainable,
)
from sscompiler.utils.constants import TARGET_MODULES
from sscompiler.utils.argument_classes import ExperimentOptions, SlimscaleParser
from sscompiler.utils.gsm8k_utils import finetune_math, finetune_math_hf, tokenize_math
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    set_seed,
)

set_seed(0)


if __name__ == "__main__":
    parser = SlimscaleParser()
    parser.add_arguments(ExperimentOptions, dest="", prefix="")
    parser.add_argument("--lora", action="store_true")
    parser.add_argument("--lora-rank", type=int, default=8)
    parser.add_argument("--ia3", action="store_true")
    cli = parser.parse_args()

    setproctitle(f"CLAM Memory Consumption, {cli.model} {cli.task}")

    technique = "none"
    if cli.lora:
        technique = "lora"
    if cli.ia3:
        technique = "ia3"

    model_name = cli.model.split("/")[-1]
    log_name = (
        "memory_consumption"
        f"_model_[{model_name}]"
        f"_task_[{cli.task}]"
        f"_batch_size_[{cli.batch_size}]"
        f"_max_length_[{cli.max_length}]"
        f"_task_[{cli.task}]"
        f"_method_[{technique}]"
    )
    log_dir = os.path.join(os.path.dirname(__file__), "logs")
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, f"{log_name}.out")

    logger = logging.getLogger()
    logger.addHandler(logging.FileHandler(log_file, mode="a", encoding="utf-8"))
    logger.setLevel(logging.INFO)

    logger.info(cli)

    tokenizer = AutoTokenizer.from_pretrained(
        cli.model,
        # model_max_length=cli.max_length,
        padding_side="left",
        use_fast=False,
    )
    auto_config = AutoConfig.from_pretrained(cli.model)

    if "t5" in cli.model:
        auto_model = AutoModelForSeq2SeqLM.from_pretrained(
            cli.model,
            torch_dtype=torch.float32,
            device_map="auto",
            config=auto_config,
            ignore_mismatched_sizes=True,
        )

    else:
        auto_model = AutoModelForCausalLM.from_pretrained(
            cli.model,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            config=auto_config,
            ignore_mismatched_sizes=True,
        )

    if tokenizer.unk_token == None and tokenizer.pad_token == None:
        # raw llama3
        print("adding a special padding token...")
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        need_resize = True
    else:
        tokenizer.pad_token = tokenizer.unk_token
        need_resize = False

    if cli.lora:
        target_modules = [
            "q_proj",
            "v_proj",
            "o_proj",
            "k_proj",
            "gate_proj",
            "down_proj",
            "up_proj",
        ]
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=target_modules,
            r=cli.lora_rank,
            lora_alpha=8,
        )
        auto_model = get_peft_model(auto_model, lora_config)
        logger.info("applied PEFT lora")

    if cli.ia3:
        # Hugging Face PEFT: Apply LoRA to all layers
        if "t5" in cli.model:
            target_modules = [
                "q",
                "k",
                "v",
                "wi",
                "wo",
            ]

            feedforward_modules = ["wi", "wo"]
        else:
            target_modules = [
                "q_proj",
                "v_proj",
                "o_proj",
                "k_proj",
                "gate_proj",
                "down_proj",
                "up_proj",
            ]
            feedforward_modules = ["down_proj", "up_proj"]
        ia3_config = IA3Config(
            task_type=TaskType.CAUSAL_LM,
            target_modules=target_modules,
            feedforward_modules=feedforward_modules,
        )
        auto_model = get_peft_model(auto_model, ia3_config)
        logger.info("Applied Hugging Face PEFT ia3")

    if need_resize:
        auto_model.resize_token_embeddings(len(tokenizer))

    tokenized_train, tokenized_eval = tokenize_math(
        tokenizer=tokenizer,
        validation_set="test",
        # max_length=cli.max_length,
        # padding=cli.should_pad,
    )

    max_length = -1
    longest_item = None
    for i, item in enumerate(tokenized_train):
        length = len(item.get("input_ids", ""))
        if length > max_length:
            max_length = length
            longest_item = item

    assert isinstance(longest_item, dict)
    logger.info("max length: %d", max_length)

    longest_item_list = [longest_item for _ in range(20 * cli.batch_size)]
    new_dataset = Dataset.from_list(longest_item_list)

    start_time = time.time()

    result, history = finetune_math_hf(
        auto_model=auto_model,
        tokenizer=tokenizer,
        tokenized_train=new_dataset,
        tokenized_eval=new_dataset,
        epochs=1,
        batch_size=cli.batch_size,
        learning_rate=1e-4,
        train_head=True,
        use_multi_lr=False,
    )

    batch_time = time.time() - start_time

    for device in range(torch.cuda.device_count()):
        logger.info(
            "device %d max memory: %.4f GiB",
            device,
            torch.cuda.max_memory_allocated() / 1024**3,
        )

    max_memories = [
        torch.cuda.max_memory_allocated(device) / 1024**3
        for device in range(torch.cuda.device_count())
    ]
    logger.info(max_memories)
    logger.info(sum(max_memories))
    logger.info(batch_time)

    sys.exit(0)
