import torch.nn as nn 
from model.blocks import DDSConv    
from model.subblocks import Mish, LayerNorm 
from utils.tools import get_mask_from_lengths 
from utils.dpp_tools import dpp_collate, log2exp   
from utils.dpp_kernel_build import dpp_kernel_build  
import torch.nn.functional as F 
import torch 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DPP_model(nn.Module):
    """DPP model for duration"""
    def __init__(self, model_config):
        super().__init__() 
        in_channels = model_config["encoder"]["encoder_hidden"]
        filter_channels = model_config["SDP"]["filter_channels"]

        self.nc = model_config['DPP']['num_can']
        self.kernel_size = model_config["SDP"]["kernel_size"]
        self.dropout = model_config["DPP"]["dropout"]
        self.sweight = model_config["DPP"]["sweight_dur"]
        
        self.pre = nn.Conv1d(in_channels, filter_channels, 1)
        self.convs = DDSConv(filter_channels, self.kernel_size, n_layers=3, p_dropout= self.dropout)
        self.proj = nn.Conv1d(filter_channels, filter_channels, 1)

        self.pre_z = nn.Conv1d(2, filter_channels, 1) 
        self.convs_z = DDSConv(filter_channels, self.kernel_size, n_layers=3, p_dropout= self.dropout)
        self.proj_z = nn.Conv1d(filter_channels, 2, 1) 
        self.mish = Mish()
        self.norm = LayerNorm(filter_channels)
        self.norm2 = LayerNorm(filter_channels)
    
    def forward(
        self,
        chunk_ids,
        np_ids,
        lcw_ids,
        rcw_ids,
        np_len,
        lcw_len,
        rcw_len,
        chunk_mask,
        np_mask,
        lcw_mask,
        rcw_mask, 
        sdp,
        h_seq, 
        h_mask,
        noise_scale = 0.8
    ):
        # Get chunk hidden sequences (context + target)
        chunk_h_idxs = chunk_ids.unsqueeze(-1).expand(-1,-1,h_seq.size(-1))
        chunk_h_seq  = torch.gather(h_seq, 1, chunk_h_idxs) # shape = [b, T_c, h]
        # Cache dimensions 
        b, t, h =  h_seq.size(0), h_seq.size(1), h_seq.size(2)
        T_c = chunk_h_seq.size(1)  
        # Duration inference (1)
        z = torch.randn(h_seq.size(0), 2, h_seq.size(1)).to(device=h_seq.device, dtype=h_seq.dtype) * noise_scale 
        duration_seq = sdp(x=h_seq.transpose(-1,-2), x_mask=~h_mask.unsqueeze(1), e_q=z, reverse=True)
        duration_seq = log2exp(duration_seq, ~h_mask.unsqueeze(1))  # logw -> w 

        # Get context, chunk duration sequences 
        chunk_d_seq = torch.gather(duration_seq, dim=2, index=chunk_ids.unsqueeze(1))  # shape = [b, 1, T_c]
        l_d_seq = torch.gather(duration_seq, dim=2, index=lcw_ids.unsqueeze(1))    # shape = [b, 1, T_l]
        r_d_seq = torch.gather(duration_seq, dim=2, index=rcw_ids.unsqueeze(1))    # shape = [b, 1, T_r]

        # Just set quality scores for contexs as w (i.e. the quality of chunk) 
        quality = torch.ones(b).unsqueeze(-1).cuda() * 10.0
        lcw_quality , rcw_quality = quality, quality       # shape = [B,1]

        ##################################PDM + inference(2)#########################################
        chunk_h_seq = chunk_h_seq.repeat(self.nc, 1, 1) # shape = [b*nc, T_c, h]
        chunk_mask = chunk_mask.repeat(self.nc, 1).unsqueeze(1) # shape = [b*nc, 1, T_c]
        e_q = torch.randn(b*self.nc, 2, T_c).to(device=chunk_h_seq.device, dtype=chunk_h_seq.dtype) * noise_scale

        x = chunk_h_seq.transpose(-1,-2)        # shape = [b*nc, h, T_c]
        x = self.pre(x)
        x = self.convs(x, ~chunk_mask)
        x = self.norm(x)
        x = self.proj(x) * ~chunk_mask

        z = self.pre_z(e_q)
        z = self.convs_z(z, ~chunk_mask, g=x)
        z = self.norm2(z)
        z = self.proj_z(z) * ~chunk_mask
        z = self.mish(z)

        target_d_seq = sdp(x=chunk_h_seq.transpose(-1,-2), x_mask=~chunk_mask, e_q=z, reverse=True)
        target_d_seq = log2exp(target_d_seq, ~chunk_mask)   # logw -> w 
        target_d_seq = target_d_seq.view(self.nc, -1, target_d_seq.size(-2), target_d_seq.size(-1)) # shape = [nc, b, 1, T_c]
        ############################################################################################### 
        
        # Replace target seqs with new candidates(i.e., target_p_seq) 
        chunk_d_seq = chunk_d_seq.unsqueeze(0).expand(self.nc,-1,-1,-1).clone() # shape = [nc, b, 1, T_c]
        for i in range(b):
            offset, length = lcw_len[i], np_len[i]
            chunk_d_seq[:, i, :, offset:offset+length] = target_d_seq[:, i, :, offset:offset+length]
        chunk_d_seq = chunk_d_seq.view(-1, 1, T_c).squeeze(1)   # [b*nc, T_c]

        # Calculate quality scores for new durations 
        target_quality = sdp.density_estimation(x=chunk_h_seq.transpose(-1,-2),
                         x_mask= ~chunk_mask, logw=chunk_d_seq)
        target_quality = target_quality.view(-1, self.nc)    # shape = [b,nc]

        # Extract targets from the chunk 
        max_len = max(np_len)
        target_d_seq_ = torch.stack([F.pad( target_d_seq[:,i,:,lcw_len[i]:lcw_len[i]+np_len[i]], 
                (0, max_len - np_len[i])) for i in range(b)], dim=1)        # shape = [nc, b, 1, T_t] 

        # Make duration, quality vecotr ; shape = [b,2+nc,t,1] , [B,2+nc]
        duration_vector, quality_vector, d_len, mask_idxs = dpp_collate(l_d_seq, r_d_seq, target_d_seq_, lcw_quality,
                            rcw_quality, target_quality, lcw_len, rcw_len, np_len, self.nc)
        
        kernel, kernel_mask = dpp_kernel_build(quality_vector, duration_vector, d_len, mask_idxs, self.sweight)      # shape = [b,2+nc,2_nc] 

        return kernel, kernel_mask, duration_vector, duration_seq.squeeze(1), h_seq, h_mask 