import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Literal, Tuple, Union

import accelerate
import numpy as np
import pandas as pd
import torch
import torch.utils.data
from torch import Tensor
from tqdm import tqdm

from src.data import create_dataset
from src.eval import METRICS_NAMES_PRETTY, get_ood_results
from src.methods import DetectorWrapper
from src.methods.templates import Detector
from src.pipelines import register_pipeline
from src.pipelines.base import Pipeline
from src.utils import ConcatDatasetsDim1

_logger = logging.getLogger(__name__)


class OODBenchmarkPipeline(Pipeline, ABC):
    def __init__(
        self,
        in_dataset_name: str,
        out_datasets_names_splits: Dict[str, Any],
        transform: Callable,
        batch_size: int,
        num_workers: int = 4,
        pin_memory: bool = True,
        prefetch_factor: int = 2,
        limit_fit: float = 1.0,
        limit_run: float = 1.0,
        seed: int = 42,
        accelerator=None,
    ) -> None:
        self.in_dataset_name = in_dataset_name
        self.out_datasets_names_splits = out_datasets_names_splits
        self.out_datasets_names = list(out_datasets_names_splits.keys())
        self.limit_fit = limit_fit
        self.limit_run = limit_run
        self.transform = transform
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.prefetch_factor = prefetch_factor
        self.seed = seed
        self.accelerator = accelerator

        self.fit_dataset = None
        self.in_dataset = None
        self.out_dataset = None
        self.out_datasets = None

        accelerate.utils.set_seed(seed)
        print("Setting up datasets...")
        self.setup()

    @abstractmethod
    def _setup_datasets(self):
        """Setup `in_dataset`, `out_dataset`, `fit_dataset` and `out_datasets`."""
        ...

    def _setup_dataloaders(self):
        if self.fit_dataset is None or self.in_dataset is None or self.out_datasets is None or self.out_dataset is None:
            raise ValueError("Datasets are not set.")

        if self.limit_fit is None:
            self.limit_fit = 1.0
        self.limit_fit = min(int(self.limit_fit * len(self.fit_dataset)), len(self.fit_dataset))

        # random indices
        subset = np.random.choice(np.arange(len(self.fit_dataset)), self.limit_fit, replace=False).tolist()
        self.fit_dataset = torch.utils.data.Subset(self.fit_dataset, subset)
        self.fit_dataloader = torch.utils.data.DataLoader(
            self.fit_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            prefetch_factor=self.prefetch_factor,
        )

        self.test_dataset = torch.utils.data.ConcatDataset([self.in_dataset, self.out_dataset])
        test_labels = torch.utils.data.TensorDataset(
            torch.cat(
                [torch.zeros(len(self.in_dataset))]  # type: ignore
                + [torch.ones(len(d)) * (i + 1) for i, d in enumerate(self.out_datasets.values())]  # type: ignore
            ).long()
        )

        self.test_dataset = ConcatDatasetsDim1([self.test_dataset, test_labels])
        # shuffle and subsample test_dataset
        subset = np.random.choice(
            np.arange(len(self.test_dataset)), int(self.limit_run * len(self.test_dataset)), replace=False
        ).tolist()
        self.test_dataset = torch.utils.data.Subset(self.test_dataset, subset)
        self.test_dataloader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            prefetch_factor=self.prefetch_factor,
        )

        if self.accelerator is not None:
            self.fit_dataloader = self.accelerator.prepare(self.fit_dataloader)
            self.test_dataloader = self.accelerator.prepare(self.test_dataloader)

        _logger.info(f"Using {len(self.fit_dataset)} samples for fitting.")
        _logger.info(f"Using {len(self.test_dataset)} samples for testing.")

    def setup(self):
        self._setup_datasets()
        self._setup_dataloaders()

    def preprocess(self, method: Union[DetectorWrapper, Detector]) -> Union[DetectorWrapper, Detector]:
        if self.fit_dataset is None:
            _logger.warning("Fit dataset is not set or not supported. Returning.")
            return method

        if not hasattr(method.detector, "update"):
            _logger.warning("Detector does not support fitting. Returning.")
            return method

        disable = False
        if self.accelerator is not None:
            disable = not self.accelerator.is_main_process
        progress_bar = tqdm(range(len(self.fit_dataloader)), desc="Fitting", disable=disable)

        fit_length = len(self.fit_dataloader.dataset)
        example = next(iter(self.fit_dataloader))[0]
        method.start(example=example, fit_length=fit_length)
        for x, y in self.fit_dataloader:
            method.update(x, y)
            progress_bar.update(1)
        progress_bar.close()
        method.end()
        return method

    def run(self, method: Union[DetectorWrapper, Detector]) -> Dict[str, Any]:
        self.method = method

        _logger.info("Running pipeline...")
        self.method = self.preprocess(self.method)

        # initialize based on dataset size
        dataset_size = len(self.test_dataloader.dataset)
        test_labels = torch.empty(dataset_size, dtype=torch.int64)
        test_scores = torch.empty(dataset_size, dtype=torch.float32)
        _logger.debug("test_labels shape: %s", test_labels.shape)
        _logger.debug("test_scores shape: %s", test_scores.shape)

        self.infer_times = []
        idx = 0
        disable = False
        if self.accelerator is not None:
            disable = not self.accelerator.is_main_process
        progress_bar = tqdm(range(len(self.test_dataloader)), desc="Inference", disable=disable)
        for x, y, labels in self.test_dataloader:
            t1 = time.time()
            score = self.method(x)
            t2 = time.time()

            if self.accelerator is not None:
                score = self.accelerator.gather_for_metrics(score)
                labels = self.accelerator.gather_for_metrics(labels)
            # score = sync_tensor_across_gpus(score.detach())
            # labels = sync_tensor_across_gpus(labels.to(score.device))

            self.infer_times.append(t2 - t1)
            test_labels[idx : idx + labels.shape[0]] = labels.cpu()
            test_scores[idx : idx + score.shape[0]] = score.cpu()

            idx += labels.shape[0]
            progress_bar.update(1)
        progress_bar.close()
        self.infer_times = np.mean(self.infer_times)
        test_scores = test_scores[:idx]
        test_labels = test_labels[:idx]
        res_obj = self.postprocess(test_scores, test_labels)

        return {"results": res_obj, "scores": test_scores, "labels": test_labels}

    def postprocess(self, test_scores: Tensor, test_labels: Tensor):
        _logger.info("Computing metrics...")
        in_scores = test_scores[test_labels == 0]

        results = {}
        for i, ood_dataset_name in enumerate(self.out_datasets_names):
            ood_scores = test_scores[test_labels == (i + 1)]
            results[ood_dataset_name] = get_ood_results(in_scores, ood_scores)
            results[ood_dataset_name]["time"] = self.infer_times

        results["average"] = {
            k: np.mean([results[ds][k] for ds in self.out_datasets_names])
            for k in results[self.out_datasets_names[0]].keys()
        }
        results["average"]["time"] = self.infer_times
        ood_scores = test_scores[test_labels > 0]

        return results

    def report(self, results: Dict[str, Dict[str, Any]]) -> str:
        # log results in a table
        if "results" in results:
            results = results["results"]
        df = pd.DataFrame()

        for ood_dataset, res in results.items():
            df = pd.concat([df, pd.DataFrame(res, index=[ood_dataset])])
        df.columns = [METRICS_NAMES_PRETTY[k] for k in df.columns]
        return df.to_string(index=True, float_format="{:.4f}".format)


@register_pipeline("ood_benchmark_imagenet")
class OODImageNetBenchmarkPipelineAll(OODBenchmarkPipeline):
    def __init__(self, transform: Callable, limit_fit=1.0, limit_run=1.0, batch_size=64, seed=42, **kwargs) -> None:
        super().__init__(
            "ilsvrc2012",
            {
                "inaturalist_clean": None,
                "species_clean": None,
                "places_clean": None,
                "openimage_o_clean": None,
                "ssb_easy": None,
                "textures_clean": None,
                "ninco": None,
                "ssb_hard": None,
            },
            limit_fit=limit_fit,
            limit_run=limit_run,
            transform=transform,
            batch_size=batch_size,
            seed=seed,
        )

    def _setup_datasets(self):
        _logger.info("Loading In-distribution dataset...")
        self.fit_dataset = create_dataset(self.in_dataset_name, split="train", transform=self.transform)
        self.in_dataset = create_dataset(self.in_dataset_name, split="val", transform=self.transform)

        _logger.info("Loading OOD datasets...")
        self.out_datasets = {
            ds: create_dataset(ds, split=split, transform=self.transform, download=True)
            for ds, split in self.out_datasets_names_splits.items()
        }
        self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values()))
