import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

class CEnc(nn.Module):
    def __init__(self, in_dim=70, cut_dim=128):
        super().__init__()

        self.in_dim = in_dim
        self.cut_dim = cut_dim
        self.linear = nn.Linear(in_dim, cut_dim)
        self.bn = nn.BatchNorm1d(cut_dim)
        self.mean_fc = nn.Linear(cut_dim, cut_dim)
        self.var_fc  = nn.Linear(cut_dim, cut_dim)
    
    def forward(self, x):

        feat = self.linear(x)           
        feat = self.bn(feat)
        feat = nn.ReLU(inplace=True)(feat)
        
        mean = self.mean_fc(feat)        
        var  = F.softplus(self.var_fc(feat)) + 1e-12        
        
        return mean, var

class CMiss(nn.Module):
    def __init__(self, in_dim, mid_dim=1024):
        super().__init__()

        self.mid_dim = mid_dim
        self.layer = nn.Linear(in_dim, mid_dim)
        self.bn = nn.BatchNorm1d(mid_dim)
        self.layer2 = nn.Linear(mid_dim, mid_dim)
        self.bn2 = nn.BatchNorm1d(mid_dim)
        self.layer3 = nn.Linear(mid_dim,1)
    def forward(self, x):
        feat = self.layer(x)
        feat = self.bn(feat)
        feat = nn.ReLU(inplace=True)(feat)
        feat = self.layer2(feat)
        feat = self.bn2(feat)
        feat = nn.ReLU(inplace=True)(feat)
        feat = self.layer3(feat)

        return nn.Sigmoid()(feat)
        
class SEnc(nn.Module):

    def __init__(self, cut_dim=128, z_dim=64, mid_dim=1024):
        super().__init__()
        self.mid_dim = mid_dim
        self.enc = nn.Sequential(
            nn.Linear(cut_dim, mid_dim),
            nn.BatchNorm1d(mid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(mid_dim, mid_dim),
            nn.BatchNorm1d(mid_dim),
            nn.ReLU(inplace=True)
        )
        self.mean_fc = nn.Linear(mid_dim, z_dim)
        self.var_fc  = nn.Linear(mid_dim, z_dim)
    
    def forward(self, x):
        out = self.enc(x)         
        mean = self.mean_fc(out)   
        var  = F.softplus(self.var_fc(out)) + 1e-12    
        return mean, var


class SDec(nn.Module):

    def __init__(self, cut_dim=128, z_dim=64, mid_dim=1024):
        super().__init__()
        self.mid_dim = mid_dim
        self.dec = nn.Sequential(
            nn.Linear(z_dim, mid_dim),
            nn.BatchNorm1d(mid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(mid_dim, mid_dim),
            nn.BatchNorm1d(mid_dim),
            nn.ReLU(inplace=True)
        )
        self.mean_fc = nn.Linear(mid_dim, cut_dim)
        self.var_fc  = nn.Linear(mid_dim, cut_dim)
    
    def forward(self, z):
        out = self.dec(z)        
        mean = self.mean_fc(out)   
        var  = F.softplus(self.var_fc(out)) + 1e-12   
        return mean, var
        
class CDec(nn.Module):

    def __init__(self, out_dim=70, cut_dim=128):
        super().__init__()
        self.out_dim = out_dim
        self.cut_dim = cut_dim
        
        self.linear = nn.Linear(cut_dim, cut_dim)
        self.bn = nn.BatchNorm1d(cut_dim)
        self.mean_fc = nn.Linear(cut_dim, out_dim)
        self.var_fc  = nn.Linear(cut_dim, out_dim)
        
    def forward(self, x):
        out = self.linear(x)                 
        out = self.bn(out)
        out = nn.ReLU(inplace=True)(out)
        
        mean_ = self.mean_fc(out) 
        var_  = F.softplus(self.var_fc(out)) + 1e-12
        
        return mean_, var_

class Disc(nn.Module):

    def __init__(self, cut_dim=128, num_clients=8, num_classes=12, agg='mean'):
        super().__init__()
        self.agg = agg
        if agg == 'concat':
            in_dim = cut_dim * num_clients
        else:
            in_dim = cut_dim
        
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.BatchNorm1d(in_dim),
            nn.ReLU(inplace=True),
            nn.Linear(in_dim, num_classes)
        )
    
    def forward(self, x):
        
        return self.net(x)


class OurModel(nn.Module):
    
    def __init__(self, dataset, num_clients, cuda_id=0, agg='mean'):
        super().__init__()

        self.dataset = dataset
        self.num_clients = num_clients
        self.agg = agg
            
        USE_CUDA = torch.cuda.is_available()
        self.DEVICE_SET = [torch.device(f"cuda:{cuda_id}" if USE_CUDA else "cpu")]
    
        print("Using Device:", self.DEVICE_SET) 
    
        if dataset == 'hapt':
            num_classes = 12
            in_channels = 70
        elif dataset == 'isolet':
            num_classes = 26
            in_channels = 77
        else:
            raise ValueError(f"Dataset {dataset} does not match resnet18 model.")

        cut_dim = get_cut_dim(dataset)
        z_dim = get_z_dim(dataset)

        self.c_enc_list = nn.ModuleList([CEnc(in_channels, cut_dim) for _ in range(num_clients)])

        self.c_miss_list = nn.ModuleList([CMiss(in_channels) for _ in range(num_clients)])

        self.s_enc = SEnc(cut_dim, z_dim)
        self.s_dec = SDec(cut_dim, z_dim)
        
        self.c_dec_list = nn.ModuleList([CDec(out_dim=in_channels, cut_dim=cut_dim)
            for _ in range(num_clients)])
        
        self.disc = Disc(cut_dim=cut_dim, num_clients=num_clients, num_classes=num_classes)

        
    def forward(self, x, y, mask, args, bs, training=True):
        loss_fn = nn.NLLLoss()
        if training:
            K = args.K
        else:
            K = args.K_test

        correct = 0.0
        device = self.DEVICE_SET[0] if len(self.DEVICE_SET)>0 else torch.device('cpu')
        
        new_x, obs_new_x = self.get_local_input(x, mask)
        new_x_masks_rep = [~torch.isnan(new_x[i]).repeat(K,1) for i in range(args.num_clients)]
        hgivenxobs_flat, qmeans_hgivenxobs, qvars_hgivenxobs = self.generate_h(obs_new_x, mask, K, args.pretrain)

        zgivenh_flat, qmeans_zgivenh, qvars_zgivenh = self.generate_z(hgivenxobs_flat)
        
        xmisgivenh_flat, pmeans_xgivenh, pvars_xgivenh = self.generate_xmis(hgivenxobs_flat, new_x_masks_rep)
        
        logp_z = self.normal_logprob(
            zgivenh_flat,
            torch.zeros([]).to(device=device),
            torch.ones([]).to(device=device),
            1
        ).reshape([K, bs])

        logq_hgivenxobs = self.normal_logprob(
            hgivenxobs_flat,
            qmeans_hgivenxobs.repeat([K,1]),
            qvars_hgivenxobs.repeat([K,1]),
            1
        ).reshape([K, bs])
        
        logq_zgivenh = self.normal_logprob(
            zgivenh_flat,
            qmeans_zgivenh,
            qvars_zgivenh,
            1
        ).reshape([K, bs])

        pmeans_xobsgivenh, pvars_xobsgivenh = [], []
        for i in range(args.num_clients):
            tmp_mean = torch.where(
                new_x_masks_rep[i],
                pmeans_xgivenh[i],
                torch.Tensor([float('nan')]).to(device)
            )
            tmp_var  = torch.where(
                new_x_masks_rep[i],
                pvars_xgivenh[i],
                torch.Tensor([float('nan')]).to(device)
            )
            pmeans_xobsgivenh.append(tmp_mean)
            pvars_xobsgivenh.append(tmp_var)

        logp_xobsgivenh_list = []
        for i in range(args.num_clients):
            expanded_xi = new_x[i].repeat([K,1]).to(device)
            val = self.normal_logprob(expanded_xi, pmeans_xobsgivenh[i], pvars_xobsgivenh[i], 1)
            val = val.reshape([K, bs])
            logp_xobsgivenh_list.append(val)
        logp_xobsgivenh = sum(logp_xobsgivenh_list)

        pmeans_hgivenz, pvars_hgivenz = self.s_dec(zgivenh_flat)
        logp_hgivenz = self.normal_logprob(
            hgivenxobs_flat, pmeans_hgivenz, pvars_hgivenz, 1
        ).reshape([K, bs])

        if args.mnar:
            logp_mgivenx_list = []
            for i in range(self.num_clients):
                val = self.c_miss_list[i](xmisgivenh_flat[i].nan_to_num() + new_x[i].repeat([K,1]).nan_to_num())
                val = self.bernoulli_logprob(mask[i].repeat(K), val.squeeze())
                val = val.reshape([K, bs])
                logp_mgivenx_list.append(val)
            logp_mgivenx = sum(logp_mgivenx_list)
            
        if args.pretrain:
            if args.mnar:
                batch_bound = -torch.sum(torch.logsumexp(
                    (logp_mgivenx + logp_xobsgivenh + logp_hgivenz + logp_z
                     - logq_hgivenxobs - logq_zgivenh),
                    dim=0))
            else:
                batch_bound = -torch.sum(torch.logsumexp(
                    (logp_xobsgivenh + logp_hgivenz + logp_z
                     - logq_hgivenxobs - logq_zgivenh),
                    dim=0))
        else:
            probs, logp_ygivenh = self.get_probs_and_logPy(hgivenxobs_flat, y, K)
            logp_ygivenh = logp_ygivenh.reshape([K, bs])
            if args.mnar:
                batch_bound = -torch.sum(torch.logsumexp(
            (logp_ygivenh + logp_mgivenx + logp_xobsgivenh + logp_hgivenz + logp_z
            - logq_hgivenxobs - logq_zgivenh),
            dim=0))
            else:
                batch_bound = -torch.sum(torch.logsumexp(
            (logp_ygivenh + logp_xobsgivenh + logp_hgivenz + logp_z
            - logq_hgivenxobs - logq_zgivenh),
            dim=0))
            if not training:
                if args.mnar:
                    r_l = (logp_mgivenx + logp_xobsgivenh + logp_hgivenz + logp_z - logq_hgivenxobs - logq_zgivenh).reshape([K, bs])
                else:
                    r_l = (logp_xobsgivenh + logp_hgivenz + logp_z - logq_hgivenxobs - logq_zgivenh).reshape([K, bs])
                w_l = torch.nn.Softmax(dim=0)(r_l)
                
                probs_ = probs.reshape([K, bs, -1]).permute(2,0,1).to(device)
                p_y = torch.sum(w_l * probs_, 1).permute(1,0)
                predicted = p_y.argmax(1)
                correct = (predicted == y).float().sum().item()

        return batch_bound, correct

    
    def unfiltered_tensor(self, tensor, mask):
        batch_size = mask.shape[0]
        device = 'cpu' if tensor.get_device()==-1 else f'cuda:{tensor.get_device()}'
        output = torch.full((batch_size, *tensor.shape[1:]), float('nan'), device=device)
        output[mask] = tensor
        return output
    
    def unfiltered_zero_tensor(self, tensor, mask):
        batch_size = mask.shape[0]
        device = 'cpu' if tensor.get_device()==-1 else f'cuda:{tensor.get_device()}'
        output = torch.zeros(batch_size, *tensor.shape[1:], device=device)
        output[mask] = tensor
        return output
    
    def generate_h(self, inputs, masks, sample_size, pretrain=False):
        dev_1 = self.DEVICE_SET[0] if len(self.DEVICE_SET)>0 else torch.device('cpu')
        partial_means = {}
        partial_vars_inv = {}
        for i in range(self.num_clients):
            if len(inputs[i]):            
                mean_i, var_i = self.c_enc_list[i](inputs[i])
                partial_means[i] = self.unfiltered_tensor(mean_i, masks[i]).to(dev_1)
                var_inv_i = (var_i**2).reciprocal()
                partial_vars_inv[i] = self.unfiltered_tensor(var_inv_i, masks[i]).to(dev_1)
                

        partial_means = torch.stack(list(partial_means.values()))
        partial_vars_inv = torch.stack(list(partial_vars_inv.values()))

        if self.agg == 'sum':
            final_mean = torch.nansum(partial_means, dim=0)
        elif self.agg == 'mean':
            final_mean = torch.nanmean(partial_means, dim=0)
        elif self.agg == 'weighted':
            weight_sum = torch.nansum(partial_vars_inv, dim=0)
            numerator = torch.nansum(partial_means * partial_vars_inv, dim=0)
            final_mean = numerator / weight_sum
        else:
            raise NotImplementedError("Wrong agg type.")

        final_var = torch.sqrt(torch.nansum(partial_vars_inv, dim=0).reciprocal())
        
        final_mean_rep = final_mean.repeat([sample_size,1])
        final_var_rep  = final_var.repeat([sample_size,1])
        hgivenxobs_flat = self.reparameterize(final_mean_rep, final_var_rep)

        return hgivenxobs_flat, final_mean, final_var 
                
    
    def generate_z(self, h):
        z_mean, z_var = self.s_enc(h)
        zgivenh_flat = self.reparameterize(z_mean, z_var)
        
        return zgivenh_flat, z_mean, z_var

    def generate_xmis(self, h, x_masks):
        xmis_dict   = {}
        pmeans_dict = {}
        pvars_dict  = {}

        for i in range(self.num_clients):
            mean_i, var_i = self.c_dec_list[i](h)

            x_i = self.reparameterize(mean_i, var_i)
            xmis_i = torch.where(~x_masks[i], x_i, torch.Tensor([torch.nan]).to(self.DEVICE_SET[0]))
            
            xmis_dict[i] = xmis_i
            pmeans_dict[i] = mean_i
            pvars_dict[i] = var_i

        return xmis_dict, pmeans_dict, pvars_dict

    def get_probs_and_logPy(self, h, y, sample_size):
        probs = F.softmax(self.disc(h), dim=1)

        return probs, self.categorical_logprob(y.repeat([sample_size]).to(self.DEVICE_SET[0]), probs)
        
    def normal_logprob(self, x, mu, var, event_dim=0):
        x_0 = x.nan_to_num(nan=0.0)
        mu_0= mu.nan_to_num(nan=0.0)
        std_0= var.nan_to_num(nan=1/torch.sqrt(2*math.pi))

        logp = -0.5*(((x_0 - mu_0)/std_0)**2) - std_0.log() -0.5*math.log(2*math.pi)
        
        for _ in range(event_dim):
            logp = logp.sum(dim=-1)

        return logp

    def bernoulli_logprob(self, mask_s, pi):
        return mask_s.float() * torch.log(pi + 1e-8) + (~mask_s).float() * torch.log(1-pi + 1e-8)
        
    def categorical_logprob(self, x, probs):
        return probs.gather(1, x.long().unsqueeze(1)).log()
        
    def reparameterize(self, mu, std):
        eps = torch.randn_like(mu)
        return mu + std*eps

    def get_local_input(self, inputs, masks):
        num_clients = masks.shape[0]
        [inputs_] = inputs

        results = []
        obs_results = []
        for i in range(num_clients):
            result = inputs_[:, i]
            result[~masks[i]] = torch.nan
            results.append(result)
            obs_results.append(result[masks[i]])
            
        return results, obs_results


def get_cut_dim(dataset):
    if dataset in ['hapt', 'isolet']:
        cut_dim = 128
    return cut_dim

def get_z_dim(dataset):
    if dataset in ['hapt', 'isolet']:
        z_dim = 64
    return z_dim