import matplotlib.pyplot as plt
import numpy as np
import os
import wandb
import pickle

map_boundary_2d = [
    (10374, -7377), (7265, -7329), (5901, -6723), (4813, -4819), (4017, -4027), (2609, -3632), 
    (865, -3326), (-518, -3012), (-1440, -1870), (-1534, 1298), (-874, 2747), (971, 3326), 
    (2501, 3631), (3490, 3738), (4607, 4485), (5355, 6204), (6471, 7106), (10037, 7377), 
    (9936, 6781), (8838, 6416), (8737, 4530), (8722, 2114), (10093, 2072), (10079, -2071), 
    (8737, -2130), (8737, -4529), (8751, -4895), (8952, -6729), (9910, -6778), (10145, -7192)
]

floor_bounary_2d = {
    'Floor_1': [(8751, 4895), (8968, 6582), (9912, 6778), (10374, 7377), (7340, 7347), (6234, 6988), (5335, 6176), 
                (4851, 4919), (4520, 4405), (5238, 3599), (5794, 4026), (7575, 4836)],
    'Floor_2': [(8751, -4895), (9092, -6722), (9939, -6778), (10374, -7377), (7456, -7369), (6179, -6960), (5358, -6208),
                (4851, -4916), (4462, -4361), (5190, -3551), (5857, -4055), (7374, -4756)], 
    'Floor_4': [(4191, 2179), (4780, 3189), (3972, 3993), (3283, 3677), (2501, 3607), (2464, 2172), (3884, 1984)], 
    'Floor_5': [(1046, -3326), (865, -3326), (-518, -3012), (-1440, -1870), (-1534, 1298), (-874, 2747), (971, 3326), 
                (1192, 2505), (1460, 1362), (3064, 1362), (3158, 954), (4877, 972), (4880, -991), (3104, -939), 
                (3107, -1362), (1178, -1254), (1350, -2025), (1330, -2516)],
    'Floor_6': [(2501, -2021), (2504, -3632), (3207, -3659), (3971, -3993), (4822, -3231), (4415, -2779), (4187, -2179), (3884, -1935)], 
    'Floor_7': [(5105, -2900), (4682, -2802), (4377, -2242), (4350, -1687), (5333, -1635), (5356, -455), (4919, -411), (4919, 410), 
                (5321, 422), (5333, 1637), (4350, 1686), (4445, 2466), (4837, 2885), (5126, 2909), (5906, 3727), (7713, 4497), 
                (8737, 4530), (8732, 2095), (10093, 2071), (10093, -2071), (8791, -2071), (8737, -4530), (7765, -4514), 
                (5766, -3643)],
    'Floor_8': [(821, -317), (821, 773), (1097, 1228), (1690, 1489), (3394, 1519), (3390, 949), (3959, 949), (4004, 1519), 
                (4897, 1386), (4946, -1393), (4009, -1519), (4014, -949), (3454, -949), (3440, -1519), (1308, -1325), (821, -666)], 
}

def plot_map():
    plt.figure(figsize=(10,10)) # Optional: You can set the figure size
    xs, ys = zip(*map_boundary_2d)
    plt.plot(xs, ys, label='Map Boundary')
    plt.gca().invert_yaxis()
    for floor_name, floor_boundary in floor_bounary_2d.items():
        xs, ys = zip(*floor_boundary)
        plt.plot(xs + xs[0:1], ys + ys[0:1], label=floor_name)
    plt.axis('equal')
    plt.legend(loc='upper right') # Optional: To display a legend
    plt.xlabel('X Coordinate') # Optional: To label the X-axis
    plt.ylabel('Y Coordinate') # Optional: To label the Y-axis
    plt.title('Map and Floor Boundaries') # Optional: To add a title to your plot
    plt.grid(True) # Optional: To add a grid for better readability
    # plt.show()


def visualize(data, outputs, run_name, data_split, epoch, iteration, use_wandb=True):
    plot_path = f'plot/{run_name}/epoch-{epoch}/iter-{iteration}/{data_split}'
    os.makedirs(plot_path, exist_ok=True)
    
    location_types = data['location_types'].cpu().numpy()
    
    # Load map position ground truth and predictions
    map_position3d_inputs = data['map_positions'].cpu().numpy()
    map_position3d_pred = outputs['map_positions_pred'].detach().cpu().numpy()[..., :3]
    
    map_yaw_inputs = data['map_yaws'].cpu().numpy() * np.pi / 180
    map_yaw_inputs = np.stack([np.cos(map_yaw_inputs), np.sin(map_yaw_inputs)], axis=-1)
    map_yaw_pred = outputs['map_positions_pred'].detach().cpu().numpy()[..., 3:5]
    
    unknown_positions = (location_types != 1)
        
    # Only use the unknown part
    map_position3d_inputs_sample = map_position3d_inputs[unknown_positions]
    map_position3d_pred_sample = map_position3d_pred[unknown_positions]
    
    map_yaw_inputs_sample = map_yaw_inputs[unknown_positions]
    map_yaw_pred_sample = map_yaw_pred[unknown_positions] 
    
    # Load normalization data
    with open('../map-pretrain-data/data-split/dataset_stats.pkl', 'rb') as file:
        normalization_data = pickle.load(file)
    
    color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
    N = 8
    max_image_num = 4
    
    plt.clf()
    plot_map()
    
    for i, (position_truth, position_pred, direction_truth, direction_pred) in enumerate(zip(
        map_position3d_inputs_sample, map_position3d_pred_sample, map_yaw_inputs_sample, map_yaw_pred_sample
    )):
        position_truth = position_truth * normalization_data['position_std'] + normalization_data['position_mean']
        position_pred = position_pred * normalization_data['position_std'] + normalization_data['position_mean']
        plt.plot(position_truth[0], position_truth[1], 'o', color=color_cycle[i%N], markersize=4, alpha=1, label=f'truth-{i%N}')
        dx = direction_truth[0] * 400
        dy = direction_truth[1] * 400
        plt.arrow(position_truth[0], position_truth[1],
                    dx, dy, color=color_cycle[i%N], head_width=240, head_length=240, width=0.001, alpha=0.6)
        
        plt.plot(position_pred[0], position_pred[1], 's', color=color_cycle[i%N], markersize=4, alpha=1, label=f'pred-{i%N}')
        dx = direction_pred[0] * 400
        dy = direction_pred[1] * 400
        plt.arrow(position_pred[0], position_pred[1],
                    dx, dy, color=color_cycle[i%N], head_width=240, head_length=240, width=0.001, alpha=0.6)
        
        if (i + 1) % N == 0 or i + 1 == len(map_position3d_inputs_sample):
            index = (i + 1) // N - 1
            plt.legend()
            plt.gca().invert_yaxis()
            plt.savefig(f'{plot_path}/position_{index}.png', dpi=160)
            
            if index == max_image_num:
                break
            
            plt.clf()
            plot_map()
    
    plt.clf()
        
    if use_wandb:
        image_logs = {
            f'position_prediction ({data_split})': [wandb.Image(f'{plot_path}/position_{i}.png') for i in range(index + 1)]
        }
        wandb.log(image_logs, commit=False)

def polygon_area(coords):
    
    n = len(coords)
    area = 0
    for i in range(n):
        x1, y1 = coords[i]
        x2, y2 = coords[(i+1) % n]
        area += (x1 * y2) - (x2 * y1)
    area = abs(area) / 2
    return area

def calc_are():
    map_area = polygon_area(map_boundary_2d)

    floor_areas = {floor: polygon_area(coords) for floor, coords in floor_bounary_2d.items()}

    conversion_factor = 1 / 100

    area_conversion_factor = conversion_factor ** 2

    map_area_m2 = map_area * area_conversion_factor

    floor_areas_m2 = {floor: area * area_conversion_factor for floor, area in floor_areas.items()}

    sum_floor = 0.0
    for cur_floor in floor_areas_m2.keys():
        sum_floor += floor_areas_m2[cur_floor]

    print(f"Map area: {map_area_m2:.2f} m^2; All floor area: {sum_floor:.2f} m^2")

if __name__ == "__main__":
    plot_map()
    plt.show()


