"""Solves random TSP instances with a random scaling factor."""

import contextlib
from functools import partial
from pathlib import Path

import hydra
import src.model.attention as attention
import src.model.layers as layers
import torch
from configs.template import MainConfig
from src.model import Model
from src.solve import solve
from src.utils import solution_cost
from tqdm import tqdm

# Some torch warning told me to do so.
torch.set_float32_matmul_precision("high")
# `torch.vmap` seems to be raising a lot of recompilations. We raise the recompilation limit to try
# to capture them all.
torch._dynamo.config.recompile_limit = 32


@torch.inference_mode(True)
def random_solve(
    model: Model,
    n_cities: int,
    scaling_factor: float,
    batch_size: int,
    chunk_size: int | None,
    device: torch.device,
) -> float:
    model.eval()
    x = torch.rand((batch_size, n_cities, 2), device=device)
    with replace_attention(chunk_size):
        s = solve(x * scaling_factor, model)
    c = torch.vmap(solution_cost)(x, s)
    return c.float().mean().cpu().item()


@contextlib.contextmanager
def replace_attention(chunk_size: int | None):
    """Temporarily replace dot_product_attention by the chunked version."""
    if chunk_size is None:
        yield
        return

    original = layers.dot_product_attention
    try:
        layers.dot_product_attention = partial(attention.chunked_attention, chunk_size)
        yield
    finally:
        layers.dot_product_attention = original


def load(experiment_dir: Path, device: torch.device) -> Model:
    with hydra.initialize_config_dir(
        config_dir=str(experiment_dir.absolute() / ".hydra/"),
        version_base="1.3",
    ):
        dict_config = hydra.compose("config")
        config = MainConfig.from_dict(dict_config)

    model = Model(
        config.model.hidden_dim,
        config.model.ff_dim,
        config.model.n_heads,
        config.model.n_layers,
        use_alibi=config.model.use_alibi,
        use_coords=config.model.use_coords,
        use_random_ids=config.model.use_random_ids,
        use_rope=config.model.use_rope,
        use_ssmax=config.model.use_ssmax,
    ).to(device)

    checkpoint = torch.load(experiment_dir / "checkpoint.pth", map_location=device)
    model.load_state_dict(checkpoint["model-state"])
    return model


if __name__ == "__main__":
    import argparse
    import csv
    import random

    parser = argparse.ArgumentParser(description="Solves random TSPs with random scaling factors and saves the results into a CSV file.")
    parser.add_argument("-b", "--batch-size", type=int, default=128)
    parser.add_argument("-c", "--chunk-size", type=int, default=None)
    parser.add_argument("-e", "--experiment-dir", required=True, type=Path, help="Path of the training run")
    parser.add_argument("-nc", "--n-cities", nargs="+", type=int, default=[50, 100, 200, 250, 500, 1000])
    parser.add_argument("-smax", "--scaling-factor-max", type=float, default=1.6)
    parser.add_argument("-smin", "--scaling-factor-min", type=float, default=0.8)
    parser.add_argument("-t", "--total-samples", type=int, default=1000)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = load(args.experiment_dir, device)

    for _ in tqdm(range(args.total_samples), "Solving random TSPs"):
        n_cities = random.choice(args.n_cities)
        scaling_factor = random.uniform(args.scaling_factor_min, args.scaling_factor_max)
        cost = random_solve(model, n_cities, scaling_factor, args.batch_size, args.chunk_size, device)

        with open(args.experiment_dir / "random-solve.csv", "a", newline="") as csvfile:
            csv.writer(csvfile).writerow([n_cities, scaling_factor, cost])
