import numpy as np
import os
import subprocess
import time
from unrealcv import Client
import sys
sys.path.append('../DataCollection')
from player import Player
from sampling_clean import get_position_sampler
from PIL import Image
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from io import BytesIO
import json
import pickle

# The environment runner
class UE4BinaryBase(object):
    """
    UE4BinaryBase is the base class for all platform-dependent classes, it is different from UE4Binary which serves as a factory to create a platform-dependent binary wrapper. User should use UE4Binary instead of UE4BinaryBase
    Binary is a python wrapper to control the start and stop of a UE4 binary.
    The wrapper provides simple features to start and stop the binary, mainly useful for automate the testing.
    Usage:
        bin = UE4Binary('/tmp/RealisticRendering/RealisticRendering')
        with bin:
            client.request('vget /camera/0/lit test.png')
    """

    def __init__(self, binary_path, port, cuda_id):
        self.binary_path = binary_path
        self.port = port
        self.cuda_id = cuda_id

    def __enter__(self):
        """Start the binary"""
        if os.path.isfile(self.binary_path) or os.path.isdir(self.binary_path):
            self.start()
        else:
            print("Binary %s can not be found" % self.binary_path)

    def __exit__(self, type, value, traceback):
        """Close the binary"""
        self.close()

    def __del__(self):
        self.close()
        print("Game binary file is closed")
        
class LinuxBinary(UE4BinaryBase):
    def start(self):
        null_file = open(os.devnull, "w")
        self.popen_obj = subprocess.Popen(
            [
                "CUDA_VISIBLE_DEVICES={} ".format(self.cuda_id)
                + self.binary_path
                + " {}".format(self.port)
            ],
            stdout=null_file,
            stderr=null_file,
            shell=True,
        )
        self.pid = self.popen_obj.pid
        time.sleep(6)

    def close(self):
        # Kill Linux process
        import signal
        self.popen_obj.terminate()
        # cmd = ["kill", str(self.pid)]
        # print('Kill process %s with command %s' % (self.pid, cmd))
        # try:
        #     subprocess.call(cmd)
        # except:
        #     pass

class NavigationEvaluator:
    def __init__(self, binary_path, port, cuda_id, 
                 video_storage_path='videos', num_test_episodes=10, max_steps=200, 
                 eval_file_path=None) -> None:
        # Start the environment, and maintain the connection to it. 
        # Save video and image storage path. 
        # Save number of episodes for evaluation
        os.makedirs(video_storage_path, exist_ok=True)
        self.video_storage_path = video_storage_path
        self.num_test_episodes = num_test_episodes
        self.max_steps = max_steps

        if eval_file_path is not None:
            with open(eval_file_path, 'rb') as file:
                eval_data = pickle.load(file)
            # self.difficulty_list = list(eval_data.keys())
            self.difficulty_list = ['medium', 'hard', 'easy']
            self.eval_data = {}
            for difficulty in self.difficulty_list:
                data = {}
                data['start_pos'] = eval_data[difficulty]['start_pos']
                data['goal_pos'] = eval_data[difficulty]['goal_pos']
                self.eval_data[difficulty] = data
        else:
            self.difficulty_list = ['uniform-sample']
            self.eval_data = None
        
        self.binary = LinuxBinary(binary_path, port, cuda_id)
        self.binary.start()
        self.client = Client(('localhost', port))
        self.client.connect()
        print(self.client.request("vget /unrealcv/version"))
        print(self.client.request("vget /unrealcv/status"))
        inGameTag = self.client.request("vget /cameras")
        print("Wait...", end="")
        while inGameTag.strip() != "PawnSensor":
            time.sleep(0.1)
            print(".", end="")
            inGameTag = self.client.request("vget /cameras")
        print('Connect successful')
        self.client.request('vset /action/game/level Highrise')
        time.sleep(1.0)
        
        self.player = Player(self.client, self.video_storage_path)
        self.sampler, self.map_data = get_position_sampler(self.client)
        # print(self.sampler.sample())
        self.test_positions = []
        self.test_goals = []
    
    def get_eval_dataset(self, n = 0):
        out_all = {} # dict of dicts
        for diff in ['easy', 'medium', 'hard']:
            rgbs = np.zeros((n, 480, 640, 3))
            start_pos, goal_pos = self.sampler.sample_diff_batch(n, diff)
            for i, _pos in enumerate(goal_pos):
                x, y, z, yaw = _pos
                pos = (x, y, z)
                rot = (0, yaw, 0)
                self.player.set_position_rotation(pos, rot)
                rgb = self.player.game_client.request(f"vget /camera/0/lit png")
                rgb = np.asarray(Image.open(BytesIO(rgb)))
                rgbs[i] = rgb[..., :3] # out put is rgba, drop a
            out_all[diff] = {
                'start_pos': start_pos,
                'goal_pos': goal_pos,
                'goal_rgb': rgbs
            }
        return out_all
        

    def evaluate(self, policy, evaluate_subdir_path):
        # evaluate and return all metrics. 
        os.makedirs(f'{self.video_storage_path}/{evaluate_subdir_path}', exist_ok=True)
        
        eval_data_path = f'eval_data/{evaluate_subdir_path}/'
        os.makedirs(eval_data_path, exist_ok=True)
        
        evaluate_data = {f'{difficulty}-raw': [] for difficulty in self.difficulty_list}

        for difficulty in self.difficulty_list:
            print(f'Evaluating difficulty {difficulty}')
            performance = {}
            if self.eval_data is not None:
                num_test_episode = len(self.eval_data[difficulty]['start_pos'])
            else:
                num_test_episode = self.num_test_episodes
            # num_test_episode = min(num_test_episode, 5)
            
            for iteration in range(num_test_episode):
                print(f'In iteration {iteration} for {difficulty} difficulty')
                evaluation_path = f'{self.video_storage_path}/{evaluate_subdir_path}/{difficulty}/iter-{iteration}'
                os.makedirs(evaluation_path, exist_ok=True)
                # Sample start point and goal
                if self.eval_data is not None:
                    start_point = self.eval_data[difficulty]['start_pos'][iteration]
                    goal_point = self.eval_data[difficulty]['goal_pos'][iteration]
                else:
                    _, *start_point = self.sampler.sample()
                    _, *goal_point = self.sampler.sample()
                
                initial_distance = np.linalg.norm(np.array(start_point[:3]) - np.array(goal_point[:3]))
                
                path_position, path_rotation = [], []
                
                # Record goal image
                goal_path = f'{evaluation_path}/goal.bmp'
                self.player.set_position_rotation(goal_point[:3], (0, goal_point[-1], 0))
                time.sleep(1.0)
                self.player.save_images_path(goal_path)
                goal_position, _ = self.player.get_position_rotation()
                goal_point[:3] = goal_position
                
                self.player.set_position_rotation(start_point[:3], (0, start_point[-1], 0))
                time.sleep(1.0)
                
                frames = []
                            
                for curr_step in range(self.max_steps):
                    self.player.sleep_until_action_finished()
                    state_path = f'{evaluation_path}/state-{curr_step}.bmp'
                    if os.path.exists(state_path):
                        os.remove(state_path)
                    self.player.save_images_path(state_path)
                    while not os.path.exists(state_path):
                        time.sleep(0.01)
                    time.sleep(0.01)
                    # Take action by policy
                    action = policy(state_path, goal_path)
                    if action == 8:
                        break
                    self.player.take_action(action)
                    current_position, current_rotation = self.player.get_position_rotation()
                    current_distance = np.linalg.norm(current_position - np.array(goal_point[:3]))
                    print(f'Current distance: {current_distance}, initial distance: {initial_distance}')
                    path_position.append(current_position)
                    path_rotation.append(current_rotation)
                    
                time.sleep(0.1)
                for step in range(self.max_steps):
                    state_path = f'{evaluation_path}/state-{step}.bmp'
                    if not os.path.exists(state_path):
                        break
                    image = Image.open(state_path).convert('RGB')
                    frames.append(np.asarray(image))
                
                clip = ImageSequenceClip(frames, with_mask=True, fps=5)
                clip.write_videofile(f'{evaluation_path}/video.mp4', logger=None)
                
                metrics = self.compute_metrics(start_point, goal_point, path_position, path_rotation)
                for key, val in metrics.items():
                    if key not in performance:
                        performance[key] = 0.0
                    performance[key] += val
                
                for key in metrics:
                    print(f'Current ({difficulty}): {key} = {metrics[key]}, Average: {key} = {performance[key] / (iteration + 1)}')
                
                evaluate_data[f'{difficulty}-raw'].append(
                    {
                        'metrics': metrics, 
                        'start_point': start_point, 
                        'goal_point': goal_point, 
                        'path_position': path_position, 
                        'path_rotation': path_rotation
                    }
                )
            
            for key in performance.keys():
                performance[key] = performance[key] / num_test_episode
            evaluate_data[difficulty] = performance

        with open(f'{eval_data_path}/data.pickle', 'wb') as file:
            pickle.dump(evaluate_data, file)
        
        return evaluate_data
    
    def compute_metrics(self, start_point, goal_point, path_position, path_rotation):
        # Success: 800 as threshold
        threshold = 800.0
        
        start_position, start_rotation = np.array(start_point[:3]), start_point[-1]
        goal_position, goal_rotation = np.array(goal_point[:3]), goal_point[-1]
        
        path_position = np.array(path_position)
        path_rotation = np.array(path_rotation)[:, 1]
        
        start_position *= np.array([1.0, 1.0, 2.0])
        goal_position *= np.array([1.0, 1.0, 2.0])
        path_position *= np.array([1.0, 1.0, 2.0])
        
        # Compute success
        path_distance = path_position - goal_position
        path_distance = np.linalg.norm(path_distance, axis=-1)
        min_distance = np.min(path_distance)
        
        success = float(min_distance < threshold)
        
        # Compute success weighted by path length normalized by euclidean distance, because we don't have true minimum distance. 
        
        distance_euclidean = np.linalg.norm(start_position - goal_position)
        
        if success < 1.0:
            success_weighted_path_length = 0.0
        else:
            first_success_index = np.nonzero(path_distance < threshold)[0][0]
            path_before_success = path_position[:first_success_index+1]
            distance_before_success = np.linalg.norm(path_before_success[1:] - path_before_success[:-1], axis=-1).sum()
            distance_to_goal = np.linalg.norm(path_before_success[-1] - goal_position)
            distance_path = distance_before_success + distance_to_goal
            success_weighted_path_length = distance_euclidean / distance_path
        
        # Compute maximum distance decrement rate
        distance_decrement_rate = (distance_euclidean - min_distance) / distance_euclidean
        
        metrics = {
            'success': success, 
            'success_weighted_path_length': success_weighted_path_length, 
            'distance_decrement_rate': distance_decrement_rate
        }
        return metrics
    
    def end(self):
        self.binary.close()

if __name__ == '__main__':
    game_configs = {
        'binary_path': '../binaries/LinuxNoEditor_08_01_3_starts/ShooterGame.sh', 
        'port': 4001, 
        'cuda_id': 0
    }
    
    policy = lambda *arg: 0
    
    navigation_evaluator = NavigationEvaluator(**game_configs)
    performance = navigation_evaluator.evaluate(policy, 'test')
    navigation_evaluator.end()
    print(f'performance: {performance}')