import torch
import numpy as np
import torch.nn as nn
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import Normal,Independent,kl_divergence
import torch.nn.functional as F
import os
import math

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size,
                                  stride, padding, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class PreActivationBlock(nn.Module):
    def __init__(self, channels,bias=False):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv1 = DepthwiseSeparableConv(channels, channels, 3, padding=1, bias=bias)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv2 = DepthwiseSeparableConv(channels, channels, 3, padding=1, bias=bias)

    def forward(self, x):
        identity = x

        # 第一个卷积块
        out = self.bn1(x)
        out = F.relu(out)
        out = self.conv1(out)

        # 第二个卷积块
        out = self.bn2(out)
        out = F.relu(out)
        out = self.conv2(out)

        return out + identity

class ResNet(nn.Module):
    """ResNet with 9 pre-activation blocks"""
    def __init__(self, in_channels, hidden_channels=256,bias=False):
        super().__init__()
        self.initial_conv = DepthwiseSeparableConv(in_channels, hidden_channels, 3, padding=1, bias=bias)

        # 9个残差块
        self.blocks = nn.ModuleList([
            PreActivationBlock(hidden_channels,bias=bias) for _ in range(9)
        ])

    def forward(self, x):
        x = self.initial_conv(x)
        for block in self.blocks:
            x = block(x)
        return x
class conv_encoder(nn.Module):
    def __init__(self,img_channels=1,latent_channels=128,hidden_channels=128):
        super().__init__()
        self.img_channels = img_channels
        self.latent_channels = latent_channels
        self.hidden_channels = hidden_channels
        self.conv_theta = self.create_abs_conv(img_channels+1, hidden_channels, kernel_size=11, padding=5)
        self.pre_latent_cnn = ResNet(hidden_channels, hidden_channels)
        self.latent_mlp = nn.Conv2d(hidden_channels, latent_channels * 2, kernel_size=1)
    def create_abs_conv(self, in_channels, out_channels, kernel_size, padding):
        conv = DepthwiseSeparableConv(in_channels, out_channels, kernel_size,
                                     padding=padding, bias=False)
        return conv
    def forward(self,y,mask):
        y_c = y*mask
        y_c = torch.cat([y_c,mask],dim=1)
        with torch.no_grad():
            self.conv_theta.depthwise.weight.data = torch.abs(self.conv_theta.depthwise.weight.data)
            self.conv_theta.pointwise.weight.data = torch.abs(self.conv_theta.pointwise.weight.data)
        x = self.conv_theta(y_c)
        x = self.pre_latent_cnn(x)
        latent_params = self.latent_mlp(x).permute(0,2,3,1)
        mu, log_var = torch.chunk(latent_params, 2, dim=-1)
        return mu,log_var

class conv_decoder(nn.Module):
    def __init__(self,img_channels=1,latent_channels=128,hidden_channels=128,is_global=True):
        super().__init__()
        self.is_global = is_global
        self.img_channels = img_channels
        self.hidden_channels = hidden_channels
        self.latent_channels = latent_channels
        self.post_latent_cnn = ResNet(latent_channels, hidden_channels)
        self.output_layer = nn.Conv2d(hidden_channels, img_channels * 2, 1)
    def forward(self,z):
        if self.is_global:
            z_local = z[:, :64]
            z_global = z[:, 64:128]
            z_global = z_global.mean(dim=(-2,-1), keepdim=True).expand_as(z_global)
            z = torch.cat([z_local, z_global], dim=1)
        z_processed = self.post_latent_cnn(z)
        output = self.output_layer(z_processed).permute(0,2,3,1).view(z.size(0),-1,2*self.img_channels)
        return output
class conv_net(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim
        self.z_dim = args.z_conv
        self.h1 = args.hidden_dim1
        self.h2 = args.hidden_dim2
        self.type = args.type
        self.num_particles = args.num_particles
        self.cp_samples = args.cp_samples
        self.alpha = args.alpha
        self.eta = args.eta
        self.eval_mode = args.eval_mode
        self.is_global = args.is_global
        self.his_nll = None
        self.delta = None
        self.w_z = 0.
        self.boxes = args.boxes
        self.encoder = conv_encoder(img_channels=self.y_dim, latent_channels=self.z_dim, hidden_channels=self.h1)
        self.decoder = conv_decoder(img_channels=self.y_dim, latent_channels=self.z_dim, hidden_channels=self.h2,is_global=self.is_global)
        self._initialize_weights()
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    def reparameterization(self,mu,logvar):
        std=(1-1e-3)*F.sigmoid(-logvar)+1e-3
        eps=torch.randn_like(std)
        return (mu+eps*std).permute(0,3,1,2)
    def prior_encoder(self,y,mask):
        mu, logvar = self.encoder(y,mask)
        std=(1-1e-3)*F.sigmoid(-logvar)+1e-3
        prior_dist= Independent(Normal(mu, std), 1)
        return prior_dist, mu, logvar
    def prior_sampler(self,prior_dist,num_particles):
        sampled_batch_z=prior_dist.rsample((num_particles,))
        batch_z=sampled_batch_z.permute((1,0,4,2,3))
        return batch_z
    def cond_evidence_loglike_single(self,batch_z,y_t):
        batch_z_exp = batch_z.reshape(-1,*batch_z.size()[-3:])
        output = self.decoder(batch_z_exp)
        output = output.view(batch_z.size()[0],-1,*output.size()[1:])
        mu_tensor, sigma_tensor = output[...,:self.y_dim],0.01+0.99*F.softplus(output[...,self.y_dim:])
        mu_tensor = torch.sigmoid(mu_tensor)
        decoder_dist=MultivariateNormal(mu_tensor,scale_tril=sigma_tensor.diag_embed())
        y_t0 = y_t.view(batch_z.size()[0],y_t.size()[1],-1).permute(0,2,1)
        y_t_exp=y_t0.unsqueeze(1).expand(-1,batch_z.size()[1],-1,-1)
        log_y_t=decoder_dist.log_prob(y_t_exp)
        return decoder_dist, log_y_t.permute(0,2,1)
    def update_weights(self,group_weights,group_losses):
        group_losses_re = self.eta*group_losses
        max_vals = group_losses_re.max(dim=1, keepdim=True)[0]
        group_weights *= torch.exp(group_losses_re-max_vals)
        group_weights /= torch.sum(group_weights,dim=-1,keepdim=True)
        return group_weights
    def forward(self,y,mask):
        prior_dist,mu,logvar = self.prior_encoder(y,mask)
        std=(1-1e-3)*F.sigmoid(-logvar)+1e-3
        if self.type == 'IWNPs':
            batch_z = self.prior_sampler(prior_dist,self.num_particles)
            _,log_y_t = self.cond_evidence_loglike_single(batch_z,y)
            log_y_t = torch.sum(log_y_t,dim=-2)
            log_y_t = torch.logsumexp(log_y_t,dim=-1)-torch.as_tensor(np.log(self.num_particles))
            return -log_y_t
        if self.type == 'TDRO-NPs':
            T=10
            batch_z = self.prior_sampler(prior_dist,self.num_particles)
            _,log_y_t = self.cond_evidence_loglike_single(batch_z,y)
            log_y_t = torch.sum(log_y_t,dim=-2)
            log_y_t = T*torch.logsumexp(log_y_t/T,dim=-1)-torch.as_tensor(np.log(self.num_particles))
            return -log_y_t
        elif self.type == 'CVaR-NPs':
            batch_z = self.prior_sampler(prior_dist,int(self.num_particles/(1-self.alpha)))
            _,log_y_t = self.cond_evidence_loglike_single(batch_z,y)
            log_y_t = torch.sum(log_y_t,dim=-2)
            log_y_t,_ = torch.topk(log_y_t,k=self.num_particles,dim=-1,largest=False,sorted=False)
            log_y_t = torch.logsumexp(log_y_t,dim=-1)-torch.as_tensor(np.log(self.num_particles))
            return -log_y_t
        elif self.type == 'GDRO-NPs':
            batch_z = self.prior_sampler(prior_dist,self.num_particles)
            _,log_y_t = self.cond_evidence_loglike_single(batch_z,y)
            log_y_t = torch.sum(log_y_t,dim=-2)
            group_weights = torch.ones_like(log_y_t)
            group_weights = self.update_weights(group_weights,-log_y_t)

            log_y_t = torch.logsumexp(log_y_t+torch.log(group_weights).detach(),dim=-1)
            return -log_y_t
        elif self.type == 'OS-NPs':
            batch_z = self.prior_sampler(prior_dist,self.num_particles)
            _,log_y_t = self.cond_evidence_loglike_single(batch_z,y)
            log_y_t = torch.sum(log_y_t,dim=-2)
            n = log_y_t.size(1)//self.boxes
            h = self.boxes
            log_m = torch.zeros(size=(log_y_t.size(0),h)).cuda()
            log_y_t = torch.sort(log_y_t.squeeze(),dim=1)[0]
            for i in range(h):
                log_m[:,i:i+1] = torch.mean(log_y_t[:,i*n:(i+1)*n].view(-1,n),dim=1).view(-1,1)
            clnp = torch.logsumexp(log_y_t,dim=-1)-torch.as_tensor(np.log(self.num_particles))
            return -clnp,log_m
        else:
            raise ValueError('Invalid argument for self.type')
    def complete(self,y,mask,z_samples=1):
        with torch.no_grad():
            test_enc_dist, mu, logvar = self.prior_encoder(y,mask)
            mask = torch.ones_like(mask)
            if z_samples == 1:
                batch_z = self.reparameterization(mu, logvar).unsqueeze(1)
                batch_z_exp = batch_z.reshape(-1,*batch_z.size()[-3:])
                output = self.decoder(batch_z_exp)
                output = output.view(batch_z.size()[0],-1,*output.size()[1:])
                mu_tensor, sigma_tensor = output[...,:self.y_dim],0.01+0.99*F.softplus(output[...,self.y_dim:])
                mu_tensor = torch.sigmoid(mu_tensor)
            else:
                sampled_batch_z=test_enc_dist.rsample((z_samples,))
                batch_z=sampled_batch_z.permute((1,0,2,3,4))
                batch_z_exp = batch_z.reshape(-1,*batch_z.size()[-3:])
                output = self.decoder(batch_z_exp)
                output = output.view(batch_z.size()[0],-1,*output.size()[1:])
                mu_tensor, sigma_tensor = output[...,:self.y_dim],0.01+0.99*F.softplus(output[...,self.y_dim:])
                mu_tensor = torch.sigmoid(mu_tensor)
            return mu_tensor, sigma_tensor
    def conditional_predict(self,y,mask,mask_o=None):
        num_samples = self.cp_samples
        if mask_o is None:
            mask_o = torch.ones_like(mask)
        with torch.no_grad():
            test_enc_dist, mu, logvar = self.prior_encoder(y,mask)
            if num_samples==1:
                batch_z = self.reparameterization(mu, logvar)
                batch_z_exp = batch_z.reshape(-1,*batch_z.size()[-3:]).permute(0,3,1,2)
                output = self.decoder(batch_z_exp)
                output = output.view(batch_z.size()[0],-1,*output.size()[1:])
                mu_tensor, sigma_tensor = output[...,:self.y_dim],0.01+0.99*F.softplus(output[...,self.y_dim:])
                mu_tensor = torch.sigmoid(mu_tensor)
                decoder_dist=MultivariateNormal(mu_tensor,scale_tril=sigma_tensor.diag_embed())
                y0 = y.view(batch_z.size()[0],y.size()[1],-1).permute(0,2,1).unsqueeze(1)
                log_y_pred=decoder_dist.log_prob(y0)
                mask_o = mask_o.view(batch_z.size()[0],1,-1)
                log_y_pred = mask_o*log_y_pred
                sum_log_y_pred=log_y_pred.sum(-1)
                b_nll=-sum_log_y_pred.mean()
                y_mean=mu_tensor
            else:
                if self.eval_mode == 'average':
                    sampled_batch_z=test_enc_dist.rsample((num_samples,))
                    batch_z=sampled_batch_z.permute((1,0,4,2,3))
                    batch_z_exp = batch_z.reshape(-1,*batch_z.size()[-3:])
                    output = self.decoder(batch_z_exp)
                    output = output.view(batch_z.size()[0],-1,*output.size()[1:])
                    mu_tensor, sigma_tensor = output[...,:self.y_dim],0.01+0.99*F.softplus(output[...,self.y_dim:])
                    mu_tensor = torch.sigmoid(mu_tensor)
                    decoder_dist=MultivariateNormal(mu_tensor,scale_tril=sigma_tensor.diag_embed())
                    y0 = y.view(batch_z.size()[0],y.size()[1],-1).permute(0,2,1)
                    y_pred_exp=y0.unsqueeze(1).expand(-1,num_samples,-1,-1)
                    log_y_pred=decoder_dist.log_prob(y_pred_exp)
                    mask_o = mask_o.view(batch_z.size()[0],1,-1).expand(-1,log_y_pred.size()[1],-1)
                    log_y_pred = mask_o*log_y_pred
                    sum_log_y_pred=log_y_pred.sum(-1)
                    b_nll=-(torch.logsumexp(sum_log_y_pred, dim=-1).mean()-torch.as_tensor(np.log(num_samples)).cuda())
                    y_mean=mu_tensor.mean(1)
                elif self.eval_mode == 'CVAR':
                    sampled_batch_z=test_enc_dist.rsample((int(num_samples/(1-self.alpha)),))
                    batch_z=sampled_batch_z.permute((1,0,4,2,3))
                    batch_z_exp = batch_z.reshape(-1,*batch_z.size()[-3:])
                    output = self.decoder(batch_z_exp)
                    output = output.view(batch_z.size()[0],-1,*output.size()[1:])
                    mu_tensor, sigma_tensor = output[...,:self.y_dim],0.01+0.99*F.softplus(output[...,self.y_dim:])
                    mu_tensor = torch.sigmoid(mu_tensor)
                    decoder_dist=MultivariateNormal(mu_tensor,scale_tril=sigma_tensor.diag_embed())
                    y0 = y.view(batch_z.size()[0],y.size()[1],-1).permute(0,2,1)
                    y_pred_exp=y0.unsqueeze(1).expand(-1,batch_z.size()[1],-1,-1)
                    log_y_pred=decoder_dist.log_prob(y_pred_exp)
                    mask_o = mask_o.view(batch_z.size()[0],1,-1).expand(-1,log_y_pred.size()[1],-1)
                    log_y_pred = mask_o*log_y_pred
                    log_y_pred=torch.sum(log_y_pred,dim=-1)
                    log_y_pred,_ = torch.topk(log_y_pred,k=num_samples,dim=-1,largest=False,sorted=False)
                    log_y_pred = (torch.logsumexp(log_y_pred,dim=-1).mean()-torch.as_tensor(np.log(num_samples)).cuda())
                    b_nll = -log_y_pred
                    y_mean=mu_tensor.mean(1)
        return mu, logvar, b_nll, y_mean
