import json
import pathlib
import random
import librosa
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from transformers import AutoProcessor, CLIPProcessor


class Datapool(Dataset):
    def __init__(self, all_ids, mode):
        self.all_ids = all_ids
        self.mode = mode
        assert len(list(set(self.all_ids))) == len(self.all_ids), "dataset has duplicated ids"
        self.unlabeled_ids = self.all_ids.copy()
        self.labeled_ids = []
        self.sample_ids = self.all_ids.copy()  # default for val and test datapool

    def initialize(self, query_budget: int):
        # query_budget is the number of labels been queried each round
        # random initialization for first batch of labels
        self.labeled_ids = self.unlabeled_ids[:query_budget]
        self.unlabeled_ids = [id for id in self.all_ids if id not in self.labeled_ids]

    def query_for_label(self, queried_ids: list):
        # queried_ids are generated from query strategy
        self.labeled_ids += queried_ids
        self.unlabeled_ids = [id for id in self.all_ids if id not in self.labeled_ids]
        assert len(self.labeled_ids) + len(self.unlabeled_ids) == len(self.all_ids)

    def query(self):
        # prepare unlabeled data index for label querying
        self.mode = "query"
        print("dataset for querying")
        self.sample_ids = self.unlabeled_ids

    def train(self):
        # prepare labeled queried data index for model training
        self.mode = "train"
        print("dataset for training")
        self.sample_ids = self.labeled_ids

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        pass
        

CLASS_NAMES = ['Accordion', 'Acoustic guitar', 'Baby cry, infant cry', 'Banjo', 'Bark', 'Bus', 'Cat', 'Chainsaw', 'Church bell', 'Clock', 'Female speech, woman speaking', 'Fixed-wing aircraft, airplane', 'Flute', 'Frying (food)', 'Goat', 'Helicopter', 'Horse', 'Male speech, man speaking', 'Mandolin', 'Motorcycle', 'Race car, auto racing', 'Rodents, rats, mice', 'Shofar', 'Toilet flush', 'Train horn', 'Truck', 'Ukulele', 'Violin, fiddle']

class AVE(Datapool):
    def __init__(self, split: str, dataset_root_dir: str, missing_info={}):
        assert split in ["train", "test"]
        self.mode = split
        self.dataset_root_dir = pathlib.Path(dataset_root_dir) / self.mode
        assert self.dataset_root_dir.exists()
        self.data = {}
        audio_dir = self.dataset_root_dir / "audio"
        video_dir = self.dataset_root_dir / "video"
        self.missing_info = missing_info

        if (self.dataset_root_dir.parent / f"{self.mode}.json").exists():
            with open(self.dataset_root_dir.parent / f"{self.mode}.json", 'r') as file:
                self.data = json.load(file)
        else:
            with open(self.dataset_root_dir.parent / f"{self.mode}Set.txt", 'r') as file:
                lines = file.readlines()
            
            for line in lines:
                parts = line.strip().split('&')
                label = parts[0]
                filename = parts[1]
                start_time = int(parts[3])
                end_time = int(parts[4])    

                audio_path = audio_dir / f"{filename}.wav"
                video_path = video_dir / filename
                assert video_path.is_dir(), f"Cannot find {video_path}"

                self.data[filename] = {
                    "label": label,
                    "audio_path": str(audio_path),
                    "video_path": str(video_path),
                    "start_time": start_time,
                    "end_time": end_time,
                }

            with open(self.dataset_root_dir.parent / f"{self.mode}.json", 'w') as file:
                json.dump(self.data, file)

        self.all_ids = list(self.data.keys())
        self.all_ids.sort()
        random.Random(0).shuffle(self.all_ids)

        super(AVE, self).__init__(self.all_ids, self.mode)

        self.video_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16").image_processor
        self.audio_processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
        ])

    def process_audio(self, audio_file_path):
        audio_file_path = pathlib.Path(audio_file_path)
        assert audio_file_path.exists()

        sample, rate = librosa.load(audio_file_path, sr=16000, mono=True)
        return self.audio_processor(sample, return_tensors="pt", sampling_rate=16000)

    def load_frames(self, video_file_path):
        video_file_path = pathlib.Path(video_file_path)
        assert video_file_path.exists()

        frame_paths = [str(path) for path in video_file_path.glob("frame_*.jpg")]
        frame_paths.sort()
        if self.mode == "train":
            select_index = np.random.choice(len(frame_paths), size=3, replace=False)
            select_index.sort()
        else:
            select_index = np.linspace(0, len(frame_paths)-1, num=3, dtype=int)
        frame_paths = [frame_paths[i] for i in select_index]
        frames = []

        for frame_path in frame_paths:
            with open(frame_path, "rb") as f:
                frame = Image.open(f)
                frame = frame.convert("RGB")
            if self.mode == "train":
                frame = self.train_transform(frame)
            else:
                frame = self.val_transform(frame)
            frames.append(frame)
        frames = torch.stack(frames, dim=0) #[frames, channels, H, W]
        return frames

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        """

        :param idx:
        :return:
            frames: torch.Size([32, 112, 112, 3]) torch.float32
            sepctrograms: torch.Size([1024, 128]) torch.float32
            label: size()
        """
        sample_id = self.sample_ids[idx]
        spectrogram = self.process_audio(self.data[sample_id]["audio_path"])["input_values"].squeeze(0)
        
        frames = self.load_frames(self.data[sample_id]["video_path"])
        frames = self.video_processor(frames, return_tensors="pt", do_rescale=False)['pixel_values']
        
        label = torch.tensor(CLASS_NAMES.index(self.data[sample_id]["label"]), dtype=torch.long)

        # Simulate missing
        simulate_missing_type = 2
        if self.mode == "train" and self.missing_info['simulate_missing']:
            # 0 -> video, 1 -> audio, 2 -> none
            simulate_missing_type = random.choice([0,1,2])
            if simulate_missing_type == 0:
                frames = torch.zeros(frames.size()).float()
            elif simulate_missing_type == 1:
                spectrogram = torch.zeros(spectrogram.size()).float()
        
        # Video Missing
        # print("Setting video to zero")
        # frames = torch.zeros(frames.size()).float()

        # Audio Missing
        # spectrogram = torch.zeros(spectrogram.size()).float()
        # simulate_missing_type = 1

        sample = {
            "video": frames,
            "audio": spectrogram,
            "label": label,
            "filename": self.data[sample_id]["video_path"].split("/")[-1],
            "missing_type": simulate_missing_type,
            "idx": torch.LongTensor([idx])
        }
        return sample
