import numpy as np
import numpy.linalg as LA
import cv2
import matplotlib.pyplot as plt
import random
from modeling.utils.baseline_utils import apply_color_to_map, pose_to_coords, gen_arrow_head_marker, read_map_npy, read_occ_map_npy, plus_theta_fn, crop_map, spatial_transform_map
from core import cfg
import modeling.utils.frontier_utils as fr_utils
from modeling.localNavigator_Astar import localNav_Astar
import networkx as nx
from random import Random
from timeit import default_timer as timer
from itertools import islice
import os
import multiprocessing
import pickle
from skimage.morphology import skeletonize
import torch
import math
from scipy import ndimage
import bz2
import _pickle as cPickle


def get_region(robot_pos, H, W, size=2):
    y, x = robot_pos
    y1 = max(0, y - size)
    y2 = min(H - 1, y + size)
    x1 = max(0, x - size)
    x2 = min(W - 1, x + size)

    return (y1, x1, y2, x2)


class Data_Gen_MP3D:
    '''
            generate partial map training data for each MP3D scene.
    '''

    def __init__(self, split, scene_name, saved_dir=''):
        self.split = split
        self.scene_name = scene_name
        self.random = Random(cfg.GENERAL.RANDOM_SEED)

        # ============= create scene folder =============
        scene_folder = f'{saved_dir}/{scene_name}'
        if not os.path.exists(scene_folder):
            os.mkdir(scene_folder)
        self.scene_folder = scene_folder

        self.init_scene()

    def init_scene(self):
        scene_name = self.scene_name
        print(f'init new scene: {scene_name}')

        # ================================= read in pre-built occupancy and semantic map =============================
        sem_map_npy = np.load(
            f'{cfg.SAVE.SEMANTIC_MAP_PATH}/{self.split}/{scene_name}/BEV_semantic_map.npy', allow_pickle=True).item()
        self.gt_sem_map, self.pose_range, self.coords_range, self.WH = read_map_npy(
            sem_map_npy)
        occ_map_npy = np.load(
            f'{cfg.SAVE.OCCUPANCY_MAP_PATH}/{self.split}/{scene_name}/BEV_occupancy_map.npy', allow_pickle=True).item()
        gt_occ_map, _, _, _ = read_occ_map_npy(occ_map_npy)

        if cfg.NAVI.D_type == 'Skeleton':
            self.skeleton = skeletonize(gt_occ_map)
            if cfg.NAVI.PRUNE_SKELETON:
                self.skeleton = fr_utils.prune_skeleton(
                    gt_occ_map, self.skeleton)

        gt_occupancy_map = gt_occ_map.copy()
        gt_occupancy_map = np.where(
            gt_occupancy_map == 1, cfg.FE.FREE_VAL, gt_occupancy_map)  # free cell
        self.gt_occupancy_map = np.where(
            gt_occupancy_map == 0, cfg.FE.COLLISION_VAL, gt_occupancy_map)  # occupied cell

        self.M_c = np.stack((self.gt_occupancy_map, self.gt_sem_map))
        self.H, self.W = self.gt_sem_map.shape

        # initialize path planner
        self.LN = localNav_Astar(
            self.pose_range, self.coords_range, self.WH, scene_name)

        # find the largest connected component on the map
        self.G = self.LN.get_G_from_map(gt_occupancy_map)
        self.largest_cc = list(max(nx.connected_components(self.G), key=len))

    def write_to_file(self, num_samples=100):
        count_sample = 0
        # =========================== process each episode
        for idx_epi in range(num_samples):
            print(f'idx_epi = {idx_epi}')

            # ====================================== generate (start, goal) locs, compute path P==========================
            start_loc, goal_loc = self.random.choices(self.largest_cc, k=2)
            path = nx.shortest_path(self.G,
                                    source=start_loc,
                                    target=goal_loc)

            M_p = np.zeros(self.M_c.shape, dtype=np.int16)
            observed_area_flag = np.zeros((self.H, self.W), dtype=bool)
            # i_loc = 0
            end_i_loc = self.random.choice(list(range(len(path)+1)))

            # while i_loc < len(path):
            for i_loc in range(end_i_loc):
                robot_loc = path[i_loc]

                # t0 = timer()
                # =================================== generate partial map M_p ==================================
                roi = get_region(robot_loc, self.H, self.W,
                                 size=cfg.PRED.PARTIAL_MAP.NEIGHBOR_SIZE)
                M_p[:, roi[0]:roi[2]+1, roi[1]:roi[3] +
                    1] = self.M_c[:, roi[0]:roi[2]+1, roi[1]:roi[3]+1]
                observed_area_flag[roi[0]:roi[2]+1, roi[1]:roi[3]+1] = True
                # t1 = timer()
                # print(f't1 - t0 = {t1 - t0}')

            # t2 = timer()
            # ================================= compute area at frontier points ========================
            U_a = np.zeros((self.H, self.W), dtype=np.float32)
            U_d = np.zeros((self.H, self.W, 3), dtype=np.float32)
            observed_occupancy_map = M_p[0]
            frontiers = fr_utils.get_frontiers(observed_occupancy_map)
            # t3 = timer()
            # print(f'get frontier time = {t3 - t2}')
            agent_map_pose = (robot_loc[1], robot_loc[0])
            frontiers = self.LN.filter_unreachable_frontiers_temp(
                frontiers, agent_map_pose, observed_occupancy_map)
            # t4 = timer()
            # print(f'filter unreachable frontiers time = {t4 - t3}')
            frontiers = fr_utils.compute_frontier_potential(frontiers, observed_occupancy_map, self.gt_occupancy_map,
                                                            observed_area_flag, None, self.skeleton)
            # t5 = timer()
            # print(f'compute frontier potential time = {t5 - t4}')

            for fron in frontiers:
                points = fron.points.transpose()  # N x 2
                U_a[points[:, 0], points[:, 1]] = 1. * \
                    fron.R / cfg.PRED.PARTIAL_MAP.DIVIDE_AREA
                U_d[points[:, 0], points[:, 1], 0] = 1. * \
                    fron.D / cfg.PRED.PARTIAL_MAP.DIVIDE_D
                U_d[points[:, 0], points[:, 1], 1] = 1. * \
                    fron.Din / cfg.PRED.PARTIAL_MAP.DIVIDE_D
                U_d[points[:, 0], points[:, 1], 2] = 1. * \
                    fron.Dout / cfg.PRED.PARTIAL_MAP.DIVIDE_D

            # =================================== visualize M_p =========================================
            if cfg.PRED.PARTIAL_MAP.FLAG_VISUALIZE_PRED_LABELS:
                occ_map_Mp = M_p[0]
                sem_map_Mp = M_p[1]
                color_sem_map_Mp = apply_color_to_map(sem_map_Mp)

                fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(20, 30))
                ax[0][0].imshow(occ_map_Mp, cmap='gray')
                ax[0][0].get_xaxis().set_visible(False)
                ax[0][0].get_yaxis().set_visible(False)
                ax[0][0].set_title('input: occupancy_map_Mp')
                ax[0][1].imshow(color_sem_map_Mp)
                ax[0][1].get_xaxis().set_visible(False)
                ax[0][1].get_yaxis().set_visible(False)
                ax[0][1].set_title('input: semantic_map_Mp')

                ax[1][0].imshow(occ_map_Mp, cmap='gray')
                x_coord_lst = [path[i][1] for i in range(i_loc+1)]
                z_coord_lst = [path[i][0] for i in range(i_loc+1)]
                ax[1][0].plot(x_coord_lst, z_coord_lst,
                              lw=5, c='blue', zorder=3)
                for f in frontiers:
                    ax[1][0].scatter(f.points[1], f.points[0],
                                     c='yellow', zorder=2)
                    ax[1][0].scatter(
                        f.centroid[1], f.centroid[0], c='red', zorder=2)
                ax[1][0].get_xaxis().set_visible(False)
                ax[1][0].get_yaxis().set_visible(False)
                ax[1][0].set_title('observed_occ_map + frontiers')

                ax[1][1].imshow(U_a, vmin=0.0)
                ax[1][1].get_xaxis().set_visible(False)
                ax[1][1].get_yaxis().set_visible(False)
                ax[1][1].set_title('output: U_a')

                ax[2][0].imshow(U_d[:, :, 0], vmin=0.0)
                ax[2][0].get_xaxis().set_visible(False)
                ax[2][0].get_yaxis().set_visible(False)
                ax[2][0].set_title('output: U_d_0')

                ax[2][1].imshow(U_d[:, :, 1], vmin=0.0)
                ax[2][1].get_xaxis().set_visible(False)
                ax[2][1].get_yaxis().set_visible(False)
                ax[2][1].set_title('output: U_d_1')

                fig.tight_layout()
                plt.show()

            # ==========================crop the image =====================
            # print(f'M_p.shape = {M_p.shape}')
            # print(f'U_a.shape = {U_a.shape}')
            # print(f'U_d.shape = {U_d.shape}')
            # M_p = np.transpose(M_p, (1, 2, 0))
            U_d = np.transpose(U_d, (2, 0, 1))
            tensor_M_p = torch.tensor(M_p).float().unsqueeze(0)
            tensor_U_a = torch.tensor(U_a).float().unsqueeze(0).unsqueeze(1)
            tensor_U_d = torch.tensor(U_d).float().unsqueeze(0)

            if self.split == 'train':
                _, H, W = M_p.shape
                Wby2, Hby2 = W // 2, H // 2
                tform_trans = torch.Tensor(
                    [[agent_map_pose[0] - Wby2, agent_map_pose[1] - Hby2, 0]])
                crop_center = torch.Tensor(
                    [[W / 2.0, H / 2.0]]) + tform_trans[:, :2]
                '''
				# Crop a large-enough map around agent
				_, N, H, W = tensor_M_p.shape
				crop_center = torch.Tensor([[W / 2.0, H / 2.0]]) + tform_trans[:, :2]
				map_size = int(2 * cfg.PRED.PARTIAL_MAP.OUTPUT_MAP_SIZE / cfg.SEM_MAP.CELL_SIZE)
				tensor_M_p = crop_map(tensor_M_p, crop_center, map_size)
				tensor_U_a = crop_map(tensor_U_a, crop_center, map_size)
				tensor_U_d = crop_map(tensor_U_d, crop_center, map_size)
				# Rotate the map
				rot = random.uniform(-math.pi, math.pi)
				tform_rot = torch.Tensor([[0, 0, rot]])
				tensor_M_p = spatial_transform_map(tensor_M_p, tform_rot, 'nearest')
				tensor_U_a = spatial_transform_map(tensor_U_a, tform_rot, 'nearest')
				tensor_U_d = spatial_transform_map(tensor_U_d, tform_rot, 'nearest')
				'''
                # Crop out the appropriate size of the map
                # _, N, H, W = tensor_M_p.shape
                # map_center = torch.Tensor([[W / 2.0, H / 2.0]])
                map_size = int(
                    cfg.PRED.PARTIAL_MAP.OUTPUT_MAP_SIZE / cfg.SEM_MAP.CELL_SIZE)
                tensor_M_p = crop_map(
                    tensor_M_p, crop_center, map_size, 'nearest')
                tensor_U_a = crop_map(
                    tensor_U_a, crop_center, map_size, 'nearest')
                tensor_U_d = crop_map(
                    tensor_U_d, crop_center, map_size, 'nearest')
            elif self.split == 'val':
                _, H, W = M_p.shape
                Wby2, Hby2 = W // 2, H // 2
                tform_trans = torch.Tensor(
                    [[agent_map_pose[0] - Wby2, agent_map_pose[1] - Hby2, 0]])
                crop_center = torch.Tensor(
                    [[W / 2.0, H / 2.0]]) + tform_trans[:, :2]
                # Crop out the appropriate size of the map
                # _, N, H, W = tensor_M_p.shape
                # map_center = torch.Tensor([[W / 2.0, H / 2.0]])
                map_size = int(
                    cfg.PRED.PARTIAL_MAP.OUTPUT_MAP_SIZE / cfg.SEM_MAP.CELL_SIZE)
                tensor_M_p = crop_map(
                    tensor_M_p, crop_center, map_size, 'nearest')
                tensor_U_a = crop_map(
                    tensor_U_a, crop_center, map_size, 'nearest')
                tensor_U_d = crop_map(
                    tensor_U_d, crop_center, map_size, 'nearest')

            # change back to numpy
            M_p = tensor_M_p.squeeze(0).numpy()
            U_a = tensor_U_a.squeeze(0).squeeze(0).numpy()
            # print(f'tensor_U_d.shape = {tensor_U_d.shape}')
            U_d = tensor_U_d.squeeze(0).numpy().transpose((1, 2, 0))

            # print(f'end M_p.shape = {M_p.shape}')
            # print(f'end U_a.shape = {U_a.shape}')
            # print(f'end U_d.shape = {U_d.shape}')

            # rotate
            if False:
                if self.split == 'train':
                    rot = random.choice((0, 45, 90, 135, 180, -45, -90, -135))
                    for i in range(M_p.shape[0]):
                        M_p[i] = ndimage.rotate(
                            M_p[i], rot, order=1, reshape=False)
                        M_p[i] = np.where(M_p[i] < 1, 0, M_p[i])
                    U_a = ndimage.rotate(U_a, rot, order=1, reshape=False)
                    U_a = np.where(U_a < 1, 0, U_a)
                    for i in range(U_d.shape[2]):
                        U_d[:, :, i] = ndimage.rotate(
                            U_d[:, :, i], rot, order=1, reshape=False)
                        U_d[:, :, i] = np.where(
                            U_d[:, :, i] < 1, 0, U_d[:, :, i])

            # =================================== visualize M_p =========================================
            if cfg.PRED.PARTIAL_MAP.FLAG_VISUALIZE_PRED_LABELS:
                occ_map_Mp = M_p[0]
                sem_map_Mp = M_p[1]
                color_sem_map_Mp = apply_color_to_map(sem_map_Mp)

                fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(20, 30))
                ax[0][0].imshow(occ_map_Mp, cmap='gray')
                ax[0][0].get_xaxis().set_visible(False)
                ax[0][0].get_yaxis().set_visible(False)
                ax[0][0].set_title('input: occupancy_map_Mp')
                ax[0][1].imshow(color_sem_map_Mp)
                ax[0][1].get_xaxis().set_visible(False)
                ax[0][1].get_yaxis().set_visible(False)
                ax[0][1].set_title('input: semantic_map_Mp')

                ax[1][0].imshow(occ_map_Mp, cmap='gray')
                '''
				x_coord_lst = [path[i][1] for i in range(i_loc+1)]
				z_coord_lst = [path[i][0] for i in range(i_loc+1)]
				ax[1][0].plot(x_coord_lst, z_coord_lst, lw=5, c='blue', zorder=3)
				for f in frontiers:
					ax[1][0].scatter(f.points[1], f.points[0], c='yellow', zorder=2)
					ax[1][0].scatter(f.centroid[1], f.centroid[0], c='red', zorder=2)
				'''
                ax[1][0].get_xaxis().set_visible(False)
                ax[1][0].get_yaxis().set_visible(False)
                ax[1][0].set_title('observed_occ_map + frontiers')

                ax[1][1].imshow(U_a, vmin=0.0)
                ax[1][1].get_xaxis().set_visible(False)
                ax[1][1].get_yaxis().set_visible(False)
                ax[1][1].set_title('output: U_a')

                ax[2][0].imshow(U_d[:, :, 0], vmin=0.0)
                ax[2][0].get_xaxis().set_visible(False)
                ax[2][0].get_yaxis().set_visible(False)
                ax[2][0].set_title('output: U_d_0')

                ax[2][1].imshow(U_d[:, :, 1], vmin=0.0)
                ax[2][1].get_xaxis().set_visible(False)
                ax[2][1].get_yaxis().set_visible(False)
                ax[2][1].set_title('output: U_d_1')

                fig.tight_layout()
                plt.show()

            # =========================== save data =========================
            eps_data = {}
            eps_data['Mp'] = M_p.copy()
            eps_data['Ua'] = U_a.copy()
            eps_data['Ud'] = U_d.copy()

            sample_name = str(count_sample).zfill(len(str(num_samples)))
            # np.save(f'{self.scene_folder}/{sample_name}.npy', eps_data)
            # with open(f'{self.scene_folder}/{sample_name}.pkl', 'wb') as pk_file:
            # pickle.dump(obj=frontiers, file=pk_file)
            with bz2.BZ2File(f'{self.scene_folder}/{sample_name}.pbz2', 'w') as fp:
                cPickle.dump(
                    eps_data,
                    fp
                )

            # ===================================================================
            count_sample += 1

            if count_sample == num_samples:
                return

            # t3 = timer()
            # print(f't3 - t2 = {t3 - t2}')


def multi_run_wrapper(args):
    """ wrapper for multiprocessor """
    gen = Data_Gen_MP3D(args[0], args[1], saved_dir=args[2])
    gen.write_to_file(
        num_samples=cfg.PRED.PARTIAL_MAP.NUM_GENERATED_SAMPLES_PER_SCENE)


if __name__ == "__main__":
    cfg.merge_from_file('configs/exp_train_input_partial_map_occ_and_sem.yaml')
    cfg.freeze()

    SEED = cfg.GENERAL.RANDOM_SEED
    random.seed(SEED)
    np.random.seed(SEED)

    split = cfg.MAIN.SPLIT
    if split == 'train':
        scene_list = cfg.MAIN.TRAIN_SCENE_LIST
    elif split == 'val':
        scene_list = cfg.MAIN.VAL_SCENE_LIST
    elif split == 'test':
        scene_list = cfg.MAIN.TEST_SCENE_LIST

    output_folder = cfg.PRED.PARTIAL_MAP.GEN_SAMPLES_SAVED_FOLDER
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    split_folder = f'{output_folder}/{split}'
    if not os.path.exists(split_folder):
        os.mkdir(split_folder)

    if cfg.PRED.PARTIAL_MAP.multiprocessing == 'single':  # single process
        for scene in scene_list:
            gen = Data_Gen_MP3D(split, scene, saved_dir=split_folder)
            gen.write_to_file(
                num_samples=cfg.PRED.PARTIAL_MAP.NUM_GENERATED_SAMPLES_PER_SCENE)
    elif cfg.PRED.PARTIAL_MAP.multiprocessing == 'mp':
        with multiprocessing.Pool(processes=cfg.PRED.PARTIAL_MAP.NUM_PROCESS) as pool:
            args0 = [split for _ in range(len(scene_list))]
            args1 = [scene for scene in scene_list]
            args2 = [split_folder for _ in range(len(scene_list))]
            pool.map(multi_run_wrapper, list(zip(args0, args1, args2)))
            pool.close()
    elif cfg.PRED.PARTIAL_MAP.multiprocessing == 'mpi4y':
        from mpi4py.futures import MPIPoolExecutor
        args0 = [split for _ in range(len(scene_list))]
        args1 = [scene for scene in scene_list]
        args2 = [split_folder for _ in range(len(scene_list))]
        executor = MPIPoolExecutor()
        prime_sets = executor.map(
            multi_run_wrapper, list(zip(args0, args1, args2)))
        executor.shutdown()
