import os
from wave import _wave_params
import torch
import torch.nn as nn
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
import pickle

import cv2
from PIL import Image
from torchvision import transforms

from config import cfg
import pdb
import soundfile as sf
import torchaudio


def load_image_in_PIL_to_Tensor(path, mode='RGB', transform=None):
    img_PIL = Image.open(path).convert(mode)
    if transform:
        img_tensor = transform(img_PIL)
        return img_tensor
    return torch.LongTensor(np.array(img_PIL))


def load_audio_lm(audio_lm_path):
    with open(audio_lm_path, 'rb') as fr:
        audio_log_mel = pickle.load(fr)
    audio_log_mel = audio_log_mel.detach() # [5, 1, 96, 64]
    return audio_log_mel


class S4Dataset(Dataset):
    """Dataset for single sound source segmentation"""
    def __init__(self, frame_shift=10, split='train', text_path = None):
        super(S4Dataset, self).__init__()
        self.split = split
        self.frame_shift = frame_shift
        self.mask_num = 1 if self.split == 'train' else 5
        # self.mask_num = 5
        df_all = pd.read_csv(cfg.DATA.ANNO_CSV4, sep=',')
        self.df_split = df_all[df_all['split'] == split]
        print("{}/{} videos are used for {}".format(len(self.df_split), len(df_all), self.split))
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  ## For maskclip, normalize later
        ])
        self.mask_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.text_out = False
        if text_path and text_path != 'no':
            df_text = pd.read_csv(text_path)
            text_col = 'object' if 'object' in df_text else 'text'
            df_text[text_col] = df_text[text_col].fillna('sounding object')
            self.df_text = df_text
            self.text_out = True
            self.text_col = text_col
        all_text = sorted(df_all['category'].unique())
        self.text2idx = {text: idx + 1 for idx, text in enumerate(all_text)}
        self.num_classes = len(all_text)

    def __getitem__(self, index):
        df_one_video = self.df_split.iloc[index]
        video_name, category = df_one_video['name'], df_one_video['category']
        # category = 'chainsawing_trees'
        # video_name = '-eV9hw6kPRU'
        img_base_path =  os.path.join(cfg.DATA.DIR_IMG4, self.split, category, video_name)
        audio_path = os.path.join(cfg.DATA.DIR_AUDIO4, self.split, category)
        img_frame_path = []
        mask_base_path = os.path.join(cfg.DATA.DIR_MASK4, self.split, category, video_name)
        masks = []
        for mask_id in range(1, self.mask_num + 1):
        # for mask_id in range(1, 2):
            mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png"%(video_name, mask_id)), transform=self.mask_transform, mode='1')
            masks.append(mask)
        masks_tensor = torch.stack(masks, dim=0)

        if self.text_out:
            if self.mask_num == 1:
                text = self.df_text.loc[self.df_text.name == video_name, self.text_col].values[0]
            else:
                if 'frame' in self.df_text.columns:
                    text = self.df_text.sort_values(by=['name', 'frame']).loc[self.df_text.name == video_name, self.text_col].values.tolist()
                else:
                    text = [self.df_text.loc[self.df_text.name == video_name, self.text_col].values[0]] * self.mask_num
                text = "<sep>".join(text)
            return img_base_path, masks_tensor, video_name, audio_path, 5, self.text2idx[category], text
        return img_base_path, masks_tensor, video_name, audio_path, 5, self.text2idx[category]

        # for img_id in range(1, 6):
        #     img_frame_path.append(os.path.join(img_base_path, "%s_%d.png"%(video_name, img_id)))
        # for mask_id in range(1, self.mask_num + 1):
        #     mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png"%(video_name, mask_id)), transform=self.mask_transform, mode='1')
        #     masks.append(mask)
        # masks_tensor = torch.stack(masks, dim=0)

        # return masks_tensor, category, audio_path, video_name, img_frame_path

    def get_data(self, name, frame_id):
        video_name = name
        category = self.df_split.loc[(self.df_split.name == name), 'category'].values[0]
        img_base_path =  os.path.join(cfg.DATA.DIR_IMG4, self.split, category, video_name)
        audio_path = os.path.join(cfg.DATA.DIR_AUDIO4, self.split, category)
        mask_base_path = os.path.join(cfg.DATA.DIR_MASK4, self.split, name)
        img_tensor = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s_%d.png"%(video_name, frame_id)), transform=self.img_transform)

        if 'frame' in self.df_text.columns:
            text = self.df_text.loc[(self.df_text.name == name) & (self.df_text.frame == frame_id-1), self.text_col].values[0]
        else:
            text = self.df_text.loc[self.df_text.name == video_name, self.text_col].values[0]

        return img_tensor, img_base_path, audio_path, text

    def __len__(self):
        return len(self.df_split)


class MS3Dataset(Dataset):
    """Dataset for multiple sound source segmentation"""
    def __init__(self, split='train', text_path="", text_base=""):
        super(MS3Dataset, self).__init__()
        self.split = split
        self.mask_num = 5
        df_all = pd.read_csv(cfg.DATA.ANNO_CSV3, sep=',')
        self.df_split = df_all[df_all['split'] == split]
        print("{}/{} videos are used for {}".format(len(self.df_split), len(df_all), self.split))
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.mask_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.text_out = False
        if text_path and text_path != 'no':
            df_text = pd.read_csv(text_path)
            # print(df_text.columns)
            if text_base and text_base != 'no':
                df_text_base = pd.read_csv(text_base)
                df_text = pd.merge(df_text, df_text_base, on=['name', 'frame'], suffixes=['', '_b'])
                self.df_text = df_text.fillna('sounding object')
                self.text_out = True
                self.text_major, self.text_base = 'object', 'text'
            else:
                text_col = 'object' if 'object' in df_text else 'text'
                df_text[text_col] = df_text[text_col].fillna('sounding object')
                self.df_text = df_text
                self.text_out = True
                self.text_col = text_col
                self.text_major = None

        all_text = sorted(set(x.replace('-', ' ') for y in  df_all['class'].unique() for x in y.split('_')))
        self.text2idx = {text: idx + 1 for idx, text in enumerate(all_text)}
        self.num_classes = len(all_text)

    def __getitem__(self, index):
        video_name = self.df_split.iloc[index]['video_id']
        categories = [x.replace('-', ' ') for x in self.df_split.iloc[index]['class'].split('_')]
        img_base_path =  os.path.join(cfg.DATA.DIR_IMG3, video_name)
        # audio_lm_path = os.path.join(cfg.DATA.DIR_AUDIO_LOG_MEL, self.split, video_name + '.pkl')
        mask_base_path = os.path.join(cfg.DATA.DIR_MASK3, self.split, video_name)
        # audio_log_mel = load_audio_lm(audio_lm_path)
        # audio_lm_tensor = torch.from_numpy(audio_log_mel)
        audio_path = os.path.join(cfg.DATA.DIR_AUDIO3, self.split)
        imgs, masks = [], []
        for img_id in range(1, 6):
            img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png"%(video_name, img_id)), transform=self.img_transform)
            imgs.append(img)
        for mask_id in range(1, self.mask_num + 1):
            mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png"%(video_name, mask_id)), transform=self.mask_transform, mode='P')
            masks.append(mask)
        imgs_tensor = torch.stack(imgs, dim=0)
        masks_tensor = torch.stack(masks, dim=0)

        if self.text_out:
            # print(self.df_text.columns)
            if self.text_major:
                text_major = self.df_text.sort_values(by=['name', 'frame']).loc[(self.df_text.name == video_name), self.text_major].values.tolist()
                text_base = self.df_text.sort_values(by=['name', 'frame']).loc[(self.df_text.name == video_name), self.text_base].values.tolist()
                return img_base_path, masks_tensor, video_name, audio_path, self.mask_num, "<sep>".join(text_major), "<sep>".join(text_base)
            else:
                text = self.df_text.sort_values(by=['name', 'frame']).loc[(self.df_text.name == video_name), self.text_col].values.tolist()
                return img_base_path, masks_tensor, video_name, audio_path, self.mask_num, "<sep>".join(text), ""

        catids = [self.text2idx[c] for c in categories]
        cats = torch.zeros(self.num_classes + 1)
        cats[catids] = 1
        return img_base_path, masks_tensor, video_name, audio_path, 5, cats

    def get_data(self, name, frame_id):
        # video_name = self.df_split.iloc[index]['video_id']
        video_name = name
        img_base_path =  os.path.join(cfg.DATA.DIR_IMG3, video_name)
        # audio_lm_path = os.path.join(cfg.DATA.DIR_AUDIO_LOG_MEL, self.split, video_name + '.pkl')
        mask_base_path = os.path.join(cfg.DATA.DIR_MASK3, self.split, video_name)
        # audio_log_mel = load_audio_lm(audio_lm_path)
        # audio_lm_tensor = torch.from_numpy(audio_log_mel)
        audio_path = os.path.join(cfg.DATA.DIR_AUDIO3, self.split)
        # imgs, masks = [], []
        img_tensor = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png"%(video_name, frame_id)), transform=self.img_transform)
        # for img_id in range(1, 6):
        #     img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png"%(video_name, img_id)), transform=self.img_transform)
        #     imgs.append(img)
        # for mask_id in range(1, self.mask_num + 1):
        #     mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png"%(video_name, mask_id)), transform=self.mask_transform, mode='P')
        #     masks.append(mask)
        # imgs_tensor = torch.stack(imgs, dim=0)
        # masks_tensor = torch.stack(masks, dim=0)
        text = self.df_text.loc[(self.df_text.name == video_name) & (self.df_text.frame == frame_id), self.text_col].values[0]
        return img_tensor, img_base_path, audio_path, text


    def __len__(self):
        return len(self.df_split)


def split(frame_name_lists, sample_num):
    if len(frame_name_lists) < sample_num:   ###padding with the last frame
        frame_name_lists += [frame_name_lists[-1]]*(sample_num - len(frame_name_lists))
    k, m = divmod(len(frame_name_lists), sample_num)
    return [frame_name_lists[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(sample_num))]


class AudioMapper:
    def __init__(self, audio_dir, melbins, target_length, frame_shift, sample_num, training = True):
        mean, std = -4.2677393, 4.5689974
        self.audio_dir = audio_dir
        self.melbins = melbins
        self.target_length = target_length
        self.mean = mean
        self.std = std
        self.training = training
        self.frame_shift = frame_shift
        self.sample_num = sample_num

    def __getitem__(self, id_):

        wav_file = os.path.join(self.audio_dir, id_+'.wav')
        if not os.path.exists(wav_file):
            wav_file = wav_file.replace('wav','mkv')
        if not os.path.exists(wav_file):
            return torch.zeros(self.sample_num, self.melbins, self.target_length)
        

        try:
            #### has no audio channel, use zero instead
            # LOGGER.info(f'{id_} has no audio file, use zero instead')
                
            waveform, sr = torchaudio.load(wav_file)

            waveform = waveform - waveform.mean()
            fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
                                                    window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=self.frame_shift)

            #### fbank shape :(src_length,64)
            # src_length = fbank.shape[0]

            # # #### sample 
            # output_slices = []

            # pad_len = self.target_length - src_length % self.target_length
            # fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank)
            # total_slice_num = fbank.shape[0] // self.target_length
            # total_slice_num = list(range(total_slice_num))
            # total_slice_num = split(total_slice_num, self.sample_num)
            
            # if self.training:
            #     sample_idx = [np.random.choice(i) for i in total_slice_num]
            # else:
            #     sample_idx = [i[(len(i)+1)//2-1] for i in total_slice_num]

            
            # for i in sample_idx:
            #     output_slices.append(fbank[i*self.target_length : (i+1)*self.target_length])
            
            # fbank = torch.stack(output_slices,dim=0).permute(0,2,1)   


            output_slices = []
            src_length = int(np.ceil(fbank.shape[0] / 5))

            out_num = 1 if self.training else 5

            for i in range(out_num):
                src = fbank[i*src_length:min((i+1)*src_length, fbank.shape[0])]
                pad_len = self.target_length - src.shape[0]
                dst = torch.nn.ZeroPad2d((0,0,0, pad_len))(src)
                output_slices.append(dst)
            
            fbank = torch.stack(output_slices, dim=0).permute(0,2,1)

            ### normalization
            fbank = (fbank - self.mean) / (self.std * 2)

            #return fbank.permute(1,0)  ### 128, target_length

            return fbank
           

        except Exception as e:
            print(e)
            return


if __name__ == "__main__":
    train_dataset = S4Dataset('train')
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                     batch_size=2,
                                                     shuffle=False,
                                                     num_workers=8,
                                                     pin_memory=True)

    for n_iter, batch_data in enumerate(train_dataloader):
        imgs, audio, mask = batch_data # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224]
        # imgs, audio, mask, category, video_name = batch_data # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224]
        pdb.set_trace()
    print('n_iter', n_iter)
    pdb.set_trace()
