import os
import pandas as pd
from PIL import Image
import numpy as np

import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import torchaudio


class AudioVisualDataset(Dataset):
    def __init__(self, annotations_file, audio_dir, img_dir, num_tokens=64, frame_size=160):
        super(AudioVisualDataset, self).__init__()

        self.annos = pd.read_csv(annotations_file)
        self.audio_dir = audio_dir
        self.num_tokens = num_tokens

        self.img_dir = img_dir
        self.frame_size = frame_size
        self.transforms = T.Compose([
            T.ToTensor(),
            T.ConvertImageDtype(torch.float32),
            T.Resize(frame_size),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.number_of_target_frames = num_tokens

    def __len__(self):
        return len(self.annos)

    def __getitem__(self, idx):
        try:
            clip_name = self.annos.iloc[idx, 0].strip()

            if clip_name.endswith('.mp4'):
                clip_name = clip_name[:-4]

            audio_path = os.path.join(self.audio_dir, clip_name + '.wav')
            if not os.path.exists(audio_path):
                print(f"Warning: Audio file {audio_path} not found. Returning empty audio tensor.")
                return torch.zeros(1, 1), torch.zeros(64, 3, self.frame_size, self.frame_size), torch.tensor(-1)

            waveform, sample_rate = torchaudio.load(audio_path)
            waveform = waveform[0]

            clip_duration = len(waveform) / sample_rate

            if clip_duration == 0:
                raise ValueError(f"Invalid audio duration for clip: {clip_name}")

            new_sample_rate = int(321.893491124260 * self.num_tokens / clip_duration)
            waveform = torchaudio.functional.resample(waveform, sample_rate, new_sample_rate)
            mono_waveform = waveform.unsqueeze(0).type(torch.float32)

            file_path = os.path.join(self.img_dir, clip_name)
            frame_names = [i for i in os.listdir(file_path) if i.split('.')[-1] == 'jpg']

            if not frame_names:
                print(f"Warning: No valid frames found for clip: {clip_name}. Returning empty frame tensor.")
                return mono_waveform, torch.zeros(64, 3, self.frame_size, self.frame_size), torch.tensor(-1)

            target_frames = np.linspace(0, len(frame_names) - 1, num=self.number_of_target_frames)
            target_frames = np.around(target_frames).astype(int)

            face_frames = []
            for i in target_frames:
                img_file_path = os.path.join(file_path, frame_names[i])
                if not os.path.exists(img_file_path):
                    print(f"Warning: Image {img_file_path} not found. Skipping frame.")
                    face_frames.append(torch.zeros(3, self.frame_size, self.frame_size))
                else:
                    img = np.asarray(Image.open(img_file_path)) / 255.0
                    face_frames.append(self.transforms(img))

            if len(face_frames) < self.number_of_target_frames:
                print(f"Warning: Not enough frames for clip: {clip_name}. Padding with zeros.")
                while len(face_frames) < self.number_of_target_frames:
                    face_frames.append(torch.zeros(3, self.frame_size, self.frame_size))

            face_frames = torch.stack(face_frames, 0).type(torch.float32)

            str_label = self.annos.iloc[idx, 1].strip().lower()
            if str_label in ['truth', 'truthful', 'truth ']:
                label = 0
            elif str_label in ['deception', 'lie', 'deception','deceptive']:
                label = 1
            else:
                raise ValueError(f"Undefined label: {str_label}")

            return mono_waveform, face_frames, label

        except (IndexError, FileNotFoundError, ValueError) as e:
            print(f"Warning: {e}")
            return None
