import os
import sys
import json
import numpy as np
from PIL import Image
from datetime import datetime
from env.thor_env import ThorEnv
from eval.eval import Eval
import torch
import networks.CLIP.clip.clip as clip
import pickle
import pprint
from models.nn.resnet import Resnet
import random
import time
import copy

class EvalSubgoals():
    '''
    evaluate subgoals by teacher-forching expert demonstrations
    '''
    ALL_SUBGOALS = ['GotoLocation', 'PickupObject', 'PutObject', 'CoolObject', 'HeatObject', 'CleanObject', 'SliceObject', 'ToggleObject']
    STOP_TOKEN = "<<stop>>"
    SEQ_TOKEN = "<<seg>>"
    TERMINAL_TOKENS = [STOP_TOKEN, SEQ_TOKEN]
    
    def __init__(self, args, agent):
        self.args = args

        # load splits
        with open(self.args.splits) as f:
            self.splits = json.load(f)
            pprint.pprint({k: len(v) for k, v in self.splits.items()})

        self.model = agent
        
        # load resnet
        self.args.visual_model = 'resnet18'
        self.resnet = Resnet(self.args, eval=True, share_memory=True, use_conv_feat=True)

        # gpu
        if self.args.gpu:
            self.model = self.model.to(torch.device('cuda'))

        # set random seed for shuffling
        random.seed(int(time.time()))
        
        self.task_queue = []
        files = self.splits[self.args.eval_split]

        if self.args.fast_epoch:
            files = files[:16]

        if self.args.shuffle:
            random.shuffle(files)
        
        if self.args.num_eval_file < 0:
            for traj in files:
                self.task_queue.append(traj)
        else:
            for traj in files[:self.args.num_eval_file]:
                self.task_queue.append(traj)
                

    def run(self):
        '''
        evaluation loop
        '''
        # start THOR
        env = ThorEnv()
        
        id2pt = {}
        num = 0

        # make subgoals list
        subgoals_to_evaluate = self.ALL_SUBGOALS if self.args.subgoals.lower() == "all" else self.args.subgoals.split(',')
        subgoals_to_evaluate = [sg for sg in subgoals_to_evaluate if sg in self.ALL_SUBGOALS]
        print ("Subgoals to evaluate: %s" % str(subgoals_to_evaluate))

        # create empty stats per subgoal
        self.successes = {}
        self.failures = {}
        self.results = {}
        for sg in subgoals_to_evaluate:
            self.successes[sg] = list()
            self.failures[sg] = list()
            self.results[sg] = {
                'sr': 0.,
                'successes': 0.,
                'evals': 0.,
                'sr_plw': 0.
            }
        
        subtask_ct = 0
        
        while True:
            if num >= 12000:
                break
            if len(self.task_queue) == 0:
                break

            task = self.task_queue.pop(0)

            try:
                traj = self.model.load_task_json(task)
                
                r_idx = task['repeat_idx']
                subgoal_idxs = [sg['high_idx'] for sg in traj['plan']['high_pddl'] if sg['discrete_action']['action'] in subgoals_to_evaluate]
                for eval_idx in subgoal_idxs:
                    print("No. of trajectories left: %d" % (len(self.task_queue)), f"\n {num} has been generated")
                    success, frames, actions_num, actions_real, actions_expert = self.evaluate(env, eval_idx, r_idx, traj)
                    
                    if success:
                        continue
                    else:
                        print("Failed")
                    
                    if actions_real == actions_expert:
                        print("The same.")
                        continue
                    
                    steps = len(actions_real)
                    done = np.zeros(steps)
                    if success:
                        done[-1] = 1
                    done = torch.tensor(done).reshape(-1, 1)
                    reward = torch.tensor(done).reshape(-1, 1)
                    
                    actions_onehot = torch.zeros(steps, 15)
                    for j in range(steps):
                        actions_onehot[j][actions_num[j]] = 1
                    
                    root = traj['root']
                    root += f'_{eval_idx}_{r_idx}'
                    
                    image_roots = os.path.join('envs/alfred/data/json_feat_2.1.0', 'noise', *(root.split('/')[-2:]))
                    # image_folder = '/'.join(image_roots.split('/')[:-1])
                    if not os.path.isdir(image_roots):
                        os.makedirs(image_roots)
                    
                    frames = [torch.tensor(i) for i in frames]
                    frames = torch.cat(frames, dim=0)[:, 0]
                    
                    torch.save(frames, './' + image_roots + '/feat_conv.pt')
                    
                    subgoal = traj['turk_annotations']['anns'][r_idx]['high_descs'][eval_idx]
                    
                    data_processed = {
                        'id': None,
                        'subgoal_type': traj['plan']['high_pddl'][eval_idx]['discrete_action']['action'],
                        'split': 'noise',
                        'root': root,
                        'goal': traj['ann']['goal'],
                        'index': (0, steps),
                        'reward': reward,
                        'done': done,
                        'actions': actions_onehot,
                        'actions_low_mask': None,
                        'valid_interacts': None,
                        'subgoal': None,
                        'mc': None,
                        'subgoal_instr': subgoal,
                        'subgoal_ori': traj['num']['lang_instr'][eval_idx],
                    }
                    print(subgoal)
                    file_folder = f'data/noise/trajs'
                    file_path = file_folder + f'/id_{num}_preprocessed.pt'
                    
                    if not os.path.exists(file_folder):
                        os.makedirs(file_folder)
                    
                    with open(file_path, 'wb') as f:
                        pickle.dump(data_processed, f)
                    
                    id2pt[num] = file_path
                    
                    num += 1
                    
                    with open(f'data/noise/id2pt.pt', 'wb') as f:
                        pickle.dump(id2pt, f)
                    
                    
            except Exception as e:
                import traceback
                traceback.print_exc()
                print("Error: " + repr(e))
        
        self.save_results()
        
        env.stop()

    def evaluate(self, env, eval_idx, r_idx, traj_data):
        # reset model
        # model.reset()

        # setup scene
        
        reward_type = 'dense'
        self.setup_scene(env, traj_data, r_idx, self.args, reward_type=reward_type)

        # expert demonstration to reach eval_idx-1 
        expert_init_actions = [a['discrete_action'] for a in traj_data['plan']['low_actions'] if a['high_idx'] < eval_idx]

        # subgoal info
        subgoal_action = traj_data['plan']['high_pddl'][eval_idx]['discrete_action']['action']
        # subgoal_instr = traj_data['turk_annotations']['anns'][r_idx]['high_descs'][eval_idx]
        subgoal_instr = traj_data['num']['lang_instr'][eval_idx]

        # print subgoal info
        print("Evaluating: %s\nSubgoal %s (%d)\nInstr: %s" % (traj_data['root'], subgoal_action, eval_idx, subgoal_instr))

        # extract language features
        feat = self.model.featurize([copy.deepcopy(traj_data)], load_mask=False)

        # previous action for teacher-forcing during expert execution (None is used for initialization)
        prev_action = torch.tensor([15])

        done, subgoal_success = False, False
        fails = 0
        t = 0
        reward = 0
        
        actions_real = []
        actions_num = []
        actions_expert = []
        frames = []
        
        for idx, action in enumerate(traj_data['plan']['low_actions']):
            if action['high_idx'] == eval_idx:
                action_words = action['discrete_action']['action']
                actions_expert.append(action_words)
            
        while not done:
            # break if max_steps reached
            if t >= self.args.max_steps + len(expert_init_actions):
                break

            # extract visual feats
            curr_image = Image.fromarray(np.uint8(env.last_event.frame))
            
            # curr_image.save('test_current.jpg', dpi=(300, 300))
            
            feat['frames'] = self.resnet.featurize([curr_image], batch=1).unsqueeze(0)
            
            
            # expert teacher-forcing upto subgoal
            if t < len(expert_init_actions):
                # get expert action
                action = expert_init_actions[t]
                subgoal_completed = traj_data['plan']['low_actions'][t+1]['high_idx'] != traj_data['plan']['low_actions'][t]['high_idx']
                compressed_mask = action['args']['mask'] if 'mask' in action['args'] else None
                mask = env.decompress_mask(compressed_mask) if compressed_mask is not None else None

                # forward model
                if not self.args.skip_model_unroll_with_expert:
                    # subgoal_ids = clip.tokenize(subgoal_instr)
                    self.model.step(feat, subgoal_ids, prev_action=prev_action)
                    prev_action = action['action'] if not self.args.no_teacher_force_unroll_with_expert else None
                
                # curr_image = Image.fromarray(np.uint8(env.last_event.frame))
                # curr_image.save('test_current.jpg', dpi=(300, 300))
                
                # curr_image = Image.fromarray(np.uint8(env.last_event.instance_segmentation_frame))
                # curr_image.save('test_current_seg.jpg', dpi=(300, 300))
                
                # execute expert action
                success, _, _, err, _ = env.va_interact(action['action'], interact_mask=mask, smooth_nav=self.args.smooth_nav, debug=self.args.debug)
                if not success:
                    print ("expert initialization failed")
                    break

                # update transition reward
                t_reward, t_done = env.get_transition_reward()

            # subgoal evaluation
            else:
                # subgoal_ids = clip.tokenize(subgoal_instr)
                subgoal_ids = torch.tensor(subgoal_instr).view(1, -1)
                action, action_words, mask = self.model.step(feat, subgoal_ids, prev_action=prev_action)
                
                actions_real.append(action_words)
                actions_num.append(int(action))
                frames.append(feat['frames'])

                mask = np.squeeze(mask, axis=0) if self.model.has_interaction(action_words) else None

                prev_action = action[0]

                # debug
                if self.args.debug:
                    print("Pred: ", action)

                if action_words not in self.TERMINAL_TOKENS:
                    # use predicted action and mask (if provided) to interact with the env
                    t_success, _, _, err, _ = env.va_interact(action_words, interact_mask=mask, smooth_nav=self.args.smooth_nav, debug=self.args.debug)
                    if not t_success:
                        fails += 1
                        if fails >= self.args.max_fails:
                            print("Interact API failed %d times" % (fails) + "; latest error '%s'" % err)
                            break

                # next time-step
                t_reward, t_done = env.get_transition_reward()
                reward += t_reward

                # update subgoals
                curr_subgoal_idx = env.get_subgoal_idx()
                if curr_subgoal_idx == eval_idx:
                    subgoal_success = True
                    break

                # terminal tokens predicted
                if action in self.TERMINAL_TOKENS:
                    print("predicted %s" % action)
                    break

            # increment time index
            t += 1

        return subgoal_success, frames, actions_num, actions_real, actions_expert

    def save_results(self):
        results = {'eval successes': dict(self.successes),
                   'eval failures': dict(self.failures),
                   'eval results': dict(self.results)}
        
        # results = {'eval results': dict(self.results)}
        
        self.results = results
        
        # names = self.args.model_name.split('_')
        model_path = self.args.model_name
        save_path = os.path.dirname(self.args.model_path)
        
        if not os.path.exists(f'{save_path}/results/{model_path}'):
            os.makedirs(f'{save_path}/results/{model_path}')
            
        save_path = f'{save_path}/results/{model_path}/{datetime.now().strftime("%m%d_%H%M_")}_{model_path}_{self.args.eval_split}.json'
        
        with open(save_path, 'w') as r:
            json.dump(results, r, indent=4, sort_keys=True)
            
        # print(results)
    
    def setup_scene(cls, env, traj_data, r_idx, args, reward_type='dense'):
        '''
        intialize the scene and agent from the task info
        '''
        # scene setup
        scene_num = traj_data['scene']['scene_num']
        object_poses = traj_data['scene']['object_poses']
        dirty_and_empty = traj_data['scene']['dirty_and_empty']
        object_toggles = traj_data['scene']['object_toggles']

        scene_name = 'FloorPlan%d' % scene_num
        env.reset(scene_name)
        env.restore_scene(object_poses, object_toggles, dirty_and_empty)

        # initialize to start position
        env.step(dict(traj_data['scene']['init_action']))

        # print goal instr
        print("Task: %s" % (traj_data['turk_annotations']['anns'][r_idx]['task_desc']))

        # setup task for reward
        env.set_task(traj_data, args, reward_type=reward_type)