import os
import sys
import h5py
import numpy as np
import torch
from pathlib import Path
import hydra
import yaml
import pprint
import torchvision.transforms as T
from PIL import Image
from easydict import EasyDict
from omegaconf import OmegaConf
from r3m import remove_language_head, cleanup_config

from libero.libero import get_libero_path
from libero.libero.benchmark import get_benchmark
from libero.lifelong.utils import control_seed

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


def load_r3m():
    modelpath = "../models/r3m/model_resnet18.pt"
    configpath = "../IL/configs/config.yaml"

    modelcfg = OmegaConf.load(configpath)
    cleancfg = cleanup_config(modelcfg)
    rep = hydra.utils.instantiate(cleancfg)
    rep = torch.nn.DataParallel(rep)
    r3m_state_dict = remove_language_head(torch.load(modelpath, map_location=torch.device(device))['r3m'])
    rep.load_state_dict(r3m_state_dict)
    return rep.module


r3m = load_r3m()
r3m.eval()
transforms = T.Compose([T.Resize(256), T.ToTensor()])


def r3m_embed(image):
    # Preprocess the image: resize, normalize, and convert to tensor
    image = Image.fromarray(image)  # Convert numpy array to PIL Image
    image = transforms(image).unsqueeze(0).to(device)  # Apply transformations and add batch dimension

    # Get the embedding from the R3M model
    with torch.no_grad():
        embedding = r3m(image * 255.0)

    # Assuming the R3M model outputs a dictionary with 'r3m' key containing the embedding
    return embedding.squeeze().cpu().numpy()


def copy_hdf5_structure(src_file, dst_file):
    with h5py.File(src_file, 'r') as src, h5py.File(dst_file, 'w') as dst:
        src.visititems(lambda name, obj: copy_item(name, obj, dst))


def copy_item(name, obj, dst):
    if isinstance(obj, h5py.Dataset):
        dst.create_dataset(name, data=obj[()])
    elif isinstance(obj, h5py.Group):
        dst.create_group(name)
    # Copy attributes
    for key, value in obj.attrs.items():
        dst[name].attrs[key] = value


def modify_embeddings(file_path, r3m_embed):
    with h5py.File(file_path, 'a') as f:
        for group_name in f.keys():
            grp = f[group_name]
            for demo_name in grp.keys():
                demo = grp[demo_name]
                obs = demo['obs']
                for obs_key in obs.keys():
                    if obs_key in ['agentview_rgb', 'eye_in_hand_rgb']:
                        images = obs[obs_key][:]
                        embeddings = np.array([r3m_embed(image) for image in images], dtype=np.float32)
                        del obs[obs_key]  # Delete the old dataset
                        obs.create_dataset(obs_key, data=embeddings)  # Create new dataset with embeddings


def repack_hdf5(src_file, dst_file):
    with h5py.File(src_file, 'r') as src, h5py.File(dst_file, 'w') as dst:
        src.visititems(lambda name, obj: copy_item(name, obj, dst))


@hydra.main(config_path="../configs", config_name="config", version_base=None)
def main(hydra_cfg):
    # preprocessing
    yaml_config = OmegaConf.to_yaml(hydra_cfg)
    cfg = EasyDict(yaml.safe_load(yaml_config))

    # print configs to terminal
    pp = pprint.PrettyPrinter(indent=2)
    pp.pprint(cfg)

    # control seed
    control_seed(cfg.seed)

    # prepare lifelong learning
    cfg.folder = cfg.folder or get_libero_path("datasets")
    cfg.bddl_folder = cfg.bddl_folder or get_libero_path("bddl_files")
    cfg.init_states_folder = cfg.init_states_folder or get_libero_path("init_states")

    benchmark = get_benchmark(cfg.benchmark_name)(cfg.data.task_order_index)
    n_manip_tasks = benchmark.n_tasks

    for i in range(n_manip_tasks):
        task_demo_path = os.path.join(cfg.folder, benchmark.get_task_demonstration(i))
        intermediate_task_demo_path = os.path.join(cfg.folder,
                                                   benchmark.get_task_demonstration_emb(i) + "_intermediate")
        final_task_demo_path = os.path.join(cfg.folder, benchmark.get_task_demonstration_emb(i))
        os.makedirs(os.path.dirname(final_task_demo_path), exist_ok=True)

        print(f"Processing dataset {i} at {task_demo_path}")  # Print dataset path

        # Copy the HDF5 file structure
        copy_hdf5_structure(task_demo_path, intermediate_task_demo_path)

        # Modify the embeddings in the new HDF5 file
        modify_embeddings(intermediate_task_demo_path, r3m_embed)

        # Repack the HDF5 file to optimize storage
        repack_hdf5(intermediate_task_demo_path, final_task_demo_path)

        # Remove the intermediate file
        os.remove(intermediate_task_demo_path)

        print(f"Saved new dataset with embeddings at {final_task_demo_path}")

if __name__ == "__main__":
    main()
