import argparse
import sys
import os
import omegaconf
import hydra
import shutil

# TODO: find a better way for this?
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import numpy as np
import torch.nn as nn
import torch
import multiprocessing
import time
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from collections import defaultdict

from libero.libero import get_libero_path
from libero.libero.benchmark import get_benchmark
from libero.libero.envs import OffScreenRenderEnv, SubprocVectorEnv
from libero.libero.utils.time_utils import Timer
from libero.libero.utils.video_utils import VideoWriter
from libero.lifelong.DataModule import DataModule
from libero.lifelong.algos import *
from libero.lifelong.datasets import get_dataset, SequenceVLDataset, GroupedTaskDataset, TruncatedSequenceDataset
from libero.lifelong.metric import (
    evaluate_loss,
    evaluate_success,
    raw_obs_to_tensor_obs,
)
from libero.lifelong.utils import (
    control_seed,
    safe_device,
    torch_load_model,
    NpEncoder,
    compute_flops,
)
from r3m import remove_language_head, cleanup_config
from libero.lifelong.main import get_task_embs
from torch.utils.data import DataLoader, ConcatDataset

import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.tensor_utils as TensorUtils


benchmark_map = {
    "libero_90": "libero_90",
    "libero_10": "libero_10",
    "libero_spatial": "libero_spatial",
    "libero_object": "libero_object",
    "libero_goal": "libero_goal",
}

algo_map = {
    "base": "Sequential",
    "agem": "AGEM",
    "er": "ER",
    "ewc": "EWC",
    "packnet": "PackNet",
    "multitask": "Multitask",
}

policy_map = {
    "bc_rnn_policy": "BCRNNPolicy",
    "bc_transformer_policy": "BCTransformerPolicy",
    "bc_vilt_policy": "BCViLTPolicy",
    "bc_transformer_policy_r3m": "BCTransformerPolicyR3M",
}

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluation Script")
    parser.add_argument("--experiment_dir", type=str, default="experiments")
    # for which task suite
    parser.add_argument(
        "--benchmark",
        type=str,
        required=True,
        choices=["libero_90", "libero_10", "libero_spatial", "libero_object", "libero_goal"],
    )
    parser.add_argument("--task_id", type=int, required=True)
    # method detail
    parser.add_argument(
        "--algo",
        type=str,
        required=True,
        choices=["base", "agem", "er", "ewc", "packnet", "multitask"],
    )
    parser.add_argument(
        "--policy",
        type=str,
        required=True,
        choices=["bc_rnn_policy", "bc_transformer_policy", "bc_vilt_policy", "bc_transformer_policy_r3m"],
    )
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--ep", type=int)
    parser.add_argument("--load_task", type=int)
    parser.add_argument("--device_id", type=int)
    parser.add_argument("--epochs", type=int)
    parser.add_argument("--save-videos", action="store_true")
    parser.add_argument("--save-images", action="store_true", help="Flag to save images")
    parser.add_argument("--weighted-adapt", action="store_true", help="Flag to adapt with weights")
    # parser.add_argument('--save_dir',  type=str, required=True)
    args = parser.parse_args()
    args.device_id = "cuda:" + str(args.device_id)
    args.save_dir = f"{args.experiment_dir}_saved"

    if args.algo == "multitask":
        assert args.ep in list(
            range(0, 50, 5)
        ), "[error] ep should be in [0, 5, ..., 50]"
    else:
        if args.benchmark == 'libero_90':
            assert args.load_task in list(
                range(20)
                ), "[error] load_task should be in [0, ..., 19]"
        else:
            assert args.load_task in list(
                range(10)
                ), "[error] load_task should be in [0, ..., 9]"
    return args

def visualize_image(image_tensor, image_type="demo", image_id=0):
    """Visualize and optionally save an image from a tensor."""
    image_array = image_tensor.squeeze(0).cpu().permute(1, 2, 0).numpy()
    plt.imshow(np.flipud(image_array))
    plt.axis('off')

    # Save the image with a distinguishing name
    image_name = f"{image_type}_{image_id}.png"
    plt.savefig(image_name)
    plt.show()


# def load_r3m(device):
#     modelpath = "../models/r3m/model_resnet18.pt"
#     configpath = "../IL/configs/config.yaml"
#
#     modelcfg = omegaconf.OmegaConf.load(configpath)
#     cleancfg = cleanup_config(modelcfg)
#     rep = hydra.utils.instantiate(cleancfg)
#     rep = torch.nn.DataParallel(rep)
#     r3m_state_dict = remove_language_head(torch.load(modelpath, map_location=torch.device(device))['r3m'])
#     rep.load_state_dict(r3m_state_dict)
#     return rep.module

def custom_percentile(a, p):
    # Ensure p is between 0 and 1 (0 to 100%)
    if not 0 <= p <= 100:
        raise ValueError("Percentile must be between 0 and 100")

    # Find the min and max values
    min_val = np.min(a)
    max_val = np.max(a)

    # Calculate the custom percentile
    percentile_value = min_val + (max_val - min_val) * p/100.0

    return percentile_value

def adapt(algo, retrieved_episodic_memory, cfg, epochs=8):
    """
    Perform local adaptation using the retrieved episodic memory.
    """
    algo.policy.train()
    optimizer = eval(cfg.train.optimizer.name)(
            algo.policy.parameters(), **cfg.train.optimizer.kwargs
        )
    retrieved_episodic_memory_combined = ConcatDataset(retrieved_episodic_memory)

    local_adapt_dataloader = DataLoader(
        retrieved_episodic_memory_combined,
        batch_size=algo.cfg.train.batch_size,
        num_workers=0,
        pin_memory=True,
        shuffle=True,
    )

    for epoch in range(epochs):
        for data in local_adapt_dataloader:
            data = algo.map_tensor_to_device(data)
            optimizer.zero_grad()
            loss = algo.policy.compute_loss(data)
            (algo.loss_scale * loss).backward()
            if algo.cfg.train.grad_clip is not None:
                nn.utils.clip_grad_norm_(
                    algo.policy.parameters(), algo.cfg.train.grad_clip
                )
            optimizer.step()

    torch.cuda.empty_cache()
    print(f"Local adaptation completed over {epochs} epochs")

def weighted_adapt(video_folder, algo, retrieved_episodic_memory, cfg, args, failed_rollouts_embeddings, 
                   benchmark, epochs=8):
    """
    Perform local adaptation using the retrieved episodic memory with selective weighting.
    """
    time0 = time.time()
    optimizer = eval(cfg.train.optimizer.name)(
        algo.policy.parameters(), **cfg.train.optimizer.kwargs
    )

    # Generate embeddings as queries for selective weighting
    current_demo_embeddings = []
    seq_len = cfg.data.seq_len
    rollout_initial_embeddings = [(idx, embedding[0]) for idx, embedding in failed_rollouts_embeddings.items()]

    for memory in retrieved_episodic_memory:
        demo_start_indices = memory.demo_start_indices
        demo_lengths = memory.demo_lengths
        demo_ids = memory.demo_ids

        for demo_idx, demo_id in enumerate(demo_ids):
            # Create subfolder for each retrieved demonstration
            demo_folder = os.path.join(video_folder, f"demo_{demo_idx}")
            os.makedirs(demo_folder, exist_ok=True)

            start_idx = demo_start_indices[demo_id]
            demo_length = demo_lengths[demo_id]
            demo_weights = np.ones(demo_length)  # Initialize weights for the current demonstration
            demo_weights_vector = np.ones((demo_length, seq_len))

            for data_idx in range(start_idx, start_idx + demo_length, algo.cfg.train.batch_size):
                data_batch = [memory[i] for i in
                              range(data_idx, min(data_idx + algo.cfg.train.batch_size, start_idx + demo_length))]
                # for item in data_batch:
                #     print("shapes!: ", item["obs"]["agentview_rgb"].shape, item["obs"]["agentview_rgb"][0, :, :, :].shape)
                agentview_images = torch.cat(
                    [torch.tensor(item["obs"]["agentview_rgb"][0, :, :, :]).unsqueeze(0).to(algo.policy.r3m.device)
                     for item in data_batch],
                    dim=0
                )

                with torch.no_grad():
                    embeddings = algo.policy.r3m(agentview_images * 255)
                current_demo_embeddings.extend(embeddings)

                # Save demonstration images
                if args.save_images:
                    for img_idx, img in enumerate(agentview_images):
                        flipped_image = img.cpu().numpy().transpose(1, 2, 0)[::-1]
                        image_path = os.path.join(demo_folder, f"frame_{data_idx + img_idx - start_idx:04d}.png")
                        image = Image.fromarray((flipped_image * 255).astype(np.uint8))
                        image.save(image_path)

            # Select the top 5 (up to 5) most similar failed rollouts for each demo based on the initial frame
            distances_init = torch.cdist(
                current_demo_embeddings[0].unsqueeze(0),
                torch.stack([embed for _, embed in rollout_initial_embeddings])
            )
            _, top_k_indices = torch.topk(distances_init, min(5, len(rollout_initial_embeddings)), dim=1, largest=False)

            for j in top_k_indices.squeeze(0):
                rollout_index = rollout_initial_embeddings[j.item()][0]
                rollout_embedding = torch.stack(failed_rollouts_embeddings[rollout_index])
                with torch.no_grad():
                    distances = torch.cdist(torch.stack(current_demo_embeddings), rollout_embedding)
                    min_distances, _ = torch.min(distances, dim=1)

                    smoothed_distances = []
                    min_distances = min_distances.cpu().numpy()
                    window_size = 5 # Number of steps before and after the current step to include in the average

                    for k in range(len(min_distances)):
                        start_index = max(0, k - window_size)
                        end_index = min(len(min_distances) - 1, k + window_size)
                        window_avg = min_distances[start_index:end_index + 1].mean()
                        smoothed_distances.append(window_avg)

                    print("min and max values", np.min(smoothed_distances), np.max(smoothed_distances))
                    boundary_threshold_75 = custom_percentile(smoothed_distances, 75)
                    step_75 = max(k for k, val in enumerate(smoothed_distances) if val >= boundary_threshold_75)
                    lower_threshold = custom_percentile(smoothed_distances[:step_75], 12.5)
                    upper_threshold = custom_percentile(smoothed_distances[:step_75], 33.3)

                    step_A = max(k for k, val in enumerate(smoothed_distances[:step_75]) if val <= lower_threshold)
                    step_B = max(k for k, val in enumerate(smoothed_distances[:step_75]) if val <= upper_threshold)

                    highlight_start = max(0, step_A - 15)
                    highlight_end = min(len(smoothed_distances) - 1, step_B + 15)

                    # Add weight to the important section
                    demo_weights[highlight_start:highlight_end + 1] += 0.3

                    # Plot and save the smoothed distance curve with thresholds
                    plt.figure(figsize=(10, 6))
                    plt.plot(range(len(smoothed_distances)), smoothed_distances, label=f'Demo {demo_idx} vs Rollout {j}', color='black')
                    plt.plot(range(highlight_start, highlight_end + 1), smoothed_distances[highlight_start:highlight_end + 1], color='black', linewidth=3)

                    plt.axhline(y=lower_threshold, color='blue', linestyle='--',
                                label=f'Lower Threshold (1/8): {lower_threshold:.2f}')
                    plt.axhline(y=upper_threshold, color='green', linestyle='--',
                                label=f'Upper Threshold (1/3): {upper_threshold:.2f}')
                    plt.axhline(y=boundary_threshold_75, color='red', linestyle='--',
                                label=f'Top Threshold (75%): {boundary_threshold_75:.2f}')
                    plt.axvline(x=step_A, color='blue', linestyle=':', label=f'Step A: {step_A}')
                    plt.axvline(x=step_B, color='green', linestyle=':', label=f'Step B: {step_B}')
                    plt.axvline(x=step_75, color='red', linestyle=':', label=f'Step 75%: {step_75}')

                    plt.xlabel('Time Step (Demonstration)')
                    plt.ylabel('Average Minimum Distance')
                    plt.title(f'Similarity Curve: Demo {demo_idx} vs Rollout {rollout_index}: Add weights to Samples '
                              f'{highlight_start} - {highlight_end}')
                    plt.legend()
                    plt.grid(True)

                    if args.save_images:
                        figure_path = os.path.join(video_folder, f"demo_{demo_idx}",
                                                   f"similarity_curve_demo_{demo_idx}_vs_rollout_{rollout_index}.png")
                        plt.savefig(figure_path)
                    plt.close()

            current_demo_embeddings = []

            # Clip weights and normalize
            demo_weights = np.clip(demo_weights, 1, 2)
            demo_weights = demo_weights / demo_weights.sum() * len(demo_weights)

            # Update the dataset with the computed weights
            for idx in range(demo_length):
                for t in range(seq_len):
                    # the logic here is relatively simple but the implementation is a bit tricky
                    demo_weights_vector[idx][t] = demo_weights[idx:min(demo_length, idx + t + 1)].mean()

                memory.weights[start_idx + idx] = demo_weights_vector[idx]

    time1 = time.time()

    retrieved_episodic_memory_combined = ConcatDataset(retrieved_episodic_memory)
    local_adapt_dataloader = DataLoader(
        retrieved_episodic_memory_combined,
        batch_size=algo.cfg.train.batch_size,
        num_workers=0,
        pin_memory=True,
        shuffle=True,
    )
    algo.policy.train()
    for epoch in range(epochs):
        for data in local_adapt_dataloader:
            data = algo.map_tensor_to_device(data)
            optimizer.zero_grad()
            loss = algo.policy.compute_loss(data, weights=data['weights'])
            (algo.loss_scale * loss).backward()
            if algo.cfg.train.grad_clip is not None:
                nn.utils.clip_grad_norm_(
                    algo.policy.parameters(), algo.cfg.train.grad_clip
                )
            optimizer.step()

    time2 = time.time()

    torch.cuda.empty_cache()
    print(f"Weighted local adaptation completed over {epochs} epochs, with the embedding of demonstrations and"
          f" the similarity comparison taking {time1 - time0} seconds, "
          f"and the local adaptation taking {time2 - time1} seconds.")



def main():
    args = parse_args()
    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    # r3m = load_r3m(device).eval()
    experiment_dir = os.path.join(
        args.experiment_dir,
        f"{benchmark_map[args.benchmark]}/"
        + f"{algo_map[args.algo]}/"
        + f"{policy_map[args.policy]}_seed{args.seed}",
    )

    # find the checkpoint
    experiment_id = 0
    for path in Path(experiment_dir).glob("run_*"):
        if not path.is_dir():
            continue
        try:
            folder_id = int(str(path).split("run_")[-1])
            if folder_id > experiment_id:
                experiment_id = folder_id
        except BaseException:
            pass
    if experiment_id == 0:
        print(f"[error] cannot find the checkpoint under {experiment_dir}")
        sys.exit(0)

    run_folder = os.path.join(experiment_dir, f"run_{experiment_id:03d}")
    try:
        if args.algo == "multitask":
            model_path = os.path.join(run_folder, f"multitask_model_ep{args.ep}.pth")
            sd, cfg, previous_mask = torch_load_model(
                model_path, map_location=args.device_id
            )
        else:
            model_path = os.path.join(run_folder, f"task{args.load_task}_model.pth")
            sd, cfg, previous_mask = torch_load_model(
                model_path, map_location=args.device_id
            )
    except:
        print(f"[error] cannot find the checkpoint at {str(model_path)}")
        sys.exit(0)

    cfg.folder = get_libero_path("datasets")
    cfg.bddl_folder = get_libero_path("bddl_files")
    cfg.init_states_folder = get_libero_path("init_states")

    cfg.device = args.device_id
    algo = safe_device(eval(algo_map[args.algo])(10, cfg), cfg.device)
    algo.policy.previous_mask = previous_mask

    if cfg.lifelong.algo == "PackNet":
        algo.eval()
        for module_idx, module in enumerate(algo.policy.modules()):
            if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                weight = module.weight.data
                mask = algo.previous_masks[module_idx].to(cfg.device)
                weight[mask.eq(0)] = 0.0
                weight[mask.gt(args.task_id + 1)] = 0.0
                # we never train norm layers
            if "BatchNorm" in str(type(module)) or "LayerNorm" in str(type(module)):
                module.eval()

    algo.policy.load_state_dict(sd)
    algo.eval()

    if not hasattr(cfg.data, "task_order_index"):
        cfg.data.task_order_index = 0

    # get the benchmark the task belongs to
    benchmark = get_benchmark(cfg.benchmark_name)(cfg.data.task_order_index)
    n_manip_tasks = benchmark.n_tasks

    # prepare datasets from the benchmark
    manip_datasets = []
    descriptions = []
    dm = DataModule(benchmark, cfg)

    for i in range(n_manip_tasks):
        task_i_dataset, shape_meta = dm.generate_dataset_by_task(
                obs_modality=cfg.data.obs.modality,
                dataset_modality=cfg.data.dataset.modality,
                task_list=[i],
                initialize_obs_utils=(i == 0),
                seq_len=cfg.data.seq_len,
                para_task_description=True,
            )

        # print(os.path.join(cfg.folder, benchmark.get_task_demonstration_new_rendering(i)))
        # add language to the vision dataset, hence we call vl_dataset
        task_description = benchmark.get_task(i).language        
        descriptions.append(task_description)
        manip_datasets.append(task_i_dataset)

    task_embs_run = get_task_embs(cfg, descriptions)
    benchmark.set_task_embs(task_embs_run)

    gsz = cfg.data.task_group_size
    if gsz == 1:  # each manipulation task is its own lifelong learning task
        # datasets = [
        #     SequenceVLDataset(ds, emb) for (ds, emb) in zip(manip_datasets, task_embs)
        # ]
        # datasets_pre = datasets[:args.load_task+1]
        # demos_per_dataset = cfg.lifelong.n_demos
        # episodic_memory = [
        #     TruncatedSequenceDataset(dataset, demos_per_dataset)
        #     for dataset in datasets_pre
        # ]

        datasets = [
            dm.process_dataset(dataset, cfg) for dataset in manip_datasets
        ]

        # for each demonstration, there is one single description embedding
        task_emb_list = []
        for dataset in datasets:
            benchmark.set_augmented_task_embs(dataset.data[0]['demo_emb_list'])
            for i in range(cfg.lifelong.n_demos):
                task_emb_list.append(dataset.data[i]['demo_emb_list'][-1].unsqueeze(0))
        task_embs = torch.cat(task_emb_list, dim=0)

        datasets_pre = datasets[:args.load_task+1]
        episodic_memory = [dataset[:cfg.lifelong.n_demos] for dataset in datasets_pre]
        demos_per_dataset = cfg.lifelong.n_demos

    # compute the embeddings for the initial images of the trajectories
    initial_image_tensors = []
    image_id = 0
    for trajectories in episodic_memory:
        print("len of trajectories: ", trajectories.n_demos)
        # for i in range(trajectories.n_demos):
        #     traj_id = trajectories.sequence_dataset.demos[i]
        #     traj_start_id = trajectories.sequence_dataset._demo_id_to_start_indices[traj_id]
        for _, traj_start_id in trajectories.demo_start_indices.items():
            # print("trajectory start id: ", traj_start_id, trajectories[traj_start_id]["obs"]["agentview_rgb"].shape)
            initial_image_array = trajectories[traj_start_id]["obs"]["agentview_rgb"][0]
            # print("\nmin and max: ", np.min(initial_image_array), np.max(initial_image_array), initial_image_array.shape)

            # initial_image_pil = Image.fromarray((np.transpose(initial_image_array, (1, 2, 0)) * 255).astype(np.uint8))
            # initial_image_tensor = transform(initial_image_pil).unsqueeze(0).to(device)
            # initial_image_tensors.append(initial_image_tensor)

            # initial_image_tensor = F.resize(torch.tensor(initial_image_array), [256, 256]).unsqueeze(0).to(device)
            initial_image_tensor = torch.tensor(initial_image_array).unsqueeze(0).to(device)
            initial_image_tensors.append(initial_image_tensor)
            # visualize_image(initial_image_tensor, "demo", image_id)
            image_id += 1

    with torch.no_grad():
        demo_embeddings = []
        for i in range(0, len(initial_image_tensors), algo.cfg.train.batch_size):
            batch_images = initial_image_tensors[i:i + algo.cfg.train.batch_size]
            batch_tensor = torch.cat(batch_images)

            embeddings = algo.policy.r3m(batch_tensor * 255.0)
            demo_embeddings.append(embeddings)

        demo_img_embeddings = torch.cat(demo_embeddings, dim=0)
        # demo_desc_embs = task_embs.repeat_interleave(8, dim=0)
        demo_desc_embs = task_embs
        print("\n\n\n\nshapes of demos and tasks: ", demo_img_embeddings.shape, task_embs.shape, demo_desc_embs.shape)


    task = benchmark.get_task(args.task_id)

    ### ======================= start evaluation ============================

    # # 1. evaluate dataset loss
    # try:
    #     dataset, shape_meta = get_dataset(
    #         dataset_path=os.path.join(
    #             cfg.folder, benchmark.get_task_demonstration(args.task_id)
    #         ),
    #         obs_modality=cfg.data.obs.modality,
    #         initialize_obs_utils=True,
    #         seq_len=cfg.data.seq_len,
    #     )
    #     dataset = GroupedTaskDataset(
    #         [dataset], task_embs[args.task_id : args.task_id + 1]
    #     )
    # except:
    #     print(
    #         f"[error] failed to load task {args.task_id} name {benchmark.get_task_names()[args.task_id]}"
    #     )
    #     sys.exit(0)

    test_loss = 0.0

    # 2. evaluate success rate
    if args.algo == "multitask":
        save_folder = os.path.join(
            args.save_dir,
            f"{args.benchmark}_{args.algo}_{args.policy}_{args.seed}_ep{args.ep}_on{args.task_id}.stats",
        )
    else:
        save_folder = os.path.join(
            args.save_dir,
            f"{args.benchmark}_{args.algo}_{args.policy}_{args.seed}_load{args.load_task}_on{args.task_id}.stats",
        )

    video_folder = os.path.join(
        args.save_dir,
        f"{args.benchmark}_{args.algo}_{args.policy}_{args.seed}_load{args.load_task}_on{args.task_id}_videos",
    )

    total_rollouts = 10  # rollouts before adaptation
    rollout_init_img_embeddings = []
    failed_rollouts_embeddings = defaultdict(list)
    env_num = 10
    rollout_epi = total_rollouts // env_num
    with Timer() as t, VideoWriter(video_folder, args.save_videos) as video_writer:
        env_args = {
            "bddl_file_name": os.path.join(
                cfg.bddl_folder, task.problem_folder, task.bddl_file
            ),
            "camera_heights": cfg.data.img_h,
            "camera_widths": cfg.data.img_w,
        }

        env = SubprocVectorEnv(
            [lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)]
        )
        num_success = 0

        for epi in range(rollout_epi):
            env.reset()
            env.seed(cfg.seed)
            algo.reset()

            init_states_path = os.path.join(
                cfg.init_states_folder, task.problem_folder, task.init_states_file
            )
            init_states = torch.load(init_states_path)
            indices = np.arange(env_num) % init_states.shape[0]
            init_states_ = init_states[indices]

            dones = [False] * env_num
            steps = 0
            obs = env.set_init_state(init_states_)
            task_emb = benchmark.get_task_emb(args.task_id)

            # rollout_emb_list = []
            # for _ in range(env_num):
            #     rollout_emb = benchmark.get_augmented_task_emb(args.task_id).unsqueeze(0)
            #     rollout_emb_list.append(rollout_emb)

            for _ in range(5):  # simulate the physics without any actions
                obs, _, _, _ = env.step(np.zeros((env_num, 7)))

            with torch.no_grad():
                while steps < cfg.eval.max_steps:
                    steps += 1

                    data = raw_obs_to_tensor_obs(obs, task_emb, cfg)
                    # data = raw_obs_to_tensor_obs(obs, torch.cat(rollout_emb_list, dim=0), cfg)
                    actions = algo.policy.get_action(data, save_emb=True)
                    obs, reward, done, info = env.step(actions)
                    video_writer.append_vector_obs(
                        obs, dones, camera_name="agentview_image"
                    )

                    for i in range(env_num):
                        if not dones[i]:
                            # Create a subfolder for each rollout (failed or not)
                            subfolder = os.path.join(video_folder, f"rollout_{epi * env_num + i}")
                            os.makedirs(subfolder, exist_ok=True)

                            img = obs[i]["agentview_image"]
                            flipped_image = np.flipud(img)

                            if args.save_images:
                                image_path = os.path.join(subfolder, f"frame_{steps:04d}.png")
                                image = Image.fromarray(flipped_image)
                                image.save(image_path)

                    # check whether succeed
                    for k in range(env_num):
                        dones[k] = dones[k] or done[k]
                    if all(dones):
                        break

                for k in range(env_num):
                    rollout_init_img_embeddings.append(algo.policy.rollouts_embeddings[k][0])
                    rollout_index = epi * env_num + k
                    subfolder = os.path.join(video_folder, f"rollout_{rollout_index}")

                    if not dones[k]:
                        failed_rollouts_embeddings[rollout_index].\
                            extend(torch.cat(algo.policy.rollouts_embeddings[k], dim=0))
                        print(f"Rollout {rollout_index} of {k} at episode {epi}  failed")
                    else:
                        # If rollout was successful, delete the corresponding subfolder
                        if os.path.exists(subfolder):
                            shutil.rmtree(subfolder)
                        num_success += 1

            algo.policy.rollouts_embeddings.clear()

        success_rate = num_success / total_rollouts
        env.close()

        eval_stats = {
            "loss": test_loss,
            "num_success": num_success,
            "success_rate": success_rate,
        }

        os.system(f"mkdir -p {args.save_dir}")
        torch.save(eval_stats, save_folder)
    print(
        f"[info] finish for ckpt at {run_folder} in {t.get_elapsed_time()} sec for rollouts"
    )
    print(f"Results are saved at {save_folder}")
    print(test_loss, num_success, success_rate)

    with torch.no_grad():
        rollout_img_embeddings = torch.cat(rollout_init_img_embeddings)
        dist_img_embeddings = torch.cdist(demo_img_embeddings, rollout_img_embeddings)
        dist_visual_avg = dist_img_embeddings.mean(dim=1).cpu()
        dist_desc_embeddings = torch.cdist(demo_desc_embs, task_emb.unsqueeze(0)).squeeze()
        # dist_desc_embeddings = torch.cdist(demo_desc_embs, torch.cat(rollout_emb_list, dim=0))
        # print("!!!!!shapes: ", rollout_emb_list[0].shape, demo_desc_embs.shape, torch.cat(rollout_emb_list, dim=0).shape,
        #       dist_desc_embeddings.shape, "\n\n", demo_desc_embs[:, :10], "\n\n",
        #       torch.cat(rollout_emb_list, dim=0)[:, :10], "\n\n\n\n")
        # dist_language_avg = dist_desc_embeddings.mean(dim=1).cpu()

        if cfg.benchmark_name == "libero_goal":
            print("libero_goal benchmark\n\n")
            dist_demo_rollout = 0.5 * dist_visual_avg + dist_desc_embeddings
        elif cfg.benchmark_name == "libero_90":
            print("libero_90 benchmark\n\n")
            dist_demo_rollout = dist_visual_avg + 0.1 * dist_desc_embeddings
        else:
            print("libero_spatial or libero_object benchmarks\n\n")
            dist_demo_rollout = dist_visual_avg + 0.5 * dist_desc_embeddings        
        print("dist: ", "\n", dist_visual_avg, "\n", dist_desc_embeddings, "\n",
              dist_demo_rollout, "\n\n\n\n")
        # dist_demo_rollout = dist_visual_avg + 0.5 * dist_language_avg
        # print("dist: ", "\n", dist_visual_avg, "\n", dist_language_avg, "\n",
        #       dist_demo_rollout, "\n\n\n\n")

        num_elements_to_select = int(0.10 * len(dist_demo_rollout))
        _, indices = torch.topk(-dist_demo_rollout, num_elements_to_select)
        sorted_indices = indices.sort().values

        # below is a temporary solution of selecting the most similar demonstrations from the datasets
        demo_to_dataset = {i: [] for i in range(len(episodic_memory))}

        # Categorize each demo index into the correct dataset
        for idx in sorted_indices:
            dataset_index = torch.div(idx, demos_per_dataset, rounding_mode='floor').item()
            demo_index_within_dataset = (idx % demos_per_dataset).item()
            demo_to_dataset[dataset_index].append(demo_index_within_dataset)

        retrieved_episodic_memory = []
        for dataset_idx, demo_indices in demo_to_dataset.items():
            if demo_indices:  # Only proceed if there are demo indices to process
                retrieved_episodic_memory.append(episodic_memory[dataset_idx].slice_by_list(demo_indices))

    # Weighted local adaptation based on the rollout results and retrieved trajectories
    if args.weighted_adapt:
        if success_rate < 1.00:
            print("Weighted adaptation")
            weighted_adapt(video_folder, algo, retrieved_episodic_memory, cfg, args,
                           failed_rollouts_embeddings, benchmark, epochs=args.epochs)
    else:
        print("Uniform adaptation")
        adapt(algo, retrieved_episodic_memory, cfg, epochs=args.epochs)

    algo.eval()

    total_rollouts = 20  # rollouts after adaptation
    rollout_epi = total_rollouts // env_num
    num_success = 0
    with Timer() as t, VideoWriter(video_folder, args.save_videos) as video_writer:
        env_args = {
            "bddl_file_name": os.path.join(
                cfg.bddl_folder, task.problem_folder, task.bddl_file
            ),
            "camera_heights": cfg.data.img_h,
            "camera_widths": cfg.data.img_w,
        }

        env = SubprocVectorEnv(
            [lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)]
        )

        for epi in range(rollout_epi):
            env.reset()
            env.seed(cfg.seed)
            algo.reset()

            init_states_path = os.path.join(
                cfg.init_states_folder, task.problem_folder, task.init_states_file
            )
            init_states = torch.load(init_states_path)
            indices = np.arange(env_num) % init_states.shape[0]
            init_states_ = init_states[indices]

            dones = [False] * env_num
            steps = 0
            obs = env.set_init_state(init_states_)
            task_emb = benchmark.get_task_emb(args.task_id)

            # rollout_emb_list = []
            # for _ in range(env_num):
            #     rollout_emb = benchmark.get_augmented_task_emb(args.task_id).unsqueeze(0)
            #     rollout_emb_list.append(rollout_emb)

            for _ in range(5):  # simulate the physics without any actions
                obs, _, _, _ = env.step(np.zeros((env_num, 7)))

            while steps < cfg.eval.max_steps:
                steps += 1

                data = raw_obs_to_tensor_obs(obs, task_emb, cfg)
                # data = raw_obs_to_tensor_obs(obs, torch.cat(rollout_emb_list, dim=0), cfg)
                actions = algo.policy.get_action(data)
                obs, reward, done, info = env.step(actions)
                video_writer.append_vector_obs(
                    obs, dones, camera_name="agentview_image"
                )

                # for i in range(env_num):
                #     if not dones[i]:
                #         img = obs[i]["agentview_image"]
                #         flipped_image = np.flipud(img)
                #
                #         if args.save_images:
                #             image_path = os.path.join(video_folder,
                #                                       f"test_frame_{epi * env_num + i}_step_{steps:04d}.png")
                #             image = Image.fromarray(flipped_image)
                #             image.save(image_path)

                # check whether succeed
                for k in range(env_num):
                    dones[k] = dones[k] or done[k]
                if all(dones):
                    break

            for k in range(env_num):
                num_success += int(dones[k])

        success_rate = num_success / total_rollouts
        env.close()

        eval_stats = {
            "test_success_rate": success_rate,
        }

        test_save_folder = os.path.join(
            args.save_dir,
            f"{args.benchmark}_{args.algo}_{args.policy}_{args.seed}_test_after_adaptation.stats"
        )

        os.system(f"mkdir -p {args.save_dir}")
        torch.save(eval_stats, test_save_folder)

    print(
        f"[info] Testing finished in {t.get_elapsed_time()} sec for {total_rollouts} rollouts"
    )
    print(f"Testing results are saved at {test_save_folder}")
    print(f"Test success rate: {success_rate}")


if __name__ == "__main__":
    if multiprocessing.get_start_method(allow_none=True) != "spawn":
        multiprocessing.set_start_method("spawn", force=True)
    main()
