import os
import cv2
import numpy as np
from tqdm import tqdm
import pickle

import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
import torchaudio
from PIL import Image
from . import DATASETS
from .utils import DataCollection, af_pad_sequence
from moviepy.editor import VideoFileClip
from collections import namedtuple


@DATASETS.register("RealLife")
class RealLife_Dataset(Dataset):
    def __init__(
            self,
            dataset_path,
            transform,
            frame_size=160,
            n_sample_frames=64,
            modalities=["visual", "audio"],
    ) -> None:
        super(RealLife_Dataset, self).__init__()
        self.path = dataset_path
        self.frame_size = frame_size
        self.n_sample_frames = n_sample_frames
        self.modalities = modalities

        self.tags = ["Deceptive", "Truthful"]
        self.clip_files = []
        self.audio_files = []
        self.labels = []

        for tag in sorted(self.tags):
            clip_dir = os.path.join(self.path, "Clips", tag)
            _files = os.listdir(clip_dir)
            _files = sorted([os.path.join(clip_dir, file) for file in _files])
            self.clip_files += _files

            label = 1 if tag == "Deceptive" else 0
            self.labels += [label] * len(_files)

            audio_files = []
            for video_path in _files:
                audio_save_dir = os.path.join(self.path, "Audios", tag)
                os.makedirs(audio_save_dir, exist_ok=True)

                audio_filename = os.path.basename(video_path).replace(".mp4", "_audio.wav")
                audio_path = os.path.join(audio_save_dir, audio_filename)

                if not os.path.exists(audio_path):
                    self.extract_audio_from_video(video_path, audio_path)

                audio_files.append(audio_path)

            self.audio_files += audio_files

        assert len(self.labels) == len(self.clip_files) == len(self.audio_files)
        self.transform = transform

    def extract_audio_from_video(self, video_path: str, audio_path: str) -> None:
        try:
            video_clip = VideoFileClip(video_path)
            if video_clip.audio is None:
                raise ValueError(f"video {video_path} no")

            audio_clip = video_clip.audio
            audio_clip.write_audiofile(
                audio_path,
                fps=16000,
                codec="pcm_s16le",
                verbose=False,
                logger=None
            )
            audio_clip.close()
            video_clip.close()
        except Exception as e:
            print(f"false: {video_path} -> {audio_path}")
            print(f"problem: {str(e)}")
            raise

    def __len__(self):
        return len(self.labels)

    def read_aud(self, path):
        waveform, sample_rate = torchaudio.load(path)
        waveform = waveform[0]
        clip_duration = len(waveform) / sample_rate
        new_sample_rate = int(
            np.round(321.893491124260 * self.n_sample_frames / clip_duration, decimals=0)
        )
        waveform = torchaudio.functional.resample(waveform, sample_rate, new_sample_rate)
        mono_waveform = waveform.unsqueeze(0)
        mono_waveform.type(torch.float32)
        return mono_waveform


    # def read_vid(self, path):
    #     vid = cv2.VideoCapture(path)
    #     frames = []
    #     while vid.isOpened():
    #         ret, frame = vid.read()
    #         if ret:
    #             frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    #             frame = T.ToPILImage()(frame)
    #             frames.append(self.transform(frame))

    #         else:
    #             break
    #     vid.release()
    #     target_frames = np.linspace(0, len(frames) - 1, num=self.n_sample_frames)
    #     target_frames = np.around(target_frames).astype(int)
    #     frames = [frames[idx] for idx in target_frames]
    #     #frames = self.transform(frames)["pixel_values"].squeeze(0)
    #     frames = torch.stack(frames, 0)
    #     frames.type(torch.float32)
    #     return frames

    def read_vid(self, path):
        vid = cv2.VideoCapture(path)
        frames = []

        while vid.isOpened():
            ret, frame = vid.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        vid.release()

        total_frames = len(frames)
        if total_frames == 0:
            return torch.zeros((self.n_sample_frames, 3, 224, 224))

        target_indices = np.linspace(0, total_frames - 1, num=self.n_sample_frames, dtype=int)
        frames = [frames[idx] for idx in target_indices if idx < total_frames]
        pil_frames = [Image.fromarray(frame) for frame in frames]
        processed = self.transform(
            pil_frames,
            return_tensors="pt",
            video_seq_length=self.n_sample_frames
        )

        video_tensor = processed["pixel_values"].squeeze(0).float()
        return video_tensor

    def __getitem__(self, index):
        if "audio" in self.modalities:
            audio = self.read_aud(self.audio_files[index])
        else:
            audio = None
        if "visual" in self.modalities:
            video = self.read_vid(self.clip_files[index])
        else:
            video = None
        return DataCollection(visual=video, audio=audio, label=self.labels[index])
