import os
from queue import Queue
import sys
import json
import numpy as np
from PIL import Image
from datetime import datetime
from env.thor_env import ThorEnv
from eval.eval import Eval
import torch
import networks.CLIP.clip.clip as clip
import pickle
import pprint
from models.nn.resnet import Resnet
import random
import time
import copy
from collections import deque
import threading
from wrapt_timeout_decorator import *
from queue import Queue


class EvalSubgoals():
    '''
    evaluate subgoals by teacher-forching expert demonstrations
    '''
    ALL_SUBGOALS = ['GotoLocation', 'PickupObject', 'PutObject', 'CoolObject', 'HeatObject', 'CleanObject', 'SliceObject', 'ToggleObject']
    STOP_TOKEN = "<<stop>>"
    SEQ_TOKEN = "<<seg>>"
    TERMINAL_TOKENS = [STOP_TOKEN, SEQ_TOKEN]
    
    def __init__(self, config, device):
        self.clip, _ = clip.load("RN50", device=device)
        self.goal_tensor_record = {}
        
        self.config = config

        # load splits
        with open(self.config.splits) as f:
            self.splits = json.load(f)
            pprint.pprint({k: len(v) for k, v in self.splits.items()})

        self.device = device
        
        # load resnet
        self.config.visual_model = 'resnet18'
        self.resnet = Resnet(self.config, eval=True, share_memory=True, use_conv_feat=True, device=device)

        # gpu
        # if self.config.gpu:
        #     self.model = self.model.to(torch.device('cuda'))

        # set random seed for shuffling
        # random.seed(int(time.time()))
        self.files = self.splits[self.config.eval_split]
        
        
    def reset_task_queue(self):
        self.task_queue = []

        if self.config.fast_epoch:
            self.files = self.files[:16]

        if self.config.shuffle:
            random.shuffle(self.files)
        
        if self.config.num_eval_file < 0:
            for traj in self.files:
                self.task_queue.append(traj)
        else:
            for traj in self.files[:self.config.num_eval_file]:
                self.task_queue.append(traj)
                

    def run(self, model):
        '''
        evaluation loop
        '''
        # start THOR
        self.model = model
        self.reset_task_queue()
        
        while 1:
            env = ThorEnv()
            if env.flag:
                break
            else:
                env.stop()
            
        self.model.reset()

        # make subgoals list
        subgoals_to_evaluate = self.ALL_SUBGOALS if self.config.subgoals.lower() == "all" else self.config.subgoals.split(',')
        subgoals_to_evaluate = [sg for sg in subgoals_to_evaluate if sg in self.ALL_SUBGOALS]
        print ("Subgoals to evaluate: %s" % str(subgoals_to_evaluate))

        # create empty stats per subgoal
        self.successes = {}
        self.failures = {}
        self.results = {}
        for sg in subgoals_to_evaluate:
            self.successes[sg] = list()
            self.failures[sg] = list()
            self.results[sg] = {
                'sr': 0.,
                'successes': 0.,
                'evals': 0.,
                'sr_plw': 0.
            }
        
        subtask_ct = 0
        
        while True:
            if len(self.task_queue) == 0:
                break

            task = self.task_queue.pop(0)

            try:
                traj = model.load_task_json(task)
                
                r_idx = task['repeat_idx']
                subgoal_idxs = [sg['high_idx'] for sg in traj['plan']['high_pddl'] if sg['discrete_action']['action'] in subgoals_to_evaluate]
                
                for eval_idx in subgoal_idxs:
                    print("No. of trajectories left: %d" % (len(self.task_queue)))
                    
                    failure_ct = 0
                    while failure_ct < 5:
                        flag = self.evaluate(env, eval_idx, r_idx, traj)
                        if flag:
                            break
                        else:
                            env.response_queue = Queue(maxsize=1)
                            env.stop()
                            env = ThorEnv()
                            failure_ct += 1
                    if not flag:
                        print("Too many failures...")
                            
                #     subtask_ct += 1
                # if subtask_ct >= 5:
                #     break
            except Exception as e:
                import traceback
                traceback.print_exc()
                print("Error: " + repr(e))
        
        sr, sr_plw = self.save_results()
        env.response_queue = Queue(maxsize=1)
        env.stop()
        return sr, sr_plw
    
    def get_goal_tensor(self, goal_ids):
        goal_ids_str = "_".join([str(int(i)) for i in goal_ids[0]])
        if goal_ids_str not in self.goal_tensor_record:
            goal_tensor = self.clip.encode_text(goal_ids.to(torch.int).to(self.device)).to(torch.float32)
            self.goal_tensor_record[goal_ids_str] = goal_tensor.detach()
        
        return self.goal_tensor_record[goal_ids_str]
    
    @timeout(30)
    def va_interact(self, env, action_words, interact_mask):
        return env.va_interact(action_words, interact_mask=interact_mask, smooth_nav=self.config.smooth_nav, debug=self.config.debug)
    
    def evaluate(self, env, eval_idx, r_idx, traj_data):
        # reset model
        # model.reset()
        reward_type = 'dense'
        flag = self.setup_scene(env, traj_data, r_idx, self.config, reward_type=reward_type)
        if not flag:
            return False
        
        self.model.reset()
        
        # expert demonstration to reach eval_idx-1 
        expert_init_actions = [a['discrete_action'] for a in traj_data['plan']['low_actions'] if a['high_idx'] < eval_idx]

        # subgoal info
        subgoal_action = traj_data['plan']['high_pddl'][eval_idx]['discrete_action']['action']
        # subgoal_instr = traj_data['turk_annotations']['anns'][r_idx]['high_descs'][eval_idx]
        subgoal_instr = traj_data['num']['lang_instr'][eval_idx]
        
        subgoal_ids = clip.tokenize(traj_data['turk_annotations']['anns'][r_idx]['high_descs'][eval_idx])
        
        subgoal_tensor = self.get_goal_tensor(subgoal_ids)
        
        subgoal_instr_ = copy.deepcopy(subgoal_instr)

        # print subgoal info
        print("Evaluating: %s\nSubgoal %s (%d)\nInstr: %s" % (traj_data['root'], subgoal_action, eval_idx, subgoal_instr))

        # extract language features
        feat = self.model.featurize([copy.deepcopy(traj_data)], load_mask=False)

        # previous action for teacher-forcing during expert execution (None is used for initialization)
        prev_action = [15]
        prev_action_words = []

        done, subgoal_success = False, False
        fails = 0
        t = 0
        in_t = 0
        reward = 0
        
        h_0 = torch.zeros(2, 1, self.config.feature_size).to(self.device)
        c_0 = torch.zeros(2, 1, self.config.feature_size).to(self.device)
        
        ht = (h_0, c_0)
        out = None
        
        while not done:
            # break if max_steps reached
            if t >= self.config.max_steps + len(expert_init_actions):
                break
            
            if in_t>=self.config.max_steps_taken:
                break

            # extract visual feats
            curr_image = Image.fromarray(np.uint8(env.last_event.frame))
            
            # curr_image.save('test_current.jpg', dpi=(300, 300))
            
            feat['frames'] = self.resnet.featurize([curr_image], batch=1).unsqueeze(0)
            
            # expert teacher-forcing upto subgoal
            if t < len(expert_init_actions):
                # get expert action
                action = expert_init_actions[t]
                subgoal_completed = traj_data['plan']['low_actions'][t+1]['high_idx'] != traj_data['plan']['low_actions'][t]['high_idx']
                compressed_mask = action['args']['mask'] if 'mask' in action['args'] else None
                mask = env.decompress_mask(compressed_mask) if compressed_mask is not None else None

                # forward model
                if not self.config.skip_model_unroll_with_expert:
                    # subgoal_ids = clip.tokenize(subgoal_instr)
                    self.model.step(torch.stack(list(self.state_buffer), dim=0), subgoal_tensor, prev_action=prev_action)
                    prev_action = action['action'] if not self.config.no_teacher_force_unroll_with_expert else None
                
                # curr_image = Image.fromarray(np.uint8(env.last_event.frame))
                # curr_image.save('test_current.jpg', dpi=(300, 300))
                
                # curr_image = Image.fromarray(np.uint8(env.last_event.instance_segmentation_frame))
                # curr_image.save('test_current_seg.jpg', dpi=(300, 300))
                
                # execute expert action
                    # success, _, _, err, _ = env.va_interact(action['action'], interact_mask=mask, smooth_nav=self.config.smooth_nav, debug=self.config.debug)
                success, _, _, err, _ = self.va_interact(env, action['action'], interact_mask=mask)
                
                if not success:
                    if isinstance(err, TimeoutError):
                        print("Timeout in expert")
                        return False
                    
                    print ("expert initialization failed")
                    break

                # update transition reward
                t_reward, t_done = env.get_transition_reward()

            # subgoal evaluation
            else:
                in_t += 1
                subgoal_ids = subgoal_ids.view(1, -1)
                # subgoal_ids = torch.tensor(subgoal_instr_).view(1, -1)
                # print(subgoal_ids)
                action, out, ht = self.model.get_action(feat['frames'][0][0].to(self.device), subgoal_tensor, ht, out)
                action_words = self.model.vocab['action_low'].index2word(list(action)[0])
                
                # mask = np.squeeze(mask, axis=0) if self.model.has_interaction(action_words) else None
                mask = None
        
                prev_action.append(int(action[0]))
                prev_action_words.append(action_words)
                # prev_action = action

                # debug
                if self.config.debug:
                    print("Pred: ", action)

                if action_words not in self.TERMINAL_TOKENS:
                    # use predicted action and mask (if provided) to interact with the env
                    # t_success, _, _, err, _ = env.va_interact(action_words, interact_mask=mask, smooth_nav=self.config.smooth_nav, debug=self.config.debug)
                    t_success, _, _, err, _ = self.va_interact(env, action_words, interact_mask=mask)
                    
                    if not t_success:
                        if isinstance(err, TimeoutError):
                            print("Timeout in agent")
                            return False
                        
                        fails += 1
                        if fails >= self.config.max_fails:
                            print("Interact API failed %d times" % (fails) + "; latest error '%s'" % err)
                            break

                # next time-step
                t_reward, t_done = env.get_transition_reward()
                reward += t_reward

                # update subgoals
                curr_subgoal_idx = env.get_subgoal_idx()
                if curr_subgoal_idx == eval_idx:
                    subgoal_success = True
                    break

                # terminal tokens predicted
                if action in self.TERMINAL_TOKENS:
                    print("predicted %s" % action)
                    break

            # increment time index
            t += 1

        # metrics
        pl = float(t - len(expert_init_actions)) + 1 # +1 for last action
        expert_pl = len([ll for ll in traj_data['plan']['low_actions'] if ll['high_idx'] == eval_idx])

        s_spl = (1 if subgoal_success else 0) * min(1., expert_pl / (pl + sys.float_info.epsilon))
        plw_s_spl = s_spl * expert_pl

        # results
        # for sg in cls.ALL_SUBGOALS:
            # results[sg] = {
            #         'sr': 0.,
            #         'successes': 0.,
            #         'evals': 0.,
            #         'sr_plw': 0.
            # }

        log_entry = {'trial': traj_data['task_id'],
                     'type': traj_data['task_type'],
                     'repeat_idx': int(r_idx),
                     'subgoal_idx': int(eval_idx),
                     'subgoal_type': subgoal_action,
                     'subgoal_instr': subgoal_instr,
                     'subgoal_success_spl': float(s_spl),
                     'subgoal_path_len_weighted_success_spl': float(plw_s_spl),
                     'subgoal_path_len_weight': float(expert_pl),
                     'reward': float(reward)}
        if subgoal_success:
            sg_successes = self.successes[subgoal_action]
            sg_successes.append(log_entry)
            self.successes[subgoal_action] = sg_successes
        else:
            sg_failures = self.failures[subgoal_action]
            sg_failures.append(log_entry)
            self.failures[subgoal_action] = sg_failures

        # save results
        print("-------------")
        subgoals_to_evaluate = list(self.successes.keys())
        subgoals_to_evaluate.sort()
        for sg in subgoals_to_evaluate:
            num_successes, num_failures = len(self.successes[sg]), len(self.failures[sg])
            num_evals = len(self.successes[sg]) + len(self.failures[sg])
            if num_evals > 0:
                sr = float(num_successes) / num_evals
                total_path_len_weight = sum([entry['subgoal_path_len_weight'] for entry in self.successes[sg]]) + \
                                        sum([entry['subgoal_path_len_weight'] for entry in self.failures[sg]])
                sr_plw = float(sum([entry['subgoal_path_len_weighted_success_spl'] for entry in self.successes[sg]]) +
                                    sum([entry['subgoal_path_len_weighted_success_spl'] for entry in self.failures[sg]])) / total_path_len_weight

                self.results[sg] = {
                    'sr': sr,
                    'successes': num_successes,
                    'evals': num_evals,
                    'sr_plw': sr_plw
                }

                # print("%s ==========" % sg)
                print("SR: %d/%d = %.3f" % (num_successes, num_evals, sr))
                print("PLW SR: %.3f" % (sr_plw))
                # print(traj_data['turk_annotations']['anns'][r_idx]['high_descs'][eval_idx])
                # print(prev_action_words)
                # print(f"Expert actions: {expert_actions}")
        print("------------")
        return True


    def save_results(self):
        results = {'eval successes': dict(self.successes),
                   'eval failures': dict(self.failures),
                   'eval results': dict(self.results)}
        
        # results = {'eval results': dict(self.results)}
        
        self.results = results
        
        # names = self.config.model_name.split('_')
        model_path = f"{self.config.model_type}-{self.config.if_clip}"
        save_path = os.path.dirname(self.config.model_path)
        
        if not os.path.exists(f'{save_path}/results/{model_path}'):
            os.makedirs(f'{save_path}/results/{model_path}')
            
        save_path = f'{save_path}/results/{model_path}/{datetime.now().strftime("%m%d_%H%M_")}_{model_path}_{self.config.eval_split}.json'
        
        with open(save_path, 'w') as r:
            json.dump(results, r, indent=4, sort_keys=True)
        
        return self.results['eval results']['GotoLocation']['sr'], self.results['eval results']['GotoLocation']['sr_plw']
        # print(results)
    
    @timeout(60)
    def reset_env(cls, env, scene_name, object_poses, object_toggles, dirty_and_empty, traj_data, config, reward_type, r_idx):
        env.reset(scene_name)
        
        # env.reset(scene_name)
        env.restore_scene(object_poses, object_toggles, dirty_and_empty)

        # initialize to start position
        env.step(dict(traj_data['scene']['init_action']))

        # print goal instr
        print("Task: %s" % (traj_data['turk_annotations']['anns'][r_idx]['task_desc']))

        # setup task for reward
        env.set_task(traj_data, config, reward_type=reward_type)
        
        return env
    
    def setup_scene(cls, env, traj_data, r_idx, config, reward_type='dense'):
        '''
        intialize the scene and agent from the task info
        '''
        # scene setup
        scene_num = traj_data['scene']['scene_num']
        object_poses = traj_data['scene']['object_poses']
        dirty_and_empty = traj_data['scene']['dirty_and_empty']
        object_toggles = traj_data['scene']['object_toggles']

        scene_name = 'FloorPlan%d' % scene_num
        
        # env.reset(scene_name)
        try:
            env = cls.reset_env(env, scene_name, object_poses, object_toggles, dirty_and_empty, traj_data, config, reward_type, r_idx)
        except Exception as e:
            print(e)
            print("Timeout. Retry.")
            return False
        
        return True