import argparse
import math
import multiprocessing as mp
import os
import warnings
from copy import deepcopy
from time import sleep

import numpy as np
import torch as t
from rich import print
from tqdm import tqdm

from .args import add_evaluate_args
from .simulators import SIMULATOR
from .solvers import get_solver

# For Procgen gym environment
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
t.set_num_threads(1)


def worker(
    wid,
    solver,
    episode_rewards,
    success,
    progress,
    start,
    end,
    max_batch_size,
):
    if t.cuda.is_available():
        t.cuda.set_device(wid % t.cuda.device_count())
    device = "cuda" if t.cuda.is_available() else "cpu"

    solver = deepcopy(solver)
    solver.to(device)
    solver.eval()

    with t.inference_mode():
        alive = list(range(start, end))

        local_episode_rewards = {env_id: 0 for env_id in alive}
        local_success = {env_id: False for env_id in alive}
        all_envs = {env_id: SIMULATOR() for env_id in alive}

        [all_envs[env_id].reset() for env_id in alive]

        n_success = 0
        for ts in range(SIMULATOR.max_steps):
            progress[wid] = (ts, n_success / max(end - start, 1e-10))

            states, actions = [all_envs[env_id].state_tensor() for env_id in alive], []
            for b in range(math.ceil(len(alive) / max_batch_size)):
                batch_states = states[b * max_batch_size : (b + 1) * max_batch_size]
                actions.extend(solver.predict(t.cat(batch_states, dim=0).to(device)))

            dead = []
            for env_id, a in zip(alive, actions):
                next_state, reward, terminal = all_envs[env_id].step(int(a))
                local_episode_rewards[env_id] += reward

                if terminal:
                    local_success[env_id] = all_envs[env_id].is_solved()
                    n_success += local_success[env_id]
                    dead.append(env_id)

            for env_id in dead:
                alive.remove(env_id)

            if len(alive) == 0:
                break

    progress[wid] = (SIMULATOR.max_steps, n_success / (end - start))
    for env_id in range(start, end):
        episode_rewards[env_id] = local_episode_rewards[env_id]
        success[env_id] = local_success[env_id]

    del solver


if __name__ == "__main__":
    mp.set_start_method("spawn", True)

    parser = argparse.ArgumentParser()
    add_evaluate_args(parser)
    args = parser.parse_args()

    args.batch_size = args.batch_size if args.batch_size else math.ceil(args.n_levels / args.n_workers)

    solver = get_solver(args)
    assert solver.load(
        args.weights,
        strict=True,
        verbose=True,
    ), f"Error while loading the weights from file: {args.weights}"

    for k, v in sorted(args.__dict__.items()):
        print(f"{k.ljust(20)} : {v}")

    episode_rewards = mp.Manager().dict({i: 0 for i in range(args.n_levels)})
    success = mp.Manager().dict({i: False for i in range(args.n_levels)})
    progress = mp.Manager().list([(0, 0)] * args.n_workers)

    print("starting evaluators...\n")
    levels_per_worker = math.ceil(args.n_levels / args.n_workers)
    processes = [
        mp.Process(
            target=worker,
            args=(
                wid,
                solver,
                episode_rewards,
                success,
                progress,
                levels_per_worker * wid,
                min(
                    levels_per_worker * (wid + 1),
                    args.n_levels,
                ),
                args.batch_size,
            ),
        )
        for wid in range(args.n_workers)
    ]
    try:
        [p.start() for p in processes]

        pbar = tqdm(total=SIMULATOR.max_steps)
        prev_ts, pbar_ts = 0, 0
        while pbar_ts < SIMULATOR.max_steps:
            sleep(1)
            pbar_ts, pbar_success = list(zip(*progress))
            pbar_ts = math.ceil(np.mean(pbar_ts))

            pbar.update(pbar_ts - prev_ts)
            pbar.set_description(f"Success: {np.mean(pbar_success) * 100:.02f}%  |  Timestep")
            prev_ts = pbar_ts
        del pbar

        [p.join() for p in processes]

    except KeyboardInterrupt:
        print("Stopping evaluation...")
        [p.kill() for p in processes]

    episode_rewards = np.array(list(episode_rewards.values()))
    success = np.array(list(success.values()))

    # Only applicable to Navigation
    collisions = np.sum(episode_rewards <= SIMULATOR.failed)
    absolute_collision_rate = collisions / len(success)
    relative_collision_rate = collisions / max(len(success) - np.sum(success), 1e-10)

    print(
        f"\nSuccess rate: {np.mean(success) * 100:.02f} ",
        f"(±{np.std(success) / np.sqrt(len(success)) * 100:.02f}%)",
        f"\nAbsolute Collision Rate: {absolute_collision_rate * 100:.02f}%",
        f"\nRelative Collision Rate: {relative_collision_rate * 100:.02f}%",
        f"\nMean Episode Reward: {np.mean(episode_rewards):.02f} ",
        f"(±{np.std(episode_rewards):.02f})",
    )

    os.makedirs(f"{os.path.dirname(args.save_summary_path)}", exist_ok=True)
    t.save(
        {
            "simulator": SIMULATOR.__name__,
            "solver": str(solver),
            "episode_rewards": episode_rewards,
            "success_rate": np.mean(success),
            "absolute_collision_rate": absolute_collision_rate,
            "relative_collision_rate": relative_collision_rate,
        },
        args.save_summary_path,
    )
