
"""Evaluation script for the wind prediction task.

Compute all the metrics for all the models in the given path.

how to run:
    python eval.py --path outputs
"""
import argparse
import csv
import os
from pathlib import Path

import torch
import yaml
from dotenv import load_dotenv

from exttfs.datasets.wind_dataset import Scale, ScaleDataset, get_mean_std
from exttfs.models.gen import GEN, GraphStructure, kmeans_from_dataset, neighbors_edges
from exttfs.models.gka import GKA
from exttfs.models.msa import MSA
from exttfs.models.nps import CNP
from exttfs.models.transformer import TFS
from exttfs.wind.metrics import metrics

# Ensure reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)

default_models = {"tfs": TFS, "np": CNP, "gka": GKA, "msa": MSA}

gen_models = {"gen": GEN}

parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, default="outputs")
args = parser.parse_args()


def main():
    """Load the relevant wind checkpoints and evalutate the models according to different metrics."""
    load_dotenv()

    pts = list((Path(args.path)).glob("**/wind/**/*.pt"))

    pts = sorted(pts, key=lambda p: float(p.stem.split("-")[-1]))

    new_pts = []
    parents = set()
    for p in pts:
        if p.parent not in parents:
            new_pts.append(p)
            parents.add(p.parent)

    pts = new_pts

    for i, p in enumerate(pts):
        print(i, p.name)

    config_files = [p.parent / ".hydra/config.yaml" for p in pts]
    cfgs = []
    for c in config_files:
        with c.open() as f:
            cfgs.append(yaml.safe_load(f))

    train_dataset = torch.load(os.getenv("TRAIN_PATH"))
    raw_val_dataset = torch.load(os.getenv("VAL_PATH"))

    meanx, stdx, meany, stdy = get_mean_std(train_dataset)
    val_dataset = ScaleDataset(raw_val_dataset, meanx, stdx, meany, stdy)
    scale = Scale(meany, stdy)

    with open("res.csv", "w", newline="") as file:
        writer = csv.writer(file)

        writer.writerow(["model", "class"] + [m.__name__ for m in metrics])

        print(",".join([f"{'model':>10}"] + [f"{m.__name__:>10}" for m in metrics]))
        for cfg, pt in zip(cfgs, pts):
            torch.manual_seed(cfg["experiment"]["seed"])
            model_name = cfg["model"]["name"]

            if model_name in default_models:
                model = default_models[model_name](**cfg["model"]["params"])
            elif model_name in gen_models:
                pos = kmeans_from_dataset(k=1000, dataset=train_dataset)
                graph = (pos, *neighbors_edges(pos, 3))
                gs = GraphStructure(*graph, fixed=False)
                model = gen_models[model_name](gs, **cfg["model"]["params"])

            state_dict = torch.load(pt, map_location="cpu")

            model.load_state_dict(state_dict)

            device = torch.device("cuda:0")

            model.to(device)
            scale = scale.to(device)

            model.eval()

            outputs = []
            for inputs, targets in val_dataset:
                cx, cy, tx = (x.unsqueeze(0).to(device) for x in inputs)
                targets = targets.to(device)
                output = scale(model(cx, cy, tx))
                outputs.append(output.detach().squeeze(0))

            targets = torch.cat([t for _, t in raw_val_dataset], dim=0).to(device)
            outputs = torch.cat(outputs, dim=0)

            res = [metric(targets, outputs).item() for metric in metrics]
            print(
                ",".join(
                    [f"{pt.name.split('-')[0]:>10}"] + [f"{r:>10.4f}" for r in res]
                )
            )
            writer.writerow([pt.name, pt.name.split("-")[0]] + res)


if __name__ == "__main__":
    with torch.no_grad():
        main()
