import matplotlib.pyplot as plt 
import h5py
import numpy as np 
import torch 
import imageio.v2 as iio
import imageio 
import pickle 
import tqdm 
import json 

from matplotlib.collections import LineCollection
import matplotlib.patches as patches
import matplotlib.transforms as transforms

RESULTS_DIR = "/yourfolderhere/final_figure_generation/"
import cv2 


def blend_images(img1, img2):
    # Split RGBA channels
    b1, g1, r1, a1 = cv2.split(img1.astype(float) / 255)
    b2, g2, r2, a2 = cv2.split(img2.astype(float) / 255)

    # Compute alpha blend factor
    alpha_blend = a1 + a2 * (1 - a1)

    # Blend each channel
    b = (b1 * a1 + b2 * a2 * (1 - a1)) / (alpha_blend + 1e-5)
    g = (g1 * a1 + g2 * a2 * (1 - a1)) / (alpha_blend + 1e-5)
    r = (r1 * a1 + r2 * a2 * (1 - a1)) / (alpha_blend + 1e-5)

    # Scale back and merge
    blended = cv2.merge([b, g, r, alpha_blend]) * 255
    blended = blended.astype(np.uint8)

    return blended

def composite_on_white(img):
    # Ensure image has an alpha channel
    if img.shape[2] < 4:
        raise ValueError("Image must have an alpha channel (RGBA)")

    # Split channels
    b, g, r, a = cv2.split(img.astype(float) / 255)  # Normalize to [0,1]

    # Create a white background
    white_bg = np.ones((img.shape[0], img.shape[1], 3), dtype=np.float32)

    # Composite: Alpha blending formula
    b = (1 - a) * white_bg[:, :, 0] + a * b
    g = (1 - a) * white_bg[:, :, 1] + a * g
    r = (1 - a) * white_bg[:, :, 2] + a * r

    # Merge back to an RGB image
    composite_img = cv2.merge([b, g, r]) * 255  # Scale back to [0,255]
    return composite_img.astype(np.uint8)

def random_rotate(rect, x, y, width, height, ax):
        # Get the center of the rectangle
    angle = 90 * np.random.random()
    center_x = x + width / 2
    center_y = y + height / 2

    # Create a rotation transform around the center
    rotate = transforms.Affine2D().rotate_deg_around(center_x, center_y, angle)

    # Apply the transform to the rectangle
    rect.set_transform(rotate + ax.transData)
    return rect 

def generate_trajectory_traces_and_rerender(name, hdf5_name, representative_example = False):
    # find the location of the cubes and actually just re-render so it doesn't feel pixelated 
                # for i in range(self.num_cubes):
                # cube_pos = obs["states"][2 * i : 2 * i + 2]
                # if np.linalg.norm(cube_pos - waypoint) > threshold:
                #     # print(f"\t Took {counter} iterations.")
                #     return waypoint 
    dataset = h5py.File(hdf5_name, 'r')
    data_grp = dataset["data"]
    plt.tight_layout()
    average_image = None 

    fig, ax = plt.subplots()
    fig.set_size_inches(5, 5, forward=True)

    # for demo in tqdm.tqdm(data_grp.keys()):
    state_average = list()
    for demo in tqdm.tqdm(data_grp.keys()):
        if data_grp[demo]["obs"]["agent_pos"].shape[0] < 20:
            continue 
        positions = data_grp[demo]["obs"]["agent_pos"][:, -1]
        states = data_grp[demo]["obs"]["states"][:, -1][0]
        blue_loc = states[0:2]
        red_loc = states[2:4]
        green_loc = states[4:6]
        yellow_loc = states[6:8]
        state_average.append(states)
        # ["Blue", "Red", "Green", "Yellow"]
        cube_size = 0.15
        alpha = 0.2 if not representative_example else 0.05 #0.01
        rect = patches.Rectangle((blue_loc[0], blue_loc[1]), cube_size, cube_size, facecolor='blue', alpha=alpha) #, angle = 90 * np.random.random())
        ax.add_patch(random_rotate(rect, blue_loc[0], blue_loc[1], cube_size, cube_size, ax))

        rect = patches.Rectangle((red_loc[0], red_loc[1]), cube_size, cube_size, facecolor='red', alpha=alpha) #, angle = 90 * np.random.random())
        ax.add_patch(random_rotate(rect, red_loc[0], red_loc[1], cube_size, cube_size, ax))

        rect = patches.Rectangle((green_loc[0], green_loc[1]), cube_size, cube_size, facecolor='green', alpha=alpha)#, angle = 90 * np.random.random())
        ax.add_patch(random_rotate(rect, green_loc[0], green_loc[1], cube_size, cube_size, ax))

        rect = patches.Rectangle((yellow_loc[0], yellow_loc[1]), cube_size, cube_size, facecolor="#AAAA00", alpha=alpha) #, angle = 90 * np.random.random())
        ax.add_patch(random_rotate(rect, yellow_loc[0], yellow_loc[1], cube_size, cube_size, ax))

        ax.add_patch(rect)
        rect = patches.Circle((positions[0][0], positions[0][1]), 0.05, facecolor='black', alpha=0.2)
        ax.add_patch(rect)
        # if positions[-1][0] > 0.2:
        #     continue 
        # if np.random.random() > 0.7: # random thinning of plot 
        #     continue 
        # this is for the paths 
        norm = plt.Normalize(0, positions.shape[0])
        # cmap = plt.get_cmap("viridis")  # 'plasma' colormap transitions from purple to yellow
        cmap = plt.get_cmap("plasma")  # 'plasma' colormap transitions from purple to yellow

        points = positions[:, np.newaxis, :] #reshape(-1, 1, 2)
        # points[:, :, 1] += 0.05
        points[:, :, 1] += 0.1
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lc = LineCollection(segments, cmap=cmap, norm=norm, linewidth=2, alpha = 0.7, zorder = 2)
        lc.set_array(np.linspace(0, positions.shape[0], positions.shape[0]-1))
        ax.add_collection(lc)
        ax.scatter(positions[-1, 0:1], positions[-1, 1:], color = "black", s = 5, zorder = 3, alpha = 0.7)

    if representative_example:
        # this will render average position of cubes at a far higher opacity for visualization purposes 
        states = np.mean(np.stack(state_average, axis = 0), axis = 0)
        blue_loc = states[0:2]
        red_loc = states[2:4]
        green_loc = states[4:6]
        yellow_loc = states[6:8]
        state_average.append(states)
        # ["Blue", "Red", "Green", "Yellow"]
        cube_size = 0.15
        alpha = 0.8
        rect = patches.Rectangle((blue_loc[0], blue_loc[1]), cube_size, cube_size, facecolor='blue', alpha=alpha) #, angle = 90 * np.random.random())
        ax.add_patch(random_rotate(rect, blue_loc[0], blue_loc[1], cube_size, cube_size, ax))

        rect = patches.Rectangle((red_loc[0], red_loc[1]), cube_size, cube_size, facecolor='red', alpha=alpha) #, angle = 90 * np.random.random())
        ax.add_patch(random_rotate(rect, red_loc[0], red_loc[1], cube_size, cube_size, ax))

        rect = patches.Rectangle((green_loc[0], green_loc[1]), cube_size, cube_size, facecolor='green', alpha=alpha)#, angle = 90 * np.random.random())
        ax.add_patch(random_rotate(rect, green_loc[0], green_loc[1], cube_size, cube_size, ax))

        rect = patches.Rectangle((yellow_loc[0], yellow_loc[1]), cube_size, cube_size, facecolor="#AAAA00", alpha=alpha) #, angle = 90 * np.random.random())
        ax.add_patch(random_rotate(rect, yellow_loc[0], yellow_loc[1], cube_size, cube_size, ax))

    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
        # plt.axis("off")
    plt.savefig(f"{RESULTS_DIR}/{name}.pdf", dpi = 300, transparent = True)
    plt.savefig(f"{RESULTS_DIR}/{name}.png", dpi = 300, transparent = True)
    plt.close()



def generate_trajectory_traces(name, hdf5_name):
    # SUGGESTIONs:
#     Transparency/color gradient based on time 
    # A little marker at the end 
    # Lines thicker [done]
    # Reduce transparency of the cubes [done]
    # Highlight one trajectory 


    # only works when there isn't randomness 
    dataset = h5py.File(hdf5_name, 'r')
    data_grp = dataset["data"]
    plt.tight_layout()
    average_image = None 

    fig, ax = plt.subplots()

    # for demo in tqdm.tqdm(data_grp.keys()):
    for demo in tqdm.tqdm(data_grp.keys()):
        positions = data_grp[demo]["obs"]["agent_pos"][:, -1]

        start_image =data_grp[demo]["obs"]["image"][0, -1] #.astype(np.float32) #/ 255.0
        
        # if average_image is None:
        #     average_image = start_image 
        # else:
        #     average_image = cv2.addWeighted(average_image, 0.7, start_image, 0.3, 0)
        start_image_alpha = cv2.cvtColor(start_image, cv2.COLOR_BGR2BGRA)
        gray_white = cv2.cvtColor(start_image, cv2.COLOR_BGR2GRAY)
        mask = cv2.threshold(gray_white, 180, 255, cv2.THRESH_BINARY)[1]  # Adjust threshold
        start_image_alpha[:, :, 3][mask == 255] = 0
        start_image_alpha[:, :, 3][mask != 255] = 50
        # Invert the mask so white areas become transparent
        # mask_inv = cv2.bitwise_not(mask)

        if average_image is None:
            average_image = start_image_alpha 
        else:
            average_image = blend_images(average_image, start_image_alpha)
        
        # plt.scatter(positions[:, 0], positions[:, 1], color = "black", alpha = 0.3, s = 3, zorder = 2)
        # plt.plot(positions[:, 0], positions[:, 1], color = "black", alpha = 0.3, zorder = 2) #, s = 3, zorder = 2)
        norm = plt.Normalize(0, positions.shape[0])
        cmap = plt.get_cmap("plasma")  # 'plasma' colormap transitions from purple to yellow

        # Create segments for a continuous color transition
        points = positions[:, np.newaxis, :] #reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)

        # Create the line collection with the gradient
        lc = LineCollection(segments, cmap=cmap, norm=norm, linewidth=2, alpha = 0.7, zorder = 2)
        lc.set_array(np.linspace(0, positions.shape[0], positions.shape[0]-1))

        ax.add_collection(lc)
        ax.scatter(positions[-1, 0:1], positions[-1, 1:], color = "black", s = 5, zorder = 3, alpha = 0.7)
        # ax.scatter([0],[0], color = "black", s = 20, zorder= 2)
        # ax.autoscale()


    # blended_image = average_image / len(data_grp.keys())
    # blended_image = average_image / 255
    blended_image = composite_on_white(average_image)
    ax.imshow(np.flipud(blended_image), extent = [-1, 1, -1, 1], zorder = 1)
    # plt.title("Cube Touch Traces")
   
    # plt.axis("off")
    plt.savefig(f"{RESULTS_DIR}/{name}.pdf", dpi = 300, transparent = True)
    plt.close()

        
        # cubes_pos = data_grp[demo]["obs"]["states"][:, -1]
    pass 


def plot_cubes_color(name, hdf5_name, target):
    # plots the scatter of all of the cubes in the trial and highlights the pressed cube in red (others are in gray)
    dataset = h5py.File(hdf5_name, 'r')
    data_grp = dataset["data"]
    total = len(data_grp.keys())
    x_list = list()
    y_list = list()
    color_list = list()
    successes = 0 
    target_count = 0
    near_target = 0 

    color_list = ["Blue", "Red", "Green", "Yellow"]
    cube_touch_distr = {k : 0 for k in color_list}

    for demo in tqdm.tqdm(data_grp.keys()):
        positions = data_grp[demo]["obs"]["agent_pos"][:, -1]
        cubes_pos = data_grp[demo]["obs"]["states"][:, -1]
        cubes_pos = np.reshape(cubes_pos, (cubes_pos.shape[0], 4, 2))
    
        if data_grp[demo]["rewards"][-1] < 0.5: #don't include bad touches 
            print("Rejected fail!")
            continue

        successes += 1 
        last_position = positions[-1]
        last_cube_position = cubes_pos[-1]
        distances = np.linalg.norm(last_cube_position - last_position, axis = 1)
        closest_cube = np.argmin(distances)
        
        cube_touch_distr[color_list[closest_cube]] += 1 

        if target[closest_cube] == 1:
            target_count += 1 
        
        close = False
        for element in range(4):
            if target[element] == 1 and distances[element] < 0.35:
                color_list.append("blue" if closest_cube == 0 else "gray") # show partial success 
                near_target += 1 
                close = True 
                break 
        
        if not close:
            color_list.append("blue" if closest_cube == 0 else "black")

        x_list.append(last_cube_position[closest_cube, 0])
        y_list.append(last_cube_position[closest_cube, 1])

    cube_touch_distr = {k: v / successes for k, v in cube_touch_distr.items()}


    # Create a bar chart
    colors = ["black" if v == 0 else "green" for v in target]
    plt.bar(cube_touch_distr.keys(), cube_touch_distr.values(), color=colors)
    plt.title(name)
    plt.savefig(f"{RESULTS_DIR}/CubeTouch_{name}.png")

    with open(f"{RESULTS_DIR}/CubeTouch_{name}.json", "w") as f:
        json.dump({"hdf5_name" : hdf5_name, "target": target.tolist(), "target rate": target_count / successes, "near target rate" : near_target / successes, 
                   "success rate": successes / len(data_grp.keys()), "distr" : cube_touch_distr}, f, indent = 2)
    plt.close()
    return cube_touch_distr


# CHERRY-PICKED VISUAL RESULTS
name = "EarlyDecision_Control"
# name = "LateDecision_Control"
# name = "LateDecision_0_5"
# name = "EarlyDecision_0_5"
# name = "Avoid_Wall_Control"
# name = "Avoid_Wall_0_5"
seeds = 6 
for seed in range(seeds):
    hdf5_name = f"/yourfolderhere/FINAL_EXPERIMENTS/Pymunk_Qualitative/{name}_{seed}/{name}_{seed}.hdf5"
    try: 
        # generate_trajectory_traces_and_rerender(f"{name}_{seed}", hdf5_name) #, representative_example = True)
        generate_trajectory_traces_and_rerender(f"{name}_{seed}", hdf5_name, representative_example = True)

        # target = np.array([1, 0, 0, 0])
    except:
        print("error with ", hdf5_name)