import os
from tqdm import tqdm
from typing import Tuple, List
import pickle 
import shlex
import subprocess

BASEDIR = os.path.dirname(os.path.dirname(__file__))
if BASEDIR not in os.sys.path:
    os.sys.path.insert(0, BASEDIR)

from util.dotenv_util import SSH_IP, SSH_PASSWORD, SSH_PORT, OBJECT_PATH


def identify_obj_id_from_mesh_name(mesh_name: str, obj_filename_candidates: List[str]) -> int:
    """
    Identifies the object ID from the mesh name.

    Args:
        mesh_name (str): The mesh file name.
        obj_filename_candidates (List[str]): List of candidate object filenames.

    Returns:
        int: The index of the object in the candidate list, or -1 if not found.
    """
    mesh_name = str(mesh_name)
    target_mesh_basename = os.path.splitext(os.path.basename(mesh_name))[0]
    target_mesh_datasetname = mesh_name.split('/')[-3]  # Adjust index based on path structure
    obj_basename_candidates = [os.path.splitext(os.path.basename(obj_filename))[0] for i, obj_filename in enumerate(obj_filename_candidates)]
    in_cnt = obj_basename_candidates.count(target_mesh_basename)
    if in_cnt > 1:
        for i, obj_filename in enumerate(obj_filename_candidates):
            source_basename = os.path.splitext(os.path.basename(obj_filename))[0]
            source_datasetname = obj_filename.split('/')[-3]
            if target_mesh_basename == source_basename and target_mesh_datasetname == source_datasetname:
                return i
    elif in_cnt == 1:
        return obj_basename_candidates.index(target_mesh_basename)
    return -1


class AssetPaths:
    def __init__(self, model_ckpt_dir, download_assets=True):
        self.model_ckpt_dir = model_ckpt_dir

        # load sdf dirs from txt
        with open(os.path.join(model_ckpt_dir, 'sdf_dirs.txt'), 'r') as f:
            self.rel_sdf_paths = f.readlines()
        self.rel_sdf_paths = [f.strip().replace('/home/dongwon/research/object_set/','') for f in self.rel_sdf_paths]
        self.rel_sdf_paths = [f.strip().replace('/dataset/object_set/','') for f in self.rel_sdf_paths]

        # load obj dirs from txt
        with open(os.path.join(model_ckpt_dir, 'obj_dirs.txt'), 'r') as f:
            self.rel_obj_paths = f.readlines()
        self.rel_obj_paths = [f.strip() for f in self.rel_obj_paths]

        if download_assets:
            self.enrol_or_download_assets()

    # def enrol_or_download_assets(self, asset_base_dir):
    def enrol_or_download_assets(self):
        self.asset_base_dir = OBJECT_PATH
        os.makedirs(OBJECT_PATH, exist_ok=True)
            # download asset files
        print('Downloading assets...')
        for f in tqdm(self.rel_obj_paths):
            absolute_asset_dir = os.path.join(OBJECT_PATH, f)
            if not os.path.exists(absolute_asset_dir):
                os.makedirs(os.path.dirname(absolute_asset_dir), exist_ok=True)
                remote_escaped = f"research/object_set/{f}"
                local_escaped = absolute_asset_dir
                # command = f"sshpass -p {SSH_PASSWORD} scp -r -P {SSH_PORT} {SSH_IP}:{shlex.quote(remote_escaped)} {shlex.quote(local_escaped)}"
                command = f"sshpass -p {SSH_PASSWORD} rsync --ignore-existing -av -e 'ssh -o StrictHostKeyChecking=no -p {SSH_PORT}' {SSH_IP}:{shlex.quote(remote_escaped)} {shlex.quote(local_escaped)}"
                process = subprocess.run(shlex.split(command), check=True)
                if process.returncode != 0:
                    print(f"Error occurred, return code: {process.returncode}")

    def enrol_or_download_sdf_files(self):
        self.asset_base_dir = OBJECT_PATH
        os.makedirs(OBJECT_PATH, exist_ok=True)
            # download asset files
        print('Downloading assets...')
        for f in tqdm(self.rel_sdf_paths):
            absolute_sdf_dir = os.path.join(OBJECT_PATH, f)
            if not os.path.exists(absolute_sdf_dir):
                os.makedirs(os.path.dirname(absolute_sdf_dir), exist_ok=True)
                remote_escaped = f"research/object_set/{f}"
                local_escaped = absolute_sdf_dir
                # command = f"sshpass -p {SSH_PASSWORD} scp -r -P {SSH_PORT} {SSH_IP}:{shlex.quote(remote_escaped)} {shlex.quote(local_escaped)}"
                command = f"sshpass -p {SSH_PASSWORD} rsync --ignore-existing -av -e 'ssh -o StrictHostKeyChecking=no -p {SSH_PORT}' {SSH_IP}:{shlex.quote(remote_escaped)} {shlex.quote(local_escaped)}"
                process = subprocess.run(shlex.split(command), check=True)
                if process.returncode != 0:
                    print(f"Error occurred, return code: {process.returncode}")

    def obj_path_by_idx(self, idx):
        return self.obj_paths[idx]

    def sdf_path_by_idx(self, idx):
        return self.sdf_paths[idx]

    def get_obj_id(self, path):
        '''
        input path could be any sdf or obj path (both relative or absolute)
        '''
        return identify_obj_id_from_mesh_name(path, self.rel_obj_paths)

    def get_obj_path_from_rel_path(self, rel_path):
        oid = self.get_obj_id(rel_path)
        assert oid!=-1
        return self.obj_paths[oid]

    def get_encoded_obj(self, obj_path, pretrain_ckpt_id):
        dataset_name = obj_path.split('/')[-3]
        obj_basename = obj_path.split('/')[-1].split('.')[0]
        oriCORN_asset_path = os.path.join('assets_oriCORNs', pretrain_ckpt_id, dataset_name, obj_basename+'.pkl')

        with open(oriCORN_asset_path, 'rb') as f:
            return pickle.load(f)

    @property
    def obj_paths(self):
        return [os.path.join(self.asset_base_dir, f.strip()) for f in self.rel_obj_paths]

    @property
    def sdf_paths(self):
        return [os.path.join(self.asset_base_dir, f.strip()) for f in self.rel_sdf_paths]

if __name__ == '__main__':

    import util.model_util as mutil
    models = mutil.Models().load_pretrained_models('rep_ckpt/ccol_dec_v2', download_assets=True)
    models.asset_path_util.enrol_or_download_sdf_files()

    print(1)
    # asset_paths = AssetPaths('rep_ckpt/10012024_075911', 'asset_test')
