import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch import distributions
import numpy as np


class batch_EM(nn.Module):
    def __init__(self, gamma=0.8, gamma_scheduler=None, const=0.99, const_scheduler=None, beta=0.):
        super(batch_EM, self).__init__()
        self.mu = None
        self.sigma = None
        self.init = False
        self.t = 1
        
        self.gamma = gamma
        self.gamma_scheduler = gamma_scheduler
        
        self.const = 0.99
        self.const_scheduler = const_scheduler
        
        self.beta = beta
        
    def _init_param(self, X):
        if X.is_cuda:
            X_copy = X.cpu().numpy()
        else:
            X_copy = X.numpy()

        self.mu = torch.from_numpy(np.nanmean(X_copy, axis = 0)).cuda()
        observed_rows = np.where(np.isnan(sum(X_copy.T)) == False)[0]
        if len(observed_rows) >= X_copy.shape[1]:
            S = np.cov(X_copy[observed_rows, ].T).astype(np.float32)
        else:
            # S = np.diag(np.nanvar(X_copy, axis = 0))
            diag = np.nanvar(X_copy, axis = 0)
            diag[diag < 1e-8] = np.median(diag) + 1e-8
            S = np.diag(diag)
        self.sigma = torch.from_numpy(S).cuda()


    @staticmethod
    def cov(m, rowvar=True, bias=False, inplace=False):
        if m.dim() > 2:
            raise ValueError('m has more than 2 dimensions')
        if m.dim() < 2:
            m = m.view(1, -1)
        if not rowvar and m.size(0) != 1:
            m = m.t()
        # m = m.type(torch.double)  # uncomment this line if desired
        if bias:
            fact = 1.0 / m.size(1)
        else:
            fact = 1.0 / (m.size(1) - 1)
        if inplace:
            m -= torch.mean(m, dim=1, keepdim=True)
        else:
            m = m - torch.mean(m, dim=1, keepdim=True)
        mt = m.t()  # if complex: mt = m.t().conj()
        return fact * m.matmul(mt).squeeze()

    def rate_update(self):
        self.t += 1

    def rate_init(self):
        self.t = 1
        if self.const_scheduler:
            self.const = max(0.55, self.const_scheduler.step())
        if self.gamma_scheduler:
            self.gamma = min(0.90, self.gamma_scheduler.step())
    
    def reset(self):
        self.init = False
        self.mu = None
        self.sigma = None

    def complete(self, X, C, mode='train'):
        if X.is_cuda:
            X = X.cpu()
        if C.is_cuda:
            C = C.cpu()
        C = C.type(torch.bool)

        if not self.init:
            self._init_param(X)
            self.init = True
            
        self.mu = self.mu.cpu()
        self.sigma = self.sigma.cpu()

        X_tilde = X.clone()
        S_tilde = torch.zeros(X.shape[1], X.shape[1])
        for i in range(X.shape[0]):
            if ~C[i].any():
                continue

            M_i = torch.where(C[i])[0]
            O_i = torch.where(~C[i])[0]

            S_MM = self.sigma[M_i,:][:, M_i]
            S_MO = self.sigma[M_i,:][:, O_i]
            S_OM = S_MO.t()
            S_OO = self.sigma[O_i,:][:, O_i]

            S_OO_inv = torch.inverse(S_OO)
            Mu_tilde = self.mu[M_i] + torch.mm(torch.mm(S_MO, S_OO_inv), (X_tilde[i, O_i] - self.mu[O_i]).reshape(-1,1)).squeeze()
            X_tilde[i, M_i] = Mu_tilde

            S_MM_O = S_MM - torch.mm(torch.mm(S_MO, S_OO_inv), S_OM)
            M_i_grid = torch.meshgrid(M_i, M_i)
            S_tilde[M_i_grid[0], M_i_grid[1]] += S_MM_O

        if mode == 'test':
            return X_tilde.cuda()

        mu_tilde = torch.mean(X_tilde, axis = 0)
        sigma_tilde = self.cov(X_tilde, rowvar = False, bias = True) + S_tilde / X.shape[0]

        rate = self.const * self.t**(-self.gamma)
        self.mu = rate * mu_tilde + (1-rate) * self.mu
        e, v = torch.symeig(self.sigma, eigenvectors=True)
        print(e.min())
        self.sigma = rate * sigma_tilde + (1-rate) * self.sigma
        e, v = torch.symeig(self.sigma, eigenvectors=True)
        print(e.min())
        self.sigma = self.sigma + self.beta * torch.diag(torch.diag(self.sigma))
        e, v = torch.symeig(self.sigma, eigenvectors=True)
        print(e.min())
        self.rate_update()
        
        self.mu = self.mu.cuda()
        self.sigma = self.sigma.cuda()

        return X_tilde.cuda(), distributions.MultivariateNormal(self.mu.cuda(), self.sigma.cuda())
    
    
    def complete_gpu(self, X, C, start=0, mode='train'):

        C = C.type(torch.bool)
        # find the row index tht has missing
        miss_idx = torch.where(C.any(dim=1))[0]
        # the largest and smallest number of missing
        miss_upper = C[miss_idx].sum(dim=1).max()
        miss_lower = C[miss_idx].sum(dim=1).min()


        if not self.init:
            self._init_param(X)
            self.init = True

        X_tilde = X.clone()
        S_tilde = torch.zeros(X.shape[1], X.shape[1]).cuda()

        # save for batch manipulation
        S_OO_lst = torch.eye(X.shape[1]-miss_lower).repeat(len(miss_idx),1,1).cuda()
        S_MO_lst = torch.zeros(miss_upper, X.shape[1]-miss_lower).repeat(len(miss_idx),1,1).cuda()
        S_MM_lst = torch.eye(miss_upper).repeat(len(miss_idx),1,1).cuda()
        X_tilde_O = torch.zeros(len(miss_idx), X.shape[1]-miss_lower).cuda()
        mu_M = torch.zeros(len(miss_idx), miss_upper).cuda()
        mu_O = torch.zeros(len(miss_idx), X.shape[1]-miss_lower).cuda()

        for k, i in enumerate(miss_idx):
            M_i = torch.where(C[i])[0]
            O_i = torch.where(~C[i])[0]

            S_MM_lst[k][:len(M_i), :len(M_i)] = self.sigma[M_i,:][:, M_i]
            S_MO_lst[k][:len(M_i), :len(O_i)] = self.sigma[M_i,:][:, O_i]
            S_OO_lst[k][:len(O_i), :len(O_i)] = self.sigma[O_i,:][:, O_i]
            X_tilde_O[k, :len(O_i)] = X_tilde[i, O_i]
            mu_M[k, :len(M_i)] = self.mu[M_i]
            mu_O[k, :len(O_i)] = self.mu[O_i]

        mu_M = mu_M.unsqueeze(-1)
        mu_O = mu_O.unsqueeze(-1)
        X_tilde_O = X_tilde_O.unsqueeze(-1)
        S_OO_inv_lst = torch.inverse(S_OO_lst)

        temp = torch.matmul(S_MO_lst , S_OO_inv_lst)
        M_tilde = mu_M + torch.matmul(temp, X_tilde_O - mu_O)
        S_MM_O_lst = S_MM_lst - torch.matmul(temp, torch.transpose(S_MO_lst, 1, 2))


        # update X_tilde and S_tilde
        for k, i in enumerate(miss_idx):
            M_i = torch.where(C[i])[0]
            O_i = torch.where(~C[i])[0]

            X_tilde[i, M_i] = M_tilde[k, :len(M_i), :].squeeze()
            M_i_grid = torch.meshgrid(M_i, M_i)
            S_tilde[M_i_grid[0], M_i_grid[1]] += S_MM_O_lst[k, :len(M_i), :len(M_i)]

        if mode == 'test':
            return X_tilde.cuda()

        mu_tilde = torch.mean(X_tilde, axis = 0)
        sigma_tilde = self.cov(X_tilde, rowvar = False, bias = True) + S_tilde / X.shape[0]

        rate = self.const * self.t**(-self.gamma)
        self.mu = rate * mu_tilde + (1-rate) * self.mu
        self.sigma = rate * sigma_tilde + (1-rate) * self.sigma
        self.sigma = self.sigma + self.beta * torch.diag(torch.diag(self.sigma))
        self.rate_update()
        
        return X_tilde, distributions.MultivariateNormal(self.mu, self.sigma)
