import argparse
import sys
import os
import omegaconf
import hydra

# TODO: find a better way for this?
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import numpy as np
import torch.nn as nn
import torch
import multiprocessing
import time
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

from libero.libero import get_libero_path
from libero.libero.benchmark import get_benchmark
from libero.libero.envs import OffScreenRenderEnv, SubprocVectorEnv
from libero.libero.utils.time_utils import Timer
from libero.libero.utils.video_utils import VideoWriter
from libero.lifelong.algos import *
from libero.lifelong.datasets import get_dataset, SequenceVLDataset, GroupedTaskDataset, TruncatedSequenceDataset
from libero.lifelong.metric import (
    evaluate_loss,
    evaluate_success,
    raw_obs_to_tensor_obs,
)
from libero.lifelong.utils import (
    control_seed,
    safe_device,
    torch_load_model,
    NpEncoder,
    compute_flops,
)
from r3m import remove_language_head, cleanup_config
from libero.lifelong.main import get_task_embs

import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.tensor_utils as TensorUtils

import time


benchmark_map = {
    "libero_10": "libero_10",
    "libero_spatial": "libero_spatial",
    "libero_object": "libero_object",
    "libero_goal": "libero_goal",
}

algo_map = {
    "base": "Sequential",
    "er": "ER",
    "ewc": "EWC",
    "packnet": "PackNet",
    "multitask": "Multitask",
}

policy_map = {
    "bc_rnn_policy": "BCRNNPolicy",
    "bc_transformer_policy": "BCTransformerPolicy",
    "bc_vilt_policy": "BCViLTPolicy",
    "bc_transformer_policy_r3m": "BCTransformerPolicyR3M",
}

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluation Script")
    parser.add_argument("--experiment_dir", type=str, default="experiments")
    # for which task suite
    parser.add_argument(
        "--benchmark",
        type=str,
        required=True,
        choices=["libero_10", "libero_spatial", "libero_object", "libero_goal"],
    )
    parser.add_argument("--task_id", type=int, required=True)
    # method detail
    parser.add_argument(
        "--algo",
        type=str,
        required=True,
        choices=["base", "er", "ewc", "packnet", "multitask"],
    )
    parser.add_argument(
        "--policy",
        type=str,
        required=True,
        choices=["bc_rnn_policy", "bc_transformer_policy", "bc_vilt_policy", "bc_transformer_policy_r3m"],
    )
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--ep", type=int)
    parser.add_argument("--load_task", type=int)
    parser.add_argument("--device_id", type=int)
    parser.add_argument("--save-videos", action="store_true")
    parser.add_argument("--save-images", action="store_true", help="Flag to save images")
    # parser.add_argument('--save_dir',  type=str, required=True)
    args = parser.parse_args()
    args.device_id = "cuda:" + str(args.device_id)
    args.save_dir = f"{args.experiment_dir}_saved"

    if args.algo == "multitask":
        assert args.ep in list(
            range(0, 50, 5)
        ), "[error] ep should be in [0, 5, ..., 50]"
    else:
        assert args.load_task in list(
            range(10)
        ), "[error] load_task should be in [0, ..., 9]"
    return args

def visualize_image(image_tensor, image_type="demo", image_id=0):
    """Visualize and optionally save an image from a tensor."""
    image_array = image_tensor.squeeze(0).cpu().permute(1, 2, 0).numpy()
    plt.imshow(np.flipud(image_array))
    plt.axis('off')

    # Save the image with a distinguishing name
    image_name = f"{image_type}_{image_id}.png"
    plt.savefig(image_name)
    plt.show()


def load_r3m(device):
    modelpath = "../models/r3m/model_resnet18.pt"
    configpath = "../IL/configs/config.yaml"

    modelcfg = omegaconf.OmegaConf.load(configpath)
    cleancfg = cleanup_config(modelcfg)
    rep = hydra.utils.instantiate(cleancfg)
    rep = torch.nn.DataParallel(rep)
    r3m_state_dict = remove_language_head(torch.load(modelpath, map_location=torch.device(device))['r3m'])
    rep.load_state_dict(r3m_state_dict)
    return rep.module


from torch.utils.data import DataLoader, RandomSampler


def local_adapt_network(algo, retrieved_episodic_memory, cfg, epochs=8):
    """
    Perform local adaptation using the retrieved episodic memory.
    """
    algo.policy.train()
    optimizer = eval(cfg.train.optimizer.name)(
            algo.policy.parameters(), **cfg.train.optimizer.kwargs
        )
    # print("length of retrieved episodic memory: ", len(retrieved_episodic_memory[0]))
    train_dataloader = DataLoader(
        retrieved_episodic_memory[0],
        batch_size=algo.cfg.train.batch_size,
        num_workers=0,
        pin_memory=True,
        shuffle=True,
    )

    for epoch in range(epochs):
        for data in train_dataloader:
            data = algo.map_tensor_to_device(data)
            optimizer.zero_grad()
            loss = algo.policy.compute_loss(data)
            (algo.loss_scale * loss).backward()
            if algo.cfg.train.grad_clip is not None:
                nn.utils.clip_grad_norm_(
                    algo.policy.parameters(), algo.cfg.train.grad_clip
                )
            optimizer.step()

    print(f"Local adaptation completed over {epochs} epochs")


def main():
    args = parse_args()
    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    r3m = load_r3m(device).eval()
    batch_size = 32

    experiment_dir = os.path.join(
        args.experiment_dir,
        f"{benchmark_map[args.benchmark]}/"
        + f"{algo_map[args.algo]}/"
        + f"{policy_map[args.policy]}_seed{args.seed}",
    )

    # find the checkpoint
    experiment_id = 0
    for path in Path(experiment_dir).glob("run_*"):
        if not path.is_dir():
            continue
        try:
            folder_id = int(str(path).split("run_")[-1])
            if folder_id > experiment_id:
                experiment_id = folder_id
        except BaseException:
            pass
    if experiment_id == 0:
        print(f"[error] cannot find the checkpoint under {experiment_dir}")
        sys.exit(0)

    run_folder = os.path.join(experiment_dir, f"run_{experiment_id:03d}")
    try:
        if args.algo == "multitask":
            model_path = os.path.join(run_folder, f"multitask_model_ep{args.ep}.pth")
            sd, cfg, previous_mask = torch_load_model(
                model_path, map_location=args.device_id
            )
        else:
            model_path = os.path.join(run_folder, f"task{args.load_task}_model.pth")
            sd, cfg, previous_mask = torch_load_model(
                model_path, map_location=args.device_id
            )
    except:
        print(f"[error] cannot find the checkpoint at {str(model_path)}")
        sys.exit(0)

    cfg.folder = get_libero_path("datasets")
    cfg.bddl_folder = get_libero_path("bddl_files")
    cfg.init_states_folder = get_libero_path("init_states")

    cfg.device = args.device_id
    algo = safe_device(eval(algo_map[args.algo])(10, cfg), cfg.device)
    algo.policy.previous_mask = previous_mask

    if cfg.lifelong.algo == "PackNet":
        algo.eval()
        for module_idx, module in enumerate(algo.policy.modules()):
            if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                weight = module.weight.data
                mask = algo.previous_masks[module_idx].to(cfg.device)
                weight[mask.eq(0)] = 0.0
                weight[mask.gt(args.task_id + 1)] = 0.0
                # we never train norm layers
            if "BatchNorm" in str(type(module)) or "LayerNorm" in str(type(module)):
                module.eval()

    algo.policy.load_state_dict(sd)
    algo.eval()

    if not hasattr(cfg.data, "task_order_index"):
        cfg.data.task_order_index = 0

    # get the benchmark the task belongs to
    benchmark = get_benchmark(cfg.benchmark_name)(cfg.data.task_order_index)
    n_manip_tasks = benchmark.n_tasks

    # prepare datasets from the benchmark
    manip_datasets = []
    descriptions = []

    for i in range(n_manip_tasks):
        # currently we assume tasks from same benchmark have the same shape_meta
        try:
            task_i_dataset, shape_meta = get_dataset(
                dataset_path=os.path.join(
                    cfg.folder, benchmark.get_task_demonstration_new_rendering(i)
                ),
                obs_modality=cfg.data.obs.modality,
                initialize_obs_utils=(i == 0),
                seq_len=cfg.data.seq_len,
            )
        except Exception as e:
            print(
                f"[error] failed to load task {i} name {benchmark.get_task_names()[i]}"
            )
            print(f"[error] {e}")
        print(os.path.join(cfg.folder, benchmark.get_task_demonstration_new_rendering(i)))
        # add language to the vision dataset, hence we call vl_dataset
        task_description = benchmark.get_task(i).language
        descriptions.append(task_description)
        manip_datasets.append(task_i_dataset)

    task_embs = get_task_embs(cfg, descriptions)
    benchmark.set_task_embs(task_embs)

    gsz = cfg.data.task_group_size
    if gsz == 1:  # each manipulation task is its own lifelong learning task
        datasets = [
            SequenceVLDataset(ds, emb) for (ds, emb) in zip(manip_datasets, task_embs)
        ]
        datasets_pre = datasets[:args.load_task+1]
        demos_per_dataset = cfg.lifelong.n_demos
        episodic_memory = [
            TruncatedSequenceDataset(dataset, demos_per_dataset)
            for dataset in datasets_pre
        ]

    # compute the embeddings for the initial images of the trajectories
    initial_image_tensors = []
    memory_indices = []
    image_id = 0
    for trajectories in episodic_memory:
        print("len of trajectories: ", trajectories.n_demos)
        # for i in range(trajectories.n_demos):
        #     traj_id = trajectories.sequence_dataset.demos[i]
        #     traj_start_id = trajectories.sequence_dataset._demo_id_to_start_indices[traj_id]
        for _, traj_start_id in trajectories.demo_start_indices.items():
            # print("trajectory start id: ", traj_start_id, trajectories[traj_start_id]["obs"]["agentview_rgb"].shape)
            initial_image_array = trajectories[traj_start_id]["obs"]["agentview_rgb"][0]
            # print("\nmin and max: ", np.min(initial_image_array), np.max(initial_image_array), initial_image_array.shape)

            # initial_image_pil = Image.fromarray((np.transpose(initial_image_array, (1, 2, 0)) * 255).astype(np.uint8))
            # initial_image_tensor = transform(initial_image_pil).unsqueeze(0).to(device)
            # initial_image_tensors.append(initial_image_tensor)

            # initial_image_tensor = F.resize(torch.tensor(initial_image_array), [256, 256]).unsqueeze(0).to(device)
            initial_image_tensor = torch.tensor(initial_image_array).unsqueeze(0).to(device)
            initial_image_tensors.append(initial_image_tensor)
            # visualize_image(initial_image_tensor, "demo", image_id)
            image_id += 1

    with torch.no_grad():
        demo_embeddings = []
        for i in range(0, len(initial_image_tensors), batch_size):
            batch_images = initial_image_tensors[i:i + batch_size]
            batch_tensor = torch.cat(batch_images)

            embeddings = r3m(batch_tensor * 255.0)
            demo_embeddings.append(embeddings)

        demo_img_embeddings = torch.cat(demo_embeddings, dim=0)
        demo_desc_embs = task_embs.repeat_interleave(8, dim=0)
        print("\n\n\n\nshapes of demos and tasks: ", demo_img_embeddings.shape, task_embs.shape, demo_desc_embs.shape)


    task = benchmark.get_task(args.task_id)

    ### ======================= start evaluation ============================

    # # 1. evaluate dataset loss
    # try:
    #     dataset, shape_meta = get_dataset(
    #         dataset_path=os.path.join(
    #             cfg.folder, benchmark.get_task_demonstration(args.task_id)
    #         ),
    #         obs_modality=cfg.data.obs.modality,
    #         initialize_obs_utils=True,
    #         seq_len=cfg.data.seq_len,
    #     )
    #     dataset = GroupedTaskDataset(
    #         [dataset], task_embs[args.task_id : args.task_id + 1]
    #     )
    # except:
    #     print(
    #         f"[error] failed to load task {args.task_id} name {benchmark.get_task_names()[args.task_id]}"
    #     )
    #     sys.exit(0)

    test_loss = 0.0

    # 2. evaluate success rate
    if args.algo == "multitask":
        save_folder = os.path.join(
            args.save_dir,
            f"{args.benchmark}_{args.algo}_{args.policy}_{args.seed}_ep{args.ep}_on{args.task_id}.stats",
        )
    else:
        save_folder = os.path.join(
            args.save_dir,
            f"{args.benchmark}_{args.algo}_{args.policy}_{args.seed}_load{args.load_task}_on{args.task_id}.stats",
        )

    video_folder = os.path.join(
        args.save_dir,
        f"{args.benchmark}_{args.algo}_{args.policy}_{args.seed}_load{args.load_task}_on{args.task_id}_videos",
    )

    total_rollouts = 15
    env_num = 5
    rollout_epi = total_rollouts // env_num
    rollout_images = [[] for _ in range(env_num)]
    with Timer() as t, VideoWriter(video_folder, args.save_videos) as video_writer:
        env_args = {
            "bddl_file_name": os.path.join(
                cfg.bddl_folder, task.problem_folder, task.bddl_file
            ),
            "camera_heights": cfg.data.img_h,
            "camera_widths": cfg.data.img_w,
        }

        env = SubprocVectorEnv(
            [lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)]
        )
        num_success = 0

        for epi in range(rollout_epi):
            env.reset()
            env.seed(cfg.seed)
            algo.reset()

            init_states_path = os.path.join(
                cfg.init_states_folder, task.problem_folder, task.init_states_file
            )
            init_states = torch.load(init_states_path)
            indices = np.arange(env_num) % init_states.shape[0]
            init_states_ = init_states[indices]

            dones = [False] * env_num
            steps = 0
            obs = env.set_init_state(init_states_)
            task_emb = benchmark.get_task_emb(args.task_id)

            for _ in range(5):  # simulate the physics without any actions
                obs, _, _, _ = env.step(np.zeros((env_num, 7)))

            if epi == 0:
                with torch.no_grad():
                    rollout_img_embeddings = []
                    for i in range(env_num):
                        obs_image_array = obs[i]["agentview_image"]

                        obs_image_tensor = \
                            ObsUtils.process_obs(torch.from_numpy(obs_image_array), "rgb").unsqueeze(0).to(device)


                        # obs_image_array_transpose = np.transpose(obs_image_array, (2, 0, 1)) * 255
                        # obs_image_tensor = torch.tensor(obs_image_array_transpose).unsqueeze(0).to(device)
                        # visualize_image(obs_image_tensor, "rollout", i)
                        rollout_img_embeddings.append(r3m(obs_image_tensor * 255))
                    # print("rollout embeddings: ", torch.cat(rollout_embeddings).shape)

                    # print("\n\n\n\n\nshapes: ", demo_img_embeddings.shape, torch.cat(rollout_img_embeddings).shape,
                    #       demo_desc_embs.shape, task_emb.shape)
                    dist_img_embeddings = torch.cdist(demo_img_embeddings, torch.cat(rollout_img_embeddings))
                    dist_visual_avg = dist_img_embeddings.mean(dim=1).cpu()
                    dist_desc_embeddings = torch.cdist(demo_desc_embs, task_emb.unsqueeze(0)).squeeze()
                    dist_demo_rollout = dist_visual_avg + 0.5 * dist_desc_embeddings
                    print("dist: ", "\n", dist_visual_avg, "\n", dist_desc_embeddings, "\n",
                          dist_demo_rollout, "\n\n\n\n")

                    num_elements_to_select = int(0.10 * len(dist_demo_rollout))
                    _, indices = torch.topk(-dist_demo_rollout, num_elements_to_select)
                    sorted_indices = indices.sort().values


                    # below is a temporary solution of selecting the most similar demonstrations from the datasets
                    demo_to_dataset = {i: [] for i in range(len(episodic_memory))}

                    # Categorize each demo index into the correct dataset
                    for idx in sorted_indices:
                        dataset_index = (idx // demos_per_dataset).item()  # Determine which dataset the demo belongs to
                        demo_index_within_dataset = (idx % demos_per_dataset).item()
                        demo_to_dataset[dataset_index].append(demo_index_within_dataset)

                    retrieved_episodic_memory = []
                    for dataset_idx, demo_indices in demo_to_dataset.items():
                        if demo_indices:  # Only proceed if there are demo indices to process
                            retrieved_episodic_memory.append(episodic_memory[dataset_idx])
                local_adapt_network(algo, retrieved_episodic_memory, cfg)

            algo.eval()
            with torch.no_grad():
                while steps < cfg.eval.max_steps:
                    steps += 1

                    data = raw_obs_to_tensor_obs(obs, task_emb, cfg)
                    actions = algo.policy.get_action(data)
                    obs, reward, done, info = env.step(actions)
                    video_writer.append_vector_obs(
                        obs, dones, camera_name="agentview_image"
                    )

                    for i in range(env_num):
                        rollout_images[i].append(obs[i]['agentview_image'])

                        # save the rollout images for env i at episode 0
                        if i == 0:
                            if not dones[i]:
                                img = obs[i]["agentview_image"]
                                flipped_image = np.flipud(img)

                                if args.save_images and epi==0:
                                    image_path = os.path.join(video_folder, f"rol_frame_{steps:04d}.png")
                                    image = Image.fromarray(flipped_image)
                                    image.save(image_path)

                    # check whether succeed
                    for k in range(env_num):
                        dones[k] = dones[k] or done[k]
                    if all(dones):
                        break

                for k in range(env_num):
                    num_success += int(dones[k])

        success_rate = num_success / total_rollouts
        env.close()

        eval_stats = {
            "loss": test_loss,
            "num_success": num_success,
            "success_rate": success_rate,
        }

        os.system(f"mkdir -p {args.save_dir}")
        torch.save(eval_stats, save_folder)
    print(
        f"[info] finish for ckpt at {run_folder} in {t.get_elapsed_time()} sec for rollouts"
    )
    print(f"Results are saved at {save_folder}")
    print(test_loss, num_success, success_rate)


if __name__ == "__main__":
    if multiprocessing.get_start_method(allow_none=True) != "spawn":
        multiprocessing.set_start_method("spawn", force=True)
    main()
