"""
A utility module for handling various operations needed in machine learning experiments.
This includes functions for saving and loading objects, loading and preprocessing data, 
configuring models, and manipulating model training states.
"""

import os
import json
import string
from typing import Any, List, Tuple, Optional
import pickle
from PIL import Image
import torch
from torchvision import datasets, transforms
from transformers import (
    BertForMaskedLM,
    BertForSequenceClassification,
    BertTokenizerFast,
)
from models.vit import ViTForClassification


def save_object(obj: Any, filename: str) -> None:
    """
    Save a Python object to a file using pickle.

    Args:
        obj: The Python object to save.
        filename: The path to the file where the object will be saved.
    """
    with open(filename, "wb") as outp:
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)


def load_object(filename: str) -> Any:
    """
    Load a Python object from a file using pickle.

    Args:
        filename: The path to the file from which the object will be loaded.

    Returns:
        The Python object loaded from the file.
    """
    with open(filename, "rb") as outp:
        obj = pickle.load(outp)
    return obj


def load_raw_images(img_dir: str) -> Tuple[torch.Tensor, List[str]]:
    """
    Load images from a directory, convert them to grayscale, resize to 28x28, and apply a standard transformation.

    Args:
        img_dir: The directory from which images are loaded.

    Returns:
        A tuple containing a batch of tensor images and their corresponding names.
    """
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
    )
    image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp")
    images = []
    images_names = []
    for filename in os.listdir(img_dir):
        if os.path.isfile(
            os.path.join(img_dir, filename)
        ) and filename.lower().endswith(image_extensions):
            image = Image.open(os.path.join(img_dir, filename)).convert("L")
            if image.size != (28, 28):
                image = image.resize((28, 28))
            images.append(transform(image))
            images_names.append(filename.split(".")[0])
    return torch.stack(images), images_names


def load_raw_sents(txt_dir: str) -> Tuple[List[str], List[str]]:
    """
    Load sentences from text files within a specified directory.

    Args:
        txt_dir: The directory from which text files are loaded.

    Returns:
        A tuple containing lists of sentences and their corresponding filenames without extension.
    """
    txts = []
    txts_names = []
    for filename in os.listdir(txt_dir):
        if os.path.isfile(
            os.path.join(txt_dir, filename)
        ) and filename.lower().endswith(".txt"):
            with open(os.path.join(txt_dir, filename), "r", encoding="utf-8") as file:
                txts.append(file.readlines()[0])
            txts_names.append(filename.split(".")[0])
    return txts, txts_names


def load_raw_sent(txt_dir: str, sentence_filename: str) -> Tuple[str, str]:
    """
    Load a single sentence from a specific text file within a directory.

    Args:
        txt_dir: The directory containing the text file.
        sentence_filename: The name of the text file (without the extension).

    Returns:
        A tuple containing the sentence and the filename without extension.
    """
    with open(
        os.path.join(txt_dir, sentence_filename + ".txt"), "r", encoding="utf-8"
    ) as file:
        txt = file.readlines()[0]
    txt_name = sentence_filename.split(".")[0]
    return txt, txt_name


def prepare_data(
    out_dir: str = ".data/MNIST/",
    batch_size: int = 128,
    num_workers: int = 2,
    test: bool = True,
    train_sample_size: Optional[int] = None,
    test_sample_size: Optional[int] = None,
) -> Tuple[Any, ...]:
    """
    Prepare data loaders for the MNIST dataset with options for subsampling and separate train/test loaders.

    Args:
        out_dir: The directory where the MNIST dataset will be downloaded and stored.
        batch_size: The number of samples per batch.
        num_workers: The number of worker processes for data loading.
        test: Whether to prepare a test set loader.
        train_sample_size: Optional number of training samples to use.
        test_sample_size: Optional number of testing samples to use.

    Returns:
        A tuple containing the training data loader, test data loader (if requested), and a tuple of class indices.
    """
    classes = tuple(range(10))
    mean, std = (0.5,), (0.5,)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]
    )
    trainset = datasets.MNIST(out_dir, download=True, train=True, transform=transform)

    if train_sample_size is not None:
        indices = torch.randperm(len(trainset))[:train_sample_size]
        trainset = torch.utils.data.Subset(trainset, indices)

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )

    if test:
        testset = datasets.MNIST(
            out_dir, download=True, train=False, transform=transform
        )

        if test_sample_size is not None:
            indices = torch.randperm(len(testset))[:test_sample_size]
            testset = torch.utils.data.Subset(testset, indices)

        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
        )
        return trainloader, testloader, classes

    return trainloader, classes


def load_config(config_path: str) -> dict:
    """
    Load a configuration file in JSON format.

    Args:
        config_path: The file path of the configuration file.

    Returns:
        The configuration as a dictionary.
    """
    with open(config_path, "r") as f:
        config = json.load(f)
    return config


def load_model(
    model_path: str, config_path: str, device: torch.device
) -> Tuple[ViTForClassification, dict]:
    """
    Load a ViT classification model with a specified configuration, adapted for a specific computing device.

    Args:
        model_path: The file path to the model's state dictionary.
        config_path: The configuration file path for initializing the model.
        device: The torch device on which the model will operate.

    Returns:
        A tuple containing the loaded model and its configuration.
    """
    if device.type == "cpu":
        checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
    else:
        checkpoint = torch.load(model_path)
    config = load_config(config_path)
    model = ViTForClassification(config)
    model.load_state_dict(checkpoint)
    return model, config


def deactivate_dropout_layers(model: torch.nn.Module) -> None:
    """
    Deactivate dropout layers in a model to potentially improve its performance during inference.

    Args:
        model: The model from which dropout layers will be deactivated.
    """
    if isinstance(model, ViTForClassification):
        model.embedding.dropout.p = 0.0
        for block in model.encoder.blocks:
            block.attention.attn_dropout.p = 0.0
            block.attention.output_dropout.p = 0.0
            block.mlp.dropout.p = 0.0
    elif isinstance(model, BertForSequenceClassification) or isinstance(
        model, BertForMaskedLM
    ):
        model.bert.embeddings.dropout.p = 0.0
        if hasattr(model, "dropout"):
            model.dropout.p = 0.0
        for layer in model.bert.encoder.layer:
            layer.attention.self.dropout.p = 0.0
            layer.attention.output.dropout.p = 0.0
            layer.output.dropout.p = 0.0
            if hasattr(layer, "crossattention"):
                layer.crossattention.self.dropout.p = 0.0
                layer.crossattention.output.dropout.p = 0.0


def load_bert_model(
    model_name: str, mask_or_cls: str
) -> Tuple[BertTokenizerFast, BertForSequenceClassification]:
    """
    Load a BERT model for either masked language modeling or sequence classification based on the model name.

    Args:
        model_name: The name of the BERT model variant to load.
        mask_or_cls: A string indicating whether to load a model for "mask" (masked LM) or "cls" (classification).

    Returns:
        A tuple of a tokenizer and the loaded BERT model.
    """
    if mask_or_cls == "mask":
        bert_tokenizer = BertTokenizerFast.from_pretrained(model_name)
        bert_model = BertForMaskedLM.from_pretrained(model_name)
        return bert_tokenizer, bert_model
    bert_tokenizer = BertTokenizerFast.from_pretrained(model_name)
    bert_model = BertForSequenceClassification.from_pretrained(model_name)
    decoder = BertForMaskedLM.from_pretrained("bert-base-uncased")
    bert_model.decoder = lambda x: decoder.cls(decoder.bert.encoder(x)[0])
    return bert_tokenizer, bert_model


def is_punctuation(token: str) -> bool:
    """
    Check if a token consists solely of punctuation characters.

    Args:
        token: The string token to check.

    Returns:
        True if the token consists only of punctuation, otherwise False.
    """
    return all(char in string.punctuation for char in token)


def get_allowed_tokens(tokenizer: BertTokenizerFast) -> List[int]:
    """
    Retrieve a list of token IDs that are not punctuation, special, or numeric tokens from a tokenizer's vocabulary.

    Args:
        tokenizer: The tokenizer whose vocabulary is to be filtered.

    Returns:
        A list of allowable token IDs.
    """
    return [
        idx
        for token, idx in tokenizer.vocab.items()
        if not (
            is_punctuation(token)
            or token in tokenizer.all_special_tokens
            or token.startswith("[unused")
            # or token.startswith("##")
            # or token.isdigit()
        )
    ]
