import io
import json
import os
import pickle
import random
import time
from multiprocessing import shared_memory
from os.path import join as pjoin
import re
import joblib
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.csv as pv
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from torch.utils.data import Dataset

from .data_collator import data_collator_token
from common.hdfs import HdfsClient


def clean_text(text):
    cleaned_text = re.sub(r"\nAnswer:.*", "", text)
    return cleaned_text


class FaceDataset(Dataset):
    def __init__(
        self,
        metadata_path,
        split,
        task_path,
        face_token_dir,
        instruction_path,
        load_mode="metadata",
        pose_type="coco",
        body_parts=("body", "face", "hand"),
        rescale=False,
        confidence=False,
        mot_token_path=None,
        mean_path=None,
        std_path=None,
        max_retries=3,
        timeout=0,
        locale="cn",
        share_index=True,
        debug=False,
        face_dir="",
        audio_dir="",
        **kwargs,
    ):
        """
        Initializes the dataset by preloading all pose data and corresponding video files from the specified directories.

        Args:
            pose_directory (str): Directory containing .pkl pose data files.
            video_directory (str): Directory containing the corresponding video files.
            pose_type (str): The pose format type (default is 'coco').
            self.body_parts (tuple): Body parts to include ('body', 'face', 'hand') as default.
            mean_path (str, optional): Path to the mean normalization file.
            std_path (str, optional): Path to the standard deviation normalization file.
        """

        self.pose_type = pose_type
        self.body_parts = body_parts
        self.face_token_dir = face_token_dir
        self.instruction_path = instruction_path
        # self.mot_token_path = mot_token_path
        assert self.body_parts, "At least one body part should be selected."

        self.max_retries = max_retries
        self.timeout = timeout
        self.rescale = rescale
        self.confidence = confidence
        self.load_mode = load_mode
        self.split = split
        self.metadata_root = os.path.dirname(metadata_path)
        self.share_index = share_index
        self.locale = locale
        self.face_dir = face_dir
        self.audio_dir = audio_dir

        # Framerate and keypoint definitions for 'coco'
        self.fps = 25

        # Load metadata
        if os.path.isdir(metadata_path):
            metadata_path = pjoin(metadata_path, f"{split}.csv")

        # Get metadata for all videos
        with open(metadata_path, "rb") as f:
            self.metadata = pv.read_csv(f).to_pandas()

        if debug:
            self.metadata = self.metadata.head(100)

        self.pose_paths = self.metadata.iloc[:, 0].values
        rank_zero_info(f"Loaded {len(self.pose_paths)} pose paths")

        self.data_collator = data_collator_token

        # Load normalization data if provided
        self.mean = np.load(mean_path) if mean_path else None
        self.std = np.load(std_path) if std_path else None

        with open(task_path, "r", encoding="utf-8") as f:
            self.tasks = [json.loads(line) for line in f]

        self.missing_videos = []
        self.loaded_data = {}
        self.loaded_num = 0

    @property
    def nfeats(self):
        return 512

    def _get_audio_path(self, face_path):
        face_filename = os.path.basename(face_path)
        face_filename = face_filename.replace(".npy", ".wav").replace("_", "/")
        audio_path = os.path.join(self.audio_dir, face_filename)
        return audio_path

    def _get_audio_token_path(self, audio_path):
        return audio_path.replace("audio_CLEAN", "audio_tokens").replace(".wav", ".npy")

    def _get_instruction_path(self, audio_path):
        return os.path.join(
            self.instruction_path,
            "/".join(audio_path.split("/")[-2:]).replace(".wav", ".json"),
        )

    def _load_single_pose(self, pose_path):
        data = np.load(pose_path, allow_pickle=False)
        if self.mean is not None:
            data = self.normalize(data)

        return {
            "feature": data,
            "length": data.shape[0],
        }

    # def _load_motion_token(self, pose_path):
    #     data = np.load(pose_path, allow_pickle=False)[0]

    #     return {
    #         "motion_tokens": data,
    #         "mot_token_lengths": data.shape[0],
    #     }

    def _load_audio_tokens(self, audio_path):
        audio_token_path = self._get_audio_token_path(audio_path)
        audio_data = np.load(audio_token_path, allow_pickle=False)
        return {
            "audio_tokens": audio_data,
            "audio_lengths": len(audio_data),
        }

    def _load_instruction_response(self, audio_path):
        if self.instruction_path is None:
            return {
                "instruction": "Randomly speak something",
                "response": "",
            }
        instruction_path = self._get_instruction_path(audio_path)
        with open(instruction_path, "r", encoding="utf-8") as f:
            instruction = json.load(f)
        return {
            "instruction": clean_text(instruction["question"]),
            "response": instruction["answer"],
        }

    def _load_motion_token(self, pose_basename):
        if self.face_token_dir is None:
            return {}
        else:
            mot_token_path = os.path.join(self.face_token_dir, pose_basename)
            data = np.load(mot_token_path, allow_pickle=False)
            if data.any() > 2048:
                raise ValueError(f"Motion token length exceeds 2048: {mot_token_path}")
            return {
                "motion_tokens": data,
                "mot_token_lengths": data.shape[0],
            }

    def __len__(self):
        """Return the number of loaded pose files."""
        return len(self.pose_paths)

    def __getitem__(self, idx):
        """Retrieve the pose data and metadata for a given index."""
        pose_path = self.pose_paths[idx]
        pose_basename = os.path.basename(pose_path)
        pose_path = os.path.join(self.face_dir, pose_basename)
        if idx in self.missing_videos:
            return None
        try:
            data = self._load_single_pose(pose_path)
            # motion_token_path = os.path.join(self.face_token_dir, pose_basename)
            # data.update(self._load_motion_token(motion_token_path))
            audio_path = self._get_audio_path(pose_path)
            data.update(self._load_audio_tokens(audio_path))
            data.update(self._load_instruction_response(audio_path))
            data.update(self._load_motion_token(pose_basename))

            data["audio_path"] = audio_path
            data["pose_path"] = pose_path
        except Exception as e:
            rank_zero_info(f"Error loading data for index {idx}: {e}")
            self.missing_videos.append(idx)
            return None

        task = random.choice(self.tasks)
        data["tasks"] = task

        return data

    def normalize(self, poses):
        """Normalize the pose features using the mean and standard deviation."""
        if self.mean is not None:
            return (poses - self.mean) / self.std
        return poses

    def denormalize(self, poses):
        """Denormalize the pose features using the mean and standard deviation."""
        if self.mean is not None:
            return poses * self.std + self.mean
        return poses


class FaceDatasetCrop(FaceDataset):
    def __init__(self, window_size=64, **kwargs):
        """
        Initializes the PoseDatasetCrop object, filtering out data entries
        shorter than the specified window size.

        Args:
            window_size (int): The size of the sliding window to crop pose sequences.
            **kwargs: Additional arguments passed to the PoseDataset parent class.
        """
        super().__init__(**kwargs)
        self.window_size = window_size

    def __getitem__(self, idx):
        """
        Retrieves and processes the pose data for the given index.

        Args:
            index (int): Index of the pose data in the dataset.

        Returns:
            tuple: A tuple containing the cropped and normalized pose features,
            sequence length, ignore mask, video size, and a subset of body indices.
        """
        # Access the pose data
        data = super().__getitem__(idx)

        if data is None:
            return None
        else:
            poses = data["feature"]
            mot_tokens = data["motion_tokens"]
            audio_tokens = data["audio_tokens"]
            length = data["length"]

        if length <= self.window_size:
            start_idx = 0
            length = poses.shape[0]
        else:
            start_idx = np.random.randint(0, length - self.window_size)
            length = self.window_size

        audio_start_idx = start_idx // 2
        audio_length = length // 2

        # Extract the windowed subset of features, mask, and body subset
        poses = poses[start_idx : start_idx + self.window_size]
        mot_tokens = mot_tokens[start_idx : start_idx + self.window_size]
        audio_tokens = audio_tokens[audio_start_idx : audio_start_idx + audio_length]

        data["feature"] = poses
        data["length"] = len(poses)
        data["motion_tokens"] = mot_tokens
        data["mot_token_lengths"] = length
        data["audio_tokens"] = audio_tokens
        data["audio_lengths"] = len(audio_tokens)
        return data
