import json
import pprint
import random
import time
import torch
import torch.multiprocessing as mp
from models.nn.resnet import Resnet
from data.preprocess import Dataset
from importlib import import_module
import networks.CLIP.clip.clip as clip

from wrapt_timeout_decorator import *

class Eval(object):

    # tokens
    STOP_TOKEN = "<<stop>>"
    SEQ_TOKEN = "<<seg>>"
    TERMINAL_TOKENS = [STOP_TOKEN, SEQ_TOKEN]
    
    
    def __init__(self, args, agent, manager):
        # args and manager
        self.device = torch.device(args.device)
        self.args = args
        self.manager = manager
        
        # load splits
        with open(self.args.splits) as f:
            self.splits = json.load(f)
            pprint.pprint({k: len(v) for k, v in self.splits.items()})

        # load model
        # print("Loading: ", self.args.model_path)
        # M = import_module(self.args.model)
        
        self.model = agent
        # self.model, optimizer = M.Module.load(self.args.model_path)
        # self.model.net.share_memory()
        # for param in self.model.net.parameters():
        #     param.requires_grad = False
        # self.model.test_mode = True

        # # updated args
        # self.model.args.dout = self.args.model_path.replace(self.args.model_path.split('/')[-1], '')
        # self.model.args.data = self.args.data if self.args.data else self.model.args.data

        # preprocess and save
        if args.preprocess:
            print("\nPreprocessing dataset and saving to %s folders ... This is will take a while. Do this once as required:" % self.model.args.pp_folder)
            self.model.args.fast_epoch = self.args.fast_epoch
            dataset = Dataset(self.model.args, self.model.vocab)
            dataset.preprocess_splits(self.splits)

        # load resnet
        args.visual_model = 'resnet18'
        self.resnet = Resnet(args, eval=True, share_memory=True, use_conv_feat=True, device=args.device)
        for param in self.resnet.resnet_model.model.parameters():
            param.requires_grad = False
        # gpu
        # if self.args.gpu:
        #     self.model.net = self.model.net.to(self.device)

        # success and failure lists
        self.create_stats()

        # set random seed for shuffling
        random.seed(int(time.time()))

    def queue_tasks(self):
        '''
        create queue of trajectories to be evaluated
        '''
        task_queue = self.manager.Queue()
        files = self.splits[self.args.eval_split]

        # debugging: fast epoch
        if self.args.fast_epoch:
            files = files[:16]

        if self.args.shuffle:
            random.shuffle(files)
        
        if self.args.num_eval_file < 0:
            for traj in files:
                task_queue.put(traj)
        else:
            for traj in files[:self.args.num_eval_file]:
                task_queue.put(traj)
        return task_queue
    
    def join_threads(self):
        for t in self.threads:
            t.join()
    
        sr, srw = self.save_results()
        return sr, srw
    
    def spawn_threads(self):
        '''
        spawn multiple threads to run eval in parallel
        '''
        self.task_queue = self.queue_tasks()

        # start threads
        self.threads = []
        self.lock = self.manager.Lock()
        for n in range(self.args.num_threads):
            with torch.no_grad():
                thread = mp.Process(target=self.run, args=(self.model, self.resnet, self.task_queue, self.args, self.lock,
                                                       self.successes, self.failures, self.results))
            thread.start()
            self.threads.append(thread)
    
    # @timeout(60)
    def reset_env(env, scene_name, object_poses, object_toggles, dirty_and_empty, traj_data, config, reward_type, r_idx):
        env.reset(scene_name)
        
        # env.reset(scene_name)
        env.restore_scene(object_poses, object_toggles, dirty_and_empty)

        # initialize to start position
        env.step(dict(traj_data['scene']['init_action']))

        # print goal instr
        print("Task: %s" % (traj_data['turk_annotations']['anns'][r_idx]['task_desc']))

        # setup task for reward
        env.set_task(traj_data, config, reward_type=reward_type)
        
        return env
    
    @classmethod
    def setup_scene(cls, env, traj_data, r_idx, args, reward_type='dense'):
        '''
        intialize the scene and agent from the task info
        '''
        # scene setup
        scene_num = traj_data['scene']['scene_num']
        object_poses = traj_data['scene']['object_poses']
        dirty_and_empty = traj_data['scene']['dirty_and_empty']
        object_toggles = traj_data['scene']['object_toggles']

        scene_name = 'FloorPlan%d' % scene_num
        try:
            env = cls.reset_env(env, scene_name, object_poses, object_toggles, dirty_and_empty, traj_data, args, reward_type, r_idx)
        except Exception as e:
            print(e)
            print("Timeout. Retry.")
            return False
        
        return True

    @classmethod
    def run(cls, model, resnet, task_queue, args, lock, successes, failures):
        raise NotImplementedError()

    @classmethod
    def evaluate(cls, env, model, r_idx, resnet, traj_data, args, lock, successes, failures):
        raise NotImplementedError()

    def save_results(self):
        raise NotImplementedError()

    def create_stats(self):
        raise NotImplementedError()
    
    