import argparse
import warnings

import numpy as np
import torch as t
from matplotlib import pyplot as plt
from rich import print
from torch.distributions import Categorical

from .args import add_visualise_args
from .simulators import SIMULATOR
from .solvers import get_solver

# For Procgen gym environment
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

t.set_num_threads(1)
t.set_printoptions(precision=2, threshold=np.inf, sci_mode=False)
np.set_printoptions(precision=2, threshold=np.inf, suppress=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_visualise_args(parser)
    args = parser.parse_args()

    args.device = "cuda" if t.cuda.is_available() else "cpu"

    solver = get_solver(args)
    solver.load(args.weights, strict=True, verbose=True)
    solver.to(args.device)
    solver.eval()

    print(
        f"\n------------------->  Summary  <-------------------"
        f"\nEnvironment:        {SIMULATOR.__name__}"
        f"\nSolver:             {solver}"
        f"\nWeights file:       {args.weights}"
        f"\nDevice:             {args.device}"
        f"\nDebugging Mode:     {args.debug}"
    )

    with t.no_grad():
        simulator = SIMULATOR()
        state = simulator.reset()

        # Clear the previous images
        if not args.debug:
            plt.ion()
            plt.show()

        episode_reward = 0
        simulator.render(as_image=(not args.debug))
        for i in range(SIMULATOR.max_steps):
            q_values = solver(simulator.state_tensor().to(args.device))[0]
            action = Categorical(logits=q_values).sample().cpu().numpy().item()

            state, reward, terminal = simulator.step(action)
            episode_reward += reward

            print(
                f"\n\n"
                f"Q-Values: {q_values.cpu().numpy()}\n"
                f"Policy:   {t.softmax(q_values, -1).cpu().numpy()}\n"
                f"Action:   {SIMULATOR.ACTIONS[action]}\n"
                f"Reward:   {reward}\n"
                f"Terminal: {terminal}\n"
            )
            simulator.render(as_image=(not args.debug))

            if terminal:
                break

        print(f"\n\nTotal episode reward: {episode_reward}")
