import os
import json

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib
from scipy.io import wavfile
from scipy.interpolate import make_interp_spline 
from matplotlib import pyplot as plt
import random 

matplotlib.use("Agg")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def to_device(data, device):

    assert len(data) == 9 or len(data) == 8  
    if len(data) == 9:
        (ids, raw_texts, speakers, texts, src_lens, max_src_len, np_lens, npids, cwids) = data

        speakers = torch.from_numpy(speakers).long().to(device)
        texts = torch.from_numpy(texts).long().to(device)
        src_lens = torch.from_numpy(src_lens).to(device)
        np_lens  = torch.from_numpy(np_lens).to(device)
        npids = torch.from_numpy(npids).to(device)
        cwids = torch.from_numpy(cwids).to(device)

        return (ids, raw_texts, speakers, texts, src_lens, max_src_len, np_lens, npids, cwids)
    else:
        (ids, raw_texts, speakers, texts, src_lens, max_src_len, npids, cwids) = data

        speakers = torch.from_numpy(speakers).long().to(device)
        texts = torch.from_numpy(texts).long().to(device)
        src_lens = torch.from_numpy(src_lens).to(device)
        npids = torch.from_numpy(npids).to(device)
        cwids = torch.from_numpy(cwids).to(device)

        return (ids, raw_texts, speakers, texts, src_lens, max_src_len, npids, cwids)

def log(logger, step=None, loss=None, model=None, fig=None, audio=None, duration=None, pitch=None, sampling_rate=22050, tag=""):
    if loss is not None:
        logger.add_scalar("Loss/diversity_loss", loss, step)
    if model is not None:
        model.eval() 
        for name, weight in model.named_parameters():
            logger.add_histogram(name, weight, step)
        model.train()
    if audio is not None:
        logger.add_audio(
            tag,
            audio / max(abs(audio)),
            sample_rate=sampling_rate,
        )


    if duration is not None:
        T = duration.size(-1)
        x = np.linspace(0, T, num=T, endpoint=False)
        vanilla, dpp = duration[1].detach().cpu().numpy(), duration[0].detach().cpu().numpy()
        vanilla, dpp = np.log(vanilla), np.log(dpp) 
        # Smoothing 
        X_Y_Spline1 = make_interp_spline(x, vanilla)
        X_Y_Spline2 = make_interp_spline(x, dpp)
        X_ = np.linspace(x.min(), x.max(), 8*T)
        Y_vanilla = X_Y_Spline1(X_)
        Y_dpp = X_Y_Spline2(X_)  
        plt.grid(True, axis='y', color='white', linestyle='--')
        ax = plt.axes()
        ax.set_facecolor('#ededed')
        plt.plot(X_, Y_dpp, color='#ef626c', label='dpp')
        plt.plot(X_, Y_vanilla, color='#358866', label='vanilla')
        plt.legend(loc='best')
        logger.add_figure('Duration plot', plt.gcf(), global_step=step)

    if pitch is not None:
        T = pitch.size(-1) 
        x = np.linspace(0, T, num=T, endpoint=False)
        vanilla, dpp = pitch[1].detach().cpu().numpy(), pitch[0].detach().cpu().numpy()
        # Smoothing 
        X_Y_Spline1 = make_interp_spline(x, vanilla)
        X_Y_Spline2 = make_interp_spline(x, dpp)
        X_ = np.linspace(x.min(), x.max(), 8*T)
        Y_vanilla = X_Y_Spline1(X_)
        Y_dpp = X_Y_Spline2(X_)  
        plt.grid(True, axis='y', color='white', linestyle='--')
        ax = plt.axes()
        ax.set_facecolor('#ededed')
        plt.plot(X_, Y_dpp, color='#ef626c', label='dpp')
        plt.plot(X_, Y_vanilla, color='#358866', label='vanilla')
        plt.legend(loc='best')
        logger.add_figure('Pitch plot', plt.gcf(), global_step=step)

def random_select(np_ids, lcw_ids, rcw_ids, lengths):
    # input.shape = [B,M,2] , lengths.shape=[B] 
    idx = torch.Tensor([random.randint(0, lengths[i]-1) for i in range(np_ids.size(0))]).long().cuda()   # shape=[B]
    idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1,-1,2)
    np_ids, lcw_ids, rcw_ids = np_ids.gather(1,idx) , lcw_ids.gather(1,idx), rcw_ids.gather(1,idx)
    np_id , lcw_id, rcw_id = np_ids.squeeze(1), lcw_ids.squeeze(1), rcw_ids.squeeze(1)

    return np_id, lcw_id, rcw_id    

def get_mask_from_lengths(lengths, max_len=None):
    batch_size = lengths.shape[0]
    if max_len is None:
        max_len = torch.max(lengths).item()
    if max_len == 0:
        ids = torch.arange(0, max_len+1).unsqueeze(0).expand(batch_size, -1).to(device)
        mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
    else:
        ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
        mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)

    return mask

def get_chunk_position(lcw_ids, np_ids, rcw_ids):
    # Input.shape = [B,2]
    B = lcw_ids.size(0)
    chunk_ids = [] 
    for i in range(B):
        if -1 in lcw_ids[i,:]:
            start = np_ids[i,0]
        else:
            start = lcw_ids[i,0]
        if -1 in rcw_ids[i,:]:
            end = np_ids[i,1]
        else:
            end= rcw_ids[i,1]
        chunk_ids.append(torch.Tensor([start,end]).cuda()) 
    chunk_ids = torch.stack(chunk_ids, dim=0)
    return chunk_ids   

def expand(input):
    # input.shape = [B,2]
    input = 2*input + torch.ones_like(input)
    return input 

def get_len(input):
    # input.shape = [B,2]  
    output = [id[1] -id[0] + 1 if not -1 in id else id[1] - id[0] for id in input]
    output = torch.stack(output)

    return output.long() 

def get_idxs_from_position(idxs, lengths):
    # ids.shape = [B,2] , lengths.shape=[B] , output.shape=[B,N] 
    max_len = max(lengths)
    # max_len 이 0일 수도 있다.
    if max_len: 
        idxs =  [F.pad(torch.arange(idx[0], idx[1]+1), (0, max_len-int(idx[1]-idx[0]+1))) if not -1 in idx \
                    else F.pad(torch.arange(0,1), (0, max_len-1)) for idx in idxs]
        return torch.stack(idxs)
    else:
        idxs = torch.arange(0,1).expand(lengths.size(0), -1)
        return idxs   

def get_random_phrase(np_num, np_ids, cw_ids):
    # Output.shape = [B, ?]
    lcw_ids, rcw_ids = cw_ids[:,:,:2], cw_ids[:,:,2:]
    np_ids, lcw_ids, rcw_ids =  random_select(np_ids, lcw_ids, rcw_ids, np_num)
    np_ids, lcw_ids, rcw_ids = expand(np_ids), expand(lcw_ids), expand(rcw_ids)      # phonemes are interspersed with transition tokens
    chunk_ids = get_chunk_position(lcw_ids, np_ids, rcw_ids)

    np_len , lcw_len, rcw_len = get_len(np_ids), get_len(lcw_ids), get_len(rcw_ids)
    #print(lcw_len, lcw_ids)
    #print(np_len, np_ids)
    chunk_len = get_len(chunk_ids) 
    chunk_mask, np_mask, lcw_mask, rcw_mask = get_mask_from_lengths(chunk_len), \
            get_mask_from_lengths(np_len),  get_mask_from_lengths(lcw_len), get_mask_from_lengths(rcw_len)

    # Turn positions into idxs 
    chunk_ids, np_ids, lcw_ids, rcw_ids = get_idxs_from_position(chunk_ids, chunk_len), \
        get_idxs_from_position(np_ids, np_len), get_idxs_from_position(lcw_ids, lcw_len), get_idxs_from_position(rcw_ids, rcw_len)
    
    return (chunk_ids.long().cuda(), np_ids.long().cuda(), lcw_ids.long().cuda(), rcw_ids.long().cuda(), \
                    np_len.cuda(), lcw_len.cuda(), rcw_len.cuda(), chunk_mask, np_mask, lcw_mask, rcw_mask)


def dpp_collate(l_seq, r_seq, target_seq, lcw_quality, rcw_quality, target_quality,
                         lcw_len, rcw_len, np_len, nc):
    """ Input shape- l_seq: [b, 1, T_l] , r_seq: [b, 1, T_r], target_seq: [nc, b, 1, T_t]
        lcw_quality: [b, 1], rcw_quality: [b, 1], target_quality: [b, nc]
        output shape- vectors, quality_vectors, d_len, mask_idxs : [b,nc+2,t,1] , [b,nc+2,1], [nc+2, b], [l] """
    
    max_np = max(np_len)
    max_lcw, max_rcw = torch.clamp(max(lcw_len), min=1), torch.clamp(max(rcw_len), min=1)   # set min=1 for the case when max_len = 0
    pad_len = max(max_lcw, max_rcw, max_np) # t 
    
    # Make all sequences into same lengths 
    l_seq = F.pad(l_seq, (0, pad_len - max_lcw))
    r_seq = F.pad(r_seq, (0, pad_len - max_rcw))
    target_seq = F.pad(target_seq, (0, pad_len - max_np))
    
    # move dummy dimensions into the first position 
    swap_idxs = [i for i in range(rcw_len.size(0)) if rcw_len[i]==0]
    zero_idxs = [i for i in range(lcw_len.size(0)) if lcw_len[i]==0]

    # Stack contexts -> list of b vectors of shape [2,1,t]
    vectors = [torch.stack([l_seq[i,:,:], r_seq[i,:,:]],dim=0) if i not in swap_idxs \
                                    else torch.stack([r_seq[i,:,:], l_seq[i,:,:]],dim=0) for i in range(rcw_len.size(0))]

    vectors = torch.stack(vectors, dim=0)   # [b,2,1,t]
    target_seq = target_seq.transpose(0,1)    # [nc, b, 1, t] -> [b, nc, 1, t]
    vectors = torch.cat([vectors, target_seq], dim=1).squeeze(-2).unsqueeze(-1)   #shape = [b,nc+2,t,1]

    quality_vectors  = [torch.stack([lcw_quality[i,:], rcw_quality[i,:]], dim=0) if i not in swap_idxs \
                                    else torch.stack([rcw_quality[i,:], lcw_quality[i,:]], dim=0) for i in range(rcw_len.size(0))]   

    quality_vectors = torch.stack(quality_vectors, dim=0).squeeze(-1) # shape = [b,2]
    quality_vectors = torch.cat([quality_vectors, target_quality], dim=1) # shape = [b, nc+2]

    d1_len = torch.stack([lcw_len[i] if i not in swap_idxs else rcw_len[i] for i in range(len(lcw_len))]).unsqueeze(0)
    d2_len = torch.stack([rcw_len[i] if i not in swap_idxs else lcw_len[i] for i in range(len(lcw_len))]).unsqueeze(0)
    np_len = np_len.unsqueeze(0).expand(nc,-1)  # shape = [nc,b] 
    d_len = torch.cat([d1_len, d2_len, np_len], dim=0)  # shape = [nc+2, b]

    mask_idxs = swap_idxs + zero_idxs 
    return vectors, quality_vectors.unsqueeze(-1), d_len, mask_idxs   

def log2exp(logw, mask):
    w = torch.exp(logw) * mask 

    return w 

def dpp_inference(kernel, num_cw):
    # kernel.shape = [b, cw+N_c , cw+N_c]
    can_idxs = [idx for idx in range(2, kernel.size(1))]    # [[2,3,4], [2,3,4,...]]
    interlace_idxs = []                 # nested list 
    for b in range(len(num_cw)):
        if num_cw[b] == 2:
            interlace_idx = [[i,j,k] for i ,j, k  in zip([0]*len(can_idxs), [1]*len(can_idxs), can_idxs)]
        else:
            interlace_idx = [[i,j] for i ,j in zip([1]*len(can_idxs), can_idxs)] 
        interlace_idxs.append(interlace_idx)

    argmax_idxs = [] 
    for b in range(0,len(num_cw)):
        sub_kernel = torch.stack([kernel[b,interlace_idxs[b][i]][:,interlace_idxs[b][i]] for i in range(len(can_idxs))])
        argmax_idx = torch.argmax(torch.logdet(sub_kernel))
        argmax_idxs.append(argmax_idx.item())
    
    return argmax_idxs  

def synthesize(mel_output, mel_lens, vocoder, model_config, preprocess_config):
    dpp_output , vanilla_output = mel_output[0][:mel_lens[0]], mel_output[1][:mel_lens[1]] 
    dpp_output, vanilla_output = dpp_output.transpose(0,1) , vanilla_output.transpose(0,1)
    assert vocoder is not None 

    if vocoder is not None:
        from .model import vocoder_infer

        wav_vanilla = vocoder_infer(
            vanilla_output.unsqueeze(0),
            vocoder,
            model_config,
            preprocess_config,
        )[0]

        wav_dpp = vocoder_infer(
            dpp_output.unsqueeze(0),
            vocoder,
            model_config,
            preprocess_config,
        )[0]
    return wav_vanilla, wav_dpp 


if __name__ == '__main__':
    '''
    cw_list = [[1,3,7,8], [-1,-1,6,9], [1,2,-1,-1], [2,3,9,10], [-1,-1,7,8], [1,2,9,11], [2,4,6,13]]
    print("---Method 1 sanity check---")
    B = 6
    np_num = torch.randint(1,5,(B,)).cuda()
    np_id = [] 
    for i in range(B):
        id = torch.Tensor([4,5]*np_num[i]).view(-1,2)
        id = F.pad(id ,(0,0,0,max(np_num) - np_num[i])) 
        np_id.append(id)
    np_ids = torch.stack(np_id, dim=0).cuda()

    cw_ids = []
    for i in range(B):
        id = torch.Tensor(random.choice(cw_list)*np_num[i]).view(-1,4)
        id = F.pad(id, (0,0,0, max(np_num) - np_num[i]))
        cw_ids.append(id)
    cw_ids = torch.stack(cw_ids, dim=0).cuda()

    _  = get_random_phrase(np_num, np_ids, cw_ids)
    
    #print(chunk_ids)
    #print(chunk_mask)
    print("---Method 1 sanity check done!---")
    '''
    
    print("---Method2 sanity check---")
    l_dr_seq = torch.randn(4,1,6).cuda()
    r_dr_seq = torch.randn(4,1,1).cuda()
    target_d_seq = torch.randn(5,4,1,10).cuda()
    lcw_quality = torch.randn(4,1).cuda()
    rcw_quality = torch.randn(4,1).cuda()
    target_quality = torch.randn(4,5).cuda()
    lcw_len = torch.LongTensor([3,6,3,2]).cuda()
    rcw_len = torch.LongTensor([0,0,0,0]).cuda()
    np_len = torch.LongTensor([3,10,5,6]).cuda()
    num_can= 5 
    duration_vector, quality_vector, d_len, mask_idxs = dpp_collate(l_dr_seq, r_dr_seq, target_d_seq,
                lcw_quality, rcw_quality, target_quality, lcw_len, rcw_len, np_len, num_can)

    print(duration_vector.size(), quality_vector.size(), d_len.size())
    print(mask_idxs)
    print("---Method2 sanity check Done!---")
    
    '''
    print("---DPP inference sanity check---")
    a = torch.randn(6)* 2
    b = torch.randn(6)* 3
  
    kernel = torch.zeros(2,6,6).cuda()
    for i in range(6):
        for j in range(6):
            dist = torch.norm(a[i]-a[j]) ** 2
            kernel[0][i][j] = torch.exp(-0.1*dist)  

    for i in range(6):
        for j in range(6):
            dist = torch.norm(b[i]-b[j]) ** 2
            kernel[1][i][j] = torch.exp(-0.1*dist)  

    argmax_idx = dpp_inference(kernel, num_cw=[2,1])

    print(argmax_idx)

    print("--- sanity check done!---")
    '''