import sys
sys.path.append("./common")
sys.path.append("./auto_LiRPA")
from auto_LiRPA import BoundedModule, BoundedTensor
from auto_LiRPA.perturbations import PerturbationLpNorm
from argparser import argparser
import copy
import numpy as np
from read_config import load_config
from attacks import attack
from common.wrappers import make_atari, wrap_deepmind, wrap_pytorch, make_atari_cart
from models_smooth import QNetwork, model_setup
import torch.optim as optim
import torch
import torch.autograd as autograd
import time
import os
import argparse
import random
from datetime import datetime
from scipy.stats import norm
from PIL import Image

from utils import Logger, get_acrobot_eps, test_plot 
from async_env import AsyncEnv
from config_v_table import get_V_table

from dataclasses import dataclass, field
from typing import Any
from queue import PriorityQueue

UINTS=[np.uint8, np.uint16, np.uint32, np.uint64]
USE_CUDA = torch.cuda.is_available()
Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)

global_id = 0

class Elem(object):
    def __init__(self, env, s, a, rad, re, no):
        global global_id
        self.gid = global_id
        global_id += 1

        self.env = env
        self.s = s
        self.a = a
        self.rad = rad
        self.re = re
        self.no = no

    def __lt__(self, other):
        return self.rad < other.rad

re_min = 1e100

certify_map = {}

state_set = set()
state_dict = dict()

@dataclass(order=True)
class PrioritizedItem:
    priority: int
    item: object = field()

p_que = PriorityQueue()

p_que.put(PrioritizedItem(1e100, Elem(None, None, None, 1e100, 0, -1)))


def take_action(env, state, rad_lim, re_cur, no):
    global m
    global sigma
    global model
    global state_set

    # im = Image.fromarray(state.squeeze(), 'L')
    # im.save(os.path.join("frames", '{:05d}.bmp'.format(no)))

    # check whether ever visited next_state
    vis = False
    if state.tobytes() in state_set:
        logger.log(f'########################################### duplicated states encountered, state_set = {len(state_set)}')
        vis = True
    state_set.add(state.tobytes())

    state_tensor = torch.from_numpy(np.ascontiguousarray(state)).unsqueeze(0).cuda().to(torch.float32)
    if dtype in UINTS:
        state_tensor /= 255    

    a_star, tilde_q = model.act(state_tensor, return_q=True, cert=False)

    # in case
    tilde_q = torch.clamp(tilde_q, min=v_lo, max=v_hi)

    a_star = a_star[0]
    tilde_q = tilde_q[0]

    assert(torch.argmax(tilde_q) == a_star)

    a_list = []

    snapshot = env.ale.cloneState()
    for a in range(action_no):
        env.ale.restoreState(snapshot)

        if training_config['use_async_env']:
            env.async_step(a)
            next_state, reward, done, _ = env.wait_step()
        else:
            next_state, reward, done, _ = env.step(a)

        if done:
            continue

        # reward shaping for Pong
        if reward == -1:
            reward = 0

        val_1 = tilde_q[a_star].cpu()
        val_2 = tilde_q[a].cpu()

        if val_1 - delta >= val_2 + delta:
            rad = sigma / 2 * (
                                norm.ppf((val_1 - delta - v_lo) / (v_hi - v_lo)) - 
                                norm.ppf((val_2 + delta - v_lo) / (v_hi - v_lo))
                                )
            logger.log(f'certified: radius = {rad}')
        else:
            rad = 0     # cannot certify
            logger.log(f'cannot certify: val_1 = {val_1}, val_2 = {val_2}, delta = {delta}, radius = 0')

        if np.isnan(rad):
            logger.log(f'a_star: {tilde_q[a_star]}, a: {tilde_q[a]}, v_lo: {v_lo}, v_hi: {v_hi}')
            raise NotImplementedError

        if rad <= rad_lim:
            a_list.append(a)
        elif not vis:
            env.ale.restoreState(snapshot)  # revert back for storing

            elem = Elem(copy.deepcopy(env), state, a, rad, re_cur, no)
            p_que.put(PrioritizedItem(elem.rad, elem))

    return a_list


def expand(env, state, rad_lim=0, re_cur=0, no=0):
    global state_dict
    global re_min

    # logger.log(f'no = {no}, re_cur = {re_cur}, rad_lim = {rad_lim}, sizeq = {len(p_que.queue)}, map = {len(certify_map)}')

    # pruning
    if re_cur >= re_min:
        logger.log(f'************************************************ pruning at no={no}')
        return 0

    if no >= max_frames_per_episode:
        re_min = min(re_min, re_cur)
        logger.log(f'================================================ run to the end with re={re_min}')
        return 0

    snapshot = env.ale.cloneState()
    a_list = take_action(env, state, rad_lim, re_cur, no)
    env.ale.restoreState(snapshot)

    # logger.log('a_list', len(a_list))

    if not len(a_list):
        re_min = min(re_min, re_cur)
        return 0

    future_min = 1e100
    # snapshot = env.ale.cloneState()
    for a in a_list:
        env.ale.restoreState(snapshot)

        if training_config['use_async_env']:
            env.async_step(a)
            next_state, reward, done, _ = env.wait_step()
        else:
            next_state, reward, done, _ = env.step(a)

        assert not done

        # reward shaping for Pong
        if reward == -1:
            reward = 0

        logger.log(f'no = {no+1}, re_cur = {re_cur+reward}, rad_lim = {rad_lim}, sizeq = {len(p_que.queue)}, map = {len(certify_map)}, a = {a}, a_list = {a_list}')

        # check whether ever visited next_state under the same rad_lim
        if next_state.tobytes() in state_dict:
            logger.log(f'+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- skipping duplicated states, state_dict = {len(state_dict)}')
            future_min = min(future_min, state_dict[next_state.tobytes()]+reward)
            re_min = min(re_min, re_cur+future_min)
            continue

        ret = expand(env, next_state, rad_lim, re_cur+reward, no+1)

        if next_state.tobytes() in state_dict:
            state_dict[next_state.tobytes()] = min(state_dict[next_state.tobytes()], ret)
        else:
            state_dict[next_state.tobytes()] = ret

        future_min = min(future_min, ret+reward)

    return future_min


args = argparser()

config = load_config(args)
prefix = config['env_id']
training_config = config['training_config']
test_config = config['test_config']
sample_config = config['sample_config']
attack_config = test_config["attack_config"]
certify_global = attack_config['certify_global']
n_runs = test_config['num_episodes']



prefix += '_tree_'
if config['name_suffix']:
    prefix += config['name_suffix']
if config['path_prefix']:
    prefix = os.path.join(config['path_prefix'], prefix)
if 'load_model_path' in test_config and os.path.isfile(test_config['load_model_path']):
    model_path = test_config['load_model_path']
    if not os.path.exists(prefix):
        os.makedirs(prefix)
    test_log = os.path.join(prefix, test_config['log_name'])
else:
    if os.path.exists(prefix):
        test_log = os.path.join(prefix, test_config['log_name'])
    else:
        raise ValueError('Path {} not exists, please specify test model path.')
smooth = test_config['smooth']
m = test_config['m']
sigma = test_config['sigma']

V_table = get_V_table(model_path)
v_lo, v_hi = V_table[sigma]

delta = ((v_hi - v_lo) * np.sqrt(np.log(1 / sample_config['conf_bound']) / (2 * m)))

if smooth:
    test_log += f"_m-{m}_sigma-{sigma}"

if 'nat' in model_path:
    test_log += '_nat'
elif 'pgd' in model_path:
    test_log += '_pgd'
elif 'cov' in model_path or 'convex' in model_path:
    test_log += '_cov'
elif 'aug' in model_path:
    p1 = model_path.rfind('_')
    p2 = model_path.rfind('.')
    aug_rad = model_path[p1+1:p2]
    test_log += f'_aug-{aug_rad}'
elif 'adv' in model_path:
    test_log += '_adv'
elif 'frame' in model_path:
    p1 = model_path.rfind('_')
    p2 = model_path.rfind('.')
    frame_cnt = model_path[p1+1:p2]
    test_log += f'_frame-{frame_cnt}'
else:
    raise NotImplementedError(f'model_path = {model_path} type unrecognized!')

num_episodes = test_config['num_episodes']
max_frames_per_episode = test_config['max_frames_per_episode']
test_log += f'_max-frames-{max_frames_per_episode}'

logger = Logger(open(test_log, "w"))

sys.setrecursionlimit(20000)
logger.log('set recursion limit at 20000 done!')


logger.log('Command line:', " ".join(sys.argv[:]))
logger.log(args)
logger.log(config)
logger.log(f'v_lo = {v_lo : .6f}, v_hi = {v_hi : .6f}')
certify = test_config.get('certify', False)
env_params = training_config['env_params']
env_params['clip_rewards'] = False
env_params['episode_life'] = False
env_id = config['env_id']

if "NoFrameskip" not in env_id:
    env = make_atari_cart(env_id)
else:
    env = make_atari(env_id)
    env = wrap_deepmind(env, **env_params)
    env = wrap_pytorch(env)

state = env.reset()

dtype = state.dtype
action_no = env.action_space.n
logger.log("env_shape: {}, num of actions: {}".format(env.observation_space.shape, action_no))

model_width = training_config['model_width']
robust_model = certify
dueling = training_config.get('dueling', True)

model = model_setup(test_config, attack_config, sample_config, env_id, env.observation_space.shape, env.action_space.n,
                    robust_model, logger, USE_CUDA, dueling, model_width, 
                    smooth, m, sigma, v_lo, v_hi)

if 'load_model_path' in test_config and os.path.isfile(test_config['load_model_path']):
    model_path = test_config['load_model_path']
else:
    logger.log("choosing the best model from " + prefix)
    all_idx = [int(f[6:-4]) for f in os.listdir(prefix) if os.path.isfile(os.path.join(prefix, f)) and os.path.splitext(f)[1]=='.pth' and 'best' not in f]
    all_best_idx = [int(f[11:-4]) for f in os.listdir(prefix) if os.path.isfile(os.path.join(prefix, f)) and os.path.splitext(f)[1]=='.pth' and 'best' in f]
    if all_best_idx:
        model_frame_idx = max(all_best_idx)
        model_name = 'best_frame_{}.pth'.format(model_frame_idx)
    else:
        model_frame_idx = max(all_idx)
        model_name = 'frame_{}.pth'.format(model_frame_idx)
    model_path = os.path.join(prefix, model_name)

logger.log('model loaded from ' + model_path)

if 'pgd' in model_path:
    model.load_state_dict(torch.load(model_path))
    logger.log(f'model={model_path} loading done!')
else:
    model.features.load_state_dict(torch.load(model_path))
    logger.log(f'model={model_path} loading done!')



all_rewards = []
episode_reward = 0

seed = random.randint(0, sys.maxsize)
logger.log('reseting env with seed', seed)
env.seed(seed)
state = env.reset()
start_time = time.time()
if training_config['use_async_env']:
    # Create an environment in a separate process, run asychronously
    env = AsyncEnv(env_id, result_path=test_log, draw=training_config['show_game'], record=training_config['record_game'], save_frames=test_config['save_frames'], env_params=env_params, seed=args.seed)

episode_idx = 1
this_episode_frame = 1

if certify:
    certified = 0

if dtype in UINTS:
    state_max = 1.0
    state_min = 0.0
else:
    state_max = float('inf')
    state_min = float('-inf')


expand(env, state, rad_lim=0, re_cur=0, no=0)

while 1:
    if p_que.empty():
        break

    elem = p_que.get().item
    env = elem.env
    s = elem.s
    a = elem.a
    rad = elem.rad
    re = elem.re
    no = elem.no

    state_dict = dict()

    logger.log(f'start from {no} with rad={rad} and re={re}')

    certify_map[rad] = re_min
    logger.log(f'------------------------ putting elem into certify_map: {rad} : {re_min}')

    while not p_que.empty():
        elem_prime = p_que.queue[0].item  # peek
        rad_prime = elem_prime.rad

        if rad_prime == rad:
            # pop out elements of same radius
            p_que.get()
        else:
            break

    if p_que.empty():
        break

    assert rad_prime > rad, f'rad_prime ({rad_prime}) <= rad ({rad}), gid_prime ({elem_prime.gid}) gid ({elem.gid})'


    if training_config['use_async_env']:
        env.async_step(a)
        next_state, reward, done, _ = env.wait_step()
    else:
        next_state, reward, done, _ = env.step(a)

    # reward shaping for Pong
    if reward == -1:
        reward = 0

    assert not done

    expand(env, next_state, rad_prime, re+reward, no+1)

logger.log(f'time = {time.time() - start_time}')

certify_map_path = test_log+'_certify-map.pt'
torch.save(certify_map, certify_map_path)
logger.log(f'certify map saved to {certify_map_path}')
