import torch
from torch.utils.data.dataset import Dataset

import os
import io
import random
import requests
import traceback
import torchaudio
from tqdm import tqdm
import numpy as np
import pandas as pd
import json

from .data_collator import data_collator_token
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from .dataset_tos import FaceDatasetCrop
from torch.utils.data import Dataset, get_worker_info


class FaceDataset(Dataset):
    def __init__(
        self,
        ltc_parquet_path,
        audio_token_path,
        cache_path,
        task_path,
        split,
        load_asr=True,
        load_wav=False,
        select_speaker=False,
        hidden_states_path=None,
        mot_token_path=None,
        data_scale_factor=10000,
        audio_fps=12.5,
        motion_fps=25,
        audio_pad_id=8520,
        mean_path=None,
        std_path=None,
        debug=False,
        **kwargs,
    ):
        self.hidden_states_path = hidden_states_path
        self.mot_token_path = mot_token_path
        self.audio_fps = audio_fps
        self.motion_fps = motion_fps
        self.audio_pad_id = audio_pad_id

        # 解决训练卡死
        self.data_scale_factor = data_scale_factor

        self.video_path_list = []
        self.audio_path_list = []
        self.motion_path_list = []

        self.videos_parquet = pd.read_parquet(
            os.path.join(ltc_parquet_path, f"{split}.parquet")
        )

        if debug:
            self.videos_parquet = self.videos_parquet.head(128)

        self.length = len(self.videos_parquet)
        self.data_collator = data_collator_token
        rank_zero_info(f"data scale: {self.length}")

        # Load normalization data if provided
        self.mean = np.load(mean_path) if mean_path else None
        self.std = np.load(std_path) if std_path else None

        with open(task_path, "r", encoding="utf-8") as f:
            self.tasks = [json.loads(line) for line in f]

        self.missing_videos = []
        self.loaded_data = {}
        self.loaded_num = 0

    def __len__(self):
        return self.length * self.data_scale_factor

    @property
    def nfeats(self):
        return 512

    def _load_motion(self, path, duration=None, timestamp=None):
        feat = np.load(path)

        if self.mean is not None:
            feat = self.normalize(feat)

        return {
            "feature": feat,
            "length": feat.shape[0],
        }

    def _load_hidden_states(self, asr_path):
        if self.hidden_states_path is None:
            return {}
        else:
            hidden_states_path = os.path.join(
                self.hidden_states_path,
                os.path.basename(asr_path).replace(".json", ".pt"),
            )
            return_dict = {"hidden_states_path": hidden_states_path}
            if os.path.exists(hidden_states_path):
                return_dict.update(
                    {"hidden_states": torch.load(hidden_states_path, weights_only=True)}
                )
            return return_dict

    def __getitem__(self, idx):
        """Retrieve the pose data and metadata for a given index."""

        idx = idx % self.length

        if idx in self.missing_videos:
            return None

        metadata = self.videos_parquet.iloc[idx]

        try:
            motion_path = metadata["behavior_v2_path"]

            data = {}
            data.update({"instruction": metadata["question"], "response": ""})
            data.update(self._load_hidden_states(""))
            data.update(self._load_motion(motion_path))
            length = int(data["length"])
            data.update(
                {
                    "motion_tokens": torch.zeros(length),
                }
            )
            audio_length = int(length * self.audio_fps / self.motion_fps)
            data.update(
                {
                    "audio_tokens": torch.ones(audio_length) * self.audio_pad_id,
                    "audio_lengths": audio_length,
                }
            )
            data["video_path"] = metadata["behavior_v2_path"]
            data["audio_path"] = ""
            data["pose_path"] = motion_path

        except Exception as e:
            rank_zero_info(f"Error loading data for index {idx}: {e}")
            self.missing_videos.append(idx)
            return None

        task = random.choice(self.tasks)
        data["tasks"] = task

        return data

    def normalize(self, poses):
        """Normalize the pose features using the mean and standard deviation."""
        if self.mean is not None:
            return (poses - self.mean) / self.std
        return poses

    def denormalize(self, poses):
        """Denormalize the pose features using the mean and standard deviation."""
        if self.mean is not None:
            return poses * self.std + self.mean
        return poses


class FaceLTCDatasetCrop(FaceDataset):
    def __init__(self, window_size=64, **kwargs):
        """
        Initializes the PoseDatasetCrop object, filtering out data entries
        shorter than the specified window size.

        Args:
            window_size (int): The size of the sliding window to crop pose sequences.
            **kwargs: Additional arguments passed to the PoseDataset parent class.
        """
        super().__init__(**kwargs)
        self.window_size = window_size

    def __getitem__(self, idx):
        """
        Retrieves and processes the pose data for the given index.

        Args:
            index (int): Index of the pose data in the dataset.

        Returns:
            tuple: A tuple containing the cropped and normalized pose features,
            sequence length, ignore mask, video size, and a subset of body indices.
        """
        # Access the pose data
        data = super().__getitem__(idx)

        if data is None:
            return None
        else:
            poses = data["feature"]
            audio_tokens = data["audio_tokens"]
            length = data["length"]

        if length <= self.window_size:
            start_idx = 0
            length = poses.shape[0]
        else:
            start_idx = np.random.randint(0, length - self.window_size)
            length = self.window_size

        audio_start_idx = start_idx // 2
        audio_length = length // 2

        # Extract the windowed subset of features, mask, and body subset
        poses = poses[start_idx : start_idx + self.window_size]
        audio_tokens = audio_tokens[audio_start_idx : audio_start_idx + audio_length]

        data["feature"] = poses
        data["length"] = len(poses)
        data["audio_tokens"] = audio_tokens
        data["audio_lengths"] = len(audio_tokens)
        return data


class MergedFaceDataset:
    def __init__(self, dataset_weights=[0.5, 0.5], **kwargs):
        """
        Initializes the MergedFaceDataset object, which combines multiple datasets
        into a single dataset for training.

        Args:
            **kwargs: Additional arguments passed to the FaceDataset parent class.
        """
        super().__init__()
        self.dataset_weights = dataset_weights
        self.data_scale_factor = kwargs.get("data_scale_factor", 10000)
        self.dataset_qa = FaceDatasetCrop(**kwargs)
        self.dataset_ltc = FaceLTCDatasetCrop(**kwargs)
        self.dataset_qa_length = len(self.dataset_qa) // self.data_scale_factor
        self.dataset_ltc_length = len(self.dataset_ltc) // self.data_scale_factor
        self.length = (
            int(
                self.dataset_qa_length * self.dataset_weights[0]
                + self.dataset_ltc_length * self.dataset_weights[1]
            )
            * self.data_scale_factor
        )

    @property
    def nfeats(self):
        return 512

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Get worker information for unique seed per worker
        worker_info = get_worker_info()
        if worker_info is not None:
            # Seed random functions differently per worker
            random.seed(worker_info.id + idx)

        if idx < self.length * self.dataset_weights[0]:
            idx = random.randint(0, self.dataset_qa_length - 1)
            original_dict = self.dataset_qa[idx]
        else:
            idx = random.randint(0, self.dataset_ltc_length - 1)
            original_dict = self.dataset_ltc[idx]

        return original_dict
