from typing import *
import gym.wrappers
import gym
import exp_utils as PQ
from loguru import logger
import torch
import numpy as np

from safe import *
from safe.envs import make_env
from rl_utils.runner import merge_episode_stats

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def evaluate(step, runner, policy, tag, *, n_eval_samples):
    runner.reset()
    ep_infos = runner.run(policy, n_eval_samples)

    for key, value in merge_episode_stats(ep_infos).items():
        value = np.array(value)
        mean, std = np.mean(value), np.std(value)
        if key == 'episode.unsafe':
            if value > 0:
                logger.warning(f'# {step}, tag = {tag}, {key} = {mean:.6f} ± {std:.6f} over {len(value)} episodes.')
        else:
            logger.info(f'# {step}, tag = {tag}, {key} = {mean:.6f} ± {std:.6f} over {len(value)} episodes.')
        PQ.writer.add_scalar(f'{tag}/{key}/mean', mean, global_step=step)
        PQ.writer.add_scalar(f'{tag}/{key}/std', std, global_step=step)
        PQ.writer.add_scalar(f'{tag}/{key}/n', len(value), global_step=step)


class Debugger:
    video_env_det: gym.wrappers.Monitor
    video_env_rand: gym.wrappers.Monitor
    det_traj: List
    rand_traj: List

    def __init__(self, env, policy, mean_policy, L, model, runner, horizon, s0, s_opt, s_opt_grad, s_opt_sample, L_opt,
                 fns, obj_eval, buffer_out, safe_invariant, FLAGS):
        self.env = env
        self.policy = policy
        self.model = model
        self.mean_policy = mean_policy
        self.runner = runner
        self.horizon = horizon
        self.L = L
        self.s0 = s0
        self.s_opt = s_opt
        self.s_opt_sample = s_opt_sample
        self.s_opt_grad = s_opt_grad
        self.L_opt = L_opt
        self.fns = fns
        self.obj_eval = obj_eval
        self.n_policy_updates = 0
        self.n_barrier_updates = 0
        self.buffer_out = buffer_out
        self.status = {}
        self.safe_invariant = safe_invariant
        self.FLAGS = FLAGS

        self.init_video_maker()

    def init_video_maker(self):
        from gym.wrappers import Monitor

        video_path = PQ.log_dir / 'videos'
        video_path.mkdir()
        env = make_env()
        # self.video_env_det = Monitor(env, video_path, force=True, video_callable=lambda episode_id: True, uid='det')
        # self.video_env_rand = Monitor(env, video_path, force=True, video_callable=lambda episode_id: True, uid='rand')
        self.video_env_det = self.video_env_rand = env

    def update(self, step_type):
        if step_type == 'policy':
            self.n_policy_updates += 1
        elif step_type == 'barrier':
            self.n_barrier_updates += 1
        else:
            assert 0

    def _do_policy(self, t):
        evaluate(t, self.runner, self.policy, 'policy', n_eval_samples=self.horizon)

    def _do_mean_policy(self, t):
        evaluate(t, self.runner, self.mean_policy, 'mean_policy', n_eval_samples=self.horizon)

    def _do_L_grad(self, t):
        if next(self.L.parameters()).grad is not None:
            grads = torch.nn.utils.parameters_to_vector([p.grad for p in self.L.parameters() if p.grad is not None])
            L_grad_norm = grads.norm().item()
        else:
            L_grad_norm = 0.

        PQ.writer.add_scalar('L/grad_norm', L_grad_norm, global_step=t)
        g_s0 = self.L.net(self.s0).item()
        logger.debug(f"g(s0) = {g_s0:.6f}, L grad norm = {L_grad_norm:.6f}, ")

    def _do_updates(self, t):
        p_policy_update = self.n_policy_updates / self.FLAGS.n_eval_iters
        p_barrier_update = self.n_barrier_updates / self.FLAGS.n_eval_iters
        PQ.writer.add_scalar('policy/p_updates', p_policy_update, global_step=t)
        PQ.writer.add_scalar('barrier/p_updates', p_barrier_update, global_step=t)
        self.n_policy_updates = 0
        self.n_barrier_updates = 0
        logger.debug(f"Pr(policy updates) = {p_policy_update:.3f}, Pr(barrier updates) = {p_barrier_update:.3f}")

    def _do_s_grad(self, t):
        self.s_opt_grad.reinit()
        for i in range(10000):
            if i % 1000 == 0:
                grad_opt_info = self.s_opt_grad.evaluate(step=t)
            self.s_opt_grad.step()

        PQ.writer.add_scalar('grad_opt/optimal', grad_opt_info['optimal'], global_step=t)

        for i in range(10000):
            if i % 1000 == 0:
                self.s_opt.evaluate(step=t)
            self.s_opt.step()

    def _do_s(self, t):
        self.s_opt_sample.evaluate(step=t)
        self.s_opt.evaluate(step=t)
        # s_opt_grad.evaluate()
        self.L_opt.evaluate(self.s_opt.s, step=t)

        for i in range(self.FLAGS.opt_s.n_steps):
            PQ.writer.add_scalar(f's/step_{i}', PQ.meters[f'opt_progress/{i}'].mean, global_step=t)
        PQ.meters.purge('opt_progress/')

    def _do_plot(self, t):
        if self.FLAGS.env.id == 'MyPendulum-v0':
            clouds = {}
            if self.FLAGS.opt_s.method in ['grad', 'MALA', 'CEM', 'metropolis']:
                clouds['s'] = self.s_opt.s.cpu().detach().numpy()
            clouds['traj'] = np.array(self.det_traj)

            plot_fns = {key: self.fns[key] for key in ['L', 'hardD', 'softD']}
            plot_pendulum_set(plot_fns, device, clouds, PQ.log_dir / f'fig-{t}.png', f'# {t}')
        elif self.FLAGS.env.id == 'SafeInvertedPendulum-v2':
            Vt = torch.eye(4, device=device, dtype=torch.float32)

            plot_fns = {
                'L': lambda xs: self.L(xs.reshape(-1, 2).mm(Vt[:2]).reshape(201, 201, 4)),
                # 'hardD': lambda xs: self.fns['hardD'](xs.reshape(-1, 2).mm(Vt[:2]).reshape(201, 201, 4)),
            }
            traj = np.array(self.det_traj)
            clouds = {
                'traj': traj,
            }

            plot_pendulum_set(plot_fns, device, clouds, PQ.log_dir / f'fig-{t}.png', f'# {t}',
                              y_max=0.2, y_min=-0.2, x_max=1.0, x_min=-1.0, xlabel="pos", ylabel="angle")

    def _do_save(self, t):
        torch.save({
            'L': self.L.state_dict(),
            's': self.s_opt.state_dict(),
            'policy': self.policy.state_dict(),
            'models': self.model.state_dict(),
            'safe_invariant': self.safe_invariant.state_dict(),
        }, PQ.log_dir / f'ckpt-{t}.pt')

    def _do_expl(self, t):
        logger.debug(f"[expl]: backup prob = {PQ.meters['expl/backup'].mean:.6f}")
        PQ.meters['expl/backup'].reset()

    def evaluate(self, t, **kwargs):
        keys = [key for key, value in kwargs.items() if value]
        if len(keys) == 0:
            return
        # logger.info(f"################ iter {t}")
        for key in keys:
            getattr(self, f'_do_{key}')(t)

    @torch.no_grad()
    def _do_virt_safe(self, t):
        self.L.eval()
        state = self.s0
        cur_L = self.L(state[None])[0].item()
        max_L = cur_L
        last_L, last_s = cur_L, state
        for t in range(10000):
            action = self.mean_policy(state)
            state = self.model.models[0](state, action)
            cur_L = self.L(state[None])[0].item()
            max_L = max(max_L, cur_L)

            if last_L < 1 < cur_L:
                print('virt vio', t, last_s.cpu().numpy(), last_L, cur_L)
                break
            last_L, last_s = cur_L, state
        self.status['policy_virt_safe'] = max_L <= 1

        logger.debug(f"[virt safe] max L = {max_L:.6f}")

    def _do_video(self, t):
        self.det_traj = render_video(t, self.video_env_det, self.mean_policy, 'det')
        self.rand_traj = render_video(t, self.video_env_rand, self.policy, 'rand')

    def _do_buf_out(self, t):
        logger.debug(f"[buf out] size = {self.buffer_out.index}")


def render_video(t, video_env, policy, tag):
    observations = []
    observation = video_env.reset()
    observations.append(observation)
    video_env.episode_id = t
    return_ = 0.
    done = False
    info = {}

    while not done:
        action = policy.get_actions(observation)
        next_observation, reward, done, info = video_env.step(action)
        return_ += reward
        observation = next_observation
        observations.append(observation)
        if done:
            break
    logger.debug(f'[video {tag}] iter = {t}: return = {return_:.6f}, last info = {info}')
    return observations
