from collections.abc import Mapping
import os
import random
import sys
from typing import Any, Dict, List, Union

import yaml
import numpy as np
import torch
from dotenv import load_dotenv
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from transformers.data.data_collator import _torch_collate_batch
from peft import LoraConfig, get_peft_model

# So we can import sibling modules
sys.path.append(os.getcwd())

from src.utils.dataset import get_dataset


def set_seed(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)


class SamePaddingAndEosTokenCollator(DataCollatorForLanguageModeling):
    """Custom data collator so that GPT2 model family can learn to generate the EOS token."""

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        """
        The method extends the logic of its parent so that the first trailing padding
        token is not zeroed out if it is also used as an EOS token.
        """
        batch = super().torch_call(examples)
        if not self.mlm and self.tokenizer.pad_token_id is not None and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
            # Identify the EOS tokens locations (the first zeroed-out token for each batch)
            # In this case, argmax returns the first True value from the array
            eos_tokens = (batch["labels"] == -100).int().argmax(dim=1)
            batch["labels"][np.arange(len(eos_tokens)), eos_tokens] = self.tokenizer.eos_token_id
        return batch


if __name__ == "__main__":
    load_dotenv()
    CONFIG_DIR = os.getenv("CONFIG_DIR", "gpt2")
    DATASET = os.getenv("DATASET", "tldr")
    DATA_DIR = os.getenv("DATA_DIR", ".")
    WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))

    with open(os.path.join("configs", CONFIG_DIR, "sft.yaml")) as file:
        config = yaml.safe_load(file)
    if "gradient_accumulation_steps" in config["training_kwargs"] and WORLD_SIZE > 1:
        # If the model is too large and we cannot fit the entire batch to the GPU at once, 
        # we can do the same by accumulating the gradients over multiple smaller batches.
        # This checks if training in parallel on multiple GPUs, as DDP is an implicit gradient
        # accumulation on its own, so that the overall number of samples stays the same
        # regardless of the number of GPUs used.
        old_grad_val = config["training_kwargs"]["gradient_accumulation_steps"]
        new_grad_val = max(1, old_grad_val // WORLD_SIZE)
        config["training_kwargs"]["gradient_accumulation_steps"] = new_grad_val
        print(
            f"Reducing gradient accumulation from {old_grad_val} to {new_grad_val} due to DDP world size {WORLD_SIZE}"
        )
    if os.getenv("WANDB_NAME", None) is None:
        os.environ["WANDB_NAME"] = CONFIG_DIR + " SFT " + DATASET
    set_seed(42)

    tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
    model = AutoModelForCausalLM.from_pretrained(config["model_name"], use_cache=False)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id
    if "lora_config" in config:
        peft_config = LoraConfig(**config["lora_config"])
        model = get_peft_model(model, peft_config)

    # Set up the datasets
    train_dataset = get_dataset(DATASET, tokenizer, "sft", split="train", max_length=config["max_length"])
    valid_dataset = get_dataset(DATASET, tokenizer, "sft", split="valid", max_length=config["max_length"])

    # Create a preprocessing function to extract out the proper logits from the model output
    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    # Prepare the trainer and start training
    training_args = TrainingArguments(
        output_dir=os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", config["model_directory"]),
        **config["training_kwargs"],
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        data_collator=SamePaddingAndEosTokenCollator(tokenizer, mlm=False),
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    trainer.train()
    trainer.save_model(os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", config["model_directory"]))
    tokenizer.save_pretrained(os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", config["model_directory"]))
