import os
from pathlib import Path
from typing import List, Tuple
import torch
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
import pickle
import json

from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model.torchvggish.vggish_input import wavfile_to_examples



def get_v2_pallete(label_to_idx_path, num_cls=71):
    def _getpallete(num_cls=71):
        """build the unified color pallete for AVSBench-object (V1) and AVSBench-semantic (V2),
        71 is the total category number of V2 dataset, you should not change that"""
        n = num_cls
        pallete = [0] * (n * 3)
        for j in range(0, n):
            lab = j
            pallete[j * 3 + 0] = 0
            pallete[j * 3 + 1] = 0
            pallete[j * 3 + 2] = 0
            i = 0
            while (lab > 0):
                pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
                pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
                pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
                i = i + 1
                lab >>= 3
        return pallete  # list, length is n_classes*3

    with open(label_to_idx_path, 'r') as fr:
        label_to_pallete_idx = json.load(fr)
    v2_pallete = _getpallete(num_cls)  # list
    v2_pallete = np.array(v2_pallete).reshape(-1, 3)
    assert len(v2_pallete) == len(label_to_pallete_idx)
    return v2_pallete


def crop_resize_img(crop_size, img, img_is_mask=False):
    outsize = crop_size
    short_size = outsize
    w, h = img.size
    if w > h:
        oh = short_size
        ow = int(1.0 * w * oh / h)
    else:
        ow = short_size
        oh = int(1.0 * h * ow / w)
    if not img_is_mask:
        img = img.resize((ow, oh), Image.BILINEAR)
    else:
        img = img.resize((ow, oh), Image.NEAREST)
    # center crop
    w, h = img.size
    x1 = int(round((w - outsize) / 2.))
    y1 = int(round((h - outsize) / 2.))
    img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
    # print("crop for train. set")
    return img


def resize_img(crop_size, img, img_is_mask=False):
    outsize = crop_size
    # only resize for val./test. set
    if not img_is_mask:
        img = img.resize((outsize, outsize), Image.BILINEAR)
    else:
        img = img.resize((outsize, outsize), Image.NEAREST)
    return img


def color_mask_to_label(mask, v_pallete):
    mask_array = np.array(mask).astype('int32')
    semantic_map = []
    for colour in v_pallete:
        equality = np.equal(mask_array, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
    # pdb.set_trace() # there is only one '1' value for each pixel, run np.sum(semantic_map, axis=-1)
    label = np.argmax(semantic_map, axis=-1)
    return label


def load_image_in_PIL_to_Tensor(path, crop_img_and_mask, crop_size, split='train', mode='RGB', transform=None):
    img_PIL = Image.open(path).convert(mode)
    if crop_img_and_mask:
        if split == 'train':
            img_PIL = crop_resize_img(crop_size, img_PIL, img_is_mask=False)
        else:
            img_PIL = resize_img(crop_size, img_PIL, img_is_mask=False)
    return np.asarray(img_PIL)



def load_color_mask_in_PIL_to_Tensor(path, crop_img_and_mask, crop_size, v_pallete, split='train', mode='RGB'):
    color_mask_PIL = Image.open(path).convert(mode)
    if crop_img_and_mask:
        if split == 'train':
            color_mask_PIL = crop_resize_img(crop_size, color_mask_PIL, img_is_mask=True)
        else:
            color_mask_PIL = resize_img(crop_size, color_mask_PIL, img_is_mask=True)
    # obtain semantic label
    color_label = color_mask_to_label(color_mask_PIL, v_pallete)
    color_label = color_label[np.newaxis] # [H, W]
    return color_label  # both [1, H, W]


def load_audio_lm(audio_lm_path):
    with open(audio_lm_path, 'rb') as fr:
        audio_log_mel = pickle.load(fr)
    audio_log_mel = audio_log_mel.detach()  # [5, 1, 96, 64]
    return audio_log_mel


def train_collate_fn(batch: List[Tuple]):
    """
    Collate function gathers data from all workers and stack them into batch data.
    :param batch: a list contains data from all workers, e.g.,[(img1,audio1,mask1),(img2,audio2,mask2)...]
    :return: batch data
    """
    batch = list(zip(*batch))
    batch_image = torch.stack(batch[0])
    batch_audio = torch.stack(batch[1])
    batch_mask = torch.stack(batch[2])
    batch_vid_temporal_mask_flag = torch.stack(batch[3])
    batch_gt_temporal_mask_flag = torch.stack(batch[4])

    return batch_image, batch_audio, batch_mask, batch_vid_temporal_mask_flag, batch_gt_temporal_mask_flag,


def val_collate_fn(batch: List[Tuple]):
    """
    Collate function gathers data from all workers and stack them into batch data.
    :param batch: a list contains data from all workers, e.g.,[(img1,audio1,mask1),(img2,audio2,mask2)...]
    :return: batch data
    """
    batch = list(zip(*batch))
    batch_image = torch.stack(batch[0])
    batch_audio = torch.stack(batch[1])
    batch_mask = torch.stack(batch[2])

    batch_vid_temporal_mask_flag = torch.stack(batch[3])
    batch_gt_temporal_mask_flag = torch.stack(batch[4])
    video_name_list = batch[5]

    return batch_image, batch_audio, batch_mask, batch_vid_temporal_mask_flag, batch_gt_temporal_mask_flag, video_name_list


class V2Dataset(Dataset):
    """Dataset for audio visual semantic segmentation of AVSBench-semantic (V2)"""

    def __init__(self, mask_num, meta_cvs_path, label_idx_path, num_classes, dir_base, crop_img_and_mask, crop_size,
                 split='train', debug_flag=False):
        super(V2Dataset, self).__init__()
        self.split = split
        self.mask_num = mask_num
        self.meta_cvs_path = meta_cvs_path
        self.label_idx_path = label_idx_path
        self.dir_base = dir_base
        self.crop_img_and_mask = crop_img_and_mask
        self.crop_size = crop_size
        self.collate_fn = train_collate_fn if self.split == 'train' else val_collate_fn
        df_all = pd.read_csv(meta_cvs_path, sep=',')
        self.df_split = df_all[df_all['split'] == split]
        if debug_flag:
            self.df_split = self.df_split[:100]
            print(f"{len(self.df_split)}/{len(df_all)} videos are used for {self.split}")

        if self.split == 'train':
            self.transform = A.Compose([
                # A.HorizontalFlip(p=0.5),
                # A.ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.25, p=0.75),
                A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ToTensorV2()],
                is_check_shapes=False)
        else:
            self.transform = A.Compose([
                A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ToTensorV2()],
                is_check_shapes=False)

        self.v2_pallete = get_v2_pallete(label_idx_path, num_cls=num_classes)

    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name, set = df_one_video['uid'], df_one_video['label']
        img_base_path = Path(self.dir_base) / set / video_name / 'frames'
        audio_path = Path(self.dir_base) / set / video_name / 'audio.wav'
        color_mask_base_path = Path(self.dir_base) / set / video_name / 'labels_rgb'

        if set == 'v1s':  # data from AVSBench-object single-source subset (5s, gt is only the first annotated frame)
            vid_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])  # .bool()
            gt_temporal_mask_flag = torch.Tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0])  # .bool()
        elif set == 'v1m':  # data from AVSBench-object multi-sources subset (5s, all 5 extracted frames are annotated)
            vid_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])  # .bool()
            gt_temporal_mask_flag = torch.Tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])  # .bool()
        elif set == 'v2':  # data from newly collected videos in AVSBench-semantic (10s, all 10 extracted frames are annotated))
            vid_temporal_mask_flag = torch.ones(10)  # .bool()
            gt_temporal_mask_flag = torch.ones(10)  # .bool()

        img_path_list = sorted(list(img_base_path.iterdir()))  # 5 for v1, 10 for new v2
        imgs_num = len(img_path_list)
        imgs_pad_zero_num = 10 - imgs_num
        imgs = []
        for img_id in range(imgs_num):
            img_path = os.path.join(img_base_path, f"{img_id}.jpg")
            img = load_image_in_PIL_to_Tensor(img_path, split=self.split,
                                              crop_img_and_mask=self.crop_img_and_mask, crop_size=self.crop_size)
            imgs.append(img)
        for pad_i in range(imgs_pad_zero_num):  # ! pad black image?
            img = np.zeros_like(img)
            imgs.append(img)

        labels = []
        mask_path_list = sorted(list(color_mask_base_path.iterdir()))
        for mask_path in mask_path_list:
            if not str(mask_path).endswith(".png"):
                mask_path_list.remove(mask_path)
        mask_num = len(mask_path_list)
        if self.split != 'train':
            if set == 'v2':
                assert mask_num == 10
            else:
                assert mask_num == 5

        mask_num = len(mask_path_list)
        label_pad_zero_num = 10 - mask_num
        for mask_id in range(mask_num):
            mask_path = color_mask_base_path / f"{mask_id}.png"
            color_label = load_color_mask_in_PIL_to_Tensor(mask_path, crop_img_and_mask=self.crop_img_and_mask,
                                                           crop_size=self.crop_size, v_pallete=self.v2_pallete,
                                                           split=self.split)
            labels.append(color_label)
        for pad_j in range(label_pad_zero_num):
            color_label = np.zeros_like(color_label)
            labels.append(color_label)

        for i in range(10):
            data = self.transform(image=imgs[i],mask=labels[i])
            imgs[i] = data['image']
            labels[i] = data['mask']

        imgs_tensor = torch.stack(imgs, dim=0)
        labels_tensor = torch.stack(labels, dim=0)

        audio_tensor = wavfile_to_examples(audio_path)  # [5 or 10, 1, 96, 64]
        # ! notice:
        if audio_tensor.shape[0] != 10:
            new_audio = torch.zeros(10, 1, 96, 64)
            new_audio[:audio_tensor.shape[0]] = audio_tensor
            audio_tensor = new_audio
        # return audio
        if self.split == "train":
            return imgs_tensor, audio_tensor, labels_tensor, vid_temporal_mask_flag, gt_temporal_mask_flag
        else:
            return imgs_tensor, audio_tensor, labels_tensor, \
                vid_temporal_mask_flag, gt_temporal_mask_flag, video_name

    def __len__(self):
        return len(self.df_split)

    @property
    def num_classes(self):
        """Number of categories (including background)."""
        return len(self.classes)

    @property
    def classes(self):
        """Category names."""
        with open(self.label_idx_path, 'r') as fr:
            classes = json.load(fr)
        return classes.keys()
