import argparse
import os
from pathlib import Path
from time import time

import einops
import lightning as L
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import rootutils
import torch
import wandb
from hydra.utils import instantiate
from mpl_toolkits.mplot3d import Axes3D
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from torchmetrics.functional.segmentation import mean_iou
from torchmetrics.segmentation import MeanIoU
from tqdm import tqdm

rootutils.set_root(path=os.path.abspath("."), pythonpath=True)

# from src.utils.logging_utils import load_run_config_from_wb


def load_model(cfg, project, run_id, device):

    ckpt_root_path = Path(cfg.trainer.default_root_dir) / project / run_id / "checkpoints"
    ckpt_files = [file for file in ckpt_root_path.iterdir() if file.name.startswith("epoch")]
    ckpt_path = ckpt_files[0] if ckpt_files else None
    if ckpt_path:
        print(f"Checkpoint found: {ckpt_path}")
    else:
        print("No checkpoint found.")

    ckpt_path = os.path.join(ckpt_path)

    ckpt = torch.load(ckpt_path, map_location=device)
    model = instantiate(cfg.model)
    model.load_state_dict(ckpt["state_dict"])
    model.to(device)
    model.freeze()
    model.eval()
    return model


def get_dataloader(cfg, batch_size, num_workers):
    cfg.data.stage = "rollout"
    cfg.data.batch_size = batch_size
    cfg.data.num_workers = num_workers
    cfg.data.pin_memory = False
    cfg.data.overfit_single_trajectory = False
    datamodule = instantiate(cfg.data)
    datamodule.setup()
    return datamodule.val_dataloader()


def calc_rollout_times(model, dl, iters, device, decode_once):
    all_all_times = []

    model.eval()
    with torch.no_grad():
        for input, target in tqdm(dl):
            input = input.to(device)
            target = target.to(device)
            timestep = input.timestep

            all_times = []
            for _ in range(iters):
                start = time()
                latent = model.encode(
                    enc_pos=input.enc_pos,
                    enc_field=input.enc_field,
                    enc_particle_type=input.enc_particle_type,
                    enc_pos_batch_index=input.enc_pos_batch,
                    supernode_index=input.supernode_index,
                    supernode_batch_index=input.supernode_index_batch,
                    timestep=timestep,
                )  # (1, 512, 192)
                if not decode_once:
                    _ = model.decode(
                        latent=latent,
                        dec_field_pos=input.grid_pos,
                        dec_occ_pos=input.grid_pos,
                        timestep=timestep,
                    )

                for t_idx in range(target.dec_pos.shape[1]):
                    latent = model.push_forward(latent=latent, timestep=timestep)

                    timestep = target.timestep[:, t_idx]
                    if not decode_once:
                        _ = model.decode(
                            latent=latent,
                            dec_field_pos=target.grid_pos,
                            dec_occ_pos=target.grid_pos,
                            timestep=timestep,
                        )
                if decode_once:
                    _ = model.decode(
                        latent=latent,
                        dec_field_pos=target.grid_pos,
                        dec_occ_pos=target.grid_pos,
                        timestep=timestep,
                    )

                duration = time() - start
                all_times.append(duration)
            all_all_times.append(torch.tensor(all_times))

    return torch.stack(all_all_times)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=int, default=0, help="GPU index to use for computation")
    parser.add_argument("--project", type=str, required=True, help="wandb project name")
    parser.add_argument("--run_id", type=str, required=True, help="wandb run ID")
    parser.add_argument("--iters", type=int, required=True, help="batch size", default=64)
    parser.add_argument(
        "--num_workers", type=int, required=True, help="number of workers", default=8
    )
    parser.add_argument(
        "--decode_once", type=bool, required=False, help="number of workers", default=False
    )
    args = parser.parse_args()

    device = f"cuda:{args.device}"
    project = args.project
    run_id = args.run_id

    api = wandb.Api()
    run = api.run(f"add_your_wandb_here/{project}/{run_id}")
    config = run.config
    cfg = OmegaConf.create(config)

    model = load_model(cfg=cfg, project=project, run_id=run_id, device=device)

    dl = get_dataloader(
        cfg=cfg, batch_size=1, num_workers=args.num_workers
    )  # Batch size 1 for time analysis

    times = calc_rollout_times(
        model=model, dl=dl, iters=args.iters, device=device, decode_once=args.decode_once
    )

    rollout_path = Path(cfg.trainer.default_root_dir) / project / run_id / "rollout"
    if not rollout_path.exists():
        rollout_path.mkdir(parents=True)

    filename = "times_" + torch.cuda.get_device_name(device).replace(" ", "_")
    if args.decode_once:
        filename += "_decode_once"
    filename += ".pt"

    save_path = rollout_path / filename
    torch.save(times, save_path)

    print(f"Finished calculating timing for {project}:{run_id}")


if __name__ == "__main__":
    main()
