import os
import pickle
from tqdm import tqdm

info = "/home/htxue/data/mit/visual_dynamics/pixel-nerf/datasets/dynamics_new/shake_extra_new/box_info.txt"
f = open(info)

water_list = []
r_list = []
g_list = []
y_list = []


bar = 27

invalid_traj = []

for i, lines in tqdm(enumerate(f)):




    line_info = lines.split()
    line_info = [int(i) for i in line_info]
    traj_id, frame_id, water, r, g, y = line_info

    if frame_id > 0:
        continue




    if frame_id == 0:
        v_red = v_green = v_yellow = 0

        v_list = [v_red, v_green, v_yellow]

        color_dir = {
            '010': 'green',
            '100': 'red',
            '110': 'yellow'
        }


        def assign_color(name, v_list):
            if name == 'red':
                v_list[0] = 1
            elif name == 'green':
                v_list[1] = 1
            elif name == 'yellow':
                v_list[2] = 1


        scene_info = pickle.load(open('/home/htxue/datasets/data_FluidShakeExtra_new/{}/info.p'.format(traj_id), 'rb'))[
            'scene_params']

        if len(scene_info) == 41:
            v_list[0] = v_list[1] = v_list[2] = 1

        elif len(scene_info) == 31:
            color1 = str(int(scene_info[-5])) + str(int(scene_info[-4])) + str(int(scene_info[-3]))
            color2 = str(int(scene_info[-15])) + str(int(scene_info[-14])) + str(int(scene_info[-13]))
            assign_color(color_dir[color1], v_list)
            assign_color(color_dir[color2], v_list)

        elif len(scene_info) == 21:
            color1 = str(int(scene_info[-5])) + str(int(scene_info[-4])) + str(int(scene_info[-3]))
            assign_color(color_dir[color1], v_list)

        [v_red, v_green, v_yellow] = v_list


    water_list.append(water)
    if v_red:
        r_list.append(r)
        if r < bar:
            invalid_traj.append(traj_id)

            continue
    if v_green:
        g_list.append(g)
        if g < bar:
            invalid_traj.append(traj_id)

            continue
    if v_yellow:
        y_list.append(y)
        if y < bar:
            invalid_traj.append(traj_id)

            continue

print(water_list)
print(g_list)
print(y_list)

import matplotlib.pylab as plt

plt.hist(water_list, bins=50)
plt.title('water_particle distribution')
plt.savefig('./water_hist.png')
plt.close()

plt.hist(r_list, bins=50)
plt.title('red_particle distribution')
plt.savefig('./red_hist.png')
plt.close()

plt.hist(g_list, bins=50)
plt.title('green_particle distribution')
plt.savefig('./green_hist.png')
plt.close()

plt.hist(y_list, bins=50)
plt.title('yellow_particle distribution')
plt.savefig('./yellow_hist.png')
plt.close()

plt.hist(r_list, bins=50, alpha=0.5, color='red')
plt.hist(g_list, bins=50, alpha=0.5, color='green')
plt.hist(y_list, bins=50, alpha=0.5, color='yellow')
plt.savefig('./box_hist.png')