import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.distributions as distribution
import math
import numpy as np
import time

from nde.FM import NFM


class FMGGC(nn.Sequential):
    """ 
        Flow Matching Based Generalized Gaussian copula
    """
    def __init__(self, n_inputs):
        super().__init__()
        self.maf1 = NFM(n_inputs)
        self.maf2 = NFM(n_inputs)

    def forward(self, x, y):
        xx, _ = self.maf1.forward(x)
        yy, _ = self.maf2.forward(y)
        return xx, yy
    
    def learn(self, x, y):
        n, d = x.size()
        # learn f, g
        self.maf1.learn(x)
        self.maf2.learn(y)
        with torch.no_grad():
            xx, yy = self.forward(x, y)
        # learn the inner Gaussian 
        self.mu, self.V = self.empirical_params(xx, yy)
        self.mu2, self.V2 = self.mu.clone(), torch.eye(2*d).to(x.device)
        self.Vx, self.mx = self.V[0:d, 0:d], self.mu[0:d]
        self.Vy, self.my = self.V[d:, d:], self.mu[d:]
        self.normal = distribution.multivariate_normal.MultivariateNormal(self.mu, self.V)
        self.normal2 = distribution.multivariate_normal.MultivariateNormal(self.mu2, self.V2)
        self.normal_x = distribution.multivariate_normal.MultivariateNormal(self.mx, self.Vx)
        self.normal_y = distribution.multivariate_normal.MultivariateNormal(self.my, self.Vy)
        return 
    
    def empirical_params(self, x, y):
        z = torch.cat([x, y], dim=1)
        n, d = z.size()
        mu = z.mean(dim=0, keepdim=True)
        V = (z-mu).t() @ (z-mu)/(n+1)
        return mu.view(-1), V
    
    def print(self):
        print('mu=', self.mu)
        print('V=',  (self.V*100).int()/100.0)

        
        
        
        
        
        
        
class FMG(nn.Sequential):
    """ 
        Flow Matching Based Gaussian (only used for ablation study)
    """
    def __init__(self, n_inputs):
        super().__init__()
        self.maf = NFM(n_inputs)

    def forward(self, x, y):
        n, d = x.size()
        z = torch.cat([x, y], dim=1)
        zz, _ = self.maf.forward(z)
        xx, yy = zz[:, 0:d], zz[:, d:] 
        return xx, yy
    
    def learn(self, x, y):
        n, d = x.size()
        z = torch.cat([x, y], dim=1)
        # learn f, g
        self.maf.learn(z)
        with torch.no_grad():
            xx, yy = self.forward(x, y)
        # learn the inner Gaussian 
        self.mu, self.V = self.empirical_params(xx, yy)
        self.mu2, self.V2 = self.mu.clone(), torch.eye(2*d).to(x.device)
        self.Vx, self.mx = self.V[0:d, 0:d], self.mu[0:d]
        self.Vy, self.my = self.V[d:, d:], self.mu[d:]
        self.normal = distribution.multivariate_normal.MultivariateNormal(self.mu, self.V)
        self.normal2 = distribution.multivariate_normal.MultivariateNormal(self.mu2, self.V2)
        self.normal_x = distribution.multivariate_normal.MultivariateNormal(self.mx, self.Vx)
        self.normal_y = distribution.multivariate_normal.MultivariateNormal(self.my, self.Vy)
        return 
    
    def empirical_params(self, x, y):
        z = torch.cat([x, y], dim=1)
        n, d = z.size()
        mu = z.mean(dim=0, keepdim=True)
        V = (z-mu).t() @ (z-mu)/(n+1)
        return mu.view(-1), V
    
    def print(self):
        print('mu=', self.mu)
        print('V=',  (self.V*100).int()/100.0)