import os
import json
import hashlib
from pathlib import Path
from itertools import chain
from typing import Literal
from collections.abc import Callable, Mapping, Sequence

import torch
from torch.utils.data import DataLoader

import numpy as np
from tabulate import tabulate

from .clip import load_clip
from .datasets import build_dataset, make_condition, make_subset


__all__ = [
    "get_dataset",
    "get_cached_features",
    "build_table",
    "save_table",
]


def get_dataset(
    dataset_name: str,
    condition_config: dict,
    preprocess: Callable,
    *,
    splits: Sequence[Literal["train", "train_all", "valid", "test"]] = ("train", "train_all", "valid", "test"),
    verbose: bool = True,
    print_fn: Callable = print,
):
    data = {}

    for split in splits:
        _split = split if split != "train_all" else "train"
        dataset = build_dataset(dataset_name, _split, transform=preprocess)

        if split == "train":
            condition = make_condition(dataset.attr_names, condition_config)
            subset = make_subset(dataset, condition)
            attrs = torch.stack([attr for attr in dataset.attr if condition(attr)])  # NOTE: hack to avoid image loading
        else:
            subset = dataset
            attrs = torch.stack([attr for attr in dataset.attr])  # NOTE: hack to avoid image loading

        if verbose:
            print_fn(f"{split} set size: {len(subset)} ({len(subset) / len(dataset) * 100:.2f}%)")

        data[split] = (subset, attrs)

    return data


def get_cached_features(
    model_name: str,
    dataset_name: str,
    condition_config: dict,
    *,
    splits: Sequence[Literal["train", "train_all", "valid", "test"]] = ("train", "train_all", "valid", "test"),
    device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
    verbose: bool = True,
    print_fn: Callable = print,
    cache_dir: str = "./.cache",
    flush: bool = False,
):
    key = json.dumps(
        {
            "model_name": model_name,
            "dataset_name": dataset_name,
            "condition_config": condition_config,
        },
        sort_keys=True,
        ensure_ascii=True,
    )
    hashkey = hashlib.md5(key.encode("utf-8")).hexdigest()
    cache_path = Path(cache_dir) / hashkey

    model, preprocess = load_clip(model_name, device=device)
    data = {}

    for split in splits:
        split_path = (cache_path / f"{split}.pt")

        if split_path.exists() and not flush:
            if verbose:
                print_fn(f"Loading {split} data from cache '{hashkey}'...")

            features, attrs = torch.load(split_path, weights_only=False)
            data[split] = (features.to(device).float(), attrs.to(device))

        else:
            _split = split if split != "train_all" else "train"
            dataset = build_dataset(dataset_name, _split, transform=preprocess)

            if split == "train":
                condition = make_condition(dataset.attr_names, condition_config)
                subset = make_subset(dataset, condition)
            else:
                subset = dataset

            dataloader = DataLoader(subset, batch_size=64, num_workers=4, shuffle=False)

            if verbose:
                print_fn(f"{split} set size: {len(subset)} ({len(subset) / len(dataset) * 100:.2f}%)")

                from tqdm.auto import tqdm
                load_iter = tqdm(dataloader, desc=f"Computing {split} features", ncols=80, leave=False)
            else:
                load_iter = dataloader

            with torch.no_grad():
                features = torch.cat([model.encode_image(v) for v, _ in load_iter])
                attrs = torch.cat([v for _, v in dataloader])

            split_path.parent.mkdir(parents=True, exist_ok=True)
            torch.save((features.cpu(), attrs.cpu()), split_path)
            data[split] = (features.float(), attrs.to(device))

            if verbose:
                print_fn(f"Saved {split} data to cache '{hashkey}'")

    return model, data


def build_table(
    metrics: Mapping[str, Mapping[str, Mapping[str, float] | Sequence[Mapping[str, float]]]],
    group_headers: Sequence[str] | None = None,
    label_headers: Sequence[str] | None = None,
    types: Sequence[str] = ("auroc", "auprc", "accuracy", "f1", "fpr95"),
    floatfmt: str = ".3f",
):
    formatted = {k: {kk: {} for kk in v} for k, v in metrics.items()}
    max_num_group_cols = 1
    group_names = list(list(metrics.values())[0].keys())

    if label_headers is None:
        label_headers = list(metrics.keys())
    elif len(label_headers) != len(metrics):
        raise ValueError(f"Expected {len(metrics)} label headers, got {len(label_headers)}")

    for k, v in metrics.items():
        for kk, vv in v.items():
            if isinstance(vv, dict):
                for t in types:
                    if t not in vv:
                        raise ValueError(f"Missing metric '{t}' for '{k}' / '{kk}'")
                    formatted[k][kk][t] = f"{vv[t]:{floatfmt}}"
            else:
                for t in types:
                    vs = []
                    for vvv in vv:
                        if t not in vvv:
                            raise ValueError(f"Missing metric '{t}' for '{k}' / '{kk}'")
                        vs.append(vvv[t])
                    formatted[k][kk][t] = f"{np.mean(vs):{floatfmt}} ± {np.std(vs):{floatfmt}}"

            num_group_cols = len(kk.split("/"))
            if max_num_group_cols < num_group_cols:
                max_num_group_cols = num_group_cols

    if group_headers is None:
        group_headers = [""] * max_num_group_cols
    elif len(group_headers) != max_num_group_cols:
        raise ValueError(f"Expected {max_num_group_cols} group headers, got {len(group_headers)}")

    types_headers = {
        "auroc": "AUROC",
        "auprc": "AUPRC",
        "accuracy": "Acc.",
        "f1": "F1",
        "fpr95": "FPR95",
        "similarity": "Sim.",
    }

    table_label_headers = (
        [""] * max_num_group_cols +
        list(chain(*[[l.capitalize()] + [""] * (len(types)-1) for l in label_headers]))
    )
    table_metric_headers = list(group_headers) + [types_headers[t] for t in types] * len(label_headers)
    table_content = [
        [f"{v0}\n{v1}" for v0, v1 in zip(table_label_headers, table_metric_headers)]
    ]

    for group_name in group_names:
        cur_row = [v for v in group_name.split("/")]
        cur_row += [""] * (max_num_group_cols - len(cur_row))

        for k, v in formatted.items():
            if group_name in v:
                cur_row.extend(v[group_name].values())
            else:
                cur_row.extend([""] * len(types))

        table_content.append(cur_row)

    table = tabulate(
        table_content,
        headers="firstrow",
        colalign=("left",) * len(table_content[0]),
        disable_numparse=True,
    )
    return table


def save_table(table: str, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(table)
