import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
from torchvision.models.resnet import BasicBlock

def resnet18_backbone(in_channels=1):
    net = torchvision.models.resnet18(weights=None)
    net.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=0, bias=False)
    net.bn1 = nn.BatchNorm2d(32)
    net.maxpool = nn.Identity()
    net.inplanes = 32
    net.layer1 = net._make_layer(BasicBlock, 32, 2)
    net.layer2 = net._make_layer(BasicBlock, 64, 2, stride=2)
    net.layer3 = net._make_layer(BasicBlock, 128, 2, stride=2)
    net.layer4 = net._make_layer(BasicBlock, 256, 2, stride=2)

    backbone = nn.Sequential(*list(net.children())[:-1])
    return backbone
    
class CEnc(nn.Module):
    def __init__(self, in_channels=1, cut_dim=196):
        super().__init__()
        self.backbone = resnet18_backbone(in_channels)

        self.mean_fc = nn.Linear(256 * BasicBlock.expansion, cut_dim)
        self.var_fc  = nn.Linear(256 * BasicBlock.expansion, cut_dim)
    
    def forward(self, x):
        B, C, _, _ = x.shape
        feat = self.backbone(x)      

        feat = feat.view(B, -1)       
        
        mean = self.mean_fc(feat)   
        var  = F.softplus(self.var_fc(feat)) + 1e-12     
        
        return mean, var

class CMiss(nn.Module):
    def __init__(self, in_dim, mid_dim=1024):
        super().__init__()
        self.mid_dim = mid_dim
        self.layer = nn.Linear(in_dim, mid_dim)
        self.bn = nn.BatchNorm1d(mid_dim)
        self.layer2 = nn.Linear(mid_dim, mid_dim)
        self.bn2 = nn.BatchNorm1d(mid_dim)
        self.layer3 = nn.Linear(mid_dim,1)

    def forward(self, x):
        B, _, _, _ = x.shape
        feat = x.view(B, -1)
        feat = self.layer(feat)
        feat = self.bn(feat)
        feat = nn.ReLU(inplace=True)(feat)
        feat = self.layer2(feat)
        feat = self.bn2(feat)
        feat = nn.ReLU(inplace=True)(feat)
        feat = self.layer3(feat)

        return nn.Sigmoid()(feat)
        
        
class SEnc(nn.Module):
    def __init__(self, cut_dim=196, z_dim=32):
        super().__init__()
        self.mid_dim = cut_dim
        self.enc = nn.Sequential(
            nn.Linear(cut_dim, cut_dim),
            nn.BatchNorm1d(cut_dim),
            nn.ReLU(inplace=True),
            nn.Linear(cut_dim, self.mid_dim),
            nn.BatchNorm1d(self.mid_dim),
            nn.ReLU(inplace=True)
        )
        self.mean_fc = nn.Linear(self.mid_dim, z_dim)
        self.var_fc  = nn.Linear(self.mid_dim, z_dim)
    
    def forward(self, x):
        out = self.enc(x)       
        mean = self.mean_fc(out)  
        var  = F.softplus(self.var_fc(out)) + 1e-12   
        return mean, var


class SDec(nn.Module):
    def __init__(self, cut_dim=196, z_dim=32):
        super().__init__()
        self.mid_dim = cut_dim
        self.dec = nn.Sequential(
            nn.Linear(z_dim, self.mid_dim),
            nn.BatchNorm1d(self.mid_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.mid_dim, cut_dim),
            nn.BatchNorm1d(cut_dim),
            nn.ReLU(inplace=True)
        )
        self.mean_fc = nn.Linear(cut_dim, cut_dim)
        self.var_fc  = nn.Linear(cut_dim, cut_dim)
    
    def forward(self, z):
        out = self.dec(z)         
        mean = self.mean_fc(out) 
        var  = F.softplus(self.var_fc(out)) + 1e-12    
        return mean, var
        
class CDec(nn.Module):
    def __init__(self, out_channels=1, cut_dim=196):
        super().__init__()
        self.cut_dim = cut_dim
        self.out_channels = out_channels
        
        hidden = 256
        self.fc = nn.Sequential(
            nn.Linear(cut_dim, hidden*2*2),
            nn.ReLU(inplace=True)
        )
         
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(hidden, hidden//2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(hidden//2),
            nn.ReLU(inplace=True)
        )

        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(hidden//2, hidden//4, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(hidden//4),
            nn.ReLU(inplace=True)
        )
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(hidden//4, hidden//8, kernel_size=(3,3), stride=(2,1), padding=(1,1), output_padding=(1,0)),
            nn.BatchNorm2d(hidden//8),
            nn.ReLU(inplace=True)
        )

        self.mean_conv = nn.Conv2d(hidden//8, out_channels, kernel_size=3, padding=1)
        self.var_conv  = nn.Conv2d(hidden//8, out_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        B = x.size(0)
        out = self.fc(x)               
        hidden = 256
     
        out = out.view(B, hidden, 2, 2)
        out = self.deconv1(out)        
        out = self.deconv2(out)      
        out = self.deconv3(out)

        mean_ = self.mean_conv(out) 
        var_  = F.softplus(self.var_conv(out)) + 1e-12
        
        return mean_, var_

class Disc(nn.Module):
    def __init__(self, cut_dim=196, num_clients=8, num_classes=10, agg='mean'):
        super().__init__()
        self.agg = agg
        if agg == 'concat':
            in_dim = cut_dim * num_clients
        else:
            in_dim = cut_dim
        
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.BatchNorm1d(in_dim),
            nn.ReLU(inplace=True),
            nn.Linear(in_dim, num_classes)
        )
    
    def forward(self, x):
        
        return self.net(x) 


class OurModel(nn.Module):
    
    def __init__(self, dataset, num_clients, cuda_id=0, agg='mean'): 
        super().__init__()

        self.dataset = dataset
        self.num_clients = num_clients
        self.agg = agg
            
        USE_CUDA = torch.cuda.is_available()
        self.DEVICE_SET = [torch.device(f"cuda:{cuda_id}" if USE_CUDA else "cpu")]
    
        print("Using Device:", self.DEVICE_SET)  
    
        if dataset in ['fashionmnist']:
            num_classes = 10
            in_channels = 1
        else:
            raise ValueError(f"Dataset {dataset} does not match resnet18 model.")

        cut_dim = get_cut_dim(dataset)
        z_dim = get_z_dim(dataset)

        self.map_idx_to_partition = get_idx_to_partition_map(dataset, num_clients)
        
        self.c_enc_list = nn.ModuleList([CEnc(in_channels, cut_dim) for _ in range(num_clients)])

        self.c_miss_list = nn.ModuleList([CMiss(14*7) for _ in range(num_clients)])

        self.s_enc = SEnc(cut_dim, z_dim)
        self.s_dec = SDec(cut_dim, z_dim)
        
        self.c_dec_list = nn.ModuleList([CDec(out_channels=in_channels, cut_dim=cut_dim)
            for _ in range(num_clients)])
        
        self.disc = Disc(cut_dim=cut_dim, num_clients=num_clients, num_classes=num_classes)

        
    def forward(self, x, y, mask, args, bs, training=True):
        loss_fn = nn.NLLLoss()
        if training:
            K = args.K
        else:
            K = args.K_test

        correct = 0.0
        device = self.DEVICE_SET[0] if len(self.DEVICE_SET)>0 else torch.device('cpu')

        new_x, obs_new_x = self.get_local_input(x, mask)
        new_x_masks_rep = [~torch.isnan(new_x[i]).repeat(K,1,1,1) for i in range(args.num_clients)]
        hgivenxobs_flat, qmeans_hgivenxobs, qvars_hgivenxobs = self.generate_h(obs_new_x, mask, K, args.pretrain)

        zgivenh_flat, qmeans_zgivenh, qvars_zgivenh = self.generate_z(hgivenxobs_flat)
        
        xmisgivenh_flat, pmeans_xgivenh, pvars_xgivenh = self.generate_xmis(hgivenxobs_flat, new_x_masks_rep)
        
        logp_z = self.normal_logprob(
            zgivenh_flat,
            torch.zeros([]).to(device=device),
            torch.ones([]).to(device=device),
            1
        ).reshape([K, bs])

        logq_hgivenxobs = self.normal_logprob(
            hgivenxobs_flat,
            qmeans_hgivenxobs.repeat([K,1]),
            qvars_hgivenxobs.repeat([K,1]),
            1
        ).reshape([K, bs])
        
        logq_zgivenh = self.normal_logprob(
            zgivenh_flat,
            qmeans_zgivenh,
            qvars_zgivenh,
            1
        ).reshape([K, bs])

        pmeans_xobsgivenh, pvars_xobsgivenh = [], []
        for i in range(args.num_clients):
            tmp_mean = torch.where(
                new_x_masks_rep[i],
                pmeans_xgivenh[i],
                torch.Tensor([float('nan')]).to(device)
            )
            tmp_var  = torch.where(
                new_x_masks_rep[i],
                pvars_xgivenh[i],
                torch.Tensor([float('nan')]).to(device)
            )
            pmeans_xobsgivenh.append(tmp_mean)
            pvars_xobsgivenh.append(tmp_var)

        logp_xobsgivenh_list = []
        for i in range(args.num_clients):
            expanded_xi = new_x[i].repeat([K,1,1,1]).to(device)
            val = self.normal_logprob(expanded_xi, pmeans_xobsgivenh[i], pvars_xobsgivenh[i], 3)
            val = val.reshape([K, bs])
            logp_xobsgivenh_list.append(val)
        logp_xobsgivenh = sum(logp_xobsgivenh_list)

        pmeans_hgivenz, pvars_hgivenz = self.s_dec(zgivenh_flat)
        logp_hgivenz = self.normal_logprob(
            hgivenxobs_flat, pmeans_hgivenz, pvars_hgivenz, 1
        ).reshape([K, bs])
        
        if args.mnar:
            logp_mgivenx_list = []
            for i in range(self.num_clients):
                val = self.c_miss_list[i](xmisgivenh_flat[i].nan_to_num() + new_x[i].repeat([K,1,1,1]).nan_to_num())
                val = self.bernoulli_logprob(mask[i].repeat(K), val.squeeze())
                val = val.reshape([K, bs])
                logp_mgivenx_list.append(val)
            logp_mgivenx = sum(logp_mgivenx_list)
        
        if args.pretrain:
            if args.mnar:
                batch_bound = -torch.sum(torch.logsumexp(
                    (logp_mgivenx + logp_xobsgivenh + logp_hgivenz + logp_z
                     - logq_hgivenxobs - logq_zgivenh),
                    dim=0))
            else:
                batch_bound = -torch.sum(torch.logsumexp(
                    (logp_xobsgivenh + logp_hgivenz + logp_z
                     - logq_hgivenxobs - logq_zgivenh),
                    dim=0))
        else:
            probs, logp_ygivenh = self.get_probs_and_logPy(hgivenxobs_flat, y, K)
            logp_ygivenh = logp_ygivenh.reshape([K, bs])
            if args.mnar:
                batch_bound = -torch.sum(torch.logsumexp(
            (logp_ygivenh + logp_mgivenx + logp_xobsgivenh + logp_hgivenz + logp_z
            - logq_hgivenxobs - logq_zgivenh),
            dim=0))
            else:
                batch_bound = -torch.sum(torch.logsumexp(
            (logp_ygivenh + logp_xobsgivenh + logp_hgivenz + logp_z
            - logq_hgivenxobs - logq_zgivenh),
            dim=0))
            if not training:
                if args.mnar:
                    r_l = (logp_mgivenx + logp_xobsgivenh + logp_hgivenz + logp_z - logq_hgivenxobs - logq_zgivenh).reshape([K, bs])
                else:
                    r_l = (logp_xobsgivenh + logp_hgivenz + logp_z - logq_hgivenxobs - logq_zgivenh).reshape([K, bs])
                w_l = torch.nn.Softmax(dim=0)(r_l)
                
                probs_ = probs.reshape([K, bs, -1]).permute(2,0,1).to(device)
                p_y = torch.sum(w_l * probs_, 1).permute(1,0)
                predicted = p_y.argmax(1)
                correct = (predicted == y).float().sum().item()

        return batch_bound, correct

    def unfiltered_tensor(self, tensor, mask):
        batch_size = mask.shape[0]
        device = 'cpu' if tensor.get_device()==-1 else f'cuda:{tensor.get_device()}'
        output = torch.full((batch_size, *tensor.shape[1:]), float('nan'), device=device)
        output[mask] = tensor
        return output
    
    def unfiltered_zero_tensor(self, tensor, mask):
        batch_size = mask.shape[0]
        device = 'cpu' if tensor.get_device()==-1 else f'cuda:{tensor.get_device()}'
        output = torch.zeros(batch_size, *tensor.shape[1:], device=device)
        output[mask] = tensor
        return output
    
    def generate_h(self, inputs, masks, sample_size, pretrain=False):
        dev_1 = self.DEVICE_SET[0] if len(self.DEVICE_SET)>0 else torch.device('cpu')
        partial_means = {}
        partial_vars_inv = {}
        for i in range(self.num_clients):
            if len(inputs[i]):            
                mean_i, var_i = self.c_enc_list[i](inputs[i])
                partial_means[i] = self.unfiltered_tensor(mean_i, masks[i]).to(dev_1)
                var_inv_i = (var_i**2).reciprocal()
                partial_vars_inv[i] = self.unfiltered_tensor(var_inv_i, masks[i]).to(dev_1)
                
        partial_means = torch.stack(list(partial_means.values()))
        partial_vars_inv = torch.stack(list(partial_vars_inv.values()))

        if self.agg == 'sum':
            final_mean = torch.nansum(partial_means, dim=0)
        elif self.agg == 'mean':
            final_mean = torch.nanmean(partial_means, dim=0)
        elif self.agg == 'weighted':
            weight_sum = torch.nansum(partial_vars_inv, dim=0)
            numerator = torch.nansum(partial_means * partial_vars_inv, dim=0)
            final_mean = numerator / weight_sum
        else:
            raise NotImplementedError("Wrong agg type.")

        final_var = torch.sqrt(torch.nansum(partial_vars_inv, dim=0).reciprocal())
        
        final_mean_rep = final_mean.repeat([sample_size,1])
        final_var_rep  = final_var.repeat([sample_size,1])
        hgivenxobs_flat = self.reparameterize(final_mean_rep, final_var_rep)

        return hgivenxobs_flat, final_mean, final_var 
                
            
    def generate_z(self, h):
        z_mean, z_var = self.s_enc(h)
        zgivenh_flat = self.reparameterize(z_mean, z_var)
        
        return zgivenh_flat, z_mean, z_var

    def generate_xmis(self, h, x_masks):
        xmis_dict   = {}
        pmeans_dict = {}
        pvars_dict  = {}

        for i in range(self.num_clients):
            mean_i, var_i = self.c_dec_list[i](h) 

            x_i = self.reparameterize(mean_i, var_i)
            xmis_i = torch.where(~x_masks[i], x_i, torch.Tensor([torch.nan]).to(self.DEVICE_SET[0]))
            
            xmis_dict[i] = xmis_i
            pmeans_dict[i] = mean_i
            pvars_dict[i] = var_i

        return xmis_dict, pmeans_dict, pvars_dict

    def get_probs_and_logPy(self, h, y, sample_size):
        probs = F.softmax(self.disc(h), dim=1)

        return probs, self.categorical_logprob(y.repeat([sample_size]).to(self.DEVICE_SET[0]), probs)
        
    def normal_logprob(self, x, mu, var, event_dim=0):
        x_0 = x.nan_to_num(nan=0.0)
        mu_0= mu.nan_to_num(nan=0.0)
        std_0= var.nan_to_num(nan=1/torch.sqrt(2*math.pi))

        logp = -0.5*(((x_0 - mu_0)/std_0)**2) - std_0.log() -0.5*math.log(2*math.pi)
        
        for _ in range(event_dim):
            logp = logp.sum(dim=-1)

        return logp

    def bernoulli_logprob(self, mask_s, pi):
        return mask_s.float() * torch.log(pi + 1e-8) + (~mask_s).float() * torch.log(1-pi + 1e-8)
    
    def categorical_logprob(self, x, probs):
        return probs.gather(1, x.long().unsqueeze(1)).log()
        
    def reparameterize(self, mu, std):
        eps = torch.randn_like(mu)
        return mu + std*eps

    def get_local_input(self, inputs, masks):
        num_clients = masks.shape[0]
        [inputs_] = inputs

        results = []
        obs_results = []
        for i in range(num_clients):
            row_indices, col_indices = self.map_idx_to_partition[i]
            start_row, end_row = row_indices
            start_col, end_col = col_indices
            result = inputs_[:,:,start_row:end_row, start_col:end_col]
            result[~masks[i]] = torch.nan
            
            results.append(result)
            obs_results.append(result[masks[i]])
            
        return results, obs_results


def get_cut_dim(dataset):
    if dataset in ['fashionmnist']:
        cut_dim = 196 

    return cut_dim

def get_z_dim(dataset):
    if dataset in ['fashionmnist']:
        z_dim = 32

    return z_dim

def get_idx_to_partition_map(dataset: str, num_clients: int) -> dict:
    if dataset in ["fashionmnist"] and num_clients == 8:
        return  {
                0: ((0, 14), (0, 7)),
                1: ((0, 14), (7, 14)),
                2: ((0, 14), (14, 21)),
                3: ((0, 14), (21, 28)),
                4: ((14, 28), (0, 7)),
                5: ((14, 28), (7, 14)),
                6: ((14, 28), (14, 21)),
                7: ((14, 28), (21, 28)),
                }
    elif dataset in ["fashionmnist"] and num_clients == 4:
        return  {
                0: ((0, 14), (0, 14)),
                1: ((0, 14), (14, 28)),
                2: ((14, 28), (0, 14)),
                3: ((14, 28), (14, 28))
                }
    else:
        raise NotImplementedError
