import torch.nn as nn 
from model.subblocks import LayerNorm, Mish
from model.blocks import DDSConv  
from utils.tools import get_mask_from_lengths 
from utils.dpp_tools import dpp_collate    
from utils.dpp_kernel_build import dpp_kernel_build  
import torch.nn.functional as F 
import torch 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PDPP_model(nn.Module):
    """DPP model for pitch"""
    def __init__(self, model_config):
        super().__init__() 
        in_channels = model_config["encoder"]["encoder_hidden"]
        filter_channels = model_config["SPP"]["filter_channels"]

        self.nc = model_config['DPP']['num_can']
        self.kernel_size = model_config["SPP"]["kernel_size"]
        self.dropout = model_config["DPP"]["dropout"]
        self.sweight = model_config["DPP"]["sweight_pitch"]

        self.pre = nn.Conv1d(in_channels, filter_channels, 1)
        self.convs = DDSConv(filter_channels, self.kernel_size, n_layers=3, p_dropout= self.dropout)
        self.proj = nn.Conv1d(filter_channels, filter_channels, 1)

        self.pre_z = nn.Conv1d(3, filter_channels, 1) 
        self.convs_z = DDSConv(filter_channels, self.kernel_size, n_layers=3, p_dropout= self.dropout)
        self.proj_z = nn.Conv1d(filter_channels, 3, 1) 
        self.norm = LayerNorm(filter_channels)
        self.norm2 = LayerNorm(filter_channels)
        self.mish = Mish()

    
    def forward(
        self,
        chunk_ids,
        np_ids,
        lcw_ids,
        rcw_ids,
        np_len,
        lcw_len,
        rcw_len,
        chunk_mask,
        np_mask,
        lcw_mask,
        rcw_mask, 
        spp,
        h_seq, 
        h_mask,
        noise_scale = 0.8
    ):
        # Get chunk hidden sequences (context + target)
        chunk_h_idxs = chunk_ids.unsqueeze(-1).expand(-1,-1,h_seq.size(-1))
        chunk_h_seq  = torch.gather(h_seq, 1, chunk_h_idxs) # shape = [b, T_c, h]
        # Cache dimensions 
        b, t, h =  h_seq.size(0), h_seq.size(1), h_seq.size(2)
        T_c = chunk_h_seq.size(1)  
        # Pitch inference (1)
        z = torch.randn(h_seq.size(0), 3, h_seq.size(1)).to(device=h_seq.device, dtype=h_seq.dtype) * noise_scale 
        pitch_seq = spp(x=h_seq, x_mask=~h_mask.unsqueeze(1), e_q=z, reverse=True)

        # Get context, chunk pitch sequences 
        chunk_p_seq = torch.gather(pitch_seq, dim=2, index=chunk_ids.unsqueeze(1))  # shape = [b, 1, T_c]
        l_p_seq = torch.gather(pitch_seq, dim=2, index=lcw_ids.unsqueeze(1))    # shape = [b, 1, T_l]
        r_p_seq = torch.gather(pitch_seq, dim=2, index=rcw_ids.unsqueeze(1))    # shape = [b, 1, T_r]

        # Just set quality scores for contexs as w (i.e. the quality of chunk) 
        quality = torch.ones(b).unsqueeze(-1).cuda() *  10.0
        lcw_quality , rcw_quality = quality, quality       # shape = [B,1]

        ##################################PDM + inference(2)#########################################
        chunk_h_seq = chunk_h_seq.repeat(self.nc, 1, 1) # shape = [b*nc, T_c, h]
        chunk_mask = chunk_mask.repeat(self.nc, 1).unsqueeze(1)
        e_q = torch.randn(b*self.nc, 3, T_c).to(device=chunk_h_seq.device, dtype=chunk_h_seq.dtype) * noise_scale

        x = chunk_h_seq.transpose(-1,-2)        # shape = [b*nc, h, T_c]
        x = self.pre(x)
        x = self.convs(x, ~chunk_mask)
        x = self.norm(x)
        x = self.proj(x) * ~chunk_mask

        z = self.pre_z(e_q)
        z = self.convs_z(z, ~chunk_mask, g=x)
        z = self.norm2(z)
        z = self.proj_z(z) * ~chunk_mask
        z = self.mish(z)

        target_p_seq = spp(x=chunk_h_seq, x_mask=~chunk_mask, e_q=z, reverse=True)
        target_p_seq = target_p_seq.view(self.nc, -1, target_p_seq.size(-2), target_p_seq.size(-1)) # shape = [nc, b, 1, T_c]
        ############################################################################################### 
        
        # Replace target seqs with new candidates(i.e., target_p_seq) 
        chunk_p_seq = chunk_p_seq.unsqueeze(0).expand(self.nc,-1,-1,-1).clone() # shape = [nc, b, 1, T_c]
        for i in range(b):
            offset, length = lcw_len[i], np_len[i]
            chunk_p_seq[:, i, :, offset:offset+length] = target_p_seq[:, i, :, offset:offset+length]
        chunk_p_seq = chunk_p_seq.view(-1, 1, T_c).squeeze(1)   # [b*nc, T_c]

        # Calculate quality scores for new pitches 
        target_quality = spp.density_estimation(x=chunk_h_seq, x_mask= ~chunk_mask, p=chunk_p_seq)
        target_quality = target_quality.view(-1, self.nc)    # shape = [b,nc]

        # Extract targets from the chunk 
        max_len = max(np_len)
        target_p_seq_ = torch.stack([F.pad( target_p_seq[:,i,:,lcw_len[i]:lcw_len[i]+np_len[i]], (0, max_len - np_len[i])) for i in range(b)], dim=1)        # shape = [nc, b, 1, T_t] 

        # Make pitch, quality vecotr ; shape = [B,2+pad_len,pad_len,1] , [B,2+pad_len]
        pitch_vector, quality_vector, d_len, mask_idxs = dpp_collate(l_p_seq, r_p_seq, target_p_seq_, lcw_quality,
                            rcw_quality, target_quality, lcw_len, rcw_len, np_len, self.nc)

        kernel, kernel_mask = dpp_kernel_build(quality_vector, pitch_vector, d_len, mask_idxs, self.sweight)      # shape = [b,7,7] 
        
        return kernel, kernel_mask, pitch_vector, pitch_seq.squeeze(1), h_seq, h_mask 