import argparse

import numpy as np
import torch
from loguru import logger
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, Subset

from src.data import build_dataset, get_dataset_collate_fn
from src.models import load_model


def get_fewshot_indices(targets: torch.Tensor, fewshot_k: int) -> list[int]:
    """Get the indices of the features that are use for training the linear probe"""
    length = len(targets)
    perm = [p.item() for p in torch.randperm(length)]
    idxs = []
    counts = {}
    num_classes = 0

    for p in perm:
        target = targets[p].item()
        if target not in counts:
            counts[target] = 0
            num_classes += 1

        if fewshot_k < 0 or counts[target] < fewshot_k:
            counts[target] += 1
            idxs.append(p)

    for c, val in counts.items():
        if fewshot_k > 0 and val != fewshot_k:
            raise ValueError(
                f"insufficient data for eval with {fewshot_k} samples per class, only {val} samples for class {c}"
            )

    return idxs


def count_samples(dataset: Dataset) -> int:
    """Count the number of samples in a dataset."""
    count = 0
    for _ in dataset:
        count += 1
    return count


def get_extraction_model_n_dataloader(
    args: argparse.Namespace,
    dataset_root: str,
    task: str,  # noqa: ARG001
) -> tuple[torch.nn.Module, DataLoader | None, DataLoader]:
    """Get the extraction model and the dataloaders for the training, validation, and testing splits."""
    # TODO: Add type hints for the arguments # noqa: TD002, TD003
    if not isinstance(args.model, str):
        raise TypeError("Model should be a string")

    logger.info(
        f"Load model and use {'no' if args.feature_alignment is None else args.feature_alignment} feature alignment",
    )
    model, transform = load_model(
        model_name=args.model,
        source=args.model_source,
        model_parameters=args.model_parameters,
        module_names=args.module_names,
        feature_alignment=args.feature_alignment,
        device=args.device,
    )
    logger.info(f"> Load datasets and dataloaders for {args.dataset}. ")

    eval_dataset = build_dataset(
        dataset_name=args.dataset,
        root=dataset_root,
        transform=transform,
        split=args.split,  # by default this is the test split
        download=True,
        wds_cache_dir=args.wds_cache_dir,
    )

    nr_samples_eval = count_samples(eval_dataset)
    eval_dataset = eval_dataset.with_length(nr_samples_eval)

    collate_fn = get_dataset_collate_fn(args.dataset)

    try:
        logger.info(f"Dataset size: {len(eval_dataset)}")
    except TypeError:
        logger.info("IterableDataset has no len()")

    logger.info(f"Dataset split: {args.split}")

    if hasattr(eval_dataset, "classes") and eval_dataset.classes:
        try:
            logger.info(f"Dataset classes: {eval_dataset.classes}")
            logger.info(f"Dataset number of classes: {len(eval_dataset.classes)}")
        except AttributeError:
            logger.info("Dataset has no classes.")

    # Get the dataloader for the split we want to evaluate on, by default this is the test split
    logger.info(f"Eval dataset data loader creation.")
    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_fn,
        pin_memory=True,
    )

    # we also need the train and validation splits for linear probing.
    train_dataset = build_dataset(
        dataset_name=args.dataset,
        root=dataset_root,
        transform=transform,
        split=args.train_split,
        download=True,
    )
    if train_dataset:
        logger.info(f"Train dataset data loader creation.")
        nr_samples_train = count_samples(train_dataset)
        train_dataloader = DataLoader(
            train_dataset.with_length(nr_samples_train),
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )
    else:
        train_dataloader = None
        nr_samples_train = None

    logger.info(f"> Train dataset size: {nr_samples_train}")
    logger.info(f"> Eval dataset size: {nr_samples_eval}")

    return model, train_dataloader, eval_dataloader


class SubsetWithTargets(Subset):
    """Simple dataset wrapper with targets."""

    def __init__(self, dataset: Dataset, indices: list[int]) -> None:
        super().__init__(dataset, indices)
        self.targets = np.array([dataset.targets[i] for i in indices])

    @property
    def feature_dims(self) -> int:
        """Get the feature dimensions of the dataset."""
        return self.dataset.feature_dims


def create_train_val_loaders(
    train_loader: DataLoader,
    val_proportion: float = 0.1,
    seed: int = 42,
    pin_memory: bool = True,
) -> tuple[DataLoader, DataLoader]:
    """Creates train and validation loaders from a given train loader.

    Args:
        train_loader (DataLoader): The training data loader to split.
        val_proportion (float, optional): Proportion of the dataset to include in the validation set. Defaults to 0.1.
        seed (int, optional): Random seed for reproducibility. Defaults to 42.

    Returns:
        tuple[DataLoader, DataLoader]: A tuple containing:
            - train_loader: DataLoader for the training split
            - val_loader: DataLoader for the validation split
    """
    train_dataset = train_loader.dataset
    targets = np.array(train_dataset.targets)
    train_indices, val_indices = train_test_split(
        np.arange(targets.shape[0]),
        test_size=val_proportion,
        stratify=targets,
        random_state=seed,
    )
    tmp_train_dataset = SubsetWithTargets(train_dataset, indices=train_indices)
    tmp_val_dataset = SubsetWithTargets(train_dataset, indices=val_indices)

    def check_batch_size(dataset: Dataset) -> int:
        batch_size = train_loader.batch_size
        if len(dataset) % batch_size == 1:
            batch_size += 1
        return batch_size

    tmp_train_batch_size = check_batch_size(tmp_train_dataset)
    tmp_val_batch_size = check_batch_size(tmp_val_dataset)

    tmp_train_loader = DataLoader(
        tmp_train_dataset,
        batch_size=tmp_train_batch_size,
        pin_memory=pin_memory,
        shuffle=True,
        num_workers=train_loader.num_workers,
    )
    tmp_val_loader = DataLoader(
        tmp_val_dataset,
        batch_size=tmp_val_batch_size,
        pin_memory=pin_memory,
        shuffle=False,
        num_workers=train_loader.num_workers,
    )
    return tmp_train_loader, tmp_val_loader
