from collections import defaultdict
from collections.abc import Callable, Iterator
from pathlib import Path

import thingsvision
import torch
from loguru import logger
from thingsvision import get_extractor
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.models.collate_fn import CollateContextManager


class ThingsvisionModel:
    """Wrapper class for thingsvision models."""

    def __init__(
        self,
        extractor: thingsvision.core.extraction.base.BaseExtractor,
        module_names: str | list[str],
        feature_alignment: str | None = None,
        flatten_acts: bool = False,
    ) -> None:
        self._extractor = extractor
        self._module_names = [module_names] if isinstance(module_names, str) else module_names
        self._extractor.model = self._extractor.model.to(extractor.device)
        self._extractor.activations = {}
        self._output_type = "tensor"
        self._alignment_type = feature_alignment
        self._flatten_acts = flatten_acts

    def _check_modules_path_exists(self, output_dir: str) -> bool:
        """Check if the model path exists."""
        for module_name in self._module_names:
            path = Path(output_dir) / module_name
            logger.info(path)
            if not path.exists():
                path.mkdir(parents=True, exist_ok=True)
                logger.warning(f"Module path does not exist: {path}. Creating it.")

    def extract_features(self, batches: DataLoader, output_dir: str, split: str = "_train") -> None:
        """Extract features from model for all batches in a dataloader of images. Features are extracted with the
        thingsvision extractor and saved in the output_dir.

        Args:
            batches: DataLoader of images.
            output_dir: Directory to save the features.
            split: Suffix for the filename of the features.
        """
        self._check_modules_path_exists(output_dir)

        with CollateContextManager(batches) as dataloader:
            self._extract_features(
                batches=dataloader,
                module_names=self._module_names,
                flatten_acts=self._flatten_acts,
                output_type=self._output_type,
                output_dir=output_dir,
                file_name_suffix=split,
                save_in_one_file=True,
            )

    def _extract_features(
        self,
        batches: DataLoader,
        module_names: list[str] | None = None,
        flatten_acts: bool = False,
        output_type: str = "tensor",
        output_dir: str | None = None,
        step_size: int | None = None,
        file_name_suffix: str = "",
        save_in_one_file: bool = False,
    ) -> None:
        self._extractor.model = self._extractor.model.to(self._extractor.device)
        self._extractor.activations = {}
        self._extractor._register_hooks(module_names=module_names)

        if module_names is None:
            module_names = self._module_names

        self._extractor._module_and_output_check(module_names, output_type)

        if output_dir is None:
            raise ValueError("output_dir is required")

        Path(output_dir).mkdir(parents=True, exist_ok=True)

        if not step_size:
            step_size = 8000 // (len(next(iter(batches))) * 3) + 1

        features = defaultdict(list)
        feature_file_names = defaultdict(list)
        image_ct, last_image_ct = 0, 0
        for i, batch in tqdm(enumerate(batches, start=1), desc="Batch", total=len(batches)):
            modules_features = self._extractor._extract_batch(
                batch=batch, module_names=module_names, flatten_acts=flatten_acts
            )

            image_ct += len(batch)
            del batch

            for module_name in module_names:
                features[module_name].append(modules_features[module_name])

                if i % step_size == 0 or i == len(batches):
                    features_subset = torch.cat(features[module_name])
                    features_subset_file = (
                        Path(output_dir) / f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.pt"
                    )
                    torch.save(features_subset, features_subset_file)

                    features[module_name] = []
                    last_image_ct = image_ct
                    feature_file_names[module_name].append(features_subset_file)

        if save_in_one_file:
            for module_name in module_names:
                all_features = [torch.load(file) for file in feature_file_names[module_name]]
                features_file = Path(output_dir) / f"{module_name}/features{file_name_suffix}.pt"
                torch.save(torch.cat(all_features), features_file)
                for file in feature_file_names[module_name]:
                    file.unlink()

        self._extractor._unregister_hooks()

    def extract_targets(self, batches: DataLoader, output_dir: str, split: str = "train") -> None:
        """Extract targets from a batch of images."""
        new_fn_path = Path(output_dir) / f"targets{split}.pt"
        if new_fn_path.exists():
            logger.warning(f"Targets file already exists at {new_fn_path}")
        else:
            targets = torch.cat([target for _, target in batches])
            torch.save(targets, new_fn_path)

        for module_name in self._module_names:
            module_dir = Path(output_dir) / module_name
            module_dir.mkdir(parents=True, exist_ok=True)
            symlink_path = module_dir / f"targets{split}.pt"
            try:
                if symlink_path.exists():
                    logger.warning(f"Extracted targets file already exists at {symlink_path}. Skipping.")
                    continue
                elif symlink_path.is_symlink():
                    logger.warning(f"Extracted targets file is a symlink at {symlink_path}. Removing it.")
                    symlink_path.unlink()
                symlink_path.symlink_to(new_fn_path.resolve())
            except Exception as e:
                logger.warning(f"Could not create symlink {symlink_path} -> {new_fn_path}: {e}")
                raise e

    def parameters(self) -> Iterator[torch.nn.parameter.Parameter]:
        """Get the parameters of the model."""
        return self._extractor.model.parameters()

    def n_parameters(self) -> int:
        """Get the number of parameters of the model."""
        return sum(p.numel() for p in self.parameters())

    @property
    def module_names(self) -> list[str]:
        """Get the module names of the model."""
        return self._module_names


def load_thingsvision_model(
    model_name: str,
    source: str,
    device: str | torch.device,
    model_parameters: dict,
    module_names: str | list[str],
    feature_alignment: str | None = None,
) -> tuple[ThingsvisionModel, Callable]:
    """Load a thingsvision model."""
    
    flatten_acts = True
    if model_parameters.get("extract_cls_token") or "token_extraction" in model_parameters:
        flatten_acts = False
        if model_parameters.get("token_extraction") == "all_tokens":
            model_parameters = model_parameters.copy()
            model_parameters.pop("token_extraction")

    extractor = get_extractor(
        model_name=model_name,
        source=source,
        device=device,
        pretrained=True,
        model_parameters=model_parameters,
    )

    model = ThingsvisionModel(
        extractor=extractor,
        module_names=module_names,
        feature_alignment=feature_alignment,
        flatten_acts=flatten_acts,
    )
    transform = extractor.get_transformations(resize_dim=256, crop_dim=224)
    return model, transform
