#-*- coding:utf-8 -*-

from dataset.robomimic_lowdim_dataset import RobomimicReplayLowdimDataset
from dataset.pusht_dataset import PushTStateDataset
from dataset.pusht_dataset import normalize_data
from vecdb import RobotFAISS
from dataset.tasks import *
from tqdm import tqdm
import argparse
import faiss 
import gdown
import os

CONTROL_TYPE = ControlType.STATE

parser = argparse.ArgumentParser()
parser.add_argument('--task_type', type=str, default="TOOLHANG")
parser.add_argument('--task_tag', type=str, default="PH")
parser.add_argument('--index-name', type=str, default="toolhang.index")
parser.add_argument('--checkpoint', type=str, default="./weights/t-push-diffusion-epoch50.pt")
opt = parser.parse_args()

if opt.task_tag == "":
    task_tag = TaskTags.NONE
elif opt.task_tag == "PH":
    task_tag = TaskTags.PH 
elif opt.task_tag == "MH":
    task_tag = TaskTags.MH
else:
    raise NotImplementedError(f"Task Tag {opt.task_tag} Not implemented")

if opt.task_type == 'PUSHT':
    task_type = TaskTypes.PUSHT
    task = PushT(ctype=CONTROL_TYPE)
elif opt.task_type == 'LIFT':
    task_type = TaskTypes.LIFT
    task = Lift(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'CAN':
    task_type = TaskTypes.CAN
    task = Can(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'SQUARE':
    task_type = TaskTypes.SQUARE
    task = Square(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'TRANSPORT':
    task_type = TaskTypes.TRANSPORT
    task = Transport(ctype=CONTROL_TYPE, tag=task_tag)
elif opt.task_type == 'TOOLHANG':
    task_type = TaskTypes.TOOLHANG
    task = ToolHang(ctype=CONTROL_TYPE)
else:
    raise NotImplementedError(f"Task {opt.task_type} Not implemented")

device = 'cuda'


def main():
    pred_horizon = task.pred_horizon
    obs_horizon = task.obs_horizon
    action_horizon = task.action_horizon

    # Prepare Index
    length = task.obs_dim * task.obs_horizon
    index_name = opt.index_name
    vecdb = RobotFAISS(index_name=index_name, vector_dimensions=length)

    # For State conditioned, ignore loading pretrained encoders.
    encoder = None
    # Loading Datasets
    if task_type == TaskTypes.PUSHT:
        # download demonstration data from Google Drive
        dataset_path = "./data/pusht/pusht_cchi_v7_replay.zarr.zip"
        if not os.path.isfile(dataset_path):
            id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
            gdown.download(id=id, output=dataset_path, quiet=False)
        dataset = PushTStateDataset(
            dataset_path=task.dataset_path,
            pred_horizon=pred_horizon,
            obs_horizon=obs_horizon,
            action_horizon=action_horizon
        )
    else:
        dataset = RobomimicReplayLowdimDataset(
            dataset_path=task.dataset_path,
            horizon=task.pred_horizon,
            obs_keys=task.obs_keys,
            abs_action=True,
            pad_before=1,
            pad_after=7,
        )

    print("Dataset Length:", len(dataset))

    obs_vectors, action_vectors = [], []
    for data in tqdm(dataset):
        obs = data['obs']
        action = data['action']

        if task_type == TaskTypes.PUSHT:
            obs_seq = obs[:obs_horizon,:]
            nobs = normalize_data(obs_seq, stats=dataset.stats['obs'])
            nobs_cond = nobs.reshape(-1)
            naction = normalize_data(action, stats=dataset.stats['action']).reshape(-1)
        else:
            nobs = dataset.normalizer['obs'].normalize(obs)
            nobs_cond = nobs[:obs_horizon,:]
            nobs_cond = nobs_cond.reshape(-1).detach().numpy()
            naction = dataset.normalizer['action'].normalize(action).reshape(-1).detach().numpy()
        obs_vectors.append(nobs_cond)
        action_vectors.append(naction)

    vecdb.initialize_db(input_vectors=obs_vectors, result_vectors=action_vectors)

    # d = dataset[0]
    # o = d['obs']
    # no = dataset.normalizer['obs'].normalize(o).reshape(-1).detach().numpy()
    # na = vecdb.search(no, k=5)
    # print(len(na))

if __name__ == '__main__':
    main()


