import json
import pathlib
import random
import librosa
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from transformers import AutoProcessor, CLIPProcessor
from .ave import Datapool


CLASS_NAMES = ['bowling', 'ripping paper', 'playing xylophone', 'playing organ', 'playing bass guitar',
               'tapping guitar', 'playing accordion', 'playing guitar', 'dribbling basketball', 'playing piano', 'playing bagpipes',
               'playing saxophone', 'playing harmonica', 'tickling', 'blowing nose', 'tapping pen', 'chopping wood',
               'blowing out candles', 'tap dancing', 'stomping grapes', 'playing clarinet', 'laughing',
               'playing trombone', 'shoveling snow', 'playing trumpet', 'playing violin', 'singing', 'shuffling cards',
               'playing keyboard', 'mowing lawn', 'playing drums']


class KineticsSound(Datapool):
    def __init__(self, split: str, dataset_root_dir: str, missing_info={}):
        assert split in ["train", "test"]
        self.mode = split
        self.dataset_root_dir = pathlib.Path(dataset_root_dir) / self.mode
        assert self.dataset_root_dir.exists()
        self.data = {}
        audio_dir = self.dataset_root_dir / "audio"
        video_dir = self.dataset_root_dir / "video"
        self.missing_info = missing_info

        if (self.dataset_root_dir.parent / f"{self.mode}.json").exists():
            with open(self.dataset_root_dir.parent / f"{self.mode}.json", 'r') as file:
                self.data = json.load(file)
        else:
            for audio_path in audio_dir.rglob("*.wav"):
                vid = audio_path.stem
                label = audio_path.parent.name
                video_path = video_dir / label / vid
                assert video_path.is_dir(), f"Cannot find {video_path}"
                self.data[vid] = {
                    "label": label,
                    "audio_path": str(audio_path),
                    "video_path": str(video_path)
                }
            with open(self.dataset_root_dir.parent / f"{self.mode}.json", 'w') as file:
                json.dump(self.data, file)

        self.all_ids = list(self.data.keys())
        self.all_ids.sort()
        random.Random(0).shuffle(self.all_ids)

        super(KineticsSound, self).__init__(self.all_ids, self.mode)

        self.video_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16").image_processor
        self.audio_processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
        ])

    def process_audio(self, audio_file_path):
        audio_file_path = pathlib.Path(audio_file_path)
        assert audio_file_path.exists()

        sample, rate = librosa.load(audio_file_path, sr=16000, mono=True)
        return self.audio_processor(sample, return_tensors="pt", sampling_rate=16000)

    def load_frames(self, video_file_path):
        video_file_path = pathlib.Path(video_file_path)
        assert video_file_path.exists()

        frame_paths = [str(path) for path in video_file_path.glob("frame_*.jpg")]
        frame_paths.sort()
        if self.mode == "train":
            select_index = np.random.choice(len(frame_paths), size=3, replace=False)
            select_index.sort()
        else:
            select_index = np.linspace(0, len(frame_paths)-1, num=3, dtype=int)
        frame_paths = [frame_paths[i] for i in select_index]
        frames = []

        for frame_path in frame_paths:
            with open(frame_path, "rb") as f:
                frame = Image.open(f)
                frame = frame.convert("RGB")
            if self.mode == "train":
                frame = self.train_transform(frame)
            else:
                frame = self.val_transform(frame)
            frames.append(frame)
        frames = torch.stack(frames, dim=0) #[frames, channels, H, W]
        return frames

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        """

        :param idx:
        :return:
            frames: torch.Size([32, 112, 112, 3]) torch.float32
            sepctrograms: torch.Size([1024, 128]) torch.float32
            label: size()
        """
        sample_id = self.sample_ids[idx]
        spectrogram = self.process_audio(self.data[sample_id]["audio_path"])["input_values"].squeeze(0)
        
        frames = self.load_frames(self.data[sample_id]["video_path"])
        frames = self.video_processor(frames, return_tensors="pt", do_rescale=False)['pixel_values']
        
        label = torch.tensor(CLASS_NAMES.index(self.data[sample_id]["label"]), dtype=torch.long)

        # Simulate missing
        simulate_missing_type = 2
        if self.mode == "train" and self.missing_info['simulate_missing']:
            # 0 -> video, 1 -> audio, 2 -> none
            simulate_missing_type = random.choice([0,1,2])
            if simulate_missing_type == 0:
                frames = torch.zeros(frames.size()).float()
            elif simulate_missing_type == 1:
                spectrogram = torch.zeros(spectrogram.size()).float()

        # Video Missing
        # print("Setting video to zero")
        # frames = torch.zeros(frames.size()).float()
        # simulate_missing_type = 0

        # Audio Missing
        # spectrogram = torch.zeros(spectrogram.size()).float()
        # simulate_missing_type = 1 

        sample = {
            "video": frames,
            "audio": spectrogram,
            "label": label,
            "filename": self.data[sample_id]["video_path"].split("/")[-1],
            "missing_type": simulate_missing_type
        }
        return sample
