import argparse
import pickle
import pprint
import os
import copy
from glob import glob
import random

import numpy as np
from sklearn.metrics import precision_score, f1_score
import torch
from PIL import Image


TARGET_OBJECT_TYPES = tuple(
        sorted(
            [
                "AlarmClock",
                "Apple",
                "Book",
                "Bowl",
                "Box",
                "Candle",
                "GarbageCan",
                "HousePlant",
                "Laptop",
                "SoapBottle",
                "Television",
                "Toaster",
            ],
        )
    )
CLASSES = ['Drawer', 'Floor', 'CounterTop', 'Mirror', 'StoveBurner', 'Cabinet', 'SideTable', 'Sink', 'GarbageBag', 
                'CoffeeMachine', 'Toaster', 'Pot', 'Plate', 'GarbageCan', 'StoveKnob', 'Fork', 'SoapBottle', 'Fridge', 
                'Pan', 'Window', 'Egg', 'Spatula', 'Microwave', 'Cup', 'SinkBasin', 'SaltShaker', 'PepperShaker', 
                'ButterKnife', 'DishSponge', 'LightSwitch', 'Spoon', 'Knife', 'Shelf', 'Mug', 'DiningTable', 'Blinds', 
                'AluminumFoil', 'Faucet', 'Bowl', 'Chair']
ACTION_TYPES = [
    "MoveAhead",
    "RotateLeft",
    "RotateRight",
    "End",
    "LookUp",
    "LookDown",
    ]

def class_mask(semantic_frame, class_color):
    if class_color is None:
        return np.zeros(semantic_frame.shape[:2], dtype=bool)
    mask = np.all(semantic_frame == class_color, axis=-1)    
    return mask

def obj_presence(class_masks):
    return class_masks.sum(axis=(1,2)) > 0

def select_few_traj(mdp_keys, dict_data, shot):
    fewshot_dict_data = copy.deepcopy(dict_data)
    # init
    for mdp in mdp_keys:
        fewshot_dict_data[mdp]={}
    # original
    cnt_dict = {}
    for mdp in mdp_keys:
        print(mdp)
        for task in dict_data[mdp].keys():
            target_obj = dict_data[mdp][task]["instruction"][0]
            traj_len = len(dict_data[mdp][task]["reward"])
            print(task, target_obj, traj_len)
            if target_obj not in cnt_dict.keys():
                cnt_dict[target_obj] = 1
            else:
                cnt_dict[target_obj] += 1
    print(cnt_dict)
    # selection
    cnt_dict = {}
    for mdp in mdp_keys:
        print(mdp)
        selected_tasks = random.sample(dict_data[mdp].keys(), shot)
        print(selected_tasks)
        for idx, task in enumerate(selected_tasks):
            target_traj = dict_data[mdp][task]
            fewshot_dict_data[mdp][idx] = copy.deepcopy(target_traj)

            obj = target_traj["instruction"][0]
            traj_len = len(target_traj["instruction"])

            if obj not in cnt_dict.keys():
                cnt_dict[obj] = 1
            else:
                cnt_dict[obj] += 1
    print(cnt_dict)
    return fewshot_dict_data

def main(args):
    global TARGET_OBJECT_TYPES, CLASSES, ACTION_TYPES

    factor_traj = {}
    global_dict = {}

    base_key = None
    
    source_dir_list = os.listdir(os.path.join(args.root, args.factor))
    for mdp, instance_dir in enumerate(source_dir_list):
        print(instance_dir)
        k = mdp
        if "base" in instance_dir:
            base_key = k
        factor_traj[k] = os.path.join(instance_dir, "expert_data.pkl")
    factor_traj = dict(sorted(factor_traj.items()))
    print(factor_traj)
    
    # base
    global_dict[base_key] = {}
    with open(os.path.join(args.root, args.factor, factor_traj[base_key]), 'rb') as f:
        data = pickle.load(f)
    print(data.keys()) # dict_keys(['object_types', 'action_types', 'domain_conf', 'traj_data'])
    print(data["traj_data"][0].keys()) # dict_keys(['frames', 'actions', 'rewards', 'dones', 'goals', 'infos', 'semantic_segmentation_frame', 'object_id_to_color', 'metadata'])
    
    for ep in range(len(data["traj_data"])):
        episode_data = data["traj_data"][ep]
        # predifine
        if CLASSES is None:
            temp={}
            for o in episode_data["metadata"][0]:
                print(o)
                o_name = o["name"].split("_")[0]
                temp[o_name]=0
            print(list(temp))
            CLASSES = list(temp)
            global_dict["classes"] = CLASSES
        else:
            global_dict["classes"] = CLASSES
        
        save_data = {
            "frame": [],
            "goal": [],
            "action": [],
            "reward": [],
            "metadata": []
        }
        for i in range(len(episode_data["frames"])):
            class_masks = np.array([
                class_mask(
                episode_data['semantic_segmentation_frame'][i],
                    episode_data['object_id_to_color'][i].get(o, None)
                ) for o in CLASSES
            ])
            class_mask_labels = {o: episode_data['object_id_to_color'][i].get(o, None) for o in CLASSES}

            object_presence = obj_presence(class_masks)
            # print(object_presence)
            # print(episode_data["frames"][i].shape)
            # img = Image.fromarray(episode_data["frames"][i])
            # img.save("test.png")

            save_data["frame"].append(episode_data["frames"][i])
            save_data["goal"].append(episode_data["goals"][i])
            save_data["action"].append(episode_data["actions"][i])
            save_data["reward"].append(episode_data["rewards"][i])
            save_data["metadata"].append(object_presence)

        global_dict[base_key][ep] = copy.deepcopy(save_data)
    
    for i, key in enumerate(factor_traj.keys()):
        f_name = factor_traj[key]
        print(i, key, f_name)
        if key not in global_dict.keys():
            global_dict[key] = {}
            
        with open(os.path.join(args.root, args.factor, f_name), 'rb') as f:
            data = pickle.load(f)
        
        used_base_tasks = set([])
        for ep in range(len(data["traj_data"])):
            episode_data = data["traj_data"][ep]
            save_data = {
                "frame": [],
                "goal": [],
                "action": [],
                "reward": [],
                "metadata": []
            }
            
            if len(episode_data["frames"]) > 500:
                if len(episode_data["frames"]) > 1000:
                    continue
                print("episode length:", len(episode_data["frames"]))
                print(episode_data["rewards"][-1])
                real_goal = None
                for _ in range(len(episode_data["frames"])):
                    if _ == 0:
                        real_goal = episode_data["goals"][_]
                        continue
                    if real_goal != episode_data["goals"][_]:
                        real_goal = episode_data["goals"][_]
                        break
                episode_data["frames"] = episode_data["frames"][_:]
                episode_data["goals"] = episode_data["goals"][_:]
                episode_data["actions"] = episode_data["actions"][_:]
                episode_data["rewards"] = episode_data["rewards"][_:]
                episode_data["metadata"] = episode_data["metadata"][_:]
                print("episode length:", len(episode_data["frames"]))
                print(episode_data["rewards"][-1])
            
            for j in range(len(episode_data["frames"])):
                # j is timestep
                class_masks = np.array([
                    class_mask(
                    episode_data['semantic_segmentation_frame'][j],
                        episode_data['object_id_to_color'][j].get(o, None)
                    ) for o in CLASSES
                ])
                class_mask_labels = {o: episode_data['object_id_to_color'][j].get(o, None) for o in CLASSES}

                object_presence = obj_presence(class_masks)

                # align the episode index
                if j == 0:
                    do_not_save = False
                    # task
                    task = episode_data["goals"][j]
                    # metadata
                    metadata = object_presence
                    # base information
                    # base task
                    base_tasks = []
                    for k in global_dict[base_key].keys():
                        if global_dict[base_key][k]["goal"][j] == task:
                            f1 = f1_score(global_dict[base_key][k]["metadata"][j], metadata, pos_label=1)
                            acc = np.mean(global_dict[base_key][k]["metadata"][j] == metadata)
                            base_tasks.append((k, f1, acc))
                    base_tasks.sort(key=lambda x: (-x[1], -x[2], x[0]))
                    #print(ep, task, base_tasks)
                    if base_tasks[0][1]:
                        align_key = base_tasks[0][0]
                        used_base_tasks.add(align_key)
                    else:
                        for base_task in base_tasks:
                            if base_task[0] in list(used_base_tasks):
                                continue
                            else:
                                if base_task[2]:
                                    align_key = base_task[0]
                                    used_base_tasks.add(align_key)
                                    break
                                else:
                                    do_not_save = True
                if do_not_save:
                    print("skip episode", ep)
                    continue
                else:
                    save_data["frame"].append(episode_data["frames"][j])
                    save_data["goal"].append(episode_data["goals"][j])
                    save_data["action"].append(episode_data["actions"][j])
                    save_data["reward"].append(episode_data["rewards"][j])
                    save_data["metadata"].append(object_presence)
            
            global_dict[key][align_key] = copy.deepcopy(save_data)
    
    global_dict["goal_types"] = TARGET_OBJECT_TYPES
    global_dict["actions"] = ACTION_TYPES
    
    print(global_dict.keys())
    # check_key = 3
    # for key in global_dict.keys():
    #     try:
    #         print(global_dict[key][check_key]["goal"][0])
    #     except:
    #         print(key)
    
    #few_shot_dict = select_few_traj(args.mdps, global_dict, shot=args.few_shot)
    os.makedirs(os.path.join(args.out_dir, args.factor), exist_ok=True)

    with open(os.path.join(args.out_dir,  args.factor, args.file_name), 'wb') as f:
        pickle.dump(global_dict, f, pickle.HIGHEST_PROTOCOL)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Expert Trajectory Preprocessing')
    # dataset parameters
    parser.add_argument('root', metavar='DIR',
                        help='root path of trajectory')
    parser.add_argument('--factor', type=str, default='FOV')
    parser.add_argument('--out-dir', type=str, default='trajdata/ObjNav/original')
    parser.add_argument('--file-name', type=str, default='train_dataset.pkl')
    parser.add_argument('--few-shot', type=int, default=8)
    args = parser.parse_args()
    main(args)
