import numpy as np
import torch

from tqdm import tqdm


from .learners import G1D_learner, INTRACTABLE_learner, RBM_learner, MNIST_learner
from .utils import train_weighted_KSD 
from .distributions import G1D, INTRACTABLE, RBM, MNIST
from .debiaser import Debiaser
from .utils import weighted_KSD


class experiment_G1D:
    def __init__(self, n, n_steps, pi_hat_model="logistic", beta_hat_model="CME", mean = 0., sigma = 1, coef_ax = 1, stdx = 1):
        self.n = n
        self.n_steps = n_steps
        self.sampler = G1D(mean=mean, sigma=sigma, coef_ax=coef_ax, stdx=stdx)
        self.pi_hat_model = pi_hat_model
        self.beta_hat_model = beta_hat_model

    def __call__(self, random_seed = 0):
        n = self.n
        n_steps = self.n_steps
        
        X, A, Y = self.sampler.sample_Z(n = n, random_seed = random_seed)

        debiaser = Debiaser(pi_hat_model = self.pi_hat_model, beta_hat_model = self.beta_hat_model)
        kwargs = {'min_node_size': 15,'num_trees': 100, 'gamma': 1e-3} # "min_node_size" and "num_trees" for DRF, "gamma" for CME.
        weights_AIPW = debiaser.find_weights_AIPW(X, A, Y, kwargs=kwargs)
        weights_IPW = debiaser.find_weights_IPW(X, A, Y)
        weights_PI = debiaser.find_weights_PI(X, A, Y, kwargs=kwargs)

        values = np.zeros(3)
        learner = G1D_learner(mean = torch.Tensor([1.]), var = torch.Tensor([1]))
        train_weighted_KSD(Y, learner, weights_AIPW, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[0] = next(params).detach().numpy()[0]
        learner = G1D_learner(mean = torch.Tensor([1.]), var = torch.Tensor([1]))
        train_weighted_KSD(Y, learner, weights_IPW, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[1] = next(params).detach().numpy()[0]
        learner = G1D_learner(mean = torch.Tensor([1.]), var = torch.Tensor([1]))
        train_weighted_KSD(Y, learner, weights_PI, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[2] = next(params).detach().numpy()[0]

        return values



class experiment_INTRACTABLE:
    def __init__(self, n, n_steps, pi_hat_model="logistic", beta_hat_model="CME", coef = torch.tensor([1, 1, 1, 1, 1])):
        self.n = n
        self.n_steps = n_steps
        self.sampler = INTRACTABLE(coef=coef)
        self.pi_hat_model = pi_hat_model
        self.beta_hat_model = beta_hat_model

    def __call__(self, random_seed = 0):
        n = self.n
        n_steps = self.n_steps
        
        X, A, Y = self.sampler.sample_Z(n = n, random_seed = random_seed)

        debiaser = Debiaser(pi_hat_model = self.pi_hat_model, beta_hat_model = self.beta_hat_model)
        kwargs = {'min_node_size': 15,'num_trees': 200, 'gamma': 1} # "min_node_size" and "num_trees" for DRF, "gamma" for CME.
        weights_AIPW = debiaser.find_weights_AIPW(X, A, Y, kwargs=kwargs)
        weights_IPW = debiaser.find_weights_IPW(X, A, Y)
        weights_PI = debiaser.find_weights_PI(X, A, Y, kwargs=kwargs)

        values = np.zeros(6)
        learner = INTRACTABLE_learner(theta4 = torch.Tensor([1]), theta5 = torch.Tensor([1]))
        train_weighted_KSD(Y, learner, weights_AIPW, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[0] = next(params).detach().numpy()[0]
        values[1] = next(params).detach().numpy()[0]
        learner = INTRACTABLE_learner(theta4 = torch.Tensor([1]), theta5 = torch.Tensor([1]))
        train_weighted_KSD(Y, learner, weights_IPW, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[2] = next(params).detach().numpy()[0]
        values[3] = next(params).detach().numpy()[0]
        learner = INTRACTABLE_learner(theta4 = torch.Tensor([1]), theta5 = torch.Tensor([1]))
        train_weighted_KSD(Y, learner, weights_PI, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[4] = next(params).detach().numpy()[0]
        values[5] = next(params).detach().numpy()[0]

        return values



class experiment_MNIST:
    def __init__(self, n, n_steps, pi_hat_model="logistic", beta_hat_model="CME", digit=0, n_layer=10):
        self.n = n
        self.n_steps = n_steps
        self.sampler = MNIST(digit=digit, n_layer=n_layer)
        self.pi_hat_model = pi_hat_model
        self.beta_hat_model = beta_hat_model
        self.digit = digit
        self.n_layer = n_layer

    def __call__(self, random_seed = 0):
        n = self.n
        n_steps = self.n_steps
        n_layer = self.n_layer
        
        X, A, Y = self.sampler.sample_Z(n = n, random_seed = random_seed)

        debiaser = Debiaser(pi_hat_model = self.pi_hat_model, beta_hat_model = self.beta_hat_model)
        kwargs = {'min_node_size': 500,'num_trees': 100, 'gamma': .01} # "min_node_size" and "num_trees" for DRF, "gamma" for CME.
        weights_AIPW = debiaser.find_weights_AIPW(X, A, Y, kwargs=kwargs)
        weights_IPW = debiaser.find_weights_IPW(X, A, Y)
        weights_PI = debiaser.find_weights_PI(X, A, Y, kwargs=kwargs)

        values = np.zeros(n_layer*3)
        learner = MNIST_learner(coef = torch.ones(self.n_layer), n = n, digit=self.digit, n_layer=self.n_layer)
        train_weighted_KSD(Y, learner, weights_AIPW, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[:n_layer] = next(params).detach().numpy()
        learner = MNIST_learner(coef = torch.ones(self.n_layer), n = n, digit=self.digit, n_layer=self.n_layer)
        train_weighted_KSD(Y, learner, weights_IPW, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[n_layer:n_layer*2] = next(params).detach().numpy()
        learner = MNIST_learner(coef = torch.ones(self.n_layer), n = n, digit=self.digit, n_layer=self.n_layer)
        train_weighted_KSD(Y, learner, weights_PI, save_out = True, n_steps = n_steps, batch = len(Y), lr=1)
        params = learner.parameters()
        values[n_layer*2:n_layer*3] = next(params).detach().numpy()

        return values


class experiment_RBM:
    def __init__(self, n, n_burn, pi_hat_model="logistic", beta_hat_model="CME", dvisible = 5, dhidden = 5):
        self.n = n
        self.n_burn = n_burn
        self.sampler = RBM(dvisible=dvisible, dhidden=dhidden)
        self.pi_hat_model = pi_hat_model
        self.beta_hat_model = beta_hat_model
        self.dvisible = dvisible
        self.dhidden = dhidden

    def __call__(self, random_seed = 0):
        n = self.n
        
        X, A, Y, bias_visible, bias_hidden, B = self.sampler.sample_Z(n = n, n_burn=self.n_burn, random_seed = random_seed)

        debiaser = Debiaser(pi_hat_model = self.pi_hat_model, beta_hat_model = self.beta_hat_model)
        kwargs = {'min_node_size': 15,'num_trees': 200, 'gamma': 1} # "min_node_size" and "num_trees" for DRF, "gamma" for CME.
        weights_AIPW = debiaser.find_weights_AIPW(X, A, Y, kwargs=kwargs)

        grid = np.arange(start=-5, stop=5, step=.1)
        loss = np.zeros((len(grid), len(grid)))
        bias_visible_init = torch.Tensor(bias_visible)
        bias_hidden_init = torch.Tensor(bias_hidden)
        B_init = torch.zeros(B.shape)
        for i, g1 in tqdm(enumerate(grid)):
            for j, g2 in enumerate(grid):
                B_init[0, 0] = g1
                B_init[1, 0] = g2
    
                learner = RBM_learner(B=B_init, bias_visible=bias_visible_init, bias_hidden=bias_hidden_init)
                loss[i, j] = weighted_KSD(samples = Y, score_func=learner.score, weights = weights_AIPW, gamma = .1) 
        
        return loss




