import torch
from torch.utils.data.dataset import Dataset

import os
import io
import random
import requests
import traceback
import torchaudio
from tqdm import tqdm
import numpy as np
import pandas as pd
import json

from .data_collator import data_collator_token
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from torch.utils.data import Dataset, get_worker_info
from .ltc_utils import load_kps_conf_normalized


class PoseDataset(Dataset):
    def __init__(
        self,
        ltc_parquet_path,
        task_path,
        split,
        motion_token_path=None,
        audio_token_path=None,
        cache_path=None,
        load_asr=True,
        load_wav=False,
        select_speaker=False,
        hidden_states_path=None,
        mot_token_path=None,
        data_scale_factor=10000,
        audio_fps=12.5,
        motion_fps=25,
        audio_pad_id=8520,
        mean_path=None,
        std_path=None,
        debug=False,
        **kwargs,
    ):
        self.hidden_states_path = hidden_states_path
        self.mot_token_path = mot_token_path
        self.audio_fps = audio_fps
        self.motion_fps = motion_fps
        self.audio_pad_id = audio_pad_id
        self.split = split
        self.motion_token_path = motion_token_path

        # 解决训练卡死
        self.data_scale_factor = data_scale_factor

        self.video_path_list = []
        self.audio_path_list = []
        self.motion_path_list = []

        self.videos_parquet = pd.read_parquet(
            os.path.join(ltc_parquet_path, f"{split}.parquet")
        )

        if debug:
            self.videos_parquet = self.videos_parquet.head(128)

        self.length = len(self.videos_parquet)
        self.data_collator = data_collator_token
        rank_zero_info(f"data scale: {self.length}")

        # Load normalization data if provided
        self.mean = np.load(mean_path) if mean_path else None
        self.std = np.load(std_path) if std_path else None

        with open(task_path, "r", encoding="utf-8") as f:
            self.tasks = [json.loads(line) for line in f]

        self.missing_videos = []
        self.loaded_data = {}
        self.loaded_num = 0

    def __len__(self):
        return self.length * self.data_scale_factor

    @property
    def nfeats(self):
        if self.motion_token_path:
            return 512
        else:
            return 262

    def _load_motion(self, path, resolution, duration=None, timestamp=None):
        with open(path, "r") as f:
            data = json.load(f)
        feat, conf = load_kps_conf_normalized(data, resolution)
        length = feat.shape[0]
        feat = feat.reshape(length, -1)
        return {
            "feature": feat,
            "length": length,
        }

    def _load_motion_token(self, video_path):
        name = video_path.split("/")[-1].split(".")[0]
        motion_token_path = os.path.join(self.motion_token_path, f"{name}.npy")
        feat = np.load(motion_token_path)
        length = feat.shape[0]
        return {
            "feature": feat,
            "length": length,
        }

    def _load_hidden_states(self, asr_path):
        if self.hidden_states_path is None:
            return {}
        else:
            hidden_states_path = os.path.join(
                self.hidden_states_path,
                os.path.basename(asr_path).replace(".json", ".pt"),
            )
            return_dict = {"hidden_states_path": hidden_states_path}
            if os.path.exists(hidden_states_path):
                return_dict.update(
                    {"hidden_states": torch.load(hidden_states_path, weights_only=True)}
                )
            return return_dict

    def __getitem__(self, idx):
        """Retrieve the pose data and metadata for a given index."""

        idx = idx % self.length

        if idx in self.missing_videos:
            return None

        metadata = self.videos_parquet.iloc[idx]

        try:
            motion_path = metadata["rtmpose_path"]
            resolution = metadata["resolution"]

            data = {}
            data.update({"instruction": metadata["question"], "response": ""})
            data.update(self._load_hidden_states(""))
            if self.motion_token_path:
                data.update(self._load_motion_token(motion_path))
            else:
                data.update(self._load_motion(motion_path, resolution))
            length = int(data["length"])
            data.update(
                {
                    "motion_tokens": torch.zeros(length),
                }
            )
            audio_length = int(length * self.audio_fps / self.motion_fps)
            data.update(
                {
                    "audio_tokens": torch.ones(audio_length) * self.audio_pad_id,
                    "audio_lengths": audio_length,
                }
            )
            data["video_path"] = metadata["rtmpose_path"]
            data["audio_path"] = ""
            data["pose_path"] = motion_path

        except Exception as e:
            rank_zero_info(f"Error loading data for index {idx}: {e}")
            self.missing_videos.append(idx)
            return None

        task = random.choice(self.tasks)
        data["tasks"] = task

        return data

    def normalize(self, poses):
        """Normalize the pose features using the mean and standard deviation."""
        if self.mean is not None:
            return (poses - self.mean) / self.std
        return poses

    def denormalize(self, poses):
        """Denormalize the pose features using the mean and standard deviation."""
        if self.mean is not None:
            return poses * self.std + self.mean
        return poses


class PoseLTCDatasetCrop(PoseDataset):
    def __init__(self, window_size=64, **kwargs):
        """
        Initializes the PoseDatasetCrop object, filtering out data entries
        shorter than the specified window size.

        Args:
            window_size (int): The size of the sliding window to crop pose sequences.
            **kwargs: Additional arguments passed to the PoseDataset parent class.
        """
        super().__init__(**kwargs)
        self.window_size = window_size

    def __getitem__(self, idx):
        """
        Retrieves and processes the pose data for the given index.

        Args:
            index (int): Index of the pose data in the dataset.

        Returns:
            tuple: A tuple containing the cropped and normalized pose features,
            sequence length, ignore mask, video size, and a subset of body indices.
        """
        # Access the pose data
        data = super().__getitem__(idx)

        if data is None:
            return None
        else:
            poses = data["feature"]
            audio_tokens = data["audio_tokens"]
            length = data["length"]

        if length <= self.window_size:
            start_idx = 0
            length = poses.shape[0]
        else:
            start_idx = np.random.randint(0, length - self.window_size)
            length = self.window_size

        audio_start_idx = start_idx // 2
        audio_length = length // 2

        # Extract the windowed subset of features, mask, and body subset
        poses = poses[start_idx : start_idx + self.window_size]
        audio_tokens = audio_tokens[audio_start_idx : audio_start_idx + audio_length]

        data["feature"] = poses
        data["length"] = len(poses)
        data["audio_tokens"] = audio_tokens
        data["audio_lengths"] = len(audio_tokens)
        return data
