import os
import json
import copy
import math
from collections import OrderedDict

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from utils.tools import get_mask_from_lengths, pad, quality_compute, quality_compute_pitch 
from utils.dpp_tools import log2exp
from .blocks import DN_block, DDSConv, Nystromer
from .subblocks import Mish
from model import flows 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class VarianceAdaptor(nn.Module):
    """Variance Adaptor"""

    def __init__(self, preprocess_config, model_config):
        super(VarianceAdaptor, self).__init__()
        self.length_regulator = LengthRegulator()
        self.dp = SDP(model_config)
        self.pp = SPP(model_config)
        self.ep = EnergyPredictor(model_config)
  
        self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
            "feature"
        ]
        self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
            "feature"
        ]

        assert self.pitch_feature_level == "phoneme_level" 
        assert self.energy_feature_level == "phoneme_level" 

        pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
        energy_quantization = model_config["variance_embedding"]["energy_quantization"]
        n_bins = model_config["variance_embedding"]["n_bins"]
        assert pitch_quantization in ["linear", "log"]
        assert energy_quantization in ["linear", "log"]
        self.noise_scale = model_config['variance_predictor']['noise_scale']
        with open(
            os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
        ) as f:
            stats = json.load(f)
            pitch_min, pitch_max = stats["pitch"][:2]
            energy_min, energy_max = stats["energy"][:2]

        if pitch_quantization == "log":
            self.pitch_bins = nn.Parameter(
                torch.exp(
                    torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
                ),
                requires_grad=False,
            )
        else:
            self.pitch_bins = nn.Parameter(
                torch.linspace(pitch_min, pitch_max, n_bins - 1),
                requires_grad=False,
            )
        if energy_quantization == "log":
            self.energy_bins = nn.Parameter(
                torch.exp(
                    torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
                ),
                requires_grad=False,
            )
        else:
            self.energy_bins = nn.Parameter(
                torch.linspace(energy_min, energy_max, n_bins - 1),
                requires_grad=False,
            )

        self.pitch_embedding = nn.Embedding(
            n_bins, model_config["encoder"]["encoder_hidden"]
        )
        self.energy_embedding = nn.Embedding(
            n_bins, model_config["encoder"]["encoder_hidden"]
        )

    def get_pitch_embedding(self, x, target, mask, control=1.0):
        # prediction.shape = [B] at training, [B,1,T] as inference. 
        if target is not None:
            e_q = torch.randn(x.size(0), 2, x.size(1)).to(device=x.device, dtype=x.dtype)
            prediction = self.pp(x, ~mask, p=target, e_q=e_q)
            prediction = prediction / torch.sum(~mask)
            embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
        else:
            e_q = torch.randn(x.size(0), 3, x.size(1)).to(device=x.device, dtype=x.dtype) * self.noise_scale
            pitch_prediction = self.pp(x, ~mask, e_q=e_q, reverse=True)
            prediction = pitch_prediction * control
            embedding = self.pitch_embedding(
                torch.bucketize(prediction, self.pitch_bins)
            ).squeeze(1)
        return prediction, embedding

    def get_energy_embedding(self, x, target, mask, control):
        prediction = self.ep(x, mask)
        if target is not None:
            embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
        else:
            prediction = prediction * control
            embedding = self.energy_embedding(
                torch.bucketize(prediction, self.energy_bins)
            )
        return prediction, embedding

    def forward(
        self,
        x, 
        src_mask,
        mel_mask=None,
        max_len=None,
        pitch_target=None,
        energy_target=None,
        duration_target=None,
        p_control=1.0,
        e_control=1.0,
        d_control=1.0,
    ):

        x_mask = src_mask.unsqueeze(1)

        pitch_prediction, pitch_embedding = self.get_pitch_embedding(x, pitch_target, x_mask, p_control)

        x = x + pitch_embedding
        energy_prediction, energy_embedding = self.get_energy_embedding(
            x, energy_target, src_mask, e_control
        )
        x = x + energy_embedding 

        if duration_target is not None:
            w = duration_target.float()
            duration_prediction = self.dp(x.transpose(-1,-2), ~x_mask,  w, g=None)
            duration_prediction = duration_prediction / torch.sum(~x_mask)
        else:
            logw = self.dp(x.transpose(-1,-2), ~x_mask, reverse=True, noise_scale=self.noise_scale)
            w = torch.exp(logw) * ~x_mask * d_control
            w_ceil = torch.ceil(w) 
            duration_prediction, duration_target = w_ceil, w_ceil   

        output, mel_len = self.length_regulator(x, duration_target, max_len)
        
        if mel_mask is None:
            mel_mask = get_mask_from_lengths(mel_len)
      
        return (
            output,
            pitch_prediction,               # nll at training 
            energy_prediction, 
            duration_prediction,            # nll at training    
            duration_target,               
            mel_len,
            mel_mask,
        )

    def inference(
        self,
        x,
        src_mask,
        duration=None,
        pitch=None,
        p_control=1.0,
        e_control=1.0,
        d_control=1.0,
    ):
        x_mask = src_mask.unsqueeze(1) 

        if pitch is not None:
            pitch_embedding = self.pitch_embedding(torch.bucketize(pitch*p_control, self.pitch_bins))
        else:
            _, pitch_embedding = self.get_pitch_embedding(x, pitch, x_mask, p_control)
            pitch_embedding[1] = pitch_embedding[0].clone()

        x = x + pitch_embedding
        energy_prediction, energy_embedding = self.get_energy_embedding(
            x, None, src_mask, e_control)
        x = x + energy_embedding   

        if duration is not None:
            w = torch.exp(duration) * ~src_mask * d_control
            duration_rounded = torch.ceil(w)
        else:
            logw = self.dp(x.transpose(-1,-2), ~x_mask, reverse=True, noise_scale=0.8).squeeze(1)
            w = torch.exp(logw) * ~src_mask * d_control 
            duration_rounded = torch.ceil(w)
            duration_rounded[1] = duration_rounded[0].clone()

        x, mel_lens = self.length_regulator(x, duration_rounded, max_len=None)
        mel_masks = get_mask_from_lengths(mel_lens)

        return (
            x,
            pitch,
            energy_prediction,
            duration_rounded,        
            mel_lens,
            mel_masks,
        )

    def inference2(self, x, src_mask, pitch=None, p_control=1.0, e_control=1.0):
        x_mask = src_mask.unsqueeze(1) 

        _, pitch_embedding = self.get_pitch_embedding(x, pitch, x_mask, p_control)  

        x = x + pitch_embedding 

        _, energy_embedding = self.get_energy_embedding(
            x, None, src_mask, e_control)
        x = x + energy_embedding   

        return x 

    def expand(self, x, duration, mask):
        x, mel_lens = self.length_regulator(x, duration, max_len=None)
        mel_masks = get_mask_from_lengths(mel_lens)
        return x, mel_lens, mel_masks 

class LengthRegulator(nn.Module):
    """Length Regulator"""

    def __init__(self):
        super(LengthRegulator, self).__init__()

    def LR(self, x, duration, max_len):
        output = list()
        mel_len = list()

        for batch, expand_target in zip(x, duration):
            expanded = self.expand(batch, expand_target)
            output.append(expanded)
            mel_len.append(expanded.shape[0])

        if max_len is not None:
            output = pad(output, max_len)
        else:
            output = pad(output)

        return output, torch.LongTensor(mel_len).to(device)

    def expand(self, batch, predicted):
        out = list()
        if predicted.size(0) == 1:
            predicted = predicted.squeeze(0)  
        
        for i, vec in enumerate(batch):
            expand_size = predicted[i].item()
            out.append(vec.expand(max(int(expand_size), 0), -1))
        out = torch.cat(out, 0)

        return out

    def forward(self, x, duration, max_len):
        output, mel_len = self.LR(x, duration, max_len)
        return output, mel_len

class EnergyPredictor(nn.Module):
    def __init__(self, model_config):
        super().__init__()
        filter_size = model_config["variance_predictor"]["filter_size"]
        self.blocks = nn.ModuleList()
        for _ in range(1):
            self.blocks.append(DN_block(model_config))
        self.proj = nn.Linear(filter_size, 1)

    def forward(self, x, x_mask):
        x = x.detach() 
        x , x_mask_ = x.transpose(-1,-2), ~x_mask.unsqueeze(1)  
        for block in self.blocks:
            x = block(x, x_mask_) 

        x = x.transpose(-1,-2)
        out = self.proj(x) 
        out = out.squeeze(-1)
        out = out.masked_fill(x_mask, 0.0)

        return out 

class SDP(nn.Module):
    """Stochastic Duration Predictior""" 
    def __init__(self,model_config):
        super(SDP,self).__init__()
        in_channels = model_config["encoder"]["encoder_hidden"]
        filter_channels = model_config["SDP"]["filter_channels"]            
        self.kernel_size = model_config["SDP"]["kernel_size"]
        self.dropout = model_config["SDP"]["dropout"]
        self.n_flows = model_config["SDP"]["n_flows"]
        self.gin_channels = model_config["SDP"]["gin_channels"]
        self.density_sample = model_config["SDP"]["density_sample"]
        self.num_can = model_config["DPP"]["num_can"]
        self.thresh = model_config["DPP"]["duration_threshold"]
        self.kappa = model_config["DPP"]["kappa"]

        self.log_flow = flows.Log()
        self.flows = nn.ModuleList()                # [d-u,v] -> noise 
        self.flows.append(flows.ElementwiseAffine(2))
        for i in range(self.n_flows):
            self.flows.append(flows.ConvFlow(2, filter_channels, self.kernel_size, n_layers=3, attention=True))
            self.flows.append(flows.Flip())

        self.post_pre = nn.Conv1d(1, filter_channels, 1)
        self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
        self.post_convs = DDSConv(filter_channels, self.kernel_size, n_layers=3, p_dropout=self.dropout)
        self.post_flows = nn.ModuleList()           # noise -> [u,v]
        self.post_flows.append(flows.ElementwiseAffine(2))
        for i in range(4):
            self.post_flows.append(flows.ConvFlow(2, filter_channels, self.kernel_size, n_layers=3, attention=False))
            self.post_flows.append(flows.Flip())

        self.pre = nn.Conv1d(in_channels, filter_channels, 1)
        self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
        self.convs = DDSConv(filter_channels, self.kernel_size, n_layers=3, p_dropout = self.dropout)
        
        if self.gin_channels != 0:
            self.cond = nn.Conv1d(self.gin_channels, filter_channels, 1)

    def forward(self, x, x_mask, w=None, e_q=None, g=None, reverse=False,noise_scale=1.0):
        # x = h_text, w = duration, g=speaker
        x = torch.detach(x)
        x = self.pre(x) 
        if g is not None:
            g = torch.detach(g)
            x  = x + self.cond(g) 
        x = self.convs(x, x_mask)
        x = self.proj(x) * x_mask 

        if e_q is None:
            e_q = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype).cuda() * x_mask 
        
        if not reverse:
            flows = self.flows 
            assert w is not None 

            w = w.unsqueeze(1)
            h_w = self.post_pre(w)              # Duration preprocessing  
            h_w = self.post_convs(h_w, x_mask) 
            h_w = self.post_proj(h_w) * x_mask 
            
            z_q = e_q * x_mask 
            logdet_tot_q  = 0
            for flow in self.post_flows:
                z_q, logdet_q = flow(z_q, x_mask, g=(x+h_w))
                logdet_tot_q += logdet_q 

            z_u, z1 = torch.split(z_q, [1,1], 1)     # [u,v] 
            u = torch.sigmoid(z_u) * x_mask 
            z0 = (w-u) * x_mask 
            logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask , [1,2])
            logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 

            logdet_tot = 0 
            z0, logdet = self.log_flow(z0, x_mask)  
            logdet_tot += logdet 
            z = torch.cat([z0,z1], 1) 
            
            for flow in flows:
                z, logdet = flow(z, x_mask, g=x, reverse=reverse)
                logdet_tot = logdet_tot + logdet 
            nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 
            return nll + logq   # lower bound of duration predictor 
        else:
            flows = list(reversed(self.flows))
            flows = flows[:-2] + [flows[-1]]  
            z = e_q * x_mask 
            for flow in flows:
                z = flow(z, x_mask, g=x, reverse=reverse)
            z0, z1 = torch.split(z, [1,1], 1)
            logw = z0 
            return logw    

    def density_estimation(self, x, x_mask, logw=None, g=None):
        """x.shape = [b*nc,C,T], w.shape=[b*nc,T], x_mask.shape=[b*nc,1,T], output.shape=[b*nc] """ 
        w = torch.ceil(logw) * x_mask.squeeze(1)
        x_, x_mask_, = x.repeat(self.density_sample,1,1), x_mask.repeat(self.density_sample,1,1)
        z = torch.randn(x_.size(0), 2, x_.size(2)).cuda()
        w_ = w.repeat(self.density_sample,1)    
        log_likelihood = -self.forward(x_, x_mask_, w=w_, e_q=z) / torch.sum(x_mask_, [1,2])
        quality =  torch.sum(log_likelihood.view(self.density_sample, -1), dim=0) / self.density_sample
        quality = quality_compute(quality, threshold= self.thresh , kappa = self.kappa)

        return quality


class SPP(nn.Module):
    """Stochastic Pitch Predictior""" 
    def __init__(self,model_config):
        super(SPP,self).__init__()
        in_channels = model_config["encoder"]["encoder_hidden"]
        filter_channels = model_config["SPP"]["filter_channels"]            
        self.kernel_size = model_config["SPP"]["kernel_size"]
        self.dropout = model_config["SPP"]["dropout"]
        self.n_flows = model_config["SPP"]["n_flows"]
        self.gin_channels = model_config["SPP"]["gin_channels"]
        self.density_sample = model_config["SPP"]["density_sample"]
        self.num_can = model_config["DPP"]["num_can"]
        self.thresh = model_config["DPP"]["pitch_threshold"]
        self.kappa = model_config["DPP"]["kappa"]

        self.log_flow = flows.Log()
        self.flows = nn.ModuleList()                # [p,v] -> noise 
        self.flows.append(flows.ElementwiseAffine(3))
        for i in range(self.n_flows):
            self.flows.append(flows.ConvFlow(3, filter_channels, self.kernel_size, n_layers=3, attention=True))
            self.flows.append(flows.Flip())

        self.post_pre = nn.Conv1d(1, filter_channels, 1)
        self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
        self.post_convs = DDSConv(filter_channels, self.kernel_size, n_layers=3, p_dropout=self.dropout)
        self.post_flows = nn.ModuleList()           # noise -> [v]
        self.post_flows.append(flows.ElementwiseAffine(2))
        for i in range(4):
            self.post_flows.append(flows.ConvFlow(2, filter_channels, self.kernel_size, n_layers=3, attention=False))
            self.post_flows.append(flows.Flip())

        self.pre = nn.Conv1d(in_channels, filter_channels, 1)
        self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
        self.convs = DDSConv(filter_channels, self.kernel_size, n_layers=3, p_dropout = self.dropout)

        if self.gin_channels != 0:
            self.cond = nn.Conv1d(self.gin_channels, filter_channels, 1)

    def forward(self, x, x_mask, p=None, g=None, e_q=None, reverse=False,noise_scale=1.0):
        # x = h_text, w = duration(B,1,T), p=pitch(B,T), g=speaker
        x  = x.detach().transpose(-1,-2) # [B,C,T]
        x = self.pre(x) 
        if g is not None:
            g = torch.detach(g)
            x  = x + self.cond(g) 

        x = self.convs(x, x_mask)
        x = self.proj(x) * x_mask 

        if not reverse:
            flows = self.flows 
            assert p is not None 

            p = p.unsqueeze(1)
            h_p = self.post_pre(p)              # Pitch preprocessing  
            h_p = self.post_convs(h_p, x_mask) 
            h_p = self.post_proj(h_p) * x_mask 

            e_q = e_q * x_mask 
            z_q = e_q 
            logdet_tot_q  = 0
            for flow in self.post_flows:
                z_q, logdet_q = flow(z_q, x_mask, g=(x+h_p))
                logdet_tot_q += logdet_q 

            logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 

            logdet_tot = 0 
            z = torch.cat([p,z_q], 1) 
            
            for flow in flows:
                z, logdet = flow(z, x_mask, g=x, reverse=reverse)
                logdet_tot = logdet_tot + logdet 
            nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 
            return nll + logq   # lower bound of duration predictor 
        else:
            flows = list(reversed(self.flows))
            flows = flows[:-2] + [flows[-1]]  
            z = e_q * x_mask 
            for flow in flows:
                z = flow(z, x_mask, g=x, reverse=reverse)
            z0, z1 = torch.split(z, [1,2], 1)
            p = z0 
            return p    

    def density_estimation(self, x, x_mask, p=None, g=None):
        """x.shape = [B*nc,T,C], w.shape=[B*nc,T], p.shape=[B*nc,T], x_mask.shape=[B*nc,1,T], output.shape=[B*12]""" 
        x_, x_mask_, p_ = x.repeat(self.density_sample,1,1), x_mask.repeat(self.density_sample,1,1), p.repeat(self.density_sample, 1)   
        z = torch.randn(x_.size(0), 2, x_.size(1)).cuda()
        log_likelihood = -self.forward(x_, x_mask_, p=p_, e_q=z) / torch.sum(x_mask_, [1,2])
        quality =  torch.sum(log_likelihood.view(self.density_sample, -1), dim=0) / self.density_sample
        quality = quality_compute_pitch(quality, threshold= self.thresh , kappa = self.kappa)
        
        return quality
