import csv
import math
from functools import partial
from pathlib import Path

import hydra
import matplotlib.pyplot as plt
import numpy as np
import src.model.layers as layers
import torch
from configs.template import MainConfig
from src.model import Model
from torch import Tensor

# Some torch warning told me to do so.
torch.set_float32_matmul_precision("high")


@torch.inference_mode()
def do_one_step(model: Model, instance: Tensor, origin: Tensor, destination: Tensor) -> Tensor:
    n_cities, _ = instance.shape
    mask = torch.ones((n_cities,), dtype=torch.bool, device=instance.device)
    activations = []

    try:
        original_attn = layers.dot_product_attention
        layers.dot_product_attention = partial(dot_product_attention, activations, 0, origin.item())
        model(instance[None], mask[None], origin[None], destination[None])
    finally:
        layers.dot_product_attention = original_attn

    return torch.stack(activations)  # [n_layers, n_heads, n_cities]


def dot_product_attention(
    activations: list,
    batch_id: int,
    token_id: int,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    mask: Tensor,
    ssmax: Tensor | None,
    bias: Tensor | None,
) -> Tensor:
    _, _, _, embed_dim = query.shape
    qk = torch.einsum("bhle,bhse->bhls", query, key) / math.sqrt(embed_dim)

    if bias is not None:
        qk = qk + bias

    if ssmax is not None:
        qk = torch.einsum("bhls,bl->bhls", qk, ssmax)

    qk = torch.where(mask[:, None], qk, -torch.inf)
    a = torch.softmax(qk, dim=-1)  # [batch_size, n_heads, n_queries, n_keys]
    y = torch.einsum("bhls,bhse->bhle", a, value)
    activations.append(a[batch_id, :, token_id].cpu())
    return y

def best_factor(n_cities: int, exp_dir: Path) -> float:
    with open(exp_dir / "best-factors.csv", newline="") as csvfile:
        factors = [(int(n), float(s)) for n, s in csv.reader(csvfile)]

    factors = list(sorted(factors, key=lambda e: e[0]))
    max_n, max_s = factors[-1]
    for n, s in factors:
        if n >= n_cities:
            max_n, max_s = n, s
            break

    min_n, min_s = factors[0]
    for n, s in reversed(factors):
        if n <= n_cities:
            min_n, min_s = n, s
            break

    if max_n == min_n:
        return max_s

    return (max_n - n_cities) * min_s / (max_n - min_n) + (n_cities - min_n) * max_s / (max_n - min_n)

def load_model(experiment_dir: Path, device: torch.device) -> Model:
    with hydra.initialize_config_dir(
        config_dir=str(experiment_dir.absolute() / ".hydra/"),
        version_base="1.3",
    ):
        dict_config = hydra.compose("config")
        config = MainConfig.from_dict(dict_config)

    model = Model(
        config.model.hidden_dim,
        config.model.ff_dim,
        config.model.n_heads,
        config.model.n_layers,
        use_alibi=config.model.use_alibi,
        use_coords=config.model.use_coords,
        use_random_ids=config.model.use_random_ids,
        use_rope=config.model.use_rope,
        use_ssmax=config.model.use_ssmax,
    ).to(device)

    checkpoint = torch.load(experiment_dir / "checkpoint.pth", map_location=device)
    model.load_state_dict(checkpoint["model-state"])
    return model


def load_instance(filepath: Path, seed: int, device: torch.device) -> tuple[Tensor, Tensor, Tensor]:
    rng = np.random.default_rng(seed)

    data = np.load(filepath)
    cities = data["coords"][:, :-1]

    n_instances, n_cities, _ = cities.shape
    instance_id = rng.integers(n_instances)
    origin = rng.integers(n_cities)
    destination = rng.integers(n_cities)

    instance = torch.tensor(cities[instance_id], dtype=torch.float32, device=device)
    origin = torch.tensor(origin, dtype=torch.long, device=device)
    destination = torch.tensor(destination, dtype=torch.long, device=device)
    return instance, origin, destination


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--data", type=Path, help="Instance to evaluate")
    parser.add_argument("-fmax", "--factor-max", type=float, default=3.0, help="Scaling factor maximum")
    parser.add_argument("-fmin", "--factor-min", type=float, default=1.0, help="Scaling factor minimum")
    parser.add_argument("-n", "--n-scaling-points", type=int, default=5, help="Number of points uniformly spaced between fmin and fmax")
    parser.add_argument("-o", "--output-file", required=True, type=Path, help="Path where the output will be written")
    parser.add_argument("-r", "--seed", type=int, default=1337)
    parser.add_argument("-t", "--title", type=str, default=None, help="Figure title")
    parser.add_argument("-vmax", "--vmax", type=float, default=None, help="Maximum value for the color map")
    parser.add_argument("-vmin", "--vmin", type=float, default=None, help="Maximum value for the color map")
    parser.add_argument("experiment_dirs", type=Path, nargs="+", help="Experiments to plot")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    plt.style.use("./analysis/paper.mplstyle")
    n_rows = len(args.experiment_dirs)
    n_cols = args.n_scaling_points + 2  # +spacer, +best-scaling
    width_ratios = [1.0] * args.n_scaling_points + [0.1, 1.0]
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows), gridspec_kw={"width_ratios": width_ratios})

    factors = np.linspace(args.factor_min, args.factor_max, args.n_scaling_points)
    instance, origin, destination = load_instance(args.data, args.seed, device)
    n_cities = instance.shape[0]
    i, o, d = instance.cpu().numpy(), origin.cpu().numpy(), destination.cpu().numpy()


    for exp_dir, exp_axes in zip(args.experiment_dirs, axes):
        model = load_model(exp_dir, device)

        for scaling_factor, ax in zip(factors, exp_axes[:len(factors)]):
            activations = do_one_step(model, scaling_factor * instance, origin, destination)
            activations = activations.mean((0, 1)).log().cpu().numpy()
            vmin = activations.min() if args.vmin is None else args.vmin
            vmax = activations.max() if args.vmax is None else args.vmax

            ax.scatter(i[:, 0], i[:, 1], c=activations, s=1, vmin=vmin, vmax=vmax)
            ax.scatter(i[o, 0], i[o, 1], c="red", marker="o", edgecolors="black", linewidths=0.5, label="Origin")
            ax.scatter(i[d, 0], i[d, 1], c="red", marker="*", s=80, edgecolors="black", linewidths=0.5, label="Destination")
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlim(0, 1)
            ax.set_ylim(0, 1)

        # Best scaling factor plot.
        exp_axes[-2].axis("off")
        ax = exp_axes[-1]
        bf = best_factor(n_cities, exp_dir)
        activations = do_one_step(model, bf * instance, origin, destination)
        activations = activations.mean((0, 1)).log().cpu().numpy()
        vmin = activations.min() if args.vmin is None else args.vmin
        vmax = activations.max() if args.vmax is None else args.vmax

        ax.scatter(i[:, 0], i[:, 1], c=activations, s=1, vmin=vmin, vmax=vmax)
        ax.scatter(i[o, 0], i[o, 1], c="red", marker="o", edgecolors="black", linewidths=0.5, label="Origin")
        ax.scatter(i[d, 0], i[d, 1], c="red", marker="*", s=80, edgecolors="black", linewidths=0.5, label="Destination")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_ylabel(f"{bf:.2f}")

        exp_axes[0].set_ylabel(exp_dir.name)

    for scaling_factor, ax in zip(factors, axes[0, :len(factors)]):
        ax.set_title(f"{scaling_factor:.2f}")
    axes[0, -1].set_title("Best")

    if args.title is not None:
        fig.suptitle(args.title)
    # fig.tight_layout()
    fig.savefig(str(args.output_file))
