import json
import pathlib
import random
import librosa
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from transformers import AutoProcessor, CLIPProcessor
from .ave import Datapool


CLASS_NAMES = ['Anger', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad']

class CREMAD(Datapool):
    def __init__(self, split: str, dataset_root_dir: str, missing_info={}):
        assert split in ["train", "test"]
        self.mode = split
        self.dataset_root_dir = pathlib.Path(dataset_root_dir)
        assert self.dataset_root_dir.exists()
        self.data = {}
        audio_dir = self.dataset_root_dir / "AudioWAV"
        video_dir = self.dataset_root_dir / "VideoFrames"
        self.missing_info = missing_info

        if (self.dataset_root_dir / f"{self.mode}.json").exists():
            with open(self.dataset_root_dir / f"{self.mode}.json", 'r') as file:
                self.data = json.load(file)
        else:
            with open(self.dataset_root_dir / f"{self.mode}_creamad.txt", 'r') as file:
                lines = file.readlines()
            
            for line in lines:
                parts = line.strip().split('.flv ')
                label = parts[1]
                filename = parts[0]  

                audio_path = audio_dir / f"{filename}.wav"
                video_path = video_dir / filename
                assert video_path.is_dir(), f"Cannot find {video_path}"

                self.data[filename] = {
                    "label": label,
                    "audio_path": str(audio_path),
                    "video_path": str(video_path),
                }

            with open(self.dataset_root_dir / f"{self.mode}.json", 'w') as file:
                json.dump(self.data, file)

        self.all_ids = list(self.data.keys())
        self.all_ids.sort()
        random.Random(0).shuffle(self.all_ids)

        super(CREMAD, self).__init__(self.all_ids, self.mode)

        self.video_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16").image_processor
        self.audio_processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
        ])

    def process_audio(self, audio_file_path):
        audio_file_path = pathlib.Path(audio_file_path)
        assert audio_file_path.exists()

        sample, rate = librosa.load(audio_file_path, sr=16000, mono=True)
        return self.audio_processor(sample, return_tensors="pt", sampling_rate=16000)

    def load_frames(self, video_file_path):
        video_file_path = pathlib.Path(video_file_path)
        assert video_file_path.exists()
        
        frame_paths = [str(path) for path in video_file_path.glob("frame_*.jpg")]
        if len(frame_paths) < 10:
            print(video_file_path)
        frame_paths.sort()
        if self.mode == "train":
            select_index = np.random.choice(len(frame_paths), size=3, replace=False)
            select_index.sort()
        else:
            select_index = np.linspace(0, len(frame_paths)-1, num=3, dtype=int)
        frame_paths = [frame_paths[i] for i in select_index]
        frames = []

        for frame_path in frame_paths:
            with open(frame_path, "rb") as f:
                frame = Image.open(f)
                frame = frame.convert("RGB")
            if self.mode == "train":
                frame = self.train_transform(frame)
            else:
                frame = self.val_transform(frame)
            frames.append(frame)
        frames = torch.stack(frames, dim=0) #[frames, channels, H, W]
        return frames

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        """

        :param idx:
        :return:
            frames: torch.Size([32, 112, 112, 3]) torch.float32
            sepctrograms: torch.Size([1024, 128]) torch.float32
            label: size()
        """
        sample_id = self.sample_ids[idx]
        spectrogram = self.process_audio(self.data[sample_id]["audio_path"])["input_values"].squeeze(0)
        
        frames = self.load_frames(self.data[sample_id]["video_path"])
        frames = self.video_processor(frames, return_tensors="pt", do_rescale=False)['pixel_values']
        
        label = torch.tensor(CLASS_NAMES.index(self.data[sample_id]["label"]), dtype=torch.long)

        # Simulate missing
        simulate_missing_type = 2
        if self.mode == "train" and self.missing_info['simulate_missing']:
            # 0 -> video, 1 -> audio, 2 -> none
            simulate_missing_type = random.choice([0,1,2])
            if simulate_missing_type == 0:
                frames = torch.zeros(frames.size()).float()
            elif simulate_missing_type == 1:
                spectrogram = torch.zeros(spectrogram.size()).float()

        # Video Missing
        # print("Setting video to zero")
        # frames = torch.zeros(frames.size()).float()
        # simulate_missing_type = 0

        # Audio Missing
        # spectrogram = torch.zeros(spectrogram.size()).float()
        # simulate_missing_type = 1

        sample = {
            "video": frames,
            "audio": spectrogram,
            "label": label,
            "filename": self.data[sample_id]["video_path"].split("/")[-1],
            "missing_type": simulate_missing_type
        }
        return sample
