import argparse
import os
from pathlib import Path

import einops
import lightning as L
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import rootutils
import torch
import wandb
from hydra.utils import instantiate
from mpl_toolkits.mplot3d import Axes3D
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from torchmetrics.functional.segmentation import mean_iou
from torchmetrics.segmentation import MeanIoU
from tqdm import tqdm

rootutils.set_root(path=os.path.abspath("."), pythonpath=True)

# from src.utils.logging_utils import load_run_config_from_wb


def load_model(cfg, project, run_id, device):
    ckpt_root_path = Path(cfg.trainer.default_root_dir) / project / run_id / "checkpoints"
    ckpt_files = [file for file in ckpt_root_path.iterdir() if file.name.startswith("epoch")]
    ckpt_path = ckpt_files[0] if ckpt_files else None
    if ckpt_path:
        print(f"Checkpoint found: {ckpt_path}")
    else:
        print("No checkpoint found.")

    ckpt_path = os.path.join(ckpt_path)

    ckpt = torch.load(ckpt_path, map_location=device)
    model = instantiate(cfg.model)
    model.load_state_dict(ckpt["state_dict"])
    model.to(device)
    model.freeze()
    model.eval()
    return model


def get_dataloader(cfg, batch_size, num_workers, stage="rollout"):
    cfg.data.stage = stage
    cfg.data.batch_size = batch_size
    cfg.data.num_workers = num_workers
    cfg.data.pin_memory = False
    cfg.data.overfit_single_trajectory = False
    datamodule = instantiate(cfg.data)
    datamodule.setup()
    return datamodule.val_dataloader()


def calc_preds_target_occ(model, dl, device):
    all_all_preds_occ = []
    all_all_preds_vel = []
    all_all_gt_occ = []
    all_all_gt_vel = []

    model.eval()
    with torch.no_grad():
        for input, target in tqdm(dl):
            input = input.to(device)
            target = target.to(device)
            batch_size = input.batch.max() + 1
            ndim = input.grid_pos.shape[-1]
            timestep = input.timestep
            latent = model.encode(
                enc_pos=input.enc_pos,
                enc_field=input.enc_field,
                enc_particle_type=input.enc_particle_type,
                enc_pos_batch_index=input.enc_pos_batch,
                supernode_index=input.supernode_index,
                supernode_batch_index=input.supernode_index_batch,
                timestep=timestep,
            )  # (1, 512, 192)
            out = model.decode(
                latent=latent,
                dec_field_pos=input.grid_pos,
                dec_occ_pos=input.grid_pos,
                timestep=timestep,
            )
            # Reshape
            preds_occ = einops.rearrange(
                out[0].cpu(),
                "(batch_size n_points) n_classes -> batch_size n_points n_classes",
                batch_size=batch_size,
            )
            preds_vel = einops.rearrange(
                out[1].cpu(),
                "(batch_size n_points) (n_time n_dim) -> batch_size n_points n_time n_dim",
                batch_size=batch_size,
                n_dim=ndim,
            )
            gt_occ = einops.rearrange(
                input.gt_occ,
                "(batch_size n_points) -> batch_size n_points",
                batch_size=batch_size,
            )
            gt_vel = einops.rearrange(
                input.gt_vel,
                "(batch_size n_points) n_dim -> batch_size n_points n_dim",
                batch_size=batch_size,
            )
            all_preds_occ = [preds_occ]
            all_preds_vel = [preds_vel]
            all_gt_occ = [gt_occ]
            all_gt_vel = [gt_vel]

            for t_idx in range(target.dec_pos.shape[1]):
                latent = model.push_forward(latent=latent, timestep=timestep)

                timestep = target.timestep[:, t_idx]
                # print(timestep)
                out = model.decode(
                    latent=latent,
                    dec_field_pos=target.grid_pos,
                    dec_occ_pos=target.grid_pos,
                    timestep=timestep,
                )
                # Reshape
                preds_occ = einops.rearrange(
                    out[0].cpu(),
                    "(batch_size n_points) n_classes -> batch_size n_points n_classes",
                    batch_size=batch_size,
                )
                preds_vel = einops.rearrange(
                    out[1].cpu(),
                    "(batch_size n_points) (n_time n_dim) -> batch_size n_points n_time n_dim",
                    batch_size=batch_size,
                    n_dim=ndim,
                )
                gt_occ = einops.rearrange(
                    target.gt_occ[:, t_idx],
                    "(batch_size n_points) -> batch_size n_points",
                    batch_size=batch_size,
                )
                gt_vel = einops.rearrange(
                    target.gt_vel[:, t_idx],
                    "(batch_size n_points) n_dim -> batch_size n_points n_dim",
                    batch_size=batch_size,
                )
                all_preds_occ.append(preds_occ)
                all_preds_vel.append(preds_vel)
                all_gt_occ.append(gt_occ)
                all_gt_vel.append(gt_vel)

            all_all_preds_occ.append(torch.stack(all_preds_occ, dim=1))
            all_all_preds_vel.append(torch.stack(all_preds_vel, dim=1))
            all_all_gt_occ.append(torch.stack(all_gt_occ, dim=1))
            all_all_gt_vel.append(torch.stack(all_gt_vel, dim=1))

    all_all_preds_occ = torch.concat(all_all_preds_occ, dim=0)
    all_all_preds_vel = torch.concat(all_all_preds_vel, dim=0)
    all_all_gt_occ = torch.concat(all_all_gt_occ, dim=0)
    all_all_gt_vel = torch.concat(all_all_gt_vel, dim=0)
    return {
        "preds_occ": all_all_preds_occ,
        "preds_vel": all_all_preds_vel,
        "gt_occ": all_all_gt_occ,
        "gt_vel": all_all_gt_vel,
        "grid_pos": input[0].grid_pos,
        "timestep": torch.concat([input[0].timestep, target[0].timestep.squeeze()]),
    }


def calc_iou(all_preds_occ, all_target_occ):
    all_mean_iou = torch.zeros(all_target_occ.shape[0:2])
    for batch_index in range(all_mean_iou.shape[0]):
        for timestep_index in range(all_mean_iou.shape[1]):
            preds_occ = torch.nn.functional.one_hot(
                all_preds_occ[batch_index, timestep_index].argmax(dim=-1), num_classes=2
            )
            target_occ = torch.nn.functional.one_hot(
                all_target_occ[batch_index, timestep_index].to(torch.int64),
                num_classes=2,
            )
            preds_occ = einops.rearrange(
                preds_occ, "n_points n_classes -> n_classes n_points"
            ).unsqueeze(dim=0)
            target_occ = einops.rearrange(
                target_occ, "n_points n_classes -> n_classes n_points"
            ).unsqueeze(dim=0)
            m_iou = mean_iou(preds_occ, target_occ, num_classes=2)
            all_mean_iou[batch_index, timestep_index] = m_iou
    return all_mean_iou


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=int, default=0, help="GPU index to use for computation")
    parser.add_argument("--project", type=str, required=True, help="wandb project name")
    parser.add_argument("--run_id", type=str, required=True, help="wandb run ID")
    parser.add_argument("--batch_size", type=int, required=True, help="batch size", default=64)
    parser.add_argument(
        "--num_workers", type=int, required=True, help="number of workers", default=8
    )
    parser.add_argument(
        "--data_split", type=str, required=False, help="data split the evaluate on", default="test"
    )
    args = parser.parse_args()

    device = f"cuda:{args.device}"
    project = args.project
    run_id = args.run_id

    api = wandb.Api()
    run = api.run(f"add_your_wandb_here/{project}/{run_id}")
    config = run.config
    cfg = OmegaConf.create(config)
    cfg.data.eval_split = args.data_split
    model = load_model(cfg=cfg, project=project, run_id=run_id, device=device)

    dl = get_dataloader(cfg=cfg, batch_size=args.batch_size, num_workers=args.num_workers)

    out_dict = calc_preds_target_occ(
        model=model,
        dl=dl,
        device=device,
    )

    save_dict = {}
    save_dict["iou"] = calc_iou(out_dict["preds_occ"].cpu(), out_dict["gt_occ"].cpu())
    # Only return the last velocity of the two consecutive velocities and mask with predicted occupancy
    preds_vel = out_dict["preds_vel"][:, :, :, -1, :].cpu() * (
        out_dict["preds_occ"][:, :, :, 1] > 0
    ).unsqueeze(dim=-1)

    save_dict["vel_mse"] = ((preds_vel - out_dict["gt_vel"].cpu()) ** 2).mean(dim=[-1, -2])
    save_dict["timestep"] = out_dict["timestep"].cpu()

    rollout_path = Path(cfg.trainer.default_root_dir) / project / run_id / "rollout"
    if not rollout_path.exists():
        rollout_path.mkdir(parents=True)

    save_path = rollout_path / ("results_" + args.data_split + ".pt")
    torch.save(save_dict, save_path)
    print(f"Finished calculating metrics for {project}:{run_id}")


if __name__ == "__main__":
    main()
