import os
import time

import numpy as np
import torch
import wandb
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from propose.datasets.human36m.Human36mDataset import Human36mDataset
from propose.evaluation.mpjpe import mpjpe, pa_mpjpe
from propose.evaluation.pck import human36m_joints_to_use, pck
from propose.models.flows import CondGraphFlow
from propose.utils.reproducibility import set_random_seed


def evaluate(flow, test_dataloader, temperature=1.0):
    ms = []

    iter_dataloader = iter(test_dataloader)

    pbar = tqdm(range(len(test_dataloader)))

    flow.eval()
    for _ in pbar:
        batch, _, action = next(iter_dataloader)
        batch.to(flow.device)

        with torch.no_grad():
            samples = flow.sample(200, batch, temperature=temperature)

        true_pose = batch["x"].x.cpu().numpy().reshape(-1, 16, 1, 3)
        sample_poses = samples["x"].x.detach().cpu().numpy().reshape(-1, 16, 200, 3)

        true_pose = np.insert(true_pose, 0, 0, axis=1)
        sample_poses = np.insert(sample_poses, 0, 0, axis=1)

        pck_score = pck(
            true_pose[:, human36m_joints_to_use] / 0.0036,
            sample_poses[:, human36m_joints_to_use] / 0.0036,
        )

        has_correct_pose = pck_score.max().unsqueeze(0).numpy()
        mean_correct_pose = pck_score.mean().unsqueeze(0).numpy()

        m = mpjpe(true_pose / 0.0036, sample_poses / 0.0036, dim=1)
        ms.append(m)

        print(
            np.quantile(
                np.concatenate(ms), np.array([0, 0.25, 0.5, 0.75, 0.9]), axis=1
            ).mean(1)
        )

    return


def mpjpe_experiment(flow, config, name="test", **kwargs):
    test_dataset = Human36mDataset(
        **config["dataset"],
        **kwargs,
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=1, shuffle=True, pin_memory=False, num_workers=0
    )
    (
        test_res,
        test_res_pa,
        test_res_single,
        test_res_pa_single,
        test_res_pck,
        test_res_mean_pck,
    ) = evaluate(flow, test_dataloader)

    res = {
        f"{name}/test_res": np.concatenate(test_res).mean(),
        f"{name}/test_res_pa": np.concatenate(test_res_pa).mean(),
        f"{name}/test_res_single": np.concatenate(test_res_single).mean(),
        f"{name}/test_res_pa_single": np.concatenate(test_res_pa_single).mean(),
        f"{name}/test_res_pck": np.concatenate(test_res_pck).mean(),
        f"{name}/test_res_mean_pck": np.concatenate(test_res_mean_pck).mean(),
    }

    return res, test_dataset, test_dataloader


def run(use_wandb: bool = False, config: dict = None):
    """
    Train a CondGraphFlow on the Human36m dataset.
    :param use_wandb: Whether to use wandb for logging.
    :param config: A dictionary of configuration parameters.
    """
    set_random_seed(config["seed"])

    config["dataset"]["dirname"] = config["dataset"]["dirname"] + "/test"

    if use_wandb:
        wandb.init(
            project="propose_human36m",
            entity=os.environ["WANDB_USER"],
            config=config,
            job_type="evaluation",
            name=f"{config['experiment_name']}_human36m_{time.strftime('%d/%m/%Y::%H:%M:%S')}",
            tags=config["tags"] if "tags" in config else None,
            group=config["group"] if "group" in config else None,
        )

    flow = CondGraphFlow.from_pretrained(
        f'ANONYMOUS/propose_human36m/{config["experiment_name"]}:latest'
    )

    config["cuda_accelerated"] = flow.set_device()
    flow.eval()

    # Test
    # test_res, test_dataset, test_dataloader = mpjpe_experiment(
    #     flow,
    #     config,
    #     occlusion_fractions=[],
    #     test=True,
    #     name="test",
    # )
    # #
    # # if use_wandb:
    # #     wandb.log(test_res)
    # #
    # # Hard
    # hard_res, hard_dataset, hard_dataloader = mpjpe_experiment(
    #     flow,
    #     config,
    #     occlusion_fractions=[],
    #     hardsubset=True,
    #     name="hard",
    # )
    #
    # # if use_wandb:
    # #     wandb.log(hard_res)
    # #
    hard_dataset = Human36mDataset(
        **config["dataset"],
        hardsubset=True,
        occlusion_fractions=[],
    )
    #
    # # # Occlusion Only
    ms = []
    mpjpes = []
    pbar = tqdm(range(len(hard_dataset)))
    for i in pbar:
        batch = hard_dataset[i][0]
        batch  # .cuda()
        samples = flow.sample(200, batch)  # .cuda())

        true_pose = (
            batch["x"]
            .x.numpy()  # .cpu()
            .reshape(-1, 16, 1, 3)[:, np.insert(hard_dataset.occlusions[i], 9, False)]
        )
        sample_poses = (
            samples["x"]
            .x.detach()
            .cpu()
            .numpy()
            .reshape(-1, 16, 200, 3)[:, np.insert(hard_dataset.occlusions[i], 9, False)]
        )

        m = mpjpe(true_pose / 0.0036, sample_poses / 0.0036, dim=1)

        ms.append(m)
        print(
            np.nanmean(
                np.quantile(
                    np.concatenate(ms), np.array([0, 0.25, 0.5, 0.75, 0.9]), axis=1
                ),
                axis=1,
            )
        )

        m = np.min(m, axis=-1)

        m = m.tolist()

        mpjpes += [m]

        pbar.set_description(f"MPJPE: {np.nanmean(np.concatenate(mpjpes)):.4f}")

    np.save("/scripts/mpjpes.npy", np.concatenate(ms))

    # occl_res = np.nanmean(mpjpes)
    # if use_wandb:
    #     wandb.log({"occl/best_mpjpe": occl_res})
    #
    # print("MPJPE for best")
    # print("---")
    # print(f"H36M: {test_res}")
    # print(f"H36MA: {hard_res}")
    # print(f"Occl.: {occl_res}")
    # print("---")
