import argparse
import os
from pathlib import Path

import rootutils
import torch
import wandb
from omegaconf import OmegaConf
from torch import nn
from tqdm import tqdm

rootutils.set_root(path=os.path.abspath("."), pythonpath=True)

# from src.utils.logging_utils import load_run_config_from_wb
from scripts.evaluate_rollout import get_dataloader, load_model
from src.models.gns.eval_rollout import eval_single_rollout

# from torchmetrics.functional.segmentation import mean_iou
from src.utils.metric import mean_iou as pyg_mean_iou
from src.utils.metric import my_mse


def rollout_occ_and_field_wrapper(model, ds, data, device):
    # set up model
    model.trainer = type("Trainer", (object,), {})
    model.trainer.datamodule = type("Datamodule", (object,), {})
    model.trainer.datamodule.val_dataset = ds
    model.setup_grid_occ_and_field()
    model.model.set_metadata_device(model.device)

    # instantiate grid_pos
    traj_pos = data.enc_pos  # (time, n_nodes, dim)
    ds.get_grid_occ_and_field(
        particle_pos=traj_pos[1].cpu(),
        particle_vel=traj_pos[1].cpu() - traj_pos[0].cpu(),
    )
    grid_pos = ds.grid_pos.to(device)
    # grid_batch_idx = torch.zeros(grid_pos.shape[0], dtype=torch.long).to(device)

    def get_pos_and_vel(pos_tmin1, traj_pos):
        # positions
        traj_pos = torch.tensor(traj_pos).to(device)  # (time, n_nodes, dim)
        # velocities
        pos_tmin1 = torch.tensor(pos_tmin1).to(device)  # (1, n_nodes, dim)
        traj_temp = torch.cat([pos_tmin1, traj_pos], dim=0)
        traj_vel = traj_temp[1:] - traj_temp[:-1]  # (time, n_nodes, dim)
        # print("#1", traj_pos.shape, traj_vel.shape)
        return traj_pos, traj_vel

    def rollout_occ_and_field(output_dict, particle_batch_idx):
        # Predicted and ground truth positions and velocities
        traj_pos, traj_vel = get_pos_and_vel(
            output_dict["initial_positions"][-1:], output_dict["predicted_rollout"]
        )
        traj_pos_gt, traj_vel_gt = get_pos_and_vel(
            output_dict["initial_positions"][-1:], output_dict["ground_truth_rollout"]
        )

        # batch grid_pos
        batch_size = particle_batch_idx.max().item() + 1
        grid_pos_batched = grid_pos.repeat(batch_size, 1)
        grid_batch_idx = torch.concat(
            [i * torch.ones(grid_pos.shape[0], dtype=torch.int64) for i in range(batch_size)]
        ).to(device)

        iou_values, field_mse_values = [], []
        num_rollout_steps = traj_pos.shape[0]  # 995 in waterdrop
        for i in range(num_rollout_steps):
            # print("#3", traj_pos[i].shape, traj_vel[i].shape, grid_pos.shape)
            pred_grid_occ, pred_grid_field = model.get_grid_occ_and_field(
                particle_pos=traj_pos[i],
                particle_vel=traj_vel[i],
                particle_batch_idx=particle_batch_idx,
                grid_pos=grid_pos_batched,
                grid_batch_idx=grid_batch_idx,
            )
            target_grid_occ, target_grid_field = model.get_grid_occ_and_field(
                particle_pos=traj_pos_gt[i],
                particle_vel=traj_vel_gt[i],
                particle_batch_idx=particle_batch_idx,
                grid_pos=grid_pos_batched,
                grid_batch_idx=grid_batch_idx,
            )

            iou = pyg_mean_iou(
                pred=pred_grid_occ.to(torch.long),
                target=target_grid_occ.to(torch.long),
                num_classes=model.hparams.num_classes,
                batch=grid_batch_idx,
            )
            field_mse = my_mse(pred_grid_field, target_grid_field, grid_batch_idx)
            iou_values.append(iou.cpu())
            field_mse_values.append(field_mse.cpu())

        iou_trajs = torch.stack(iou_values).T  # (bs, num_rollout_steps)
        field_trajs = torch.stack(field_mse_values).T  # (bs, num_rollout_steps)

        return iou_trajs, field_trajs

    return rollout_occ_and_field


def calc_iou_and_field(model, dl, device):
    iou_list = []
    field_mse_list = []

    model.eval()
    num_steps = dl.dataset.n_jumps
    counter = 0
    bs, batches = dl.batch_size, len(dl)
    print(f"Calculating metrics for {batches} batches of size {bs}...")
    with torch.no_grad():
        for data in tqdm(dl):
            if counter == 0:
                rollout_occ_and_field = rollout_occ_and_field_wrapper(
                    model, dl.dataset, data, device
                )

            data = data.to(device)

            output_dict, loss = eval_single_rollout(
                simulator=model.model,
                features=data,
                num_steps=num_steps,  # 995 for waterdrop
                device=device,
            )

            iou_values, field_mse_values = rollout_occ_and_field(output_dict, data.enc_pos_batch)

            iou_list.append(iou_values)
            field_mse_list.append(field_mse_values)

            print(
                f"{counter}/{batches}, iou: {iou_values.mean().item():.4f}, mse: {field_mse_values.mean().item():.4f}"
            )
            counter += 1

    return {
        "iou": torch.concat(iou_list, dim=0),  # (num_trajs, num_steps)
        "vel_mse": torch.concat(field_mse_list, dim=0),  # (num_trajs, num_steps)
        "timestep": torch.arange(
            data.enc_pos.shape[1], data.enc_pos.shape[1] + num_steps
        ),  # (num_steps)
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=int, default=0, help="GPU index to use for computation")
    parser.add_argument("--project", type=str, required=True, help="wandb project name")
    parser.add_argument("--run_id", type=str, required=True, help="wandb run ID")
    parser.add_argument("--batch_size", type=int, required=True, help="batch size", default=64)
    parser.add_argument(
        "--num_workers", type=int, required=True, help="number of workers", default=8
    )
    args = parser.parse_args()

    device = f"cuda:{args.device}"
    project = args.project
    run_id = args.run_id

    api = wandb.Api()
    run = api.run(f"add_your_wandb_here/{project}/{run_id}")
    config = run.config
    cfg = OmegaConf.create(config)

    cfg.data.eval_split = "test"
    model = load_model(cfg=cfg, project=project, run_id=run_id, device=device)

    dl = get_dataloader(
        cfg=cfg,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        stage="rollout_gns_f",
    )

    out_dict = calc_iou_and_field(model=model, dl=dl, device=device)
    print("Inference done with:")
    for k, v in out_dict.items():
        if v.dtype == torch.long:  # timestep
            print(f"{k} {tuple(v.size())}: min={v[0]}, max={v[-1]}")
        else:  # iou, vel_mse
            print(f"{k} {tuple(v.size())}: mean={v.mean().item():.4f}, std={v.std().item():.4f}")

    rollout_path = Path(cfg.trainer.default_root_dir) / project / run_id / "rollout"
    if not rollout_path.exists():
        rollout_path.mkdir(parents=True)

    save_path = rollout_path / "results.pt"
    torch.save(out_dict, save_path)
    print(f"Finished calculating metrics for {project}:{run_id}")


if __name__ == "__main__":
    main()

# python scripts/evaluate_rollout_gns_f.py --device=7 --project=waterdrop --run_id=304hm65l --batch_size=15 --num_workers=15
# viper: nohup python scripts/evaluate_rollout_gns_f.py --device=0 --project=waterdrop_xl --run_id=muinbl5x --batch_size=10 --num_workers=10  > eval_rollout_waterdrop_xl_gns_f_4.log 2>&1 &
# viper: nohup python scripts/evaluate_rollout_gns_f.py --device=3 --project=dam3d --run_id=no9lk2o5 --batch_size=1 --num_workers=0 > eval_rollout_dam3d_gns_f_4.log 2>&1 &
