import numpy as np
import torch
import torch.nn.functional as F

import os
try:
    from llava import conversation as conversation_lib
except:
    import os
    import sys
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
    from llava import conversation as conversation_lib
import glob
import json
import random
from torch.utils.data import DataLoader
DEFAULT_POINT_TOKEN = "<point>"
DEFAULT_POINT_PATCH_TOKEN = "<pt_patch>"
DEFAULT_PT_START_TOKEN = "<pt_start>"
DEFAULT_PT_END_TOKEN = "<pt_end>"

LONG_QUESTION_LIST = [
    DEFAULT_POINT_TOKEN + "\n" + "{sent}",
]

ANSWER_LIST = ["{sent}"]

def get_info_from_json(json_path):
        try:
            with open(json_path, "r") as r:
                anno = json.loads(r.read())
        except:
            with open(json_path, "r", encoding="cp1252") as r:
                anno = json.loads(r.read())
        comments = anno["question"]
        answer = anno["answer"]
        return comments, answer

def pc_normalize(pc):
        centroid = np.mean(pc, axis=0, keepdims=True)
        pc = pc - centroid
        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
        pc = pc / m
        return pc

def read_text_file(file_path):
    lines_list = []

    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                lines_list.append(line.strip())
    except FileNotFoundError:
        print(f"The file {file_path} was not found.")
    except Exception as e:
        print(f"An error occurred: {e}")

    return lines_list

class ReasonSegDataset(torch.utils.data.Dataset):
    def __init__(self,
                 tokenizer=None,
                 samples_per_epoch=None,
                 run_type = "train",
                 if_infer = False,
                 json_path = "./playground/data/urdf/train_test_txt",
                 json_questions_path = "./playground/data/urdf/json_questions",
                 point_path = "./playground/data/urdf/split_same_obj2",
                 idx_file = "./playground/data/urdf/all_5.txt",
                 predict_type = "seg"
                 ):
        self.samples_per_epoch = samples_per_epoch
        self.tokenizer = tokenizer
        self.long_question_list = LONG_QUESTION_LIST
        self.answer_list = ANSWER_LIST
        self.json_path = json_path
        self.json_questions_path = json_questions_path
        self.point_path = point_path
        self.predict_type = predict_type
        self.if_infer = if_infer
        
        self.seg_label_list = ["alarm_ring", "ball", "board", "bottle_body", "box_body",
            "bucket_body", "button", "camera_body", "cap", "cart_body",
            "caster", "chair_leg", "circle", "clock_body", "coffee_machine_body",
            "connector", "container", "cover_lid", "dishwasher_body", "dispenser_body",
            "display_base", "door", "door_frame", "drawer", "fan_frame",
            "fastener", "fastener_connector", "faucet_base", "foot_pad", "furniture_body",
            "glasses_body", "globe_frame", "hand", "handle", "head",
            "kettle_body", "key", "keyboard_base", "knife_body", "knob",
            "lamp_base", "laptop_base", "leg", "lens", "lever",
            "lid", "lighter_body", "lock", "microwave_body", "mouse_body",
            "nose", "oven_body", "pen_body", "phone_base", "portafilter",
            "pot_body", "pressing_lid", "printer_body", "pump_lid", "refrigerator_body",
            "remote_base", "rotation_bar", "rotation_blade", "rotation_body", "rotation_button",
            "rotation_container", "rotation_door", "rotation_handle", "rotation_lid", "rotation_screen",
            "rotation_slider", "rotation_tray", "rotation_window", "rotor", "safe_body",
            "screen", "seat", "shelf", "slider", "slot",
            "sphere", "spout", "stapler_base", "stapler_body", "steering_wheel",
            "stem", "suitcase_body", "switch", "switch_frame", "tilt_leg",
            "toaster_body", "toggle_button", "toilet_body", "translation_bar", "translation_blade",
            "translation_door", "translation_handle", "translation_lid", "translation_screen", "translation_tray",
            "translation_window", "trashcan_body", "usb_body", "usb_rotation", "washing_machine_body",
            "wheel", "window_frame"
        ]
        
        list_json_file_path = os.path.join(self.json_path, f'json_{run_type}_{self.predict_type}.txt')
        list_txt_file_path = os.path.join(self.json_path, f'point_{run_type}_{self.predict_type}.txt')

        with open(list_json_file_path, 'r') as f:
            idxes = [line.strip().split('/')[-1].split('.')[0] for line in f if line.strip()]
            self.idxes = idxes
        with open(list_txt_file_path, 'r') as f:
            idxes = [line.strip().split('/')[-1].split('.')[0] for line in f if line.strip()]
            self.idxes = [idx for idx in self.idxes if idx in idxes]

        if idx_file is not None:
            idxes = [x.strip() for x in open(idx_file, 'r').read().split('\n')]
            self.idxes = [idx for idx in self.idxes if idx in idxes]
        
        self.num_samples = len(self.idxes)
        self.run_type = run_type
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        idx_str = self.idxes[idx]
        json_path = self.json_questions_path + '/' + self.predict_type + '/' + idx_str + '.json'
        txt_path = self.point_path + '/' + idx_str + '.txt'
        
        with open(json_path, 'r') as f:
            data = json.load(f)
                
        if self.predict_type != "seg":
            seg_types = ["slot"]
        else:
            seg_types = []
            for seg_type in data["point_cloud"]:
                seg_types.append(data["point_cloud"][seg_type])
                    
        
        points, colors, seg_label = self.extract_point_file(txt_path) 
        points = pc_normalize(points) 
        seg_idx = self.get_seg_idx(seg_types)
                
        sent, answer_raw = get_info_from_json(json_path)

        questions = [] 
        answers = []    

        questions.append(self.long_question_list[0].format(sent=sent))
        answers.append(self.answer_list[0].format(sent=answer_raw))

        i = 0
        all_conversations_structured = []
        while i < len(questions):
            structured_conv = [
                {"from": "human", "value": questions[i]},
                {"from": "gpt", "value": answers[i]}
            ]
            all_conversations_structured.append(structured_conv)
            i += 1
            
        sampled_sents = ['','','']
        assert seg_label.shape[0] == len(answers[0].split("[SEG]")) - 1, json_path + "\n" + txt_path
        
        return {
            "points": torch.from_numpy(points).float(),
            "rgb": torch.from_numpy(colors).float(), 
            "conversations": all_conversations_structured,
            "questions": questions,
            "seg_types": seg_types,
            "seg_label": torch.from_numpy(seg_label).float(), 
            "logist_label": torch.tensor(seg_idx),
            "json_path": json_path,
        }

    def get_seg_idx(self, seg_type):
        return torch.arange(len(seg_type))
        
    def get_seg_label(self, seg_type, label):
        
        labels = []
        indexs = []
        
        for i in range(len(seg_type)):
            index = self.seg_label_list.index(seg_type[i])
            label_raw = label[:, index]
            indexs.append(index)
            labels.append(label_raw)
        
        return torch.tensor(np.array(labels)), torch.tensor(np.array(indexs))
    
    def extract_point_file(self, path):
        with open(path,'r') as f:
            coordinates = []
            lines = f.readlines()
        for line in lines:
            line = line.strip('\n')
            line = line.strip(' ')
            data = line.split(' ')
            coordinate = [float(x) for x in data[2:]]
            coordinates.append(coordinate)
        data_array = np.array(coordinates)
        points_coordinates = data_array[:, 0:3] 
        colors = data_array[:, 3:6] 
        seg_label = data_array[: , 6:] 

        return points_coordinates, colors, seg_label.T

if __name__ == "__main__":
    predict_types = [
        "seg",
    ]
    for predict_type in predict_types:
        train_dataset = ReasonSegDataset( 
            predict_type= predict_type,
            run_type = "train",
            if_infer=True
        )
        for idx in range(len(train_dataset)):
            example = train_dataset[idx]
            text = example[3]
            label = example[4]
            json_path = example[-1]
            assert len(text.split('[SEG]')) -1 == label.shape[0], json_path