import numpy as np
import torch
import torch.nn.functional as F

import os

import glob
import json

import pdb

RESULT_PATH = "./output/shapellm-7b-general-v1.0-lora/0505_1325seg/logs/epoch_3/pred/test/seg_reslut"
JSON_PATH = "./output/shapellm-7b-general-v1.0-lora/0505_1325seg/logs/epoch_3/pred/test/seg"
TXT_PATH = "./playground/data/urdf/point_clouds"

seg_label_list = ["alarm_ring", "ball", "board", "bottle_body", "box_body",
            "bucket_body", "button", "camera_body", "cap", "cart_body",
            "caster", "chair_leg", "circle", "clock_body", "coffee_machine_body",
            "connector", "container", "cover_lid", "dishwasher_body", "dispenser_body",
            "display_base", "door", "door_frame", "drawer", "fan_frame",
            "fastener", "fastener_connector", "faucet_base", "foot_pad", "furniture_body",
            "glasses_body", "globe_frame", "hand", "handle", "head",
            "kettle_body", "key", "keyboard_base", "knife_body", "knob",
            "lamp_base", "laptop_base", "leg", "lens", "lever",
            "lid", "lighter_body", "lock", "microwave_body", "mouse_body",
            "nose", "oven_body", "pen_body", "phone_base", "portafilter",
            "pot_body", "pressing_lid", "printer_body", "pump_lid", "refrigerator_body",
            "remote_base", "rotation_bar", "rotation_blade", "rotation_body", "rotation_button",
            "rotation_container", "rotation_door", "rotation_handle", "rotation_lid", "rotation_screen",
            "rotation_slider", "rotation_tray", "rotation_window", "rotor", "safe_body",
            "screen", "seat", "shelf", "slider", "slot",
            "sphere", "spout", "stapler_base", "stapler_body", "steering_wheel",
            "stem", "suitcase_body", "switch", "switch_frame", "tilt_leg",
            "toaster_body", "toggle_button", "toilet_body", "translation_bar", "translation_blade",
            "translation_door", "translation_handle", "translation_lid", "translation_screen", "translation_tray",
            "translation_window", "trashcan_body", "usb_body", "usb_rotation", "washing_machine_body",
            "wheel", "window_frame"
        ]
def load_result(file_path):
    data = np.loadtxt(file_path, dtype=int, ndmin=2)
    return data
def load_json(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

def extract_point_file(path):
    with open(path,'r') as f:
        coordinates = []
        lines = f.readlines()
    for line in lines:
        line = line.strip('\n')
        line = line.strip(' ')
        data = line.split(' ')
        coordinate = [float(x) for x in data[2:]]
        coordinates.append(coordinate)
    data_array = np.array(coordinates)
    seg_label = data_array[: , 6:] #(2048, 107)

    return seg_label

def get_seg_label(seg_type, label):
    
    labels = []
    indexs = []
    
    for i in range(len(seg_type)):
        index = seg_label_list.index(seg_type[i])
        label_raw = label[:, index]
        indexs.append(index)
        labels.append(label_raw)
    
    return torch.tensor(np.array(labels)), torch.tensor(np.array(indexs))

# def calculate_miou(gt, pred):
#     sum_miou = 0
#     for i in range(len(gt)):
#         gt_i = gt[i]
#         pred_i = pred[i]
        
#         intersection = np.sum(gt_i & pred_i)
#         union = np.sum(gt_i | pred_i)
        
#         miou = 1.0 * intersection / union
        
#         sum_miou += miou
        
#     return sum_miou * 1. / len(gt)
    
def calculate_miou(gt_list, pred_list, file_name, pred, gt):
    if len(gt_list) != len(pred_list):
        return 0.0

    sum_miou = 0
    total_samples = len(gt_list)

    if total_samples == 0:
        return 0.0 # Avoid division by zero for empty input

    for i in range(total_samples):
        gt_i = gt_list[i]
        pred_i = pred_list[i]

        if isinstance(gt_i, np.ndarray):
            gt_i = torch.from_numpy(gt_i)
        if isinstance(pred_i, np.ndarray):
            pred_i = torch.from_numpy(pred_i)

        gt_i_int = gt_i.int() 
        pred_i_int = pred_i.int() 

        intersection = torch.sum(gt_i_int & pred_i_int).item()
        union = torch.sum(gt_i_int | pred_i_int).item()

        if union == 0:
            miou = 1.0
        else:
            miou = intersection / union

        sum_miou += miou

    mean_miou = sum_miou / total_samples

    return mean_miou

import re

def find_link_seg_indices(input_string):
    results = []

    seg_pattern = r'\[SEG]'
    seg_matches = list(re.finditer(seg_pattern, input_string))
    seg_info = [(i, match.start()) for i, match in enumerate(seg_matches)]

    if not seg_info:
        return results

    link_pattern = r'link_\d+'
    link_matches = list(re.finditer(link_pattern, input_string))

    if not link_matches:
        return results

    for i, link_match in enumerate(link_matches):
        link_id_string = link_match.group(0) 
        link_pos = link_match.start()

        if i + 1 < len(link_matches):
            segment_end_pos = link_matches[i+1].start()
        else:
            segment_end_pos = len(input_string)

        for seg_index, seg_pos in seg_info:
            if seg_pos > link_pos and seg_pos < segment_end_pos:
                results.append((link_id_string, seg_index))

    return results

def reorganize_tensor_by_link(original_string, seg_link_map, input_tensor, threshold=0.9):
    if isinstance(input_tensor, np.ndarray):
        input_tensor = torch.from_numpy(input_tensor).float()
    elif not isinstance(input_tensor, torch.Tensor):
         raise TypeError("input_tensor must be a numpy.ndarray or torch.Tensor")

    if input_tensor.shape[0] == 0:
         return torch.empty((0, 2048), dtype=torch.int32), []


    binary_tensor = (input_tensor >= threshold).int()

    link_pattern = r'link_\d+'
    all_link_matches_in_string = re.finditer(link_pattern, original_string)
    all_unique_links_in_string = sorted(list(set(match.group(0) for match in all_link_matches_in_string)))

    if not all_unique_links_in_string:
         return torch.empty((0, binary_tensor.shape[1]), dtype=torch.int32), []


    link_to_output_index = {link_id: i for i, link_id in enumerate(all_unique_links_in_string)}
    num_links = len(all_unique_links_in_string)

    output_tensor = torch.zeros((num_links, binary_tensor.shape[1]), dtype=torch.int32)

    if seg_link_map:
        max_seg_index = max(seg_index for _, seg_index in seg_link_map)
        if max_seg_index >= binary_tensor.shape[0]:
             raise IndexError(f"Segment index {max_seg_index} found in seg_link_map "
                              f"is out of bounds for input tensor with {binary_tensor.shape[0]} rows.")


    for link_id_from_map, seg_index in seg_link_map:
        if link_id_from_map not in link_to_output_index:
            continue

        if seg_index < 0 or seg_index >= binary_tensor.shape[0]:
            continue

        link_output_index = link_to_output_index[link_id_from_map]
        output_tensor[link_output_index, :] = torch.max(output_tensor[link_output_index, :], binary_tensor[seg_index, :])

    return output_tensor, all_unique_links_in_string

def fix_result_data(raw_data, pred_seg):
    counts = find_link_seg_indices(pred_seg)
    
    output_tensor, unique_links = reorganize_tensor_by_link(pred_seg ,counts, raw_data)
    
    return output_tensor, unique_links

if __name__ == "__main__":
    sum_miou = 0
    txt_path_list = glob.glob(os.path.join(RESULT_PATH, "*.txt"))
    
    for file_name in os.listdir(RESULT_PATH):
        result_path = os.path.join(RESULT_PATH, file_name)
        result_data_raw = load_result(result_path)
        result_data_raw = result_data_raw.transpose()
        
        json_path = os.path.join(JSON_PATH, file_name.replace("_seg.txt", ".json"))
        json_data = load_json(json_path)
        seg_types = []
        for seg_type in json_data["point_cloud"]:
            if seg_type == "base":
                continue
            seg_types.append(json_data["point_cloud"][seg_type])
            
        txt_path = os.path.join(TXT_PATH, file_name.replace("_seg.txt", ".txt"))
        
        seg_all_label = extract_point_file(txt_path)
        
        seg_label, seg_idx = get_seg_label(seg_types, seg_all_label)
        
        result_data, links_str = fix_result_data(result_data_raw, json_data["pred_answers"])

        miou = calculate_miou(seg_label, result_data, file_name, json_data["pred_answers"], json_data["gt_answers"])
        
        sum_miou += miou
    
    sum_miou /= len(txt_path_list)
    print(f"{sum_miou}")
        
        
        
        
        
        