import contextlib
import csv
import time
from functools import partial
from pathlib import Path

import hydra
import src.model.attention as attention
import src.model.layers as layers
import torch
from configs.template import MainConfig
from src.dataset import TSPDataset, TSPLIBDataset
from src.model import Model
from src.solve import solve
from src.utils import solution_cost
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# Some torch warning told me to do so.
torch.set_float32_matmul_precision("high")
# `torch.vmap` seems to be raising a lot of recompilations. We raise the recompilation limit to try
# to capture them all.
torch._dynamo.config.recompile_limit = 32


@contextlib.contextmanager
def replace_attention(chunk_size: int | None):
    """Temporarily replace dot_product_attention by the chunked version."""
    if chunk_size is None:
        yield
        return

    original = layers.dot_product_attention
    try:
        layers.dot_product_attention = partial(attention.chunked_attention, chunk_size)
        yield
    finally:
        layers.dot_product_attention = original


def best_factor(n_cities: int, exp_dir: Path) -> float:
    with open(exp_dir / "best-factors.csv", newline="") as csvfile:
        factors = [(int(n), float(s)) for n, s in csv.reader(csvfile)]

    factors = list(sorted(factors, key=lambda e: e[0]))
    max_n, max_s = factors[-1]
    for n, s in factors:
        if n >= n_cities:
            max_n, max_s = n, s
            break

    min_n, min_s = factors[0]
    for n, s in reversed(factors):
        if n <= n_cities:
            min_n, min_s = n, s
            break

    if max_n == min_n:
        return max_s

    return (max_n - n_cities) * min_s / (max_n - min_n) + (n_cities - min_n) * max_s / (max_n - min_n)


@torch.inference_mode(True)
def evaluate(
    model: Model,
    dataloader: DataLoader,
    n_estimates: int,
    apply_scaling_factor: bool,
    device: torch.device,
    exp_dir: Path,
    chunk_size: int | None = None,
) -> dict[str, float]:
    """Evaluate the model on the given dataset.

    ---
    Args:
        model: Model to evaluate.
        dataloader: Dataloader used to fetch the tuples (cities, optimal_cost).
        n_estimates: Number of times the dataset is being solved, to get a better estimate of the gap.
        scaling_factor: A scaling factor applied to the cities's coordinates before being fed to the model.
        device: Device.
        chunk_size: Split attention computation in chunks to save memory.

    ---
    Returns:
        The average optimal gap and the total time taken to solve all instances.
    """
    gaps, times = [], []
    model.eval()

    for _ in tqdm(range(n_estimates), f"Evaluating {dataloader.dataset.name}", leave=False):
        iter = tqdm(dataloader, leave=False) if len(dataloader) != 1 else dataloader
        for x, opt_cost in iter:
            x, opt_cost = x.to(device), opt_cost.to(device)

            _, n_cities, _ = x.shape
            x = x[:, torch.randperm(n_cities)]

            scaling_factor = best_factor(n_cities, exp_dir) if apply_scaling_factor else 1.0

            start = time.time()
            with replace_attention(chunk_size):
                s = solve(x * scaling_factor, model)
            end = time.time()

            model_cost = torch.vmap(solution_cost)(x, s)
            gaps.append((model_cost - opt_cost) / opt_cost)
            times.append(end - start)

    return {
        "average-gap": torch.concat(gaps, dim=0).float().mean().cpu().item(),
        "total-time": sum(times) / n_estimates,
    }


def load_dataset(path: Path, batch_size: int) -> DataLoader:
    def collate_fn(b: list[tuple[Tensor, Tensor]]) -> tuple[Tensor, Tensor]:
        x = [x_ for x_, _ in b]
        s = [s_ for _, s_ in b]
        x = torch.stack(x)
        s = torch.stack(s)

        v = torch.vmap(solution_cost)(x, s)
        return x, v

    dataset = TSPDataset.from_npz(path)
    dataloader = DataLoader(
        dataset, batch_size, collate_fn=collate_fn, shuffle=True, drop_last=False
    )
    return dataloader


def load_tsplib(path: Path, min_cities: int | None, max_cities: int | None) -> DataLoader:
    dataset = TSPLIBDataset.from_dir(path, min_cities, max_cities)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, drop_last=False)
    return dataloader


def load_fake() -> DataLoader:
    class FakeDataset(Dataset):
        def __init__(self):
            super().__init__()
            self.name = "fake dataset"
            self.x = torch.randn((5, 50, 2), dtype=torch.float32)
            self.v = torch.randn((5,), dtype=torch.float32)

        def __len__(self) -> int:
            return len(self.x)

        def __getitem__(self, i: int) -> tuple[Tensor, Tensor]:
            return self.x[i], self.v[i]

    dataset = FakeDataset()
    return DataLoader(dataset, batch_size=1, shuffle=True, drop_last=False)


def load(experiment_dir: Path, device: torch.device) -> Model:
    with hydra.initialize_config_dir(
        config_dir=str(experiment_dir.absolute() / ".hydra/"),
        version_base="1.3",
    ):
        dict_config = hydra.compose("config")
        config = MainConfig.from_dict(dict_config)

    model = Model(
        config.model.hidden_dim,
        config.model.ff_dim,
        config.model.n_heads,
        config.model.n_layers,
        use_alibi=config.model.use_alibi,
        use_coords=config.model.use_coords,
        use_random_ids=config.model.use_random_ids,
        use_rope=config.model.use_rope,
        use_ssmax=config.model.use_ssmax,
    ).to(device)

    checkpoint = torch.load(experiment_dir / "checkpoint.pth", map_location=device)
    model.load_state_dict(checkpoint["model-state"])
    return model


if __name__ == "__main__":
    import argparse
    import csv

    parser = argparse.ArgumentParser(description="Load a trained model and evaluate it on multiple datasets.")
    parser.add_argument("-e", "--experiment-dir", required=True, type=Path, help="Path of the training run")
    parser.add_argument("-n", "--n-estimates", type=int, default=1, help="Number of times the dataset is being evaluated")
    parser.add_argument("-s", "--scaling-factor", action="store_true", help="Find and apply the best scaling factor by interpolating values contained in 'best-factors.csv'")
    parser.add_argument("-b", "--batch-size", type=int, default=128, help="Batch size")
    parser.add_argument("-c", "--chunk-size", type=int, default=None, help="Use chunked attention")
    parser.add_argument("-tmin", "--tsplib-min", type=int, default=None, help="Minimum TSPLIB size")
    parser.add_argument("-tmax", "--tsplib-max", type=int, default=None, help="Maximum TSPLIB size")
    parser.add_argument("datasets", type=Path, nargs="+", help="Datasets stored as an npz file")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load(args.experiment_dir, device)

    print(f"Evaluating model {args.experiment_dir.name}")
    for path in args.datasets:
        dataloader = (
            load_tsplib(path, args.tsplib_min, args.tsplib_max)
            if path.is_dir()
            else load_dataset(path, args.batch_size)
        )
        metrics = evaluate(
            model,
            dataloader,
            n_estimates=args.n_estimates,
            apply_scaling_factor=args.scaling_factor,
            exp_dir=args.experiment_dir,
            device=device,
            chunk_size=args.chunk_size,
        )

        with open(args.experiment_dir / "evaluations.csv", "a", newline="") as csvfile:
            csv.writer(csvfile).writerow(
                [
                    dataloader.dataset.name,
                    metrics["average-gap"],
                    metrics["total-time"],
                    args.scaling_factor,
                    args.batch_size,
                    args.chunk_size,
                    args.n_estimates,
                    torch.cuda.get_device_name(device),
                ]
            )

        print(f"{dataloader.dataset.name}: {metrics['average-gap']:.3%} (s={args.scaling_factor})")
