import os, sys, urllib, warnings, errno, logging, time

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.model_selection import train_test_split

import torch.nn as nn

sys.path.append('../domain_adaptation/Iterative-Alignment-Flows/')
sys.path.append('../domain_adaptation/Iterative-Alignment-Flows/destructive-deep-learning/')
sys.path.append('../domain_adaptation/Iterative-Alignment-Flows/destructive-deep-learning/ddl/')

sys.path.append('../../../domain_adaptation/Iterative-Alignment-Flows/')
sys.path.append('../../../domain_adaptation/Iterative-Alignment-Flows/destructive-deep-learning/')
sys.path.append('../../../domain_adaptation/Iterative-Alignment-Flows/destructive-deep-learning/ddl/')

sys.path.append('../../domain_adaptation/Iterative-Alignment-Flows/')
sys.path.append('../../domain_adaptation/Iterative-Alignment-Flows/destructive-deep-learning/')
sys.path.append('../../domain_adaptation/Iterative-Alignment-Flows/destructive-deep-learning/ddl/')
print(sys.path)
from two_barycenter import *
from util import load_data, plot_images, create_deep_density_cd, create_deep_tree_cd, wrap_transform
from metrics import *
from ddl.base import (CompositeDestructor, DestructorMixin, create_inverse_transformer, 
                      BoundaryWarning, DataConversionWarning, IdentityDestructor)
from ddl.independent import IndependentDensity, IndependentDestructor, IndependentInverseCdf
from ddl.linear import LinearProjector
from ddl.univariate import HistogramUnivariateDensity, ScipyUnivariateDensity
from ddl.deep import DeepDestructor
from ddl.linear import LinearProjector, RandomOrthogonalEstimator, BestLinearReconstructionDestructor
from weakflow import *

warnings.simplefilter('ignore', BoundaryWarning) # Ignore boundary warnings from ddl
warnings.simplefilter('ignore', DataConversionWarning) # Ignore data conversion warnings from ddl



def inb_wrapper(reference, query, n_layers = 3):
    
    # pre-processing
    query_0 = query.copy()
    query_0[np.isnan(query_0)] = 0
    X = np.concatenate([reference, query_0])
    y = np.concatenate([np.zeros(reference.shape[0]), np.ones(query_0.shape[0])])
    
    start = time.time()

    cd_swd_nb = MSWDBaryClassifierDestructor()
    ndim = 30
    # add an inverse normal CDF
    Z = cd_swd_nb.initialize(X,y)

    # add mSWD-bayes layers
    for i in range(n_layers):
        print('log:', i)
        cd_swd_nb, Z = add_one_layer(cd_swd_nb, Z, y, 'nb', ndim=ndim )

    # add a normal CDF
    Z = cd_swd_nb.end(Z, y)
    del(Z)
    print(f'fitting time: {time.time()-start} s')
    
    # Inference
    cd = cd_swd_nb
    Z_temp = cd(X, y)
    Xflip_temp = cd.inverse(Z_temp, 1-y)
    
    # post-processing
    query_imputed = Xflip_temp[reference.shape[0]:,:]
    query_imputed[~np.isnan(query)] = query[~np.isnan(query)]

    
    return query_imputed



def dd_wrapper(reference, query, n_canonical_destructors=2):
    
    # Pre-processing
    query_0 = query.copy()
    query_0[np.isnan(query_0)] = 0
    X = np.concatenate([reference, query_0])
    y = np.concatenate([np.zeros(reference.shape[0]), np.ones(query_0.shape[0])])
    
    
    start = time.time()
    cd_dd = create_deep_density_cd(n_canonical_destructors=n_canonical_destructors) #10
    with warnings.catch_warnings():
        warnings.simplefilter('ignore') # Ignore boundary warnings
        cd_dd.fit_transform(X, y)
    print(f'fitting time: {time.time()-start} s')
    
    cd = cd_dd
    
    Z_temp = cd.transform(X, y)
    Xflip_temp = cd.inverse_transform(Z_temp, 1-y)
    
    
    # post-processing
    query_imputed = Xflip_temp[reference.shape[0]:,:]
    query_imputed[~np.isnan(query)] = query[~np.isnan(query)]
    
    return query_imputed