import matplotlib.pyplot as plt 
import argparse 
import torch 
import tqdm 
import numpy as np 

from embedder_models import FinalStatePredictionDino
from image_models import VAE, ResNet18Dec

from embedder_datasets import MultiviewDataset
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import random 
import cv2 
import imageio
import torchvision 
import os 

torch.set_printoptions(sci_mode=False)

# MAIN_CAMERA = "image"
# MAIN_CAMERA = "third_person"
MAIN_CAMERA = "camera0_rgb"

def generate_TSNE_points(embeddings, dims):
    tsne = TSNE(dims, verbose=1)
    tsne_proj = tsne.fit_transform(embeddings)
    return tsne_proj

def generate_PCA_points(embeddings, dims):
    pca = PCA(n_components=dims)
    pca.fit(embeddings)
    print(pca.explained_variance_ratio_)
    pca_proj = pca.transform(embeddings)
    return pca_proj, pca 

def prepare(data, device = "cuda"):
    if type(data) == dict:
        return {k : v.to(device).to(torch.float32) for k, v in data.items()}
    return data.to(device).to(torch.float32)

def prepare_np(data, device = "cuda"):
    if type(data) == dict:
        return {k : torch.tensor(v).to(device).to(torch.float32) for k, v in data.items()}
    return torch.tensor(data).to(device).to(torch.float32)

def calculate_simple_probe(features, labels):
    from sklearn.ensemble import GradientBoostingClassifier

    split = int(0.2 * labels.shape[0])
    train_features = features[:-split]
    test_features = features[-split:]

    train_labels = labels[:-split]
    test_labels = labels[-split:]


    # log = LogisticRegression()

    model = GradientBoostingClassifier(
        random_state=5, n_estimators=200, max_depth=10, learning_rate=0.2
    )

    model.fit(train_features, train_labels)
    bool_labels = model.predict(test_features)
    accuracy = np.sum(bool_labels == (test_labels > 0)) / test_features.shape[0]
    precision = np.sum(np.logical_and(bool_labels, test_labels > 0)) / np.sum(bool_labels)
    recall = np.sum(np.logical_and(bool_labels, test_labels > 0)) / np.sum(test_labels > 0)
    f1_score = (2 * precision * recall) / (precision + recall)
    print("Set Balance:", np.sum(test_labels) / np.size(test_labels), "Accuracy: ", accuracy.item(), "Precision: ", precision.item(), "Recall", recall.item(), "F1 Score", f1_score.item())


def plot_connected_traj(model, mixed_dataset, output_dir, step):
    embedding_list = list()
    end_embedding_list = list()
    demo_label = {}
    for i in tqdm.tqdm(range(len(mixed_dataset))):
    # for i in tqdm.tqdm(range(100)):
        sample = mixed_dataset.get_labeled_item(i, flatten_action = False)
        demo, idx_in_demo = mixed_dataset.parse_idx(i)

        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
       
        if demo not in demo_label:
            demo_label[demo] = label

        with torch.no_grad():
            embedding = model.state_embedding(state).detach().cpu().numpy() # gets the s, a embedding only
        embedding_list.append(embedding[0]) # right now only plotting one step ahead 
        if idx_in_demo == mixed_dataset.lengths_list[demo] - 1:
            end_embedding_list.append(embedding[0])

    projection = generate_TSNE_points(np.stack(embedding_list, axis = 0), dims = 2)
    end_projection = generate_TSNE_points(np.stack(end_embedding_list, axis = 0), dims = 2)

    curr = 0 
    # this is a cursed way of plotting everything but it works so I'm not touching it 
    start_list_x = list() 
    start_list_y = list()
    end_list_x = list()
    end_list_y = list() 
    color_list = list()
    for demo in range(len(mixed_dataset.lengths_list)):
        extracted_projection = projection[curr : curr + mixed_dataset.lengths_list[demo]]
        color = "blue" if demo_label[demo] == 1 else "green"
        plt.plot(extracted_projection[:, 0], extracted_projection[:, 1], color = color, linewidth= 1, alpha = 0.5)
        start_list_x.append(extracted_projection[0, 0])
        start_list_y.append(extracted_projection[0, 1])
        end_list_x.append(extracted_projection[-1, 0])
        end_list_y.append(extracted_projection[-1, 1])
        color_list.append(color)
        # plt.scatter(extracted_projection[0, 0], extracted_projection[0, 1], color = "black", s= 8)
        # plt.scatter(extracted_projection[-1, 0], extracted_projection[-1, 1], color = color, s= 8)
        curr += mixed_dataset.lengths_list[demo]
    plt.scatter(start_list_x, start_list_y, color = "black", s= 8)
    plt.scatter(end_list_x, end_list_y, color = color_list, s= 8)

    plt.title("Plotted Trajectories")
    plt.savefig(output_dir + str(step) + "_connected_traj.png", dpi = 300)
    plt.close()
    plt.scatter([end_projection[k, 0] for k in range(end_projection.shape[0])], [end_projection[k, 1] for k in range(end_projection.shape[0])], color = color_list, s= 8)

    plt.title("End Projection")
    plt.savefig(output_dir + str(step) + "_end_projection.png", dpi = 300)
    plt.close()

def jiggle_action(model, mixed_dataset, output_dir, step):
    # take a demonstration
    # for every state: sample 10 random actions for one step, get the next state, plot average variance 
    # tests the impact of actions in a unit ball and the smoothness of the embedding 
    # to ensure reasonability, we jiggle the action in the unit ball as set by the real actions 
    print("Jiggling Action!")
    demo = random.randint(0, len(mixed_dataset.lengths_list) - 1)
    start, end = mixed_dataset.get_bounds_of_demo(demo)
    var_list = list() 
    dist_list = list() 
    for i in tqdm.tqdm(range(start, end)):
        sample = mixed_dataset.get_labeled_item(i, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            batch_state = {MAIN_CAMERA : torch.tile(state[MAIN_CAMERA], (11, 1, 1, 1))} 
            action_mags = torch.mean(torch.abs(action), dim = 1)
            # take magnitude along each axis 
            # batch_action = 2 * (torch.rand((11, 16, 7), device = action.device) - 0.5)
            batch_action = 2 * (torch.rand((11, 16, 10), device = action.device) - 0.5)
            batch_action = batch_action * action_mags # scaling 
            
            batch_action[0] = action # first slot is the base 
            # batch_action[:, :, -1] = 1 # close the gripper

            z_hats = model.state_action_embedding(batch_state, batch_action, normalize =False).detach().cpu().numpy()
        
        base_z_hat = z_hats[0]
        other_z_hats = z_hats[1:]
        z_hat_variances = np.mean(np.std(other_z_hats, axis = 0))
        z_hat_distances = np.mean(np.abs(other_z_hats - base_z_hat))   
        var_list.append(z_hat_variances)
        dist_list.append(z_hat_distances)
    fig, axs = plt.subplots(2, 1)
    axs[0].plot(var_list)
    axs[0].set_title("Z hat noise variances (10 samples)")
    axs[1].plot(dist_list)
    axs[1].set_title("Z hat distance from accepted action (10 samples)")
    plt.tight_layout()
    plt.savefig(output_dir + str(step) + f"_action_jiggle_{demo}.png", dpi=300)
    plt.close()
    

def replay_through_reconstruction(model, mixed_dataset, output_dir, step):
    print("Replay through reconstruction")
    color_output = imageio.get_writer(output_dir + str(step) +  "reconstructions.mp4")
    # resizer = torchvision.transforms.Resize((200, 200))
    resizer = torchvision.transforms.Resize((224, 224))
    # resizer = torchvision.transforms.Resize((128, 128))

    for trial in tqdm.tqdm(range(10)):
        demo = random.randint(0, len(mixed_dataset.lengths_list) - 1)
        start, end = mixed_dataset.get_bounds_of_demo(demo)
        for i in range(start, end):
            sample = mixed_dataset.__getitem__(i)
            # demo, idx_in_demo = mixed_dataset.parse_idx(i)
            state, action, last_state = prepare_np(sample[0]), prepare_np(sample[1]), prepare_np(sample[2])
            state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
            last_state = {k : torch.unsqueeze(v, dim = 0) for k, v in last_state.items()}
            action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 

            with torch.no_grad():
                pred_last, reco = model(state, action)
                # predicted_last_state_embedding = model.state_action_embedding(state, action) # gets the s, a embedding only
                # reco = model.image_reconstruct(predicted_last_state_embedding)
                reco = torch.clip(resizer(reco), 0, 1)
            combined_frame = torch.concatenate((state[MAIN_CAMERA][0]  / 255, reco[0], last_state[MAIN_CAMERA][0]  / 255), dim = 1).detach().cpu().numpy()
            combined_frame = np.transpose(combined_frame, (1, 2, 0))
            color_output.append_data(combined_frame)
    color_output.close()

def plot_prediction_similarity(model, mixed_dataset, output_dir, step, action_mod = None): # this plots latent predictions and their true values over time 
    # plot similarity of true end state and predicted end state 
    print("Plotting Prediction Similarity")
    for selection in range(10):
        demo = random.randint(0, len(mixed_dataset.lengths_list) - 1)
        start, end = mixed_dataset.get_bounds_of_demo(demo)
        mse_loss = torch.nn.MSELoss()
        error_list = list()
        variance_list = list()
        for j in tqdm.tqdm(range(start, end)):
            sample = mixed_dataset.__getitem__(j)
            state, action, last_state = prepare_np(sample[0]), prepare_np(sample[1]), prepare_np(sample[2])
            state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
            last_state = {k : torch.unsqueeze(v, dim = 0) for k, v in last_state.items()}
            action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
            if action_mod == "noise":
                action_noised = torch.randn_like(action)
                action_noised[-1] = action[-1]
                action = action_noised 
            if action_mod == "negate":
                action[0:-1] = -action[0:-1]

            with torch.no_grad():
                # last_state_predict, reco_last = model(state, action) # the image is 200x200, we resize to 224 for dino 
                last_state_embed = model.state_embedding(last_state)#.detach().cpu().numpy()
                predicted_last_state_embed = model.state_action_embedding(state, action)#.detach().cpu().numpy() # gets the s, a embedding only
                mse_loss = torch.nn.MSELoss()
                loss = mse_loss(last_state_embed, predicted_last_state_embed)
            error_list.append(loss.item())
        
        plt.plot(error_list)
            
    # plt.tight_layout()
    plt.title("End State Prediction")
    plt.xlabel("Step")
    plt.ylabel("MSE Prediction Error")
    plt.legend()
    mod = action_mod if action_mod is not None else ""
    plt.savefig(output_dir + str(step) + f"_end_prediction_error_{mod}.png", dpi=300)
    plt.close()


def generate_conf_matr(matrix, img_size, row_imgs, col_imgs):
    conf_matrix = np.zeros((matrix.shape[0] * img_size, matrix.shape[1] * img_size, 3))
    buf_matrix = np.zeros((img_size, img_size, 3)) # for the upper left corner 
    row_imgs.insert(0, buf_matrix)
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            conf_matrix[img_size * i : img_size * (i + 1), img_size * j : img_size * (j + 1), :]  = matrix[i, j]
    row_addition = np.concatenate(row_imgs, axis = 0)
    col_addition = np.concatenate(col_imgs, axis = 1)
    right_half = np.concatenate([col_addition, conf_matrix], axis = 0)
    put_together = np.concatenate([row_addition, right_half], axis = 1)
    return put_together

def final_state_sa_interpolation(model, good_dataset, output_dir, step):
    good_embeddings_list = list()
    print("Precomputing good")
    idx = 0
    good_img_list = list() 
    for length in tqdm.tqdm(good_dataset.lengths_list):
        idx += length 
        sample = good_dataset.get_labeled_item(idx - 1, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        good_img_list.append(np.transpose(state["robot0_eye_in_hand_image"].detach().cpu().numpy(), (1, 2, 0)) / 255)
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            good_embedding = model.state_embedding(state, normalize = False) # gets the s, a embedding only 
        good_embeddings_list.append(good_embedding.clone())
    
    good_embeddings_sa_list = list()
    print("Precomputing good")
    idx = 0
    good_img_list = list() 
    for length in tqdm.tqdm(good_dataset.lengths_list):
        idx += length 
        # sample = good_dataset.get_labeled_item(idx - 16, flatten_action = False)
        sample = good_dataset.get_labeled_item(idx - 1, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        good_img_list.append(np.transpose(state["robot0_eye_in_hand_image"].detach().cpu().numpy(), (1, 2, 0)) / 255)
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            good_embedding = model.state_action_embedding(state, action, normalize = False) # gets the s, a embedding only 
        good_embeddings_sa_list.append(good_embedding.clone())

    color_output = imageio.get_writer(output_dir + str(step) +  "good_s_to_sa_interpolations.mp4")

    for i in range(len(good_dataset.lengths_list) - 1):
        point_a = good_embeddings_list[i]
        point_b = good_embeddings_sa_list[i + random.randint(1, len(good_dataset.lengths_list) - i - 1)]
        tau = 0
        while tau <= 1:
            midpoint = tau * (point_b - point_a) + point_a 
            with torch.no_grad():
                reconstruction = model.image_reconstruct(midpoint).detach().cpu().numpy()[0]
            reconstruction = np.transpose(reconstruction, (1, 2, 0))
            color_output.append_data(reconstruction)
            tau += 0.05
        for i in range(10):
            color_output.append_data(np.zeros((128, 128, 3)))

# TODO: a variant where we bin by category and get average L2 distance 


def compare_final_state_similarity_by_category(model, good_dataset, mixed_dataset, output_dir, step, good_key):
    good_embeddings_list = list()
    print("Precomputing good")
    idx = 0
    good_img_list = list() 
    for length in tqdm.tqdm(good_dataset.lengths_list):
        idx += length 
        sample = good_dataset.get_labeled_item(idx - 1, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        good_img_list.append(np.transpose(state[MAIN_CAMERA].detach().cpu().numpy(), (1, 2, 0)) / 255)
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            good_embedding = model.state_embedding(state, normalize = False).flatten(start_dim = 1) # gets the s, a embedding only 
        good_embeddings_list.append(good_embedding.clone())
    good_embeddings = torch.concatenate(good_embeddings_list, dim = 0)

    good_list = list()
    bad_list = list() 
    label_list = list()
    embed_list = list() 
    idx = 0
    mixed_img_list = list()
    category_lists = {} 

    print("Calculating mixed")
    for length in tqdm.tqdm(mixed_dataset.lengths_list):
        idx += length 
        sample = mixed_dataset.get_labeled_item(idx - 1, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        mixed_img_list.append(np.transpose(state[MAIN_CAMERA].detach().cpu().numpy(), (1, 2, 0)) / 255)
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            s_embedding = model.state_embedding(state, normalize = False).flatten(start_dim = 1) # gets the s, a embedding only 
        # sa_dot_product = good_embeddings @ sa_embedding.T 
        s_norm = torch.cdist(good_embeddings, s_embedding, p=2.0)
        # sa_dot_product = good_embeddings @ s_embedding.T  # just testing state-to-state similarity 
        # sa_average_dot_product = torch.mean(sa_dot_product).detach().cpu().numpy()
        sa_average_norm = torch.mean(s_norm).detach().cpu().numpy()

        # embed_list.append(sa_embedding.detach().cpu().numpy()[0])
        # embed_list.append(sa_embedding[0])
        label = str(label)

        embed_list.append(s_embedding[0])
        if label not in category_lists:
            category_lists[label] = []
        category_lists[label].append(sa_average_norm)
    
    averages = {k : sum(v) / len(v) for k, v in category_lists.items()}
    color_list = ["blue" if k != good_key else "green" for k in averages.keys()]
    plt.bar(averages.keys(), averages.values(), color = color_list)
    plt.title("End State L2 Distance Average")
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.ylim(min(averages.values()) - 50, max(averages.values()) + 50)
    plt.savefig(output_dir + str(step) + "_end_state_category_distr.png")
    plt.close() 


def compare_final_state_similarity(model, good_dataset, mixed_dataset, output_dir, step , good_key):
    good_embeddings_list = list()
    print("Precomputing good")
    idx = 0
    good_img_list = list() 
    for length in tqdm.tqdm(good_dataset.lengths_list):
        idx += length 
        sample = good_dataset.get_labeled_item(idx - 1, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        good_img_list.append(np.transpose(state[MAIN_CAMERA].detach().cpu().numpy(), (1, 2, 0)) / 255)
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            good_embedding = model.state_embedding(state, normalize = False).flatten(start_dim = 1) # gets the s, a embedding only 
        good_embeddings_list.append(good_embedding.clone())
    good_embeddings = torch.concatenate(good_embeddings_list, dim = 0)

    good_list = list()
    bad_list = list() 
    label_list = list()
    embed_list = list() 
    idx = 0
    mixed_img_list = list()
    print("Calculating mixed")
    for length in tqdm.tqdm(mixed_dataset.lengths_list):
        idx += length 
        sample = mixed_dataset.get_labeled_item(idx - 1, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        mixed_img_list.append(np.transpose(state[MAIN_CAMERA].detach().cpu().numpy(), (1, 2, 0)) / 255)
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            s_embedding = model.state_embedding(state, normalize = False).flatten(start_dim = 1) # gets the s, a embedding only 
        # sa_dot_product = good_embeddings @ sa_embedding.T 
        s_norm = torch.cdist(good_embeddings, s_embedding, p=2.0)
        # sa_dot_product = good_embeddings @ s_embedding.T  # just testing state-to-state similarity 
        # sa_average_dot_product = torch.mean(sa_dot_product).detach().cpu().numpy()
        # sa_average_norm = torch.mean(s_norm).detach().cpu().numpy()
        sa_average_norm = torch.min(s_norm).detach().cpu().numpy() # get the closet one 

        # embed_list.append(sa_embedding.detach().cpu().numpy()[0])
        # embed_list.append(sa_embedding[0])
        label = str(label)
        embed_list.append(s_embedding[0])
        label_list.append(label)
        if label == good_key: # str is needed to adapt to numerical labels for the pymunk 
            good_list.append(sa_average_norm)
        else:
            bad_list.append(sa_average_norm)
    
    # embed_concat = torch.stack(embed_list, dim = 0)

    # joint_norm = torch.cdist(good_embeddings, embed_concat, p=2.0).detach().cpu().numpy()
    # joint_norm = (joint_norm - np.min(joint_norm)) / (np.max(joint_norm) - np.min(joint_norm))


    # # joint_dot_product = good_embeddings @ embed_concat.T # good x mixed 
    # # joint_dot_product = joint_dot_product.detach().cpu().numpy()
    # # joint_dot_product = (joint_dot_product - np.min(joint_dot_product)) / (np.max(joint_dot_product) - np.min(joint_dot_product))
    # confusion_matrix = generate_conf_matr(joint_norm, 200, good_img_list, mixed_img_list)
    # plt.imsave(output_dir + str(step) + "_confusion_matrix.png", confusion_matrix)

    bins = np.linspace(min(min(good_list), min(bad_list)), max(max(good_list), max(bad_list)), 20)
    # plt.hist(good_list, bins = bins, color = "blue", alpha = 0.5)
    # plt.hist(bad_list, bins = bins, color = "green", alpha = 0.5)
    plt.hist([good_list, bad_list], bins = bins, color = ["green", "blue"], label = ["good", "bad"])
    plt.legend(loc='upper right')
    plt.title("End State L2 Distance Average (lower better)")
    plt.savefig(output_dir + str(step) + "_end_state_distr.png")
    plt.close() 


def jiggle_action_with_reco(model, mixed_dataset, output_dir, step):
    # take a demonstration
    # for every state: sample 10 random actions for one step, get the next state, plot average variance 
    # tests the impact of actions in a unit ball and the smoothness of the embedding 
    # to ensure reasonability, we jiggle the action in the unit ball as set by the real actions 
    resizer = torchvision.transforms.Resize((224, 224))

    demo = random.randint(0, len(mixed_dataset.lengths_list) - 1)
    color_output = imageio.get_writer(output_dir + str(step) +  f"jiggle_reco_{demo}.mp4")

    start, end = mixed_dataset.get_bounds_of_demo(demo)
    for i in tqdm.tqdm(range(start, end, 10)): # do this every 10 frames 
        sample = mixed_dataset.get_labeled_item(i, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            batch_state = {"robot0_eye_in_hand_image" : torch.tile(state["robot0_eye_in_hand_image"], (4, 1, 1, 1))} 
            action_mags = torch.mean(torch.abs(action), dim = 1)
            # take magnitude along each axis 
            batch_action = torch.zeros((4, 16, 7), device = action.device) 
            batch_action = batch_action * action_mags # scaling 
            
            # this covers the four compass directions. I'm not worrying about z or the rotations for this test 
            batch_action[:] = action 
            batch_action[1, :, :2] *= -1 
            batch_action[2, :, 0], batch_action[2, :, 1] = batch_action[2, :, 1], -batch_action[2, :, 0] # makes orthogonal direction 
            batch_action[3] = batch_action[2]
            batch_action[3, :, :2] *= -1
            z_hats = model.state_action_embedding(batch_state, batch_action, normalize =False)
            B, S = z_hats.shape[0], z_hats.shape[1] 
            z_hats = z_hats.view(B * S, z_hats.shape[2])
            reco = model.image_reconstruct(z_hats)
            reco = resizer(reco).detach().cpu().numpy()
        
        for frame in range(reco.shape[0]):
            color_output.append_data(np.transpose(reco[frame], (1, 2, 0)))
            if frame % S == 0 and frame > 0: # denotes different direction 
                red_frame = np.ones((224, 224, 3))
                red_frame[:, :, 1:] = 0
                color_output.append_data(red_frame)

        for i in range(10): # denotes end of jiggle trial 
            color_output.append_data(np.zeros((224, 224, 3)))

    color_output.close()
    


# this will visualize how the model processes individual and average trajectories 
def plot_valid_trajectory(model, good_dataset, mixed_dataset, output_dir, step, good_label): # this doesn't take in the dataloader 
    # precompute the good end state 
    good_embeddings_list = list()
    print("Precomputing good")
    idx = 0
    good_img_list = list() 
    for length in tqdm.tqdm(good_dataset.lengths_list):
        idx += length 
        # sample = good_dataset.get_labeled_item(idx - 1, flatten_action = False)
        sample = good_dataset.get_labeled_item(idx - 16, flatten_action = False)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        good_img_list.append(np.transpose(state[MAIN_CAMERA].detach().cpu().numpy(), (1, 2, 0)) / 255)
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            # good_embedding = model.state_action_embedding(state, action).flatten(start_dim = 1)
            good_embedding = model.state_embedding(state).flatten(start_dim = 1)
            # good_embedding = model.state_embedding(state, normalize = False) # gets the s, a embedding only 
        good_embeddings_list.append(good_embedding.clone())
    good_embeddings = torch.concatenate(good_embeddings_list, dim = 0)


    good_dict = {}
    bad_dict = {}
    good_step_list = list()
    bad_step_list = list()
    print("Evaluating trajectories")
    for i in tqdm.tqdm(range(len(mixed_dataset))):
        sample = mixed_dataset.get_labeled_item(i, flatten_action = False)
        demo, idx_in_demo = mixed_dataset.parse_idx(i)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        if label == good_label and demo not in good_dict:
            good_dict[demo] = list()
            current_traj = good_dict[demo]
            current_step_list = good_step_list 
        elif label != good_label and demo not in bad_dict:
            bad_dict[demo] = list()
            current_traj = bad_dict[demo] 
            current_step_list = bad_step_list
        
        if len(current_step_list) <= idx_in_demo:
            current_step_list.append(list())
        

        # state = torch.unsqueeze(state, dim = 0)
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}

        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            final_embedding = model.state_action_embedding(state, action).flatten(start_dim = 1) # gets the s, a embedding only 
        
        mse_distance = torch.mean(torch.square(good_embeddings - final_embedding)) # using broadcasting 
        # dot_product = good_embeddings @ embedding.T 
        # # average_dot_product = torch.mean(dot_product)
        # average_dot_product = torch.quantile(dot_product, 0.8)
        # outputs = torch.einsum('ik,jk->ij', s_a_embed, s_prime_embed)
        current_traj.append(mse_distance.item())
        current_step_list[idx_in_demo].append(mse_distance.item())

    print("Compute self-similarity")
    key_dict = dict()
    for i in tqdm.tqdm(range(len(good_dataset))):
        sample = good_dataset.get_labeled_item(i, flatten_action = False)
        demo, idx_in_demo = good_dataset.parse_idx(i)
        state, action, label = prepare_np(sample[0]), prepare_np(sample[1]), sample[2]
        state = {k : torch.unsqueeze(v, dim = 0) for k, v in state.items()}
        action = torch.unsqueeze(action, dim = 0) # compensates for the batch dimension 
        with torch.no_grad():
            embedding = model.state_action_embedding(state, action).flatten(start_dim = 1) # gets the next step embedding 
            
        mse_distance = torch.mean(torch.square(good_embeddings - embedding))

        if demo not in key_dict:
            key_dict[demo] = list()

        key_dict[demo].append(mse_distance.item())


    for demo, traj in good_dict.items():
        plt.plot(traj, color = "green")
    for demo, traj in bad_dict.items():
        plt.plot(traj, color = "red")

    plt.xlabel("Steps")
    plt.ylabel("Classifier Output")
    plt.title("Individual Trajectories")
    plt.savefig(output_dir + str(step) + "_individual_similarities.png")
    plt.close()

    # plot them individually 
    fig, ax_tuple = plt.subplots(ncols = 11, nrows = 10) #, figsize = (7, 7)) #use figure size to manually control how large the plot will be, in inches
    count_red = 0 
    count_green = 1 
    for demo, traj in good_dict.items():
        ax_tuple[count_green // 10, count_green % 10].plot(traj, color = "green")
        ax_tuple[count_green // 10, count_green % 10].get_xaxis().set_visible(False)
        ax_tuple[count_green // 10, count_green % 10].get_yaxis().set_visible(False)
        count_green += 2 
        if count_green > 99:
            break # don't overfill 
    for demo, traj in bad_dict.items():
        ax_tuple[count_red // 10, count_red % 10].plot(traj, color = "red")
        ax_tuple[count_red // 10, count_red % 10].get_xaxis().set_visible(False)
        ax_tuple[count_red // 10, count_red % 10].get_yaxis().set_visible(False)
        count_red += 2 
        if count_red > 99:
            break 
    
    count = 0
    for demo, traj in key_dict.items():
        ax_tuple[count, 10].plot(traj, color = "blue")
        ax_tuple[count, 10].get_xaxis().set_visible(False)
        ax_tuple[count, 10].get_yaxis().set_visible(False)
        count += 1 
        if count > 9:
            break

    plt.savefig(output_dir + str(step) + "_tiled_plots.pdf")
    plt.close()

    mean_good_list = [sum(k) / len(k) for k in good_step_list]
    std_good_list = [np.std(k) for k in good_step_list]
    mean_bad_list = [sum(k) / len(k) for k in bad_step_list]
    std_bad_list = [np.std(k) for k in bad_step_list]

    good_x = np.arange(len(mean_good_list))
    bad_x = np.arange(len(mean_bad_list))
    plt.plot(mean_good_list, "green")
    plt.plot(mean_bad_list, "red")
    plt.fill_between(good_x, [k-v for k, v in zip(mean_good_list, std_good_list)], [k+v for k, v in zip(mean_good_list, std_good_list)], color = "green", alpha = 0.2)
    plt.fill_between(bad_x, [k-v for k, v in zip(mean_bad_list, std_bad_list)], [k+v for k, v in zip(mean_bad_list, std_bad_list)], color = "red", alpha = 0.2)
    plt.xlabel("Steps")
    plt.ylabel("Classifier Output")
    plt.title("Average Trajectories")
    plt.savefig(output_dir + str(step) + "_average_similarities.png")
    plt.close()


def main(args):
    # this needs to be aligned with the action chunk length in the trained model 
    ACTION_DIM = 7 # for CAlvin 
    # ACTION_DIM = 10 # for UMI 

    # for calvin 
    proprio_dim = 15 
    proprio = "proprio" # set to None if you want to exclude propriorception 

    # for umi 
    cameras = [MAIN_CAMERA] # you can change this; it's hardcoded
    padding = True
    pad_mode = "repeat" #"repeat" # "zeros" for calvin 

    model = FinalStatePredictionDino(ACTION_DIM, args.action_chunk_length, cameras=cameras, reconstruction = True, \
                                     proprio = proprio, proprio_dim = proprio_dim)

    model.load_state_dict(torch.load(args.checkpoint))
    model.to("cuda")
    model.eval()

    # dataset = MultiviewDataset(args.train_hdf5, action_chunk_length = args.action_chunk_length, cameras = cameras, proprio = proprio,
    #     padding = padding, pad_mode = pad_mode)
    


    good_dataset = MultiviewDataset(args.good_hdf5, action_chunk_length = args.action_chunk_length, cameras = cameras, padding = padding,
                                    pad_mode = pad_mode, proprio = proprio)
    
    checkpoint_number = args.checkpoint.split("/")[-1].split(".")[0]
    # for calvin environment 
    mixed_dataset = MultiviewDataset(args.mixed_hdf5, action_chunk_length = args.action_chunk_length, cameras = cameras, padding = padding,
                                     pad_mode = pad_mode, proprio = proprio)

    # # for UMI environment 
    # mixed_dataset = MultiviewDatasetUMI(args.mixed_hdf5, action_chunk_length = args.action_chunk_length, cameras = cameras,
    #     padding = padding, pad_mode = pad_mode)
    

    log_dir = args.exp_dir + "tests/"
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    # these are the tests that you can run 

    compare_final_state_similarity(model, good_dataset, mixed_dataset, log_dir, checkpoint_number, args.key)
    compare_final_state_similarity_by_category(model, good_dataset, mixed_dataset, log_dir, checkpoint_number, args.key)


    plot_prediction_similarity(model, mixed_dataset, log_dir, checkpoint_number, action_mod = "noise")
    plot_prediction_similarity(model, mixed_dataset, log_dir, checkpoint_number, action_mod = "negate")
    plot_prediction_similarity(model, mixed_dataset, log_dir, checkpoint_number)

    replay_through_reconstruction(model, mixed_dataset, log_dir, checkpoint_number) #step hard coded for now 

    plot_valid_trajectory(model, good_dataset, mixed_dataset, log_dir, checkpoint_number, args.key) #step hard coded for now 
    jiggle_action(model, mixed_dataset, log_dir, checkpoint_number) #step hard coded for now



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--exp_dir",
        type=str,
        default=None,
        help="",
    )
    # parser.add_argument(
    #     "--model_name",
    #     type=str,
    #     default=None,
    #     help="",
    # )

    parser.add_argument(
        "--key",
        type=str,
        default=None,
        help="",
    )

    parser.add_argument(
        "--good_hdf5",
        type=str,
        default=None,
        help="",
    )
    parser.add_argument(
        "--mixed_hdf5",
        type=str,
        default=None,
        help="",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="",
    )
    parser.add_argument(
        "--action_chunk_length",
        type=int,
        default=None,
        help="",
    )
    args = parser.parse_args()

    main(args)