import argparse
import os
from pathlib import Path
from time import time

import rootutils
import torch
import wandb
from omegaconf import OmegaConf
from torch import nn
from tqdm import tqdm

rootutils.set_root(path=os.path.abspath("."), pythonpath=True)

# from src.utils.logging_utils import load_run_config_from_wb
from scripts.evaluate_rollout import get_dataloader, load_model

# from src.models.gns.eval_rollout import eval_single_rollout
# from torchmetrics.functional.segmentation import mean_iou
from src.utils.metric import mean_iou as pyg_mean_iou
from src.utils.metric import my_mse


def eval_single_rollout_barebone(simulator, features, num_steps, device):
    initial_positions = features["enc_pos"]
    ground_truth_positions = features["target_pos"]
    dim = initial_positions.shape[-1]

    current_positions = initial_positions
    for step in range(num_steps):
        next_position = simulator.predict_positions(
            current_positions,
            n_particles_per_example=features["n_particles_per_example"],
            particle_types=features["particle_type"],
        )  # (n_nodes, dim)
        # Update kinematic particles from prescribed trajectory.
        kinematic_mask = (features["particle_type"] == 3).clone().detach().to(device)
        next_position_ground_truth = ground_truth_positions[:, step]
        kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, dim)
        next_position = torch.where(kinematic_mask, next_position_ground_truth, next_position)
        current_positions = torch.cat([current_positions[:, 1:], next_position[:, None, :]], dim=1)


def calc_rollout_times(model, dl, iters, device):
    model.model.set_metadata_device(model.device)
    num_steps = dl.dataset.n_jumps

    print("Start timing runs.")
    all_all_times = []
    counter = 0
    with torch.no_grad():
        for data in tqdm(dl):
            data = data.to(device)

            all_times = []
            for _ in range(iters):
                start = time()
                eval_single_rollout_barebone(
                    simulator=model.model,
                    features=data,
                    num_steps=num_steps,  # 995 for waterdrop
                    device=device,
                )
                duration = time() - start
                all_times.append(duration)
            all_all_times.append(torch.tensor(all_times))

            print(f"{counter}/{len(dl)}, times: {all_times}")
            counter += 1

    return torch.stack(all_all_times)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=int, default=0, help="GPU index to use for computation")
    parser.add_argument("--project", type=str, required=True, help="wandb project name")
    parser.add_argument("--run_id", type=str, required=True, help="wandb run ID")
    parser.add_argument(
        "--iters", type=int, required=False, help="averate over N iterations", default=2
    )
    parser.add_argument(
        "--num_workers", type=int, required=False, help="number of workers", default=0
    )
    args = parser.parse_args()

    device = f"cuda:{args.device}"
    project = args.project
    run_id = args.run_id

    api = wandb.Api()
    run = api.run(f"add_your_wandb_here/{project}/{run_id}")
    config = run.config
    cfg = OmegaConf.create(config)

    model = load_model(cfg=cfg, project=project, run_id=run_id, device=device)

    dl = get_dataloader(
        cfg=cfg,
        batch_size=1,
        num_workers=args.num_workers,
        stage="rollout_gns_f",
        split="test",
    )

    times = calc_rollout_times(model=model, dl=dl, iters=args.iters, device=device)
    print(
        f"Inference time per traj {times[:,-1].mean().item():.3f}+/-{times[:,-1].std().item():.3f}"
    )

    rollout_path = Path(cfg.trainer.default_root_dir) / project / run_id / "rollout"
    if not rollout_path.exists():
        rollout_path.mkdir(parents=True)

    save_path = rollout_path / (
        "times_" + torch.cuda.get_device_name(device).replace(" ", "_") + ".pt"
    )
    torch.save(times, save_path)
    print(f"Finished calculating timing for {project}:{run_id}")


if __name__ == "__main__":
    main()

# cayman:
# python scripts/evaluate_rollout_gns_f_times.py --device=7 --project=waterdrop --run_id=304hm65l
# nohup python scripts/evaluate_rollout_gns_f_times.py --device=7 --project=waterdrop_xl --run_id=muinbl5x > wdxl_times_gns_f.log 2>&1 &
# nohup python scripts/evaluate_rollout_gns_f_times.py --device=6 --project=dam3d --run_id=no9lk2o5 > dam3d_times_gns_f.log 2>&1 &
