import os
from typing import List, Tuple

import torch
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import pickle
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

import pdb


def load_image_in_PIL_to_Tensor(path, mode='RGB', transform=None):
    img_PIL = Image.open(path).convert(mode)
    if transform:
        img_tensor = transform(img_PIL)
        return img_tensor
    return np.asarray(img_PIL)


def load_audio_lm(audio_lm_path):
    with open(audio_lm_path, 'rb') as fr:
        audio_log_mel = pickle.load(fr)
    audio_log_mel = audio_log_mel.detach()  # [5, 1, 96, 64]
    return audio_log_mel


def train_collate_fn(batch: List[Tuple]):
    """
    Collate function gathers data from all workers and stack them into batch data.
    :param batch: a list contains data from all workers, e.g.,[(img1,audio1,mask1),(img2,audio2,mask2)...]
    :return: batch data
    """
    batch = list(zip(*batch))
    batch_image = torch.stack(batch[0])
    batch_audio = torch.stack(batch[1])
    batch_mask = torch.stack(batch[2])

    return batch_image, batch_audio, batch_mask


def val_collate_fn(batch: List[Tuple]):
    """
    Collate function gathers data from all workers and stack them into batch data.
    :param batch: a list contains data from all workers, e.g.,[(img1,audio1,mask1),(img2,audio2,mask2)...]
    :return: batch data
    """
    batch = list(zip(*batch))
    batch_image = torch.stack(batch[0])
    batch_audio = torch.stack(batch[1])
    batch_mask = torch.stack(batch[2])

    category_list = batch[3]
    video_name_list = batch[4]

    return batch_image, batch_audio, batch_mask, category_list, video_name_list


class S4Dataset(Dataset):
    """Dataset for single sound source segmentation"""

    def __init__(self, anno_csv: str, dir_img: str, dir_audio_log_mel: str, dir_mask, split='train'):
        super(S4Dataset, self).__init__()
        self.split = split
        self.dir_img = dir_img
        self.dir_audio_log_mel = dir_audio_log_mel
        self.dir_mask = dir_mask
        self.collate_fn = train_collate_fn if self.split == 'train' else val_collate_fn
        self.mask_num = 1 if self.split == 'train' else 5
        df_all = pd.read_csv(anno_csv, sep=',')
        self.df_split = df_all[df_all['split'] == split]
        print(f"{len(self.df_split)}/{len(df_all)} videos are used for {self.split}")
        if self.split == 'train':
            self.additional_keys = {
                'image1': 'image',
                'image2': 'image',
                'image3': 'image',
                'image4': 'image', }
            self.transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ToTensorV2()],
                additional_targets=self.additional_keys)
        else:
            self.additional_keys = {
                'image1': 'image',
                'image2': 'image',
                'image3': 'image',
                'image4': 'image',
                'mask1': 'mask',
                'mask2': 'mask',
                'mask3': 'mask',
                'mask4': 'mask', }
            self.transform = A.Compose([
                A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ToTensorV2()],
                additional_targets=self.additional_keys)

    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name, category = df_one_video.iloc[0], df_one_video.iloc[2]
        img_base_path = os.path.join(self.dir_img, self.split, category, video_name)
        audio_lm_path = os.path.join(self.dir_audio_log_mel, self.split, category, video_name + '.pkl')
        mask_base_path = os.path.join(self.dir_mask, self.split, category, video_name)
        audio_log_mel = load_audio_lm(audio_lm_path)
        # audio_lm_tensor = torch.from_numpy(audio_log_mel)
        imgs, masks = [], []
        for img_id in range(1, 6):
            img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s_%d.png" % (video_name, img_id)),
                                              transform=None)
            imgs.append(img)
            if self.split == 'train':
                break
        for mask_id in range(1, self.mask_num + 1):
            mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png" % (video_name, mask_id)),
                                               transform=None, mode='1')
            masks.append(mask.astype(np.float32))
            if self.split == 'train':
                break

        if self.split == 'train':
            data = self.transform(image=imgs[0], mask=masks[0])
            imgs = [data['image']]
            masks = [data['mask']]
            audio_log_mel = audio_log_mel[0:1]
        else:
            data = self.transform(image=imgs[0],
                                  **{f'image{i}': imgs[i] for i in range(1, 5)},
                                  mask=masks[0],
                                  **{f'mask{i}': masks[i] for i in range(1, 5)},
                                  )
            imgs = [data['image']]
            imgs.extend([data[f'image{i}'] for i in range(1, 5)])
            masks = [data['mask']]
            masks.extend([data[f'mask{i}'] for i in range(1, 5)])

        imgs_tensor = torch.stack(imgs, dim=0)
        masks_tensor = torch.stack(masks, dim=0)

        if self.split == 'train':
            return imgs_tensor, audio_log_mel, masks_tensor
        else:
            return imgs_tensor, audio_log_mel, masks_tensor, category, video_name

    def __len__(self):
        return len(self.df_split)


if __name__ == "__main__":
    train_dataset = S4Dataset('train')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=2,
                                                   shuffle=False,
                                                   num_workers=8,
                                                   pin_memory=True)

    for n_iter, batch_data in enumerate(train_dataloader):
        imgs, audio, mask = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224]
        # imgs, audio, mask, category, video_name = batch_data # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224]
        pdb.set_trace()
    print('n_iter', n_iter)
    pdb.set_trace()
