import sys
import os
import argparse
import numpy as np
import torch
from evaluate_dataset import load
from model.policy.VGM.Policy import VGMPolicy
sys.path.append('../GameAgent')
sys.path.append('../GameAgent/DataCollection')
from IGNet.navigation_evaluator import NavigationEvaluator
from PIL import Image
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from env_utils.env_wrapper.env_graph_wrapper import GraphWrapper
from env_utils.env_wrapper.graph import Graph
import joblib
import glob

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--num-episodes", type=int, default=1400)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--gpu", type=str, default="0")
parser.add_argument("--version-name", type=str, required=True)
parser.add_argument("--stop", action='store_true', default=False)
parser.add_argument("--new", action='store_true', default=False)
parser.add_argument("--diff", choices=['random', 'easy', 'medium', 'hard'], default='hard')
parser.add_argument("--split", choices=['val', 'train', 'min_val'], default='val')
parser.add_argument('--eval-ckpt', type=str, required=True)
parser.add_argument('--render', action='store_true', default=False)
parser.add_argument('--record', choices=['0','1','2','3'], default='0') # 0: no record 1: env.render 2: pose + action numerical traj 3: features
parser.add_argument('--th', type=str, default='0.75') # s_th
parser.add_argument('--record-dir', type=str, default='data/video_dir')
args = parser.parse_args()
args.record = int(args.record)
args.th = float(args.th)

if args.gpu != 'cpu':
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
os.environ['GLOG_minloglevel'] = "3"
os.environ['MAGNUM_LOG'] = "quiet"
os.environ['HABITAT_SIM_LOG'] = "quiet"
torch.backends.cudnn.enable = True
torch.set_num_threads(5)
from configs.default import get_config, CN
import time

class VGMShooterPolicy(VGMPolicy):

    def __init__(self, observation_space, action_space, goal_sensor_uuid="pointgoal_with_gps_compass", hidden_size=512, num_recurrent_layers=2, rnn_type="LSTM", resnet_baseplanes=32, backbone="resnet50", normalize_visual_inputs=True, cfg=None, new=False):
        super().__init__(observation_space, action_space, goal_sensor_uuid, hidden_size, num_recurrent_layers, rnn_type, resnet_baseplanes, backbone, normalize_visual_inputs, cfg, new)
        self.B = 1
        self.prev_h = torch.zeros(
            self.net.num_recurrent_layers, self.B,
            self.net._hidden_size
        ).cuda()
        self.prev_a = torch.zeros([self.B]).cuda()
        # self.time_t = 0
        # self.check_keys = ['panoramic_rgb', 'panoramic_depth', 'target_goal']

    def __call__(self, obs: dict, deterministic=True):
        done = False
        masks = torch.ones(self.B).unsqueeze(1).cuda() * (1 - done)
        values, actions, actions_log_probs, hidden_states, actions_logits, preds, act_features \
            = self.act(obs, self.prev_h, self.prev_a, masks, deterministic)
        self.prev_h = hidden_states
        self.prev_a = actions
        return actions

class GraphShooterEnv(GraphWrapper):

    def __init__(self, player, config, sampler, evaluation_path):
        super().__init__(player, config)
        self.env = self.envs = None
        self.player = player
        self.sampler = sampler
        self.curr_step = 0
        self.evaluation_path = evaluation_path
        self.goal_path = f'{evaluation_path}/goal.bmp'
        self.B = 1 # 1 env at a time now
        self.device = 'cuda'
        self.graph = Graph(config, self.B, self.device)
        self.img_size = (64, 64)
        self.feature_dim = 512
        # only feature_dim is needed for loading visual encoder
        self.visual_encoder = self.load_visual_encoder(None, None, self.feature_dim).to(self.device)
        self.reset_all_memory()

    def step(self, action):
        self.curr_step += 1
        self.player.take_action(action)
        self.player.sleep_until_action_finished()
        state_path = f'{self.evaluation_path}/state-{self.curr_step}.bmp'
        if os.path.exists(state_path):
            os.remove(state_path)
        self.player.save_images_path(state_path)
        time.sleep(0.1)
        rgb = np.asarray(Image.open(state_path).resize(self.img_size).convert('RGB'))[None, ...]
        tar = np.asarray(Image.open(self.goal_path).resize(self.img_size).convert('RGB'))[None, ...]
        tar = np.concatenate((tar, np.zeros((1, 64, 64, 1))), axis=3)
        pos, rot = self.player.get_position_rotation()
        obs_batch = {
            'panoramic_rgb': torch.tensor(rgb).to(device=self.device).float(),
            'panoramic_depth': torch.tensor(np.zeros((1, 64, 64, 1))).to(device=self.device).float(),
            'target_goal': torch.tensor(tar).to(device=self.device).float(),
            'position': np.array([pos]),
            'step': torch.tensor([1.]).to(device=self.device).float()
        }
        done_list = [False]
        curr_vis_embedding = self.embed_obs(obs_batch)
        self.localize(curr_vis_embedding, obs_batch['position'], obs_batch['step'], done_list)
        global_memory_dict = self.get_global_memory()
        obs_batch = self.update_obs(obs_batch, global_memory_dict)
        self.player.sleep_until_action_finished()
        return obs_batch
    
    def reset(self, evaluation_path = None):
        if evaluation_path is not None:
            self.evaluation_path = evaluation_path
            self.goal_path = f'{evaluation_path}/goal.bmp'
        self.curr_step = 0
        # Sample start point and goal
        _, *start_point = self.sampler.sample()
        _, *goal_point = self.sampler.sample()
        # Record goal image
        self.goal_path = f'{self.evaluation_path}/goal.bmp'
        self.player.set_position_rotation(goal_point[:3], (0, goal_point[-1], 0))
        time.sleep(1.0)
        self.player.save_images_path(self.goal_path)
        # Go to respawn position and record image
        self.player.set_position_rotation(start_point[:3], (0, start_point[-1], 0))
        time.sleep(1.0)
        state_path = f'{self.evaluation_path}/state-{self.curr_step}.bmp'
        if os.path.exists(state_path):
            os.remove(state_path)
        self.player.save_images_path(state_path)
        time.sleep(1.0)
        rgb = np.asarray(Image.open(state_path).resize(self.img_size).convert('RGB'))[None, ...]
        tar = np.asarray(Image.open(self.goal_path).resize(self.img_size).convert('RGB'))[None, ...]
        tar = np.concatenate((tar, np.zeros((1, 64, 64, 1))), axis=3)
        pos, rot = self.player.get_position_rotation()
        obs_batch = {
            'panoramic_rgb': torch.tensor(rgb).to(device=self.device).float(),
            'panoramic_depth': torch.tensor(np.zeros((1, 64, 64, 1))).to(device=self.device).float(),
            'target_goal': torch.tensor(tar).to(device=self.device).float(),
            'position': np.array([pos]),
            'step': torch.tensor([1.]).to(device=self.device).float()
        }
        done_list = [True]
        curr_vis_embedding = self.embed_obs(obs_batch)
        self.localize(curr_vis_embedding, obs_batch['position'], obs_batch['step'], done_list)
        global_memory_dict = self.get_global_memory()
        obs_batch = self.update_obs(obs_batch, global_memory_dict)
        self.player.sleep_until_action_finished()
        return obs_batch



class VGMShooterEvaluator(NavigationEvaluator):

    def __init__(
        self, 
        binary_path, port, cuda_id, video_storage_path='videos', num_test_episodes=10, max_steps=100
    ):
        super().__init__(binary_path, port, cuda_id, video_storage_path='videos', num_test_episodes=10, max_steps=100)
        self.player.action_list = ['H', 'W', 'S', 'A', 'D', 'Left', 'Right', 'Spacebar']
    
    def evaluate(
        self,
        policy: VGMShooterPolicy,
        evaluate_subdir_path: str
    ):
        # evaluate and return all metrics. 
        os.makedirs(f'{self.video_storage_path}/{evaluate_subdir_path}', exist_ok=True)
        success_time = 0
        
        for iteration in range(self.num_test_episodes):
            print(f'In iteration {iteration}')
            evaluation_path = f'/home/baiting/Visual-Graph-Memory/{self.video_storage_path}/{evaluate_subdir_path}/iter-{iteration}'
            os.makedirs(evaluation_path, exist_ok=True)
            env = GraphShooterEnv(
                player=self.player,
                config=eval_config(args),
                sampler=self.sampler,
                evaluation_path=evaluation_path
            )
            obs = env.reset()
            for curr_step in range(100):
                action = policy(obs)
                action = action.detach().cpu().numpy()[0]
                print("\n\nproposed action: ", action)
                action = int(action)
                obs = env.step(action)
                if action == 0:
                    break
                
            time.sleep(0.1)
            frames = []
            for step in range(100):
                state_path = f'{evaluation_path}/state-{step}.bmp'
                if not os.path.exists(state_path):
                    break
                image = Image.open(state_path).convert('RGB')
                frames.append(np.asarray(image))
            
            clip = ImageSequenceClip(frames, with_mask=True, fps=5)
            clip.write_videofile(f'{evaluation_path}/video.mp4', logger=None)
        
        performance = {
            'success_rate': success_time / self.num_test_episodes
        }
        return performance

def eval_config(args):
    config = get_config(args.config)
    config.defrost()
    config.use_depth = config.TASK_CONFIG.use_depth = False # was True, False for ShooterGame
    print(args.config)
    if args.stop:
        config.ACTION_DIM = 8
        config.TASK_CONFIG.TASK.POSSIBLE_ACTIONS= ["STOP", "MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"]
    else:
        config.ACTION_DIM = 7
        config.TASK_CONFIG.TASK.POSSIBLE_ACTIONS = ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"]
        config.TASK_CONFIG.TASK.SUCCESS.TYPE = "Success_woSTOP"
    config.freeze()
    return config

def evaluate(eval_config, ckpt, graph_file_name = None):
    if args.record > 0:
        if not os.path.exists(os.path.join(args.record_dir, args.version_name)):
            os.mkdir(os.path.join(args.record_dir, args.version_name))
        VIDEO_DIR = os.path.join(args.record_dir, args.version_name + '_video_' + ckpt.split('/')[-1] + '_' +str(time.ctime()))
        if not os.path.exists(VIDEO_DIR): os.mkdir(VIDEO_DIR)
        if args.record > 1:
            OTHER_DIR = os.path.join(args.record_dir, args.version_name + '_other_' + ckpt.split('/')[-1] + '_' + str(time.ctime()))
            if not os.path.exists(OTHER_DIR): os.mkdir(OTHER_DIR)
    state_dict, ckpt_config = load(ckpt)

    if ckpt_config is not None:
        task_config = eval_config.TASK_CONFIG
        ckpt_config.defrost()
        task_config.defrost()
        ckpt_config.TASK_CONFIG = task_config
        ckpt_config.runner = eval_config.runner
        ckpt_config.AGENT_TASK = 'search'
        ckpt_config.DIFFICULTY = eval_config.DIFFICULTY
        ckpt_config.ACTION_DIM = eval_config.ACTION_DIM
        ckpt_config.memory = eval_config.memory
        ckpt_config.scene_data = eval_config.scene_data
        ckpt_config.WRAPPER = eval_config.WRAPPER
        ckpt_config.REWARD_METHOD = eval_config.REWARD_METHOD
        ckpt_config.ENV_NAME = eval_config.ENV_NAME
        for k, v in eval_config.items():
            if k not in ckpt_config:
                ckpt_config.update({k:v})
            if isinstance(v, CN):
                for kk, vv in v.items():
                    if kk not in ckpt_config[k]:
                        ckpt_config[k].update({kk: vv})
        ckpt_config.freeze()
        eval_config = ckpt_config
    print(eval_config.memory)
    eval_config.defrost()
    eval_config.th = args.th

    eval_config.record = False # record from this side , not in env
    eval_config.render_map = args.record > 0 or args.render or 'hand' in args.config
    eval_config.noisy_actuation = True
    eval_config.freeze()
    obs_width = 64
    depth_high = 0

    from gym.spaces.dict import Dict as SpaceDict
    from gym.spaces.box import Box
    from gym.spaces.discrete import Discrete
    observation_space = SpaceDict({
        'panoramic_rgb': Box(low=0, high=256, shape=(64, obs_width, 3), dtype=np.float32),
        'panoramic_depth': Box(low=0, high=depth_high, shape=(64, obs_width, 1), dtype=np.float32),
        'target_goal': Box(low=0, high=256, shape=(64, obs_width, 3), dtype=np.float32),
        'step': Box(low=0, high=500, shape=(1,), dtype=np.float32),
        'prev_act': Box(low=0, high=7, shape=(1,), dtype=np.int32),
        'gt_action': Box(low=0, high=7, shape=(1,), dtype=np.int32)
    })
    action_space = Discrete(8)

    agent = VGMShooterPolicy(
        observation_space=observation_space,
        action_space=action_space,
        hidden_size=eval_config.features.hidden_size,
        rnn_type=eval_config.features.rnn_type,
        num_recurrent_layers=eval_config.features.num_recurrent_layers,
        backbone=eval_config.features.backbone,
        goal_sensor_uuid=eval_config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID,
        normalize_visual_inputs=True,
        cfg=eval_config,
        new=args.new
    ).eval()
    if torch.cuda.device_count() > 0:
        agent.cuda()
    agent.load_state_dict(state_dict)
    game_configs = {
        'binary_path': '/home/baiting/GameAgent/ShooterAgent/binaries/LinuxNoEditor/ShooterGame.sh', 
        'port': 4001, 
        'cuda_id': 0
    }
    evaluator = VGMShooterEvaluator(
        **game_configs
    )
    evaluator.evaluate(agent, 'ShooterEval3')
    print("success")
    evaluator.game_binary.close()
    os._exit(0)

if __name__=='__main__':
    cfg = eval_config(args)
    if os.path.isdir(args.eval_ckpt):
        print('eval_ckpt ', args.eval_ckpt, ' is directory')
        ckpts = [os.path.join(args.eval_ckpt,x) for x in sorted(os.listdir(args.eval_ckpt))]
        ckpts.reverse()
    elif os.path.exists(args.eval_ckpt):
        ckpts = args.eval_ckpt.split(",")
    else:
        ckpts = [x for x in sorted(glob.glob(args.eval_ckpt+'*'))]
        ckpts.reverse()
    print('evaluate total {} ckpts'.format(len(ckpts)))
    for ckpt in ckpts:
        if 'ipynb' in ckpt or 'pt' not in ckpt:
            continue
        print('============================', ckpt.split('/')[-1], '==================')
        print(ckpt, type(ckpt))
        graph_file_name = f"{ckpt[:-3]}_graph.pkl"
        if not os.path.isfile(graph_file_name):
            graph_file_name = None
        result = evaluate(cfg, ckpt, graph_file_name)
        # joblib.dump(data, eval_data_name)
