import sys
sys.path.append("./common")
sys.path.append("./auto_LiRPA")
from auto_LiRPA import BoundedModule, BoundedTensor
from auto_LiRPA.perturbations import PerturbationLpNorm
from argparser import argparser
import numpy as np
from read_config import load_config
from attacks import attack
from common.wrappers import make_atari, wrap_deepmind, wrap_pytorch, make_atari_cart
from models import QNetwork, model_setup
import torch.optim as optim
import torch
import torch.autograd as autograd
import time
import os
import argparse
import random
from datetime import datetime
from utils import Logger, get_acrobot_eps, test_plot 
from async_env import AsyncEnv
from train import get_logits_lower_bound
import pickle

# from model_smooth import CnnDQN
# from environment import atari_env
# from utils import read_config

UINTS=[np.uint8, np.uint16, np.uint32, np.uint64]
USE_CUDA = torch.cuda.is_available()
Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)


def main(args):
    config = load_config(args)
    prefix = config['env_id']
    training_config = config['training_config']
    test_config = config['test_config']
    attack_config = test_config["attack_config"]
    certify_global = attack_config['certify_global']
    n_runs = test_config['num_episodes']
    v_lo = config['v_lo']
    v_hi = config['v_hi']

    if config['name_suffix']:
        prefix += config['name_suffix']
    if config['path_prefix']:
        prefix = os.path.join(config['path_prefix'], prefix)
    if 'load_model_path' in test_config and os.path.isfile(test_config['load_model_path']):
        model_path = test_config['load_model_path']
        if not os.path.exists(prefix):
            os.makedirs(prefix)
        test_log = os.path.join(prefix, test_config['log_name'])
    else:
        if os.path.exists(prefix):
            test_log = os.path.join(prefix, test_config['log_name'])
        else:
            raise ValueError('Path {} not exists, please specify test model path.')
    smooth = test_config['smooth']
    m = test_config['m']
    eps = test_config['eps']

    if test_config['attack']:
        if certify_global:
            test_log += f'_attack-eps-certify-all'
        else:
            attack_eps = attack_config['params']['epsilon']
            test_log += f'_attack-eps-{attack_eps}'

    test_log += f'_nr-{n_runs}'
    if not test_config['attack']:
        test_log += '_clean'
    else:
        method = attack_config.get('method', 'pgd')
        test_log += f'_{method}'

        attack_freq = attack_config['params']['freq']
        if attack_freq < 1:
            test_log += f'_freq-{attack_freq}'

    if smooth:
        test_log += f"_m-{m}_eps-{eps}"

    test_log += test_config['process_mode']

    if 'nat' in model_path:
        test_log += '_nat'
    elif 'pgd' in model_path:
        test_log += '_pgd'
    elif 'cov' in model_path or 'convex' in model_path:
        test_log += '_cov'
    elif 'aug' in model_path:
        p1 = model_path.rfind('_')
        p2 = model_path.rfind('.')
        aug_rad = model_path[p1+1:p2]
        test_log += f'_aug-{aug_rad}'
    elif 'adv' in model_path:
        test_log += '_adv'
    elif 'frame' in model_path:
        p1 = model_path.rfind('_')
        p2 = model_path.rfind('.')
        frame_cnt = model_path[p1+1:p2]
        test_log += f'_frame-{frame_cnt}'
#     elif 'robust' in model_path:
#         test_log += '_radialrl'
    else:
        raise NotImplementedError(f'model_path = {model_path} type unrecognized!')

    logger = Logger(open(test_log, "w"))
    logger.log('Command line:', " ".join(sys.argv[:]))
    logger.log(args)
    logger.log(config)
    certify = test_config.get('certify', False)
    env_params = training_config['env_params']
    env_params['clip_rewards'] = False
    env_params['episode_life'] = False
    env_id = config['env_id']
    
#     if test_config['GS'] == True:
#         if 'Pong' in model_path:
#             if 'nat' in model_path or 'natural' in model_path:
#                 LB_log = 'LB_Pong_nat.log'
#             elif 'pgd' in model_path:
#                 LB_log = 'LB_Pong_pgd.log'
#             elif 'cov' in model_path or 'convex' in model_path:
#                 LB_log = 'LB_Pong_cov.log'
#             elif 'aug' in model_path:
#                 LB_log = 'LB_Pong_aug.log'
#             elif 'adv' in model_path:
#                 LB_log = 'LB_Pong_adv.log'
#         elif 'Freeway' in model_path:
#             if 'nat' in model_path or 'natural' in model_path:
#                 LB_log = 'LB_Freeway_nat.log'
#             elif 'pgd' in model_path:
#                 LB_log = 'LB_Freeway_pgd.log'
#             elif 'cov' in model_path or 'convex' in model_path:
#                 LB_log = 'LB_Freeway_cov.log'
#             elif 'aug' in model_path:
#                 LB_log = 'LB_Freeway_aug.log'
#             elif 'adv' in model_path:
#                 LB_log = 'LB_Freeway_adv.log'
#         logger_LB = Logger(open(LB_log, "a"))
        
#     if 'robust' in model_path:
#         setup_json = read_config(args.env_config)
#         env_conf = setup_json["Default"]
#         for i in setup_json.keys():
#             if i in args.env:
#                 env_conf = setup_json[i]
#         env = atari_env(env_id, env_conf, args)
#     else: 
    if "NoFrameskip" not in env_id:
        env = make_atari_cart(env_id)
    else:
        env = make_atari(env_id)
        env = wrap_deepmind(env, **env_params)
        env = wrap_pytorch(env)

    state = env.reset()
    dtype = state.dtype
    logger.log("env_shape: {}, num of actions: {}".format(env.observation_space.shape, env.action_space.n))

    model_width = training_config['model_width']
    robust_model = certify
    dueling = training_config.get('dueling', True)
    
#     if 'robust' in model_path:
#         model = CnnDQN(env.observation_space.shape[0], env.action_space, env.observation_space.shape,
#                         logger=logger, v_lo=v_lo, v_hi=v_hi,
#                         sigma=eps, m=m).cuda()

    model = model_setup(test_config, attack_config, None, env_id, env, robust_model, logger, USE_CUDA, dueling, model_width, smooth, m, eps, v_lo, v_hi)

    if 'load_model_path' in test_config and os.path.isfile(test_config['load_model_path']):
        model_path = test_config['load_model_path']
    else:
        logger.log("choosing the best model from " + prefix)
        all_idx = [int(f[6:-4]) for f in os.listdir(prefix) if os.path.isfile(os.path.join(prefix, f)) and os.path.splitext(f)[1]=='.pth' and 'best' not in f]
        all_best_idx = [int(f[11:-4]) for f in os.listdir(prefix) if os.path.isfile(os.path.join(prefix, f)) and os.path.splitext(f)[1]=='.pth' and 'best' in f]
        if all_best_idx:
            model_frame_idx = max(all_best_idx)
            model_name = 'best_frame_{}.pth'.format(model_frame_idx)
        else:
            model_frame_idx = max(all_idx)
            model_name = 'frame_{}.pth'.format(model_frame_idx)
        model_path = os.path.join(prefix, model_name)

    logger.log('model loaded from ' + model_path)

#     if 'pgd' in model_path:
#         model.load_state_dict(torch.load(model_path))
#     else:
#     if 'robust' in model_path:
#         model.load_state_dict(torch.load(model_path))

    model.features.load_state_dict(torch.load(model_path))
    num_episodes = test_config['num_episodes']
    max_frames_per_episode = test_config['max_frames_per_episode']

    all_rewards = []
    episode_reward = 0

#     seed = 0
    seed = random.randint(0, sys.maxsize)
    logger.log('reseting env with seed', seed)
    env.seed(seed)
    state = env.reset()
    start_time = time.time()
    if training_config['use_async_env']:
        # Create an environment in a separate process, run asychronously
        async_env = AsyncEnv(env_id, result_path=test_log, draw=training_config['show_game'], record=training_config['record_game'], save_frames=test_config['save_frames'], env_params=env_params, seed=args.seed)

    episode_idx = 1
    this_episode_frame = 1

    if certify:
        certified = 0

    if dtype in UINTS:
        state_max = 1.0
        state_min = 0.0
    else:
        state_max = float('inf')
        state_min = float('-inf')
    
    episodes_reward = []
    this_episode_rewards = []
    
    for frame_idx in range(1, num_episodes * max_frames_per_episode + 1):

        state_tensor = torch.from_numpy(np.ascontiguousarray(state)).unsqueeze(0).cuda().to(torch.float32)
        # Normalize input pixel to 0-1
        if dtype in UINTS:
            state_tensor /= 255

        if test_config['attack']:
            attack_config['params']['robust_model'] = certify
            attack_frame = np.random.rand() < attack_freq

#         if 'robust' in model_path:
#             output = model.smooth_forward(state_tensor)
#             action = torch.argmax(output, dim=1)
        action = model.act(state_tensor, perturb=0 if certify_global else -1)[0]

        if certify:
            max_logit = torch.tensor([action])
            c = torch.eye(model.num_actions).type_as(state_tensor)[max_logit].unsqueeze(1) - torch.eye(model.num_actions).type_as(state_tensor).unsqueeze(0)
            I = (~(max_logit.data.unsqueeze(1) == torch.arange(model.num_actions).type_as(max_logit.data).unsqueeze(0)))
            c = (c[I].view(state_tensor.size(0), model.num_actions-1, model.num_actions))
            logits_diff_lb = get_logits_lower_bound(model, state_tensor, state_ub, state_lb, eps_v, c, beta)
            if torch.min(logits_diff_lb[0], 0)[0].data.cpu().numpy() > 0:
                certified += 1


        if training_config['use_async_env']:
            async_env.async_step(action)
            next_state, reward, done, _ = async_env.wait_step()
        else:
            next_state, reward, done, _ = env.step(action)

        state = next_state
        episode_reward += reward

        if frame_idx % test_config['print_frame']==0:
            logger.log('\ntotal frame {}/{}, episode {}/{}, episode frame {}/{}, latest episode reward: {:.6g}, avg 10 episode reward: {:.6g}'.format(frame_idx, num_episodes*max_frames_per_episode, episode_idx, num_episodes, this_episode_frame, max_frames_per_episode,
                all_rewards[-1] if all_rewards else np.nan,
                np.average(all_rewards[:-11:-1]) if all_rewards else np.nan))
            if certify:
                logger.log('certified action: {}, certified action ratio: {:.6g}'.format(certified, certified*1.0/frame_idx))
        
        if this_episode_frame % 100 == 0 and test_config['GS'] == True:
            logger.log('smooth eps: {}, episode idx: {}, num of frames: {}, emp reward: {}'.format(eps, episode_idx, this_episode_frame, episode_reward))
            
        if this_episode_frame % 100 == 0:
            this_episode_rewards.append(episode_reward)
            
        if this_episode_frame == max_frames_per_episode:
            logger.log('maximum number of frames reached in this episode, reset environment!')
            done = True
            if training_config['use_async_env']:
                async_env.epi_reward = 0

        if done:
            logger.log('reseting env')
            if training_config['use_async_env']:
                state = async_env.reset()
            else:
                state = env.reset()

            all_rewards.append(episode_reward)
            episode_reward = 0

#             if attack_config['certify_eps_list']:
#                 # store
#                 certify_file = test_log+f'_certify-round-{episode_idx}.pt'
#                 torch.save({
#                     'top_q': model.top_q,
#                     'R_list': model.R_list,
#                 }, certify_file)
#                 logger.log(f'certified result saved to {certify_file}')

#             action_seq_file = test_log+f'_action-seq-round-{episode_idx}.pt'
#             q_range_file = test_log+f'_q-range-round-{episode_idx}.pt'
#             torch.save(model.action_seq, action_seq_file)
#             torch.save(model.q_range, q_range_file)
#             logger.log(f'action sequence saved to {action_seq_file}')
#             logger.log(f'q range saved to {q_range_file}')

            # reset
            model.init_record_list()

            this_episode_frame = 1
            episode_idx += 1
            
            episodes_reward.append(this_episode_rewards)
            this_episode_rewards = []

            if episode_idx > num_episodes:
                break
        else:
            this_episode_frame += 1

    logger.log('\navg reward' + (' and avg certify:' if certify else ':'))
    logger.log(np.mean(all_rewards),'+-',np.std(all_rewards))
    if certify:
        logger.log(certified*1.0/frame_idx)
    
    if test_config['Emp'] == True:
        if 'Pong' in model_path:
            Emp_log = 'Emp_Pong.log'
        elif 'Freeway' in model_path:
            Emp_log = 'Emp_Freeway.log'

        logger_Emp = Logger(open(Emp_log, "a"))
        if 'nat' in model_path or 'natural' in model_path:
            logger_Emp.log('{} frames, nat model, attack: {}, attack eps: {}, reward: {}'.format(test_config['max_frames_per_episode'], attack_config['method'], attack_config['params']['epsilon'], np.mean(all_rewards)))
        elif 'pgd' in model_path:
            logger_Emp.log('{} frames, pgd model, attack: {}, attack eps: {}, reward: {}'.format(test_config['max_frames_per_episode'], attack_config['method'], attack_config['params']['epsilon'], np.mean(all_rewards)))
        elif 'cov' in model_path or 'convex' in model_path:
            logger_Emp.log('{} frames, cov model, attack: {}, attack eps: {}, reward: {}'.format(test_config['max_frames_per_episode'], attack_config['method'], attack_config['params']['epsilon'], np.mean(all_rewards)))
        elif 'aug' in model_path:
            logger_Emp.log('{} frames, aug model, attack: {}, attack eps: {}, reward: {}'.format(test_config['max_frames_per_episode'], attack_config['method'], attack_config['params']['epsilon'], np.mean(all_rewards)))
        elif 'adv' in model_path:
            logger_Emp.log('{} frames, adv model, attack: {}, attack eps: {}, reward: {}'.format(test_config['max_frames_per_episode'], attack_config['method'], attack_config['params']['epsilon'], np.mean(all_rewards)))
    
#     GS = test_config['GS']
#     alpha = test_config['alpha']
#     if GS:
#         # smoothed actual reward
#         AR = np.mean(all_rewards)
#         # high confidence lower bound
#         hc_LB = AR - (v_hi-v_lo) * np.sqrt(np.log(1/alpha)/(2*num_episodes))
#         # certified lower bound under perturbation
#         c_eps = test_config['c_eps']
#         c_LB = hc_LB - c_eps * np.sqrt(max_frames_per_episode) * (v_hi-v_lo) / eps * np.sqrt(2/np.pi)

#         if 'Pong' in model_path:
#             LB_log = 'LB_Pong.log'
#         elif 'Freeway' in model_path:
#             LB_log = 'LB_Freeway.log'

#         logger_LB = Logger(open(LB_log, "a"))
#         if 'nat' in model_path or 'natural' in model_path:
#             logger_LB.log('nat model, smooth eps: {}, attack eps: {}, lower bound: {}'.format(eps, c_eps, c_LB))
#         elif 'pgd' in model_path:
#             logger_LB.log('pgd model, smooth eps: {}, attack eps: {}, lower bound: {}'.format(eps, c_eps, c_LB))
#         elif 'cov' in model_path or 'convex' in model_path:
#             logger_LB.log('cov model, smooth eps: {}, attack eps: {}, lower bound: {}'.format(eps, c_eps, c_LB))
#         elif 'aug' in model_path:
#             logger_LB.log('aug model, smooth eps: {}, attack eps: {}, lower bound: {}'.format(eps, c_eps, c_LB))
#         elif 'adv' in model_path:
#             logger_LB.log('adv model, smooth eps: {}, attack eps: {}, lower bound: {}'.format(eps, c_eps, c_LB))
    
    if 'nat' in model_path or 'natural' in model_path:
        open_file = open('{}_nat_GS_{}.txt'.format(config['env_id'], eps), "wb")
    elif 'pgd' in model_path:
        open_file = open('{}_pgd_GS_{}.txt'.format(config['env_id'], eps), "wb")
    elif 'cov' in model_path or 'convex' in model_path:
        open_file = open('{}_cov_GS_{}.txt'.format(config['env_id'], eps), "wb")
    elif 'adv' in model_path:
        open_file = open('{}_adv_GS_{}.txt'.format(config['env_id'], eps), "wb")
    elif 'aug' in model_path:
        open_file = open('{}_aug_GS_{}.txt'.format(config['env_id'], eps), "wb")
#     elif 'robust' in model_path:
#         open_file = open('{}_radialrl_GS_{}.txt'.format(config['env_id'], eps), "wb")
    pickle.dump(episodes_reward, open_file)
    open_file.close()



if __name__ == "__main__":
    args=  argparser()
    main(args)
