import os

os.environ["TORCH_USE_CUDA_DSA"] = "True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import logging
import sys

import torch
from setproctitle import setproctitle
from sscompiler.compiler import AbstractTransformer, PortableLoRAAdapter
from sscompiler.utils.argument_classes import ExperimentOptions, SlimscaleParser
from sscompiler.utils.constants import SUPERGLUE_DATASETS, TARGET_MODULES
from sscompiler.utils.experiments import finetune_at
from sscompiler.utils.tokenization import tokenize_glue
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    set_seed,
)

from datasets import Dataset, load_dataset

set_seed(0)


def lora(at: AbstractTransformer):
    at.inject_adapter(
        ["key", "value"],
        lambda x: PortableLoRAAdapter(
            x,
            in_features=x.in_features,
            out_features=x.out_features,
            r=8,
        ),
    )


def lora(at: AbstractTransformer):
    at.inject_adapter(
        ["key", "value"],
        lambda x: PortableLoRAAdapter(
            x,
            in_features=x.in_features,
            out_features=x.out_features,
            r=8,
        ),
    )


if __name__ == "__main__":
    parser = SlimscaleParser()
    parser.add_arguments(ExperimentOptions, dest="", prefix="")
    parser.add_argument("--lora", action="store_true")
    cli = parser.parse_args()

    setproctitle(f"CLAM Memory Consumption, {cli.model} {cli.task}")

    model_name = cli.model.split("/")[-1]
    log_name = (
        "memory_consumption"
        f"_model_[{model_name}]"
        f"_task_[{cli.task}]"
        f"_batch_size_[{cli.batch_size}]"
        f"_max_length_[{cli.max_length}]"
        f"_task_[{cli.task}]"
        f"_lora_[{cli.lora}]"
    )
    log_dir = os.path.join(os.path.dirname(__file__), "logs")
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, f"{log_name}.out")

    logger = logging.getLogger()
    logger.addHandler(logging.FileHandler(log_file, mode="a", encoding="utf-8"))
    logger.setLevel(logging.INFO)

    logger.info(cli)

    tokenizer = AutoTokenizer.from_pretrained(
        cli.model,
        model_max_length=cli.max_length,
        padding_side="left",
        use_fast=False,
    )

    if cli.task == "boolq":
        num_labels = 2
    elif cli.task == "stsb":
        num_labels = 1
    else:
        raw_datasets = load_dataset(
            "super_glue" if cli.task in SUPERGLUE_DATASETS else "nyu-mll/glue",
            cli.task,
        )
        num_labels = len(raw_datasets["train"].features["label"].names)

    auto_config = AutoConfig.from_pretrained(
        cli.model,
        num_labels=num_labels,
        finetuning_task=cli.task,
    )
    auto_model = AutoModelForSequenceClassification.from_pretrained(
        cli.model,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        config=auto_config,
        ignore_mismatched_sizes=True,
    )
    at = AbstractTransformer(
        model_dir=cli.model,
        groups=TARGET_MODULES[cli.model],
        auto_model=auto_model,
    )

    if cli.lora:
        lora(at)

    if hasattr(auto_config, "max_position_embeddings"):
        max_pos_embeddings = auto_config.max_position_embeddings
    else:
        max_pos_embeddings = cli.max_length

    block_size = tokenizer.model_max_length
    if block_size > max_pos_embeddings:
        block_size = min(cli.max_lenth, max_pos_embeddings)

    tokenized_train, tokenized_eval = tokenize_glue(
        tokenizer=tokenizer,
        task=cli.task,
        model=at.auto_model,
        should_pad=cli.should_pad,
        max_length=cli.max_length,
        full_train=True,
    )

    max_length = -1
    longest_item = None
    for i, item in enumerate(tokenized_train):
        length = len(item.get("input_ids", -1))
        if length > max_length:
            max_length = length
            longest_item = item

    assert isinstance(longest_item, dict)

    longest_item_list = [longest_item for _ in range(25 * cli.batch_size)]
    new_dataset = Dataset.from_list(longest_item_list)

    finetune_at(
        at=at,
        task=cli.task,
        tokenizer=tokenizer,
        tokenized_train=new_dataset,
        tokenized_eval=tokenized_eval,
        epochs=5,
        batch_size=cli.batch_size,
        train_head=True,
    )

    max_memories = [
        torch.cuda.max_memory_allocated(device)
        for device in range(torch.cuda.device_count())
    ]
    print(max_memories)
