import torch
from torch.utils.data.dataset import Dataset

import os
import random
import traceback
import torchaudio
from tqdm import tqdm
import numpy as np
import pandas as pd
import json

from .data_collator import data_collator_token
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only


def get_all_videos(
    parquet_path,
    cache_path,
    selected_cols=["video_path", "seq_vocals_path", "behavior_path", "syncnet_path"],
):
    if os.path.exists(cache_path):
        rank_zero_info(f"Loading cached video paths from {cache_path}...")
        return pd.read_parquet(cache_path)

    # Read all parquet files and combine them
    parquet_files = [f for f in os.listdir(parquet_path) if f.endswith(".parquet")]
    dfs = []

    for parquet_file in tqdm(parquet_files, desc="Reading parquet files"):
        file_path = os.path.join(parquet_path, parquet_file)
        df = pd.read_parquet(file_path, columns=selected_cols)
        dfs.append(df)

    if len(dfs) > 0:
        # Combine all dataframes and save to cache
        combined_df = pd.concat(dfs, ignore_index=True)
        combined_df.to_parquet(cache_path)
        # Remove rows where behavior_path is None or empty dict
        combined_df = combined_df[
            combined_df["behavior_path"].notna()
            & (combined_df["behavior_path"] != "")
            & (combined_df["behavior_path"] != "{}")
        ]
        rank_zero_info(f"Video paths cached to {cache_path}.")
        return combined_df
    else:
        rank_zero_info(f"No parquet files found in {parquet_path}")
        return pd.DataFrame(columns=selected_cols)


_resample_buffer: dict[int, torchaudio.transforms.Resample] = {}


def load_audio(audio_path, sample_rate=16000):
    audio, sample_rate = torchaudio.load(audio_path)
    if sample_rate != sample_rate:
        if sample_rate not in _resample_buffer:
            _resample_buffer[sample_rate] = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=16000
            )
        audio = _resample_buffer[sample_rate](audio)
    return audio


def clip_by_timestamp(data, duration=None, timestamp=None):
    if duration is None or timestamp is None:
        return data

    total_frames = len(data)
    if timestamp[0] is None:
        start_frame = 0
    else:
        start_frame = total_frames * (timestamp[0] / duration)
    if timestamp[1] is None:
        end_frame = total_frames
    else:
        end_frame = total_frames * (timestamp[1] / duration)
    data = data[int(start_frame) : int(end_frame)]

    rank_zero_info(f"Clipped audio from {start_frame} to {end_frame}")
    return data


class FaceDataset(Dataset):
    def __init__(
        self,
        parquet_path,
        cache_path,
        task_path,
        split,
        load_asr=True,
        select_speaker=False,
        hidden_states_path=None,
        audio_token_path=None,
        mot_token_path=None,
        mean_path=None,
        std_path=None,
        debug=False,
        **kwargs,
    ):
        self.hidden_states_path = hidden_states_path
        self.select_speaker = select_speaker
        self.load_asr = load_asr
        self.mot_token_path = mot_token_path
        self.audio_token_path = audio_token_path
        self.split = split
        self.video_path_list = []
        self.audio_path_list = []
        self.motion_path_list = []

        self.videos_parquet = get_all_videos(
            parquet_path,
            cache_path,
            selected_cols=[
                "video_path",
                "seq_vocals_path",
                "duration",
                "behavior_path",
                "lang_tag_whisperv3",
                "asr_result_paraformer_path",
                "asr_result_whisperv3_path",
                "syncnet_path",
            ],
        )

        if debug:
            self.videos_parquet = self.videos_parquet.head(512)

        self.length = len(self.videos_parquet)
        self.data_collator = data_collator_token
        print(f"data scale: {self.length}")

        # Load normalization data if provided
        self.mean = np.load(mean_path) if mean_path else None
        self.std = np.load(std_path) if std_path else None

        with open(task_path, "r", encoding="utf-8") as f:
            self.tasks = [json.loads(line) for line in f]

        self.missing_videos = []
        self.loaded_data = {}
        self.loaded_num = 0

    def __len__(self):
        return self.length

    @property
    def nfeats(self):
        return 512

    def _get_audio_token_path(self, audio_path):
        if not self.audio_token_path:
            return (
                audio_path.replace("audio_CLEAN", "audio_tokens")
                .replace("/seq/talk/", "/seq/talk_glm_tokens/")
                .replace(".wav", ".npy")
            )
        else:
            return os.path.join(
                self.audio_token_path,
                os.path.basename(audio_path).replace(".wav", ".npy"),
            )

    def _load_motion(self, motion_path, audio_length, duration=None, timestamp=None):
        try:
            data = np.load(motion_path, allow_pickle=False)
            if self.mean is not None:
                data = self.normalize(data)

            # Clip by timestamp
            data = clip_by_timestamp(data, duration, timestamp)

            return {
                "feature": data,
                "length": data.shape[0],
            }
        except:
            rank_zero_info(f"Error loading motion data from {motion_path}")
            return {
                "feature": np.zeros((audio_length * 2, 512)),
                "length": audio_length * 2,
            }

    def _load_audio_tokens(self, audio_path, duration=None, timestamp=None):
        audio_token_path = self._get_audio_token_path(audio_path)

        if not os.path.exists(audio_token_path):
            raise FileNotFoundError("Audio token file not found")
        audio_data = np.load(audio_token_path, allow_pickle=False)

        # Clip by timestamp
        audio_data = clip_by_timestamp(audio_data, duration, timestamp)

        return {
            "audio_tokens": audio_data,
            "audio_lengths": len(audio_data),
        }

    def _load_instruction(self, instruction_path):
        if instruction_path is None:
            return {
                "instruction": "Randomly speak something",
            }
        with open(instruction_path, "r", encoding="utf-8") as f:
            instruction = json.load(f)
        return {
            "instruction": instruction["question"],
        }

    def _load_asr(self, asr_path):
        with open(asr_path, "r") as f:
            data = json.load(f)

        return_dict = {}
        duration = data["speakers"][-1]["timestamp"][-1]

        # Randomly select a speaker
        if self.select_speaker == "longest":
            speakers = {}
            speaker_duration = {}
            speaker_text = {}
            speaker_timestamp = {}
            for i, segment in enumerate(data["speakers"]):
                speakers[i] = segment["speaker"]
                start, end = segment["timestamp"]
                duration = end - start
                speaker_duration[i] = duration
                speaker_text[i] = segment["text"].strip()
                speaker_timestamp[i] = segment["timestamp"]

            max_speaker = max(speaker_duration, key=speaker_duration.get)

            return_dict.update(
                {
                    "speaker": max_speaker,
                    "timestamp": speaker_timestamp[max_speaker],
                    "response": speaker_text[max_speaker],
                    "duration": data["speakers"][-1]["timestamp"][-1],
                }
            )

        elif self.select_speaker:
            item = random.choice(data["speakers"])
            return_dict.update(
                {
                    "speaker": item["speaker"],
                    "timestamp": item["timestamp"],
                    "response": item["text"],
                    "duration": data["speakers"][-1]["timestamp"][-1],
                }
            )
        else:
            return_dict.update(
                {
                    "timestamp": None,
                    "duration": None,
                    "response": data["text"],
                }
            )

        return return_dict

    def _load_hidden_states(self, asr_path):
        if self.hidden_states_path is None:
            return {}
        else:
            hidden_states_path = os.path.join(
                self.hidden_states_path,
                os.path.basename(asr_path).replace(".json", ".pt"),
            )
            return_dict = {"hidden_states_path": hidden_states_path}
            if os.path.exists(hidden_states_path):
                return_dict.update(
                    {"hidden_states": torch.load(hidden_states_path, weights_only=True)}
                )
            return return_dict

    def __getitem__(self, idx):
        """Retrieve the pose data and metadata for a given index."""

        if idx in self.missing_videos:
            return None

        metadata = self.videos_parquet.iloc[idx]

        try:
            motion_path = metadata["behavior_path"]
            try:
                if not os.path.exists(motion_path):
                    motion_path = metadata["behavior_v2_path"]
            except:
                motion_path = metadata["behavior_v2_path"]

            audio_path = metadata["seq_vocals_path"]
            if not audio_path:
                audio_path = metadata["audio_vocals_path"]
            # duration = metadata["duration"]

            data = {}
            if self.load_asr:
                if metadata["lang_tag_whisperv3"] == "chinese":
                    asr_path = metadata["asr_result_paraformer_path"]
                elif metadata["lang_tag_whisperv3"] == "english":
                    asr_path = metadata["asr_result_whisperv3_path"]
                else:
                    raise ValueError("Unknown language tag")

                data.update(self._load_asr(asr_path))
                data.update(self._load_instruction(None))
                data.update(self._load_hidden_states(asr_path))
            else:
                data.update({"timestamp": None})

            data.update(
                self._load_audio_tokens(audio_path, data["duration"], data["timestamp"])
            )
            data.update(
                self._load_motion(
                    motion_path,
                    data["audio_lengths"],
                    data["duration"],
                    data["timestamp"],
                )
            )

            if not os.path.exists(audio_path):
                audio_path = audio_path.replace("_vocals.wav", "_input_vocals.wav")

            data["video_path"] = metadata["video_path"]
            data["audio_path"] = audio_path
            data["pose_path"] = motion_path

        except Exception as e:
            rank_zero_info(f"Error loading data for index {idx}: {e}")
            self.missing_videos.append(idx)
            return None

        task = random.choice(self.tasks)
        data["tasks"] = task
        data["motion_tokens"] = np.zeros_like(data["audio_tokens"])

        return data

    def normalize(self, poses):
        """Normalize the pose features using the mean and standard deviation."""
        if self.mean is not None:
            return (poses - self.mean) / self.std
        return poses

    def denormalize(self, poses):
        """Denormalize the pose features using the mean and standard deviation."""
        if self.mean is not None:
            return poses * self.std + self.mean
        return poses


class FaceDatasetCrop(FaceDataset):
    def __init__(self, window_size=64, **kwargs):
        """
        Initializes the PoseDatasetCrop object, filtering out data entries
        shorter than the specified window size.

        Args:
            window_size (int): The size of the sliding window to crop pose sequences.
            **kwargs: Additional arguments passed to the PoseDataset parent class.
        """
        super().__init__(**kwargs)
        self.window_size = window_size

    def __getitem__(self, idx):
        """
        Retrieves and processes the pose data for the given index.

        Args:
            index (int): Index of the pose data in the dataset.

        Returns:
            tuple: A tuple containing the cropped and normalized pose features,
            sequence length, ignore mask, video size, and a subset of body indices.
        """
        # Access the pose data
        data = super().__getitem__(idx)

        if data is None:
            return None
        else:
            poses = data["feature"]
            audio_tokens = data["audio_tokens"]
            length = data["length"]
        if length <= self.window_size:
            start_idx = 0
            length = poses.shape[0]
        elif self.split == "train":
            start_idx = np.random.randint(0, length - self.window_size)
            length = self.window_size
        else:
            start_idx = 0
            length = self.window_size

        audio_start_idx = start_idx // 2
        audio_length = length // 2

        # Extract the windowed subset of features, mask, and body subset
        poses = poses[start_idx : start_idx + self.window_size]
        audio_tokens = audio_tokens[audio_start_idx : audio_start_idx + audio_length]

        data["feature"] = poses
        data["length"] = len(poses)
        data["audio_tokens"] = audio_tokens
        data["audio_lengths"] = len(audio_tokens)

        return data
