import logging
import os
import random
from typing import Any, Dict, List, Union
from mmengine.dataset import BaseDataset
from mmcv import BaseTransform
from mmengine.dataset import Compose

import mmengine
import numpy as np
import torch
from mmengine import print_log
from tqdm import tqdm

from mmhug.registry import DATASETS, HF_MODELS

ALL_UNITALKER_ID_NUM = 935  # 934 id and 1 common id


@DATASETS.register_module(force=True)
class UnitalkerMultiDataset(BaseDataset):
    def __init__(self, subsets: List[Dict], duplicate: List[int] = None):
        self.datasets = []
        if duplicate is None:
            duplicate = [1] * len(subsets)
        self.dup = duplicate
        assert len(subsets) == len(
            duplicate
        ), f"subsets must have same length as duplicate, but got {len(subsets)} != {len(duplicate)}"
        # 构建每个子数据集
        for subset_cfg, dup in zip(subsets, duplicate):
            ds = DATASETS.build(subset_cfg)
            self.datasets.append(ds)
        # 预计算每个子集原始长度
        self.original_lengths = [len(ds) for ds in self.datasets]
        # 计算每个子集在复制后所占的总长度
        self.cumulative_lengths = []
        cum = 0
        for length, d in zip(self.original_lengths, self.dup):
            cum += length * d
            self.cumulative_lengths.append(cum)

        print_log(
            f"UnitalkerMultiDataset init done, total length: {len(self)}",
            logger="current",
        )

    def __len__(self):
        # 总长度为所有子集长度×重复次数之和
        return self.cumulative_lengths[-1] if self.cumulative_lengths else 0

    def __getitem__(self, idx: int):
        if idx < 0:
            idx += len(self)
        assert 0 <= idx < len(self), f"索引 {idx} 越界，总长度 {len(self)}"
        # 查找 idx 对应哪个子集
        ds_idx = 0
        while self.cumulative_lengths[ds_idx] <= idx:
            ds_idx += 1

        prev_cum = self.cumulative_lengths[ds_idx - 1] if ds_idx > 0 else 0
        local_idx = idx - prev_cum
        sub_len = self.original_lengths[ds_idx]
        # 确定样本在子数据集中的实际索引
        actual_idx = local_idx % sub_len

        return self.datasets[ds_idx][actual_idx]

    def load_data_list(self) -> List[dict]:
        pass

    def full_init(self):
        pass


@DATASETS.register_module(force=True)
class UnitalkerSingleDataset(BaseDataset):

    def __init__(
        self,
        data_dir: str = "data/unitalker_data_release_V1/D5_unitalker_faceforensics++",
        anno_file: str = "data/unitalker_data_release_V1/D5_unitalker_faceforensics++/train.json",
        pipeline: Union[Dict, BaseTransform, List[Union[Dict, BaseTransform]]] = None,
        template_file: str = None,
        refetch=True,
    ):
        self.data_dir = data_dir
        self.ann_file = anno_file  # Bug fix: Should use ann_file instead of data_list
        self._metainfo = {}

        self.pipeline = Compose(pipeline)
        self.data_list = self.load_data_list()
        self.refetch = refetch
        if template_file is not None:
            self.template_file = template_file
        else:
            self.template_file = os.path.join(
                os.path.dirname(self.ann_file), "id_template.npy"
            )

        assert os.path.exists(self.template_file)

        if self.template_file.endswith(".npy"):
            template = torch.from_numpy(np.load(self.template_file))

        elif self.template_file.endswith(".pth"):
            template = torch.load(self.template_file, map_location="cpu")

        self.template = template

    def load_data_list(self) -> List[dict]:
        """Unitalker dataset annotation format:
        {
            "info": ...
            "data": [{data_1...}, {data_2...}, ...]
        }
        """
        annotations = mmengine.load(self.ann_file)
        if not isinstance(annotations, dict):
            raise TypeError(
                f"The annotations loaded from annotation file "
                f"should be a dict, but got {type(annotations)}!"
            )

        meta_info = annotations["info"]
        for k, v in meta_info.items():
            self._metainfo.setdefault(k, v)

        raw_data_list = annotations["data"]

        return raw_data_list

    def __len__(self) -> int:
        return len(self.data_list)

    def prepare_data(self, idx) -> Any:
        raw_data_info = self.data_list[idx]

        # the template file is [t, d]. ONLY USE a single frame as the template
        # the template file is [t, d]. ONLY USE a single frame as the template
        template_frame_idx = unitalker_dataset_config[raw_data_info["dataset"]][
            "template_idx"
        ]
        template = self.template[template_frame_idx]

        template_frame_idx = unitalker_dataset_config[raw_data_info["dataset"]][
            "template_idx"
        ]
        template = self.template[template_frame_idx]

        data_info = {
            "motion_path": os.path.join(self.data_dir, raw_data_info["annot_path"]),
            "motion_metadata": {
                "fps": raw_data_info["fps"],
                "motion_type": raw_data_info["annot_type"],
                "scale": unitalker_dataset_config[raw_data_info["dataset"]]["scale"],
            },
            "audio_path": os.path.join(self.data_dir, raw_data_info["audio_path"]),
            "audio_metadata": {},
            "dataset_id": raw_data_info["dataset"],  # D0 - D7
            "template": template,
            "use_pca": unitalker_dataset_config[raw_data_info["dataset"]]["pca"],
            "identity": raw_data_info["id"]
            + unitalker_dataset_config[raw_data_info["dataset"]][
                "id_index_offset"
            ],  # a unique id for each identity
        }
        return data_info

    def __getitem__(self, idx: int) -> dict:
        """The returned Dict will be like:
        {
            "motion": torch.rand([T, D]),
            "motion_metadata": {
                "motion_path": motion_path,
                "motion_type": motion_type,
                "motion_scale": motion_scale
                "duration": T / fps,
                "num_frames": T,
                "fps": fps,
            },
            "audio": torch.rand([T]),
            "audio_metadata": {
                "audio_path": audio_path,
                'sr': SR,
                'duration': T,
                'num_frames': T * SR,
            },

        }
        """
        sample = self.prepare_data(idx)

        try:
            sample = self.pipeline(sample)
            return sample
        except Exception as e:
            if not self.refetch:
                raise e
            print_log(
                f"Get error when loading idx={idx}: fetching another instead, error: {e}",
                level=logging.WARNING,
            )
            # EN: Log a warning indicating which index failed and the exception.

            new_idx = random.randint(0, len(self.data_list) - 1)
            if new_idx == idx:
                new_idx = (idx + 1) % len(self.data_list)
            # EN: If by chance the random pick equals the failed idx, shift by 1.

            # EN: Recursively fetch another sample.
            return self.__getitem__(new_idx)


unitalker_dataset_config = {
    "D0": {
        "dirname": "D0_BIWI",
        "annot_type": "BIWI_23370_vertices",
        "scale": 0.2,
        "annot_dim": 23370 * 3,
        "subjects": 6,
        "template_idx": 3,
        "pca": True,
        "pca_path": "data/unitalker_data_release_V1/D0_BIWI/pca.npz",
        "id_index_offset": 0,
    },
    "D1": {
        "dirname": "D1_vocaset",
        "annot_type": "FLAME_5023_vertices",
        "scale": 1.0,
        "annot_dim": 5023 * 3,
        "subjects": 12,
        "template_idx": 3,
        "pca": True,
        "pca_path": "data/unitalker_data_release_V1/D1_vocaset/pca.npz",
        "id_index_offset": 6,
    },
    "D2": {
        "dirname": "D2_meshtalk",
        "annot_type": "meshtalk_6172_vertices",
        "scale": 0.001,
        "annot_dim": 6172 * 3,
        "subjects": 13,
        "template_idx": 0,
        "pca": True,
        "pca_path": "data/unitalker_data_release_V1/D2_meshtalk/pca.npz",
        "id_index_offset": 6 + 12,
    },
    "D3": {
        "dirname": "D3D4_3DETF/D3_HDTF",
        "annot_type": "3DETF_blendshape_weight",
        "scale": 1.0,
        "annot_dim": 52,
        "subjects": 141,
        "template_idx": 0,
        "pca": False,
        "id_index_offset": 6 + 12 + 13,
    },
    "D4": {
        "dirname": "D3D4_3DETF/D4_RAVDESS",
        "annot_type": "3DETF_blendshape_weight",
        "scale": 1.0,
        "annot_dim": 52,
        "subjects": 24,
        "template_idx": 0,
        "pca": False,
        "id_index_offset": 6 + 12 + 13 + 141,
    },
    "D5": {
        "dirname": "D5_unitalker_faceforensics++",
        "annot_type": "flame_params_from_dadhead",
        "scale": 1.0,
        "annot_dim": 413,
        "subjects": 719,
        "template_idx": 0,
        "pca": False,
        "id_index_offset": 6 + 12 + 13 + 141 + 24,
    },
    "D6": {
        "dirname": "D6_unitalker_Chinese_speech",
        "annot_type": "inhouse_blendshape_weight",
        "scale": 1.0,
        "annot_dim": 51,
        "subjects": 8,
        "template_idx": 4,
        "pca": False,
        "id_index_offset": 6 + 12 + 13 + 141 + 24 + 719,
    },
    "D7": {
        "dirname": "D7_unitalker_song",
        "annot_type": "inhouse_blendshape_weight",
        "scale": 1.0,
        "annot_dim": 51,
        "subjects": 11,
        "template_idx": 0,
        "pca": False,
        "id_index_offset": 6 + 12 + 13 + 141 + 24 + 719 + 8,
    },
}


if __name__ == "__main__":
    from mmengine import init_default_scope

    # Test unitalker dataset
    init_default_scope("mmhug")

    pipeline = [
        dict(
            type="LoadUnitalkerMotion",
            motion_path_key="motion_path",
        ),
        dict(
            type="LoadAudio",
            audio_path_key="audio_path",
            sampling_rate=16000,
            mono=True,
        ),
    ]

    dataset = dict(
        type="UnitalkerMultiDataset",
        subsets=[
            dict(
                type="UnitalkerSingleDataset",
                data_dir="data/unitalker_data_release_V1/D5_unitalker_faceforensics++",
                anno_file="data/unitalker_data_release_V1/D5_unitalker_faceforensics++/train_val.json",
                pipeline=pipeline,
                refetch=True,
            ),
            dict(
                type="UnitalkerSingleDataset",
                data_dir="data/unitalker_data_release_V1/D6_unitalker_Chinese_speech",
                anno_file="data/unitalker_data_release_V1/D6_unitalker_Chinese_speech/train_val.json",
                pipeline=pipeline,
                refetch=True,
            ),
            dict(
                type="UnitalkerSingleDataset",
                data_dir="data/unitalker_data_release_V1/D7_unitalker_song",
                anno_file="data/unitalker_data_release_V1/D7_unitalker_song/train_val.json",
                pipeline=pipeline,
                refetch=True,
            ),
        ],
        duplicate=None,
    )

    dataset = DATASETS.build(dataset)

    idx = list(range(len(dataset)))
    random.shuffle(idx)

    for i in tqdm(idx):
        data = dataset[i]
        motion = data["motion"]
        audio = data["audio"]

        print(motion.shape)
        print(audio.shape)
