import os
import copy
import json
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.data.separate import separate
    
FIXED = 0
MOVABLE = 1

class GRNDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, mode=None, args=None):
        self.args = args
        self.mode = mode
        super().__init__(root, transform, pre_transform, pre_filter)
        print(self.processed_paths)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return os.listdir(os.path.join(self.root, "scenes"))

    @property
    def processed_file_names(self):
        if self.mode == "train":
            return ["grouped_" + self.args.augmentations + "_data.pt"]
        else:
            return ["grouped_data.pt"]

    def process(self):
        gnn_dataset = pd.read_json(os.path.join(self.root, "data", "processed_gnn_data.json"))
        dataset = gnn_dataset.groupby("scene_id").agg(list)
        data_list = []
        for scene_name in tqdm(gnn_dataset.scene_id.unique()):
            datapoint = dataset.loc[scene_name]
            movable_indices = {obj: i for i, obj in enumerate(datapoint.object_id)}

            with open(os.path.join(self.root, "scenes", scene_name + ".json")) as f:
                scene = json.load(f)

            objects = list(scene["objects"].keys())
            indices = {obj: i for i, obj in enumerate(objects)}
            nodes = torch.zeros((len(objects), 7))
            movable_mask = torch.zeros((len(objects)), dtype = bool)
            frame_ids = torch.zeros((len(objects)), dtype = int)
            pos = torch.empty((0,4))
            edges = torch.empty(2, 0)
            proximity_mask = torch.empty(0, 1, dtype = bool)
            edge_features = torch.empty(0, 2)

            IK_labels = torch.zeros((len(objects), 5))
            GO_labels = torch.empty((0, 5))
            F_labels = torch.empty((len(objects),6))
            
            for i, obj in enumerate(objects):
                object_ = scene["objects"][obj]
                frame_id = object_["frame_id"]

                #===================================== Nodes =====================================
                if object_["fixed"]:
                    movable_mask[i] = FIXED
                else:
                    movable_mask[i] = MOVABLE

                node_features = object_["dimensions"] + object_["abs_pose"][:3] + [object_["abs_pose"][-1]]
                nodes[i] = torch.tensor(node_features).unsqueeze(0)
                if frame_id == "world" or frame_id == "odom_combined":
                    frame_ids[i] = -1
                else:
                    frame_ids[i] = indices[frame_id]

                pos = torch.cat((pos, torch.tensor(object_["abs_pose"][:3] + [object_["abs_pose"][-1]]).unsqueeze(0)), dim = 0)

                #===================================== Labels =====================================
                if object_["fixed"]:
                    IK_labels[i] = torch.tensor([-1, -1, -1, -1, -1])
                    F_labels[i] = torch.tensor([0, 0, 0, 0, 0, 0]).unsqueeze(0)
                else:
                    IK_labels[i] = torch.tensor([datapoint.Top_IK[movable_indices[obj]], datapoint.Front_IK[movable_indices[obj]], 
                                                 datapoint.Rear_IK[movable_indices[obj]], datapoint.Right_IK[movable_indices[obj]], 
                                                 datapoint.Left_IK[movable_indices[obj]]])
                    F_labels[i] = torch.tensor([datapoint.feasibility[movable_indices[obj]], datapoint.Top[movable_indices[obj]],
                                                datapoint.Front[movable_indices[obj]], datapoint.Rear[movable_indices[obj]],
                                                datapoint.Right[movable_indices[obj]], datapoint.Left[movable_indices[obj]]]).unsqueeze(0)

                #===================================== Edges =====================================
                if not object_["fixed"]:
                    edge = [indices[frame_id], indices[obj]]
                    edges = torch.cat((edges, torch.tensor(edge).unsqueeze(1)), dim=1)
                    proximity_mask = torch.cat((proximity_mask, torch.tensor([True]).unsqueeze(0)), dim = 0)
                    edge_features = torch.cat((edge_features, torch.tensor([1, 0]).unsqueeze(0)), dim = 0)
                    GO_features = torch.tensor([0., 0., 0., 0., 0.])
                    for g, grasp in enumerate(["Top", "Front", "Rear", "Right", "Left"]):
                        if frame_id in datapoint[grasp+"_GO"][movable_indices[obj]]:
                            GO_features[g] = datapoint[grasp+"_scores"][movable_indices[obj]][datapoint[grasp+"_GO"][movable_indices[obj]].index(frame_id)]/datapoint["Nb_"+grasp+"_grasps"][movable_indices[obj]]

                    GO_labels = torch.cat((GO_labels, GO_features.unsqueeze(0)), dim = 0)
                    
                    for neighbor in objects:
                        if neighbor == obj or neighbor == frame_id or (frame_id != "base" and neighbor == "base"):    
                            continue
                        if not self.is_neighbor(scene["objects"][obj]["dimensions"], scene["objects"][neighbor]["dimensions"], 
                                                scene["objects"][obj]["abs_pose"], scene["objects"][neighbor]["abs_pose"]):
                            continue
                        edge = [indices[neighbor], indices[obj]]
                        edges = torch.cat((edges, torch.tensor(edge).unsqueeze(1)), dim=1)
                        proximity_mask = torch.cat((proximity_mask, torch.tensor([True]).unsqueeze(0)), dim = 0)
                        edge_features = torch.cat((edge_features, torch.tensor([1, 0]).unsqueeze(0)), dim = 0)
                        GO_features = torch.tensor([0., 0., 0., 0., 0.])
                        for g, grasp in enumerate(["Top", "Front", "Rear", "Right", "Left"]):
                            if neighbor in datapoint[grasp+"_GO"][movable_indices[obj]]:
                                GO_features[g] = datapoint[grasp+"_scores"][movable_indices[obj]][datapoint[grasp+"_GO"][movable_indices[obj]].index(neighbor)]/datapoint["Nb_"+grasp+"_grasps"][movable_indices[obj]]
                            
                        GO_labels = torch.cat((GO_labels, GO_features.unsqueeze(0)), dim = 0)

            base_index = indices["base"]
            base_mask = torch.tensor([False if n != base_index else True for n in range(len(nodes))])
            data = Data(x = nodes, movable_mask = movable_mask, frame_ids = frame_ids, 
                        edge_index = edges.long(), proximity_mask = proximity_mask, edge_attr = edge_features.float(),
                        IK_labels = IK_labels, F_labels = F_labels, GO_labels = GO_labels, pos = pos, base_mask = base_mask,
                        scene = torch.tensor(int(scene_name.replace("scene_", ""))))

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data) 
            
            data_list.append(copy.deepcopy(data))
            #------------------------------------------------------------------------------------------    
            if self.mode == "train":
                if "dimswitch_all" in self.args.augmentations:
                    dimswitched = copy.deepcopy(data)
                    for _ in range(3):
                        dimswitched = self.dimswitch_fixed(dimswitched)
                        data_list.append(copy.deepcopy(dimswitched))

                    dimswitched = copy.deepcopy(data)
                    for _ in range(3):
                        dimswitched = self.dimswitch(dimswitched)
                        data_list.append(copy.deepcopy(dimswitched))

                        if "dimswitch_all" in self.args.augmentations:
                            dimswitched_fixed = copy.deepcopy(dimswitched)
                            for _ in range(3):
                                dimswitched_fixed = self.dimswitch_fixed(dimswitched_fixed)
                                data_list.append(copy.deepcopy(dimswitched_fixed))
                    
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        
    def compute_distance(self, pose1, pose2):
        return np.linalg.norm(np.array(pose1) - np.array(pose2))

    def compute_threshold(self, dim1, dim2):
        return (max(dim1) + max(dim2)+0.6) / 2
    
    def is_neighbor(self, dim1, dim2, pose1, pose2):
        distance = self.compute_distance(pose1[:2], pose2[:2])
        threshold = self.compute_threshold(dim1[:2], dim2[:2])
        if distance > threshold:
            return False
        else:
            return True
    
    def get(self, idx):
        data = separate(
            cls=self._data.__class__,
            batch=self._data,
            idx=idx,
            slice_dict=self.slices,
            decrement=False,
        )

        if self.args.IK_GO_mode == "gt":
            if "GO" in self.args.edge_features:
                data.edge_attr = torch.cat((data.edge_attr, data.GO_labels), dim = 1)
            else:
                data.edge_attr = torch.cat((data.edge_attr, torch.zeros(data.edge_attr.shape[0], 5)), dim = 1)
            
            mask = torch.where(data.movable_mask == MOVABLE)[0]
            edges = torch.cat((mask.unsqueeze(0), mask.unsqueeze(0)), dim = 0)
            data.edge_index = torch.cat((data.edge_index, edges), dim = 1)
            edge_features = torch.tensor([[0, 1] for _ in range(mask.shape[0])])
            if "IK" in self.args.edge_features:
                edge_features = torch.cat((edge_features, torch.ones(mask.shape[0], 5) - data.IK_labels[mask]), dim = 1)
            else:
                edge_features = torch.cat((edge_features, torch.zeros(mask.shape[0], 5)), dim = 1)
            data.edge_attr = torch.cat((data.edge_attr, edge_features), dim = 0)
            data.proximity_mask = torch.cat((data.proximity_mask, torch.zeros(mask.shape[0], dtype = bool)), dim = 0)
            data.GO_labels = torch.cat((data.GO_labels, torch.zeros(mask.shape[0], 5)), dim = 0)

            if "IK" not in self.args.edge_features and "GO" not in self.args.edge_features:
                data.edge_attr = data.edge_attr[:, :2]
                
        return self.scale(data)
    
    def scale(self, data):
        data.x[:, 6] = (data.x[:, 6] % (2*np.pi))
        return data
 
    def dimswitch(self, data):
        movable_indices = data.movable_mask == MOVABLE
        switched = copy.deepcopy(data)
        switched.x[movable_indices, 0], switched.x[movable_indices, 1] = data.x[movable_indices, 1], data.x[movable_indices, 0]
        switched.x[movable_indices, 6] = (data.x[movable_indices, 6] + np.pi/2) % (2*np.pi)
        comb = [(2, 5), (3, 4), (4, 2), (5, 3)]
        for i, j in comb:
            switched.F_labels[:, i] = data.F_labels[:, j]
            switched.IK_labels[:, i-1] = data.IK_labels[:, j-1]
            switched.GO_labels[:, i-1] = data.GO_labels[:, j-1]
        return switched
    
    def dimswitch_fixed(self, data):
        fixed_indices = torch.logical_and(data.movable_mask == FIXED, data.base_mask == False)
        switched = copy.deepcopy(data)
        switched.x[fixed_indices, 0], switched.x[fixed_indices, 1] = data.x[fixed_indices, 1], data.x[fixed_indices, 0]
        switched.x[fixed_indices, 6] = (data.x[fixed_indices, 6] + np.pi/2) % (2*np.pi)
        return switched

   
class GODataset(torch.utils.data.Dataset):
    def __init__(self, path, mode, args):
        self.mode = mode
        self.args = args
        self.inputs = torch.load(os.path.join(path, "data", "inputs.pt")).to(args.device)
        self.labels = torch.load(os.path.join(path, "data", "labels.pt")).to(args.device).float()
        self.masks = torch.load(os.path.join(path, "data", "masks.pt")).to(args.device).float()
        if self.mode == "train":
            if "dimswitch" in self.args.augmentations:
                self.inputs, self.labels, self.masks = self.dimswitch_all()

        self.scaled_inputs = self.scale()
        self.scaled_labels = self.labels
        if self.mode == "train":
            #shuffle the data
            indices = torch.randperm(self.scaled_inputs.shape[0])
            self.scaled_inputs = self.scaled_inputs[indices]
            self.scaled_labels = self.scaled_labels[indices]
        
    def __len__(self):
        return self.scaled_inputs.shape[0]
    
    def __getitem__(self, index):
        x = self.scaled_inputs[index]
        label = self.scaled_labels[index]
        mask = self.masks[index]
        return x, label, mask
    
    def scale(self):
        scaled_inputs = copy.deepcopy(self.inputs)
        scaled_inputs[:, 6] = (self.inputs[:, 6] % (2*np.pi))
        scaled_inputs[:, 13] = (self.inputs[:, 13] % (2*np.pi))
        return scaled_inputs

    def dimswitch(self, inputs, labels, masks, obj_idx):
        i = obj_idx - 1
        length, width = copy.deepcopy(inputs[:, 7*i+0]), copy.deepcopy(inputs[:, 7*i+1])
        inputs[:, 7*i+0], inputs[:, 7*i+1] = width, length
        inputs[:, 7*i+6] = (inputs[:, 7*i+6] + np.pi/2) % (2*np.pi)
        if obj_idx == 1:
            f, re, ri, l = copy.deepcopy(labels[:, 1]), copy.deepcopy(labels[:, 2]), copy.deepcopy(labels[:, 3]), copy.deepcopy(labels[:, 4])
            labels[:, 1], labels[:, 2], labels[:, 3], labels[:, 4] = l, ri, f, re
            mf, mre, mri, ml = copy.deepcopy(masks[:, 1]), copy.deepcopy(masks[:, 2]), copy.deepcopy(masks[:, 3]), copy.deepcopy(masks[:, 4])
            masks[:, 1], masks[:, 2], masks[:, 3], masks[:, 4] = ml, mri, mf, mre
        return inputs, labels, masks
    
    def dimswitch_all(self):
        augmented_inputs = copy.deepcopy(self.inputs)
        augmented_labels = copy.deepcopy(self.labels)
        augmented_masks = copy.deepcopy(self.masks)
        switched_inputs, switched_labels, switched_masks = copy.deepcopy(self.inputs), copy.deepcopy(self.labels), copy.deepcopy(self.masks)
        for i in range(3):
            switched_inputs, switched_labels, switched_masks = self.dimswitch(switched_inputs, switched_labels, switched_masks, 1)
            augmented_inputs = torch.cat((augmented_inputs, copy.deepcopy(switched_inputs)))
            augmented_labels = torch.cat((augmented_labels, copy.deepcopy(switched_labels)))
            augmented_masks = torch.cat((augmented_masks, copy.deepcopy(switched_masks)))
        return augmented_inputs, augmented_labels, augmented_masks
    
    
class IKDataset(torch.utils.data.Dataset):
    def __init__(self, path, mode, args):
        self.mode = mode
        self.args = args
        self.data = pd.read_json(os.path.join(path, "data", "processed_gnn_data.json"))
        inputs = torch.zeros((len(self.data), 7))
        labels = torch.zeros((len(self.data), 6))
        inputs[:, :3] = torch.tensor(self.data.dim.values.tolist())
        inputs[:, 3:] = torch.tensor(self.data.pose.values.tolist())
        labels = torch.tensor(self.data[["Top_IK", "Front_IK", "Rear_IK", "Right_IK", "Left_IK"]].values.tolist())
        self.inputs = inputs.to(args.device)
        self.labels = labels.float().to(args.device)
        if self.mode == "train":
            if "dimswitch" in self.args.augmentations:
                switched_inputs = copy.deepcopy(self.inputs)
                switched_labels = copy.deepcopy(self.labels)
                for i in range(3):
                    switched_inputs, switched_labels = self.dimswitch(switched_inputs, switched_labels)
                    self.inputs = torch.cat((self.inputs, copy.deepcopy(switched_inputs)))
                    self.labels = torch.cat((self.labels, copy.deepcopy(switched_labels)))

        self.scaled_inputs = self.scale().to(args.device)
        self.scaled_labels = self.labels.to(args.device)
        if self.mode == "train":
            indices = torch.randperm(self.scaled_inputs.shape[0])
            self.scaled_inputs = self.scaled_inputs[indices]
            self.scaled_labels = self.scaled_labels[indices]
                   
    def __len__(self):
        return self.scaled_inputs.shape[0]
    
    def __getitem__(self, index):
        x = self.scaled_inputs[index]
        label = self.scaled_labels[index]
        return x, label
    
    def get(self, index):
        x = self.scaled_inputs[index]
        label = self.scaled_labels[index]
        datapoint = self.data.iloc[index]
        return x, label, datapoint
    
    def scale(self):
        scaled_inputs = copy.deepcopy(self.inputs)
        scaled_inputs[:, 6] = (self.inputs[:, 6] % (2*np.pi))
        return scaled_inputs

    def dimswitch(self, inputs, labels):
        length, width = copy.deepcopy(inputs[:, 0]), copy.deepcopy(inputs[:, 1])
        inputs[:, 0], inputs[:, 1] = width, length
        inputs[:, 6] = (inputs[:, 6] + np.pi/2) % (2*np.pi)
        f, re, ri, l = copy.deepcopy(labels[:, 1]), copy.deepcopy(labels[:, 2]), copy.deepcopy(labels[:, 3]), copy.deepcopy(labels[:, 4])
        labels[:, 1], labels[:, 2], labels[:, 3], labels[:, 4] = l, ri, f, re
        return inputs, labels