import os
import pickle
import pdb
from typing import List, Tuple

import torch
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2


def load_image_in_PIL_to_Tensor(path, mode='RGB', transform=None):
    img_PIL = Image.open(path).convert(mode)
    if transform:
        img_tensor = transform(img_PIL)
        return img_tensor
    return np.asarray(img_PIL)


def load_audio_lm(audio_lm_path):
    with open(audio_lm_path, 'rb') as fr:
        audio_log_mel = pickle.load(fr)
    audio_log_mel = audio_log_mel.detach()  # [5, 1, 96, 64]
    return audio_log_mel


def train_collate_fn(batch: List[Tuple]):
    """
    Collate function gathers data from all workers and stack them into batch data.
    :param batch: a list contains data from all workers, e.g.,[(img1,audio1,mask1),(img2,audio2,mask2)...]
    :return: batch data
    """
    batch = list(zip(*batch))
    batch_image = torch.stack(batch[0])
    batch_audio = torch.stack(batch[1])
    batch_mask = torch.stack(batch[2])

    return batch_image, batch_audio, batch_mask


def val_collate_fn(batch: List[Tuple]):
    """
    Collate function gathers data from all workers and stack them into batch data.
    :param batch: a list contains data from all workers, e.g.,[(img1,audio1,mask1),(img2,audio2,mask2)...]
    :return: batch data
    """
    batch = list(zip(*batch))
    batch_image = torch.stack(batch[0])
    batch_audio = torch.stack(batch[1])
    batch_mask = torch.stack(batch[2])

    video_name_list = batch[3]

    return batch_image, batch_audio, batch_mask, video_name_list


class MS3Dataset(Dataset):
    """Dataset for multiple sound source segmentation"""

    def __init__(self, anno_csv: str, dir_img: str, dir_audio_log_mel: str, dir_mask, split='train'):
        super(MS3Dataset, self).__init__()
        self.split = split
        self.dir_img = dir_img
        self.dir_audio_log_mel = dir_audio_log_mel
        self.dir_mask = dir_mask
        self.collate_fn = train_collate_fn if self.split == 'train' else val_collate_fn
        self.mask_num = 5
        df_all = pd.read_csv(anno_csv, sep=',')
        self.df_split = df_all[df_all['split'] == split]
        print(f"{len(self.df_split)}/{len(df_all)} videos are used for {self.split}")

        self.additional_keys = {
            'image1': 'image',
            'image2': 'image',
            'image3': 'image',
            'image4': 'image',
            'mask1': 'mask',
            'mask2': 'mask',
            'mask3': 'mask',
            'mask4': 'mask', }

        if self.split == 'train':
            self.transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ToTensorV2()],
                additional_targets=self.additional_keys)
        else:
            self.transform = A.Compose([
                A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ToTensorV2()],
                additional_targets=self.additional_keys)

    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name = df_one_video.iloc[0]
        img_base_path = os.path.join(self.dir_img, video_name)
        audio_lm_path = os.path.join(self.dir_audio_log_mel, self.split, video_name + '.pkl')
        mask_base_path = os.path.join(self.dir_mask, self.split, video_name)
        audio_log_mel = load_audio_lm(audio_lm_path)
        # audio_lm_tensor = torch.from_numpy(audio_log_mel)
        imgs, masks = [], []
        for img_id in range(1, 6):
            img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png" % (video_name, img_id)),
                                              transform=None)
            imgs.append(img)
        for mask_id in range(1, self.mask_num + 1):
            mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png" % (video_name, mask_id)),
                                               transform=None, mode='P')
            masks.append(mask.astype(np.float32)/255)

        data = self.transform(image=imgs[0],
                              **{f'image{i}': imgs[i] for i in range(1, 5)},
                              mask=masks[0],
                              **{f'mask{i}': masks[i] for i in range(1, 5)}, )
        imgs = [data['image']]
        imgs.extend([data[f'image{i}'] for i in range(1, 5)])
        masks = [data['mask']]
        masks.extend([data[f'mask{i}'] for i in range(1, 5)])
        imgs_tensor = torch.stack(imgs, dim=0)
        masks_tensor = torch.stack(masks, dim=0)

        if self.split == 'train':
            return imgs_tensor, audio_log_mel, masks_tensor
        else:
            return imgs_tensor, audio_log_mel, masks_tensor, video_name

    def __len__(self):
        return len(self.df_split)

