# The original code is too messy. This has been rewrited.
# The core is to ensure that the division of the training and test sets is consistent with AVoiD-DF, as well as the reading of the missing simulation modes.

import torch
import cv2
import librosa
import numpy as np
import pandas as pd
import os
import argparse
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm

# The number of frames to be extracted from each video.
TARGET_FRAMES = 16 # 8
# The target frame rate for video sampling.
VIDEO_FPS = 25
# The target dimensions for each video frame.
FRAME_SIZE = 224
# The target sampling rate for audio.
AUDIO_SAMPLING_RATE = 16000


def _pad_or_crop_frames(frames, target_frames):
    """
    Pads or crops the sequence of frames to a target length.

    Args:
        frames (np.ndarray): The array of video frames.
        target_frames (int): The desired number of frames.

    Returns:
        np.ndarray: The processed array of frames.
    """
    num_frames = len(frames)
    if num_frames < target_frames:
        # Pad the frames by repeating the last frame.
        padding_count = target_frames - num_frames
        last_frame = frames[-1:]
        padding = np.repeat(last_frame, padding_count, axis=0)
        processed_frames = np.concatenate([frames, padding], axis=0)
    elif num_frames > target_frames:
        # Crop the frames from the center.
        start_index = (num_frames - target_frames) // 2
        processed_frames = frames[start_index:start_index + target_frames]
    else:
        processed_frames = frames
    return processed_frames


def extract_features(video_path, audio_only=False, video_only=False):
    """
    Extracts and synchronizes video frames and audio waveform from a video file.

    Args:
        video_path (str): The path to the video file.
        audio_only (bool): If True, only extract audio. Defaults to False.
        video_only (bool): If True, only extract video. Defaults to False.

    Returns:
        tuple: A tuple containing the video frames tensor and audio waveform tensor.
    """
    video_frames = np.zeros((TARGET_FRAMES, FRAME_SIZE, FRAME_SIZE, 3))
    audio_waveform = np.zeros(AUDIO_SAMPLING_RATE * TARGET_FRAMES // VIDEO_FPS)

    # Define the image transformation pipeline.
    transform = transforms.Compose([
        transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
        transforms.ToTensor(),
    ])

    # --- Video Processing ---
    if not audio_only:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise FileNotFoundError(f"Cannot open video file: {video_path}")

        original_fps = cap.get(cv2.CAP_PROP_FPS)
        frame_indices = np.arange(0, cap.get(cv2.CAP_PROP_FRAME_COUNT), original_fps / VIDEO_FPS).astype(int)

        frames = []
        for i in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame_pil = Image.fromarray(frame_rgb)
                frame_tensor = transform(frame_pil)
                # Rearrange from (C, H, W) to (H, W, C) for consistency.
                frames.append(frame_tensor.permute(1, 2, 0).numpy())
        cap.release()
        
        if frames:
            video_frames = _pad_or_crop_frames(np.array(frames), TARGET_FRAMES)

    # --- Audio Processing ---
    if not video_only:
        y, sr = librosa.load(video_path, sr=AUDIO_SAMPLING_RATE)
        
        # Calculate the expected audio length for the target number of video frames.
        target_audio_length = AUDIO_SAMPLING_RATE * TARGET_FRAMES // VIDEO_FPS
        
        if len(y) > target_audio_length:
            # If audio is longer, crop from the beginning.
            audio_waveform = y[:target_audio_length]
        else:
            # If audio is shorter, pad with zeros.
            audio_waveform[:len(y)] = y

    return torch.FloatTensor(video_frames), torch.FloatTensor(audio_waveform)


def _get_data_splits(data_map):
    """
    Splits data into reproducible training and testing sets based on categories.

    Args:
        data_map (dict): A dictionary where keys are category names and
                         values are lists of data samples.

    Returns:
        tuple: A tuple containing the training set and testing set.
    """
    train_set, test_set = [], []
    for category_data in data_map.values():
        total_samples = len(category_data)
        train_len = int(total_samples * 0.7)
        
        # Ensure reproducible split by taking the first 70% for training.
        train_set.extend(category_data[:train_len])
        test_set.extend(category_data[train_len:])
    
    # Shuffle the combined sets to ensure random batching during training.
    np.random.shuffle(train_set)
    np.random.shuffle(test_set)
    return train_set, test_set


class FAVDataset(Dataset):
    """
    A PyTorch Dataset for the FakeAVCeleb dataset.
    This class handles data loading, preprocessing, and modality selection.
    """
    def __init__(self, opt, audio_only=False, video_only=False):
        """
        Initializes the dataset.

        Args:
            opt (argparse.Namespace): Configuration options. Must contain 'dataroot' and 'mode'.
            audio_only (bool): If True, only audio data will be loaded. Defaults to False.
            video_only (bool): If True, only video data will be loaded. Defaults to False.
        """
        super().__init__()
        self.opt = opt
        self.dataroot = opt.dataroot
        self.mode = opt.mode
        self.audio_only = audio_only
        self.video_only = video_only
        
        # Prevent loading both or neither modality exclusively.
        if self.audio_only and self.video_only:
            raise ValueError("Both audio_only and video_only cannot be True simultaneously.")

        metadata_path = os.path.join(self.dataroot, 'meta_data.csv')
        self.metadata = pd.read_csv(metadata_path)
        
        self.samples = self._load_samples()

    def _load_samples(self):
        """
        Loads and splits the dataset samples according to the specified mode.
        """
        # Categorize data based on the type of manipulation.
        data_map = {
            "RealVideo-RealAudio": [],
            "FakeVideo-RealAudio": [],
            "FakeVideo-FakeAudio": [],
            "RealVideo-FakeAudio": []
        }

        print("Loading and categorizing metadata...")
        for _, row in tqdm(self.metadata.iterrows(), total=self.metadata.shape[0]):
            if row['washed'] == 0:
                continue
            
            video_path = os.path.join(self.dataroot, row['path'].replace('FakeAVCeleb/', ''), row['name'])
            label_type = row['type']
            
            if label_type in data_map:
                is_video_real = "RealVideo" in label_type
                is_audio_real = "RealAudio" in label_type
                data_map[label_type].append({
                    "path": video_path,
                    "v_label": float(is_video_real),
                    "a_label": float(is_audio_real)
                })
        
        train_set, test_set = _get_data_splits(data_map)
        
        if self.mode == 'train':
            print(f"Loaded {len(train_set)} samples for training.")
            return train_set
        elif self.mode == 'test':
            print(f"Loaded {len(test_set)} samples for testing.")
            return test_set
        else:
            raise ValueError(f"Mode '{self.mode}' is not supported. Use 'train' or 'test'.")

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.samples)

    def __getitem__(self, index):
        """
        Retrieves a single sample from the dataset.

        Args:
            index (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing (video, audio, v_label, a_label, av_label).
        """
        sample_info = self.samples[index]
        video_path = sample_info["path"]
        v_label = sample_info["v_label"]
        a_label = sample_info["a_label"]
        
        # The cross-modal label is 1.0 only if both modalities are real.
        av_label = v_label * a_label

        try:
            video_tensor, audio_tensor = extract_features(
                video_path,
                self.audio_only,
                self.video_only
            )
        except Exception as e:
            print(f"Error processing file {video_path}: {e}. Returning a dummy sample.")
            # Return a zero-tensor sample in case of a processing error.
            video_tensor = torch.zeros((TARGET_FRAMES, FRAME_SIZE, FRAME_SIZE, 3))
            audio_tensor = torch.zeros(AUDIO_SAMPLING_RATE * TARGET_FRAMES // VIDEO_FPS)
        
        return video_tensor, audio_tensor, v_label, a_label, av_label