import os
import sys
import torch
import argparse
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from PIL import Image, UnidentifiedImageError
from libero.libero import get_libero_path
from r3m import remove_language_head, cleanup_config
from libero.lifelong.algos import *
import omegaconf
import hydra
from libero.lifelong.utils import (
    control_seed,
    safe_device,
    torch_load_model,
    NpEncoder,
    compute_flops,
)
from pathlib import Path

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

benchmark_map = {
    "libero_10": "libero_10",
    "libero_spatial": "libero_spatial",
    "libero_object": "libero_object",
    "libero_goal": "libero_goal",
}

algo_map = {
    "base": "Sequential",
    "er": "ER",
    "ewc": "EWC",
    "packnet": "PackNet",
    "multitask": "Multitask",
}

policy_map = {
    "bc_rnn_policy": "BCRNNPolicy",
    "bc_transformer_policy": "BCTransformerPolicy",
    "bc_vilt_policy": "BCViLTPolicy",
    "bc_transformer_policy_r3m": "BCTransformerPolicyR3M",
}

# Add import for the function to parse arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Retrieval with R3M")
    parser.add_argument("--experiment_dir", type=str, default="experiments")
    parser.add_argument("--benchmark", type=str, required=True, choices=["libero_10", "libero_spatial", "libero_object", "libero_goal"])
    parser.add_argument("--task_id", type=int, required=True)
    parser.add_argument("--algo", type=str, required=True, choices=["base", "er", "ewc", "packnet", "multitask"])
    parser.add_argument("--policy", type=str, required=True, choices=["bc_rnn_policy", "bc_transformer_policy", "bc_vilt_policy", "bc_transformer_policy_r3m"])
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--ep", type=int)
    parser.add_argument("--load_task", type=int)
    parser.add_argument("--device_id", type=int, default=0)
    return parser.parse_args()

def load_r3m_from_algo_policy():
    args = parse_args()
    device = f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu"
    set_seed(args.seed)

    experiment_dir = os.path.join(
        args.experiment_dir,
        f"{benchmark_map[args.benchmark]}/"
        + f"{algo_map[args.algo]}/"
        + f"{policy_map[args.policy]}_seed{args.seed}",
    )
    # find the checkpoint
    experiment_id = 0
    for path in Path(experiment_dir).glob("run_*"):
        if not path.is_dir():
            continue
        try:
            folder_id = int(str(path).split("run_")[-1])
            if folder_id > experiment_id:
                experiment_id = folder_id
        except BaseException:
            pass

    if experiment_id == 0:
        print(f"[error] cannot find the checkpoint under {experiment_dir}")
        sys.exit(0)

    run_folder = os.path.join(experiment_dir, f"run_{experiment_id:03d}")

    try:
        if args.algo == "multitask":
            model_path = os.path.join(run_folder, f"multitask_model_ep{args.ep}.pth")
        else:
            model_path = os.path.join(run_folder, f"task{args.load_task}_model.pth")

        sd, cfg, _ = torch_load_model(model_path, map_location=device)
    except:
        print(f"[error] cannot find the checkpoint at {model_path}")
        sys.exit(0)

    cfg.folder = get_libero_path("datasets")
    cfg.bddl_folder = get_libero_path("bddl_files")
    cfg.init_states_folder = get_libero_path("init_states")

    cfg.device = args.device_id
    algo = safe_device(eval(algo_map[args.algo])(10, cfg), "cuda")
    algo.policy.load_state_dict(sd)

    return algo.policy.r3m.eval()

def load_r3m():
    modelpath = "../../models/r3m/model_resnet18.pt"
    configpath = "../configs/config.yaml"

    modelcfg = omegaconf.OmegaConf.load(configpath)
    cleancfg = cleanup_config(modelcfg)
    rep = hydra.utils.instantiate(cleancfg)
    rep = torch.nn.DataParallel(rep)
    r3m_state_dict = remove_language_head(torch.load(modelpath, map_location=torch.device(device))['r3m'])
    rep.load_state_dict(r3m_state_dict)
    return rep.module

def get_image_paths(folder_path):
    # Get all image file paths in the folder and sort them
    image_paths = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith(('.png', '.jpg', '.jpeg'))]
    image_paths.sort()  # Sort the list of paths to ensure correct order
    return image_paths

def preprocess_image(image_path, transforms):
    try:
        image = Image.open(image_path).convert("RGB")
    except UnidentifiedImageError:
        print(f"Error: Cannot identify image file {image_path}")
        return None
    image = TF.vflip(image)
    return transforms(image).unsqueeze(0).to(device)


def plot_similarity_curve(min_distances):
    window_size = 5  # Number of steps before and after the current step to include in the average
    smoothed_distances = []
    min_distances = min_distances.cpu().numpy()

    for i in range(len(min_distances)):
        # window range
        start_index = max(0, i - window_size)
        end_index = min(len(min_distances) - 1, i + window_size)

        # Smooth the distance values using a moving average
        window_avg = min_distances[start_index:end_index + 1].mean()
        smoothed_distances.append(window_avg)

    # Identify indices of smallest and largest values
    min_value_indices = [i for i, v in enumerate(smoothed_distances) if v == min(smoothed_distances)]
    max_value_indices = [i for i, v in enumerate(smoothed_distances) if v == max(smoothed_distances)]

    plt.figure(figsize=(10, 6))
    plt.plot(range(len(smoothed_distances)), smoothed_distances, label='Smoothed Minimum Distance', color='black')

    plt.scatter(min_value_indices, [smoothed_distances[i] for i in min_value_indices], color='red', label='Smallest Values', s=100, edgecolors='black', zorder=5)
    plt.scatter(max_value_indices, [smoothed_distances[i] for i in max_value_indices], color='blue', label='Largest Values', s=100, edgecolors='black', zorder=5)

    plt.xlabel('Time Step (Demonstration)')
    plt.ylabel('Average Minimum Distance')
    plt.title('Average Minimum Distance over Time Steps')
    plt.legend()
    plt.grid(True)
    plt.show()


def main():
    # r3m = load_r3m()
    # r3m.eval()
    r3m = load_r3m_from_algo_policy()
    transforms = T.Compose([T.Resize(128),
                            T.ToTensor()])  # ToTensor() divides by 255

    # Process all demonstration images
    demo_folder = "demos_video/libero_object_task_0_videos/"
    # demo_folder = "../../libero/demos_video/initials/"
    demo_image_paths = get_image_paths(demo_folder)
    demo_embeddings = []

    for demo_image_path in demo_image_paths:
        preprocessed_demo_image = preprocess_image(demo_image_path, transforms)
        if preprocessed_demo_image is not None:
            with torch.no_grad():
                demo_embedding = r3m(preprocessed_demo_image * 255.0)  # R3M expects image input to be [0-255]
                demo_embeddings.append(demo_embedding)

    demo_embeddings = torch.cat(demo_embeddings, dim=0)

    # Process rollout images in batches
    rollout_folder = "experiments_saved/libero_object_er_bc_transformer_policy_r3m_1_load9_on0_videos/"
    # rollout_folder = "experiments_saved/special_one_task0/"
    # rollout_folder = "../../libero/experiments_saved/initials/"
    rollout_image_paths = get_image_paths(rollout_folder)
    batch_size = 32
    all_rollout_embeddings = []

    for i in range(0, len(rollout_image_paths), batch_size):
        batch_paths = rollout_image_paths[i:i + batch_size]
        preprocessed_images = [preprocess_image(path, transforms) for path in batch_paths]
        preprocessed_images = [img for img in preprocessed_images if img is not None]  # Filter out None values

        if len(preprocessed_images) == 0:
            continue

        preprocessed_batch = torch.cat(preprocessed_images, dim=0)
        with torch.no_grad():
            batch_embeddings = r3m(preprocessed_batch * 255.0)  # R3M expects image input to be [0-255]
            all_rollout_embeddings.append(batch_embeddings)

    if len(all_rollout_embeddings) == 0:
        print("No valid images found.")
        return

    all_rollout_embeddings = torch.cat(all_rollout_embeddings, dim=0)

    # Compute distances between all demo embeddings and all rollout embeddings
    with torch.no_grad():
        distances = torch.cdist(demo_embeddings, all_rollout_embeddings)
        print("distances", distances.shape, distances)

    # Find the minimum distance for each demo image and its index in the rollout images
    min_distances, min_indices = torch.min(distances, dim=1)

    # Print the results
    for i, (min_distance, min_index) in enumerate(zip(min_distances, min_indices)):
        print(f"Demo image {i} has minimum distance {min_distance.item()} with rollout image {min_index.item()}")

    plot_similarity_curve(min_distances)

if __name__ == "__main__":
    main()
