import pickle
import random
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist
import math
from pathlib import Path
from glob import glob
import os
from tqdm import tqdm
from utils.indexed_datasets import IndexedDataset
from utils.image_utils import get_images, get_random_spk_img, get_spk_face_img
from utils.text_process import string_to_phoneme_id_list, ph_id2ctc_ph_id
# from utils.utils import collate_nd
from utils.hparams import hparams as global_hparams
from utils.video_aug_utils import get_video_aug, RandomAddSil
import cv2
from modules.fastspeech.videotts.resnet50_ft_dag import MEAN_RGB
from vidaug import augmentors as va
from utils.pitch_utils import norm_interp_f0



from torch.utils.data.sampler import Sampler


def copy_tensor(src, dst):
    assert dst.numel() == src.numel(), (dst.numel(), src.numel())
    dst.copy_(src)


def collate_nd(values, pad_idx=0):
    """
    Convert a list of n d tensors into a padded n+1 d tensor.
    (T, X, Y, *) -> (B, T_max, X, Y, *)
    T must be the first dimension
    need n >= 1
    """
    assert len(values) > 0
    T_max = max(v.size(0) for v in values)
    res = values[0].new(len(values), T_max, *values[0].shape[1:]).fill_(pad_idx)
    for i, v in enumerate(values):
        copy_tensor(v, res[i][:v.size(0)])
    return res


class BinnedLengthSampler(Sampler):
    """
    Binned Length Sampler which supports DDP mode.
    """
    def __init__(self, lengths, batch_size, bin_size, use_ddp=False, num_replicas=1, rank=0, shuffle=True, seed=0):
        """
        lengths: list/np.array of data length
        batch_size: int
        bin_size: int, bin_size == K * batch_size
        """
        if use_ddp:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()
        self.use_ddp = use_ddp
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(lengths) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

        self.lengths = np.asarray(lengths)
        _, self.idx = torch.sort(torch.as_tensor(lengths).long())
        self.batch_size = batch_size
        self.bin_size = bin_size
        assert self.bin_size % self.batch_size == 0
        assert self.num_replicas == 1 or self.use_ddp, (self.num_replicas, self.use_ddp)

    def __iter__(self):
        # Need to change to numpy since there's a bug in random.shuffle(tensor)
        # TODO: Post an issue on pytorch repo
        if self.use_ddp and self.num_replicas > 1:
            sorted_pos_idx = self.divide_to_replicas()
            idx = self.idx.numpy()[sorted_pos_idx]
        else:
            idx = self.idx.numpy()

        bins = []
        bin_nums = len(idx) // self.bin_size
        for i in range(bin_nums):
            this_bin = idx[i * self.bin_size:(i + 1) * self.bin_size]
            random.shuffle(this_bin)
            bins += [this_bin]

        binned_idx = None
        if bin_nums > 0:
            if not self.use_ddp:
                random.shuffle(bins)
            binned_idx = np.stack(bins).reshape(-1)

        binned_len = bin_nums * self.bin_size
        if binned_len < len(idx):
            last_bin = idx[binned_len:]
            random.shuffle(last_bin)
            binned_idx = last_bin if binned_idx is None else np.concatenate([binned_idx, last_bin])

        return iter(torch.as_tensor(binned_idx).long())

    def __len__(self):
        return self.num_samples

    def divide_to_replicas(self):
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.lengths), generator=g).tolist()
        else:
            indices = list(range(len(self.lengths)))


        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        # sort in ascending order
        indices.sort()  # The sort is in-place
        assert len(indices) == self.num_samples

        return indices

    def set_epoch(self, epoch):
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.

        Arguments:
            epoch (int): Epoch number.
        """
        self.epoch = epoch

###################################################################################
# Tacotron/TTS Dataset ############################################################
###################################################################################


def get_tts_dataloader(hparams, prefix='train', use_ddp=False):
    batch_size = hparams['batch_size']
    num_workers = hparams['dataloader_num_workers']
    transform = get_video_aug(hparams, hparams.use_vid_aug if prefix == 'train' else False)
    my_dataset = TtsDataset(hparams, prefix) if hparams['task_type'] == 'tts' \
        else VideoTtsDataset(hparams, prefix, transform=transform)
    collate_fn = collate_tts if hparams['task_type'] == 'tts' else collate_video_tts

    sampler = BinnedLengthSampler(my_dataset.sizes, batch_size, batch_size * 3, use_ddp=use_ddp)

    my_dataloader = DataLoader(my_dataset,
                               collate_fn=collate_fn,
                               batch_size=batch_size,
                               sampler=sampler,
                               num_workers=num_workers,
                               pin_memory=True)

    indexed_ds = IndexedDataset(f"{hparams['binary_data_root']}/{prefix}")
    longest_idx = my_dataset.sizes.index(max(my_dataset.sizes))
    # Used to evaluate attention during training process
    attn_example_id = indexed_ds[longest_idx]['id']

    # print(attn_example)

    return my_dataloader, attn_example_id


class TtsDataset(Dataset):
    def __init__(self, hparams, prefix):
        self.hparams = hparams
        self.data_dir = hparams['binary_data_root']
        self.prefix = prefix
        self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy').tolist()
        # pitch stats
        f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
        hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
        hparams['f0_mean'] = float(hparams['f0_mean'])
        hparams['f0_std'] = float(hparams['f0_std'])
        self.indexed_ds = None

    def _get_item(self, index):
        if self.indexed_ds is None:
            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
        return self.indexed_ds[index]

    def __getitem__(self, index):
        item = self._get_item(index)
        mel = item['mel']  # (80, T)
        mel_len = mel.shape[1]
        ret = dict()
        ret['token'] = item['token']
        ret['mel'] = mel
        ret['item_id'] = item['id']
        ret['mel_len'] = mel_len
        if 'f0' in item:
            f0, uv = norm_interp_f0(item["f0"][:mel_len], global_hparams)
            ret['f0'] = f0
            ret['uv'] = uv
        if 'pitch' in item:
            ret['pitch'] = item['pitch'][:mel_len]
        return ret

    def __len__(self):
        return len(self.sizes)


class VideoTtsDataset(TtsDataset):
    def __init__(self, hparams, prefix, transform=None):
        super(VideoTtsDataset, self).__init__(hparams, prefix)
        self.transform = transform
        self.rand_add_sil = va.Sometimes(0.5, RandomAddSil(self.hparams['video_fps']//2, self.hparams['video_fps']*2,
                                                           self.hparams['vid_mel_repeat_num']))

    def __getitem__(self, index):
        item = self._get_item(index)
        mel = item['mel']  # (80, T)
        img = item['img'][::self.hparams['img_stride']]  # (T) (H, W, *)
        if self.transform is not None:
            img = self.transform(img)  # list
        img = np.asarray(img, dtype=np.float32) / 255.0  # img shape: (T, H, W, C)
        if len(img.shape) < 4:
            img = img[..., None]
        img, mel = self.mul_video_and_mel(img, mel, self.hparams['vid_mel_repeat_num'])
        if self.hparams['use_add_front_sil'] and self.prefix == 'train':
            img, mel = self.rand_add_sil((img, mel))
        img_len = img.shape[0]
        mel_len = mel.shape[1]  # mel shape: (80, T)
        ret = dict()
        ret['token'] = item['token']
        ret['mel'] = mel
        ret['item_id'] = item['id']
        ret['mel_len'] = mel_len
        ret['images'] = img
        ret['img_len'] = img_len
        if 'f0' in item:
            f0, uv = norm_interp_f0(item["f0"][:mel_len], global_hparams)
            ret['f0'] = f0
            ret['uv'] = uv
        if 'pitch' in item:
            ret['pitch'] = item['pitch'][:mel_len]
        if global_hparams['use_img_spk_embed'] and 'spk_img' in item:
            spk_img = get_random_spk_img(item['id'], global_hparams['preprocessed_root']) \
                if global_hparams['random_spk_img'] else item['spk_img']
            ret['spk_img'] = cv2.resize(spk_img,
                                        (self.hparams['spk_img_size'], self.hparams['spk_img_size'])
                                        ).astype(np.float32) - MEAN_RGB
        if self.hparams['use_mutual_information']:
            if ret['img_len'] <= len(ret['token']):
                print(f"| Skip item {ret['item_id']}, img_len = {ret['img_len']}, token_len = {len(ret['token'])}")
                return None
            ret['ctc_token'] = ph_id2ctc_ph_id(ret['token'])
        return ret

    @staticmethod
    def mul_video_and_mel(images, mels, vid_mel_repeat_num):
        img_len = images.shape[0]
        mel_len = mels.shape[-1]
        assert abs(img_len * vid_mel_repeat_num - mel_len) < 50, (img_len * vid_mel_repeat_num, mel_len)
        if img_len * vid_mel_repeat_num < mel_len:
            mels = mels[:, :img_len * vid_mel_repeat_num]
        elif img_len * vid_mel_repeat_num > mel_len:
            new_img_len = mel_len // vid_mel_repeat_num
            new_mel_len = new_img_len * vid_mel_repeat_num
            images = images[:new_img_len]
            mels = mels[:, :new_mel_len]
        assert images.shape[0] * vid_mel_repeat_num == mels.shape[1], f'{images.shape}, {mels.shape}'
        return images, mels


def pad1d(x, max_len):
    return np.pad(x, (0, max_len - len(x)), mode='constant')


def pad2d(x, max_len):
    return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode='constant')


def collate_tts(batch):
    """
    batch[0]:
    ret['token'] = item['token']
    ret['mel'] = mel
    ret['item_id'] = item['id']
    ret['mel_len'] = mel_len
    ret['images'] = img
    ret['img_len'] = img_len
    if 'f0' in item:
        ret['f0'] = item['f0'][:mel_len]
    if 'pitch' in item:
        ret['pitch'] = item['pitch'][:mel_len]
    """
    batch = [b for b in batch if b is not None]
    x_lens = [len(x['token']) for x in batch]
    max_x_len = max(x_lens)
    tokens = [pad1d(x['token'], max_x_len) for x in batch]
    tokens = np.stack(tokens)  # (B, max_x_len)
    tokens = torch.as_tensor(tokens).long()

    spec_lens = [x['mel'].shape[-1] for x in batch]
    max_spec_len = max(spec_lens)
    mels = [pad2d(x['mel'], max_spec_len) for x in batch]
    mels = np.stack(mels)  # (B, 80, T)
    mels = torch.as_tensor(mels)
    mel_lens = [x['mel_len'] for x in batch]

    ids = [x['item_id'] for x in batch]

    # scale spectrograms from 0 <--> 1 to -4 <--> 4
    if global_hparams['mel_gen_mode'] == 'wavernn':
        mels = (mels * 8.) - 4.

    ret = dict()
    ret['ids'] = ids
    ret['tokens'] = tokens
    ret['token_lens'] = x_lens
    ret['mels'] = mels.transpose(1, 2)  # (B, T, H)
    ret['mel_lens'] = mel_lens

    if 'f0' in batch[0]:
        f0s = [pad1d(x['f0'], max_spec_len) for x in batch]
        f0s = np.stack(f0s)
        ret['f0s'] = torch.as_tensor(f0s).float()
    if 'uv' in batch[0]:
        uvs = [pad1d(x['uv'], max_spec_len) for x in batch]
        uvs = np.stack(uvs)
        ret['uvs'] = torch.as_tensor(uvs).float()
    if 'pitch' in batch[0]:
        pitches = [pad1d(x['pitch'], max_spec_len) for x in batch]
        pitches = np.stack(pitches)
        ret['pitches'] = torch.as_tensor(pitches).long()
    if 'ctc_token' in batch[0]:
        tokens = np.stack([pad1d(x['ctc_token'], max_x_len) for x in batch])
        ret['ctc_tokens'] = torch.as_tensor(tokens).long()  # (B, max_x_len)
    return ret


def collate_video_tts(batch):
    batch = [b for b in batch if b is not None]
    ret = collate_tts(batch)
    imgs = collate_nd([torch.from_numpy(x['images']).float() for x in batch])
    img_lens = [x['img_len'] for x in batch]
    # ret['mels'] = ret['mels']  # (B, T, H)
    ret['imgs'] = imgs  # (B, T, H, W, C)
    ret['img_lens'] = img_lens
    if 'spk_img' in batch[0]:
        ret['spk_img'] = torch.from_numpy(np.stack([x['spk_img'] for x in batch], axis=0))  # (B, H, W, 3)
    return ret


def get_input_from_line(line, hparams):
    line = line.strip()
    infos = line.split('|')
    if len(infos) in [2, 3]:
        if len(infos) == 3:
            text_id, text, video_id = infos
            if hparams['change_spk_face']:
                text_id, text, video_id = infos[0], infos[1], infos[0]
                spk_face_img_path = infos[2]
                text_id = text_id+'_'+spk_face_img_path[:-4].replace('/', '_')
        else:  # len(infos) == 2:
            text_id, text, video_id = infos[0], infos[1], infos[0]
        ep_id, cut_id = video_id.rsplit('-', 1)
        raw_mp4_path = Path(hparams['sent_video_path']) / f'{ep_id}' / f'{cut_id}.mp4'
        imgs_path = Path(hparams['preprocessed_root']) / f'{ep_id}' / f'{cut_id}'
    elif len(infos) == 5:
        text_id, text, video_id, raw_mp4_path, imgs_path = infos
    else:
        raise ValueError("Info is wrong!")

    video, spk_img = get_images(imgs_path, hparams)
    if hparams['change_spk_face']:
        spk_img = get_spk_face_img(spk_face_img_path, hparams)
    # shape: (T, H, W, C), T == the frame num in video
    video = video[::hparams['img_stride']]  # (T, H, W, *)
    if hparams['add_sil_video_test']:
        video = np.concatenate((np.stack([video[0]] * 75, axis=0), video), axis=0)
    text_id_seq = string_to_phoneme_id_list(text.strip())

    ret = dict()
    ret['tokens'] = torch.as_tensor(text_id_seq).unsqueeze(0).long()
    ret['token_lens'] = [len(text_id_seq)]
    ret['token_id'] = text_id
    ret['video_id'] = video_id
    ret['imgs'] = torch.as_tensor(video).unsqueeze(0).float()
    ret['img_lens'] = [video.shape[0]]
    ret['video_path'] = raw_mp4_path
    ret['spk_img'] = torch.as_tensor(spk_img).unsqueeze(0).float()  # (1, H, W, 3)
    return ret


def get_test_dataloader(hparams):
    # pitch stats
    f0_stats_fn = f'{hparams["binary_data_root"]}/train_f0s_mean_std.npy'
    hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn)
    hparams['f0_mean'] = float(hparams['f0_mean'])
    hparams['f0_std'] = float(hparams['f0_std'])
    ########
    text_file = hparams['videotts_test_file']
    with open(text_file, 'r') as f:
        inputs = [get_input_from_line(l, hparams) for l in f]
    return inputs


