import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import causally.start.diffusion as diff
import causally.start.process_edited as pce
from causally.data.acic import sigmoid, generate_inner

class AbstractTrainer(object):

    def __init__(self, config, model, policy=None, **kwargs):
        self.config = config
        self.model = model
        self.policy = policy
        self.device = config['device']
        
    def fit(self, train_data):
        raise NotImplementedError('Method [next] should be implemented.')

    def to_device(self,x,t,y,w):
        return x.to(self.device),t.to(self.device),y.to(self.device),w.to(self.device)
    
    def _check_nan(self, loss):
        if torch.isnan(loss):
            print('Training loss is nan, raise and break!!!!!!!!!!!')
            raise ValueError('Training loss is nan')
            return True
        return False
        
    def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
        train_loss_output = 'epoch %d training [time: %.2fs, ' % (epoch_idx, e_time - s_time)
        if isinstance(losses, tuple):
            train_loss_output += ', '.join('train_loss%d: %.4f' % (idx + 1, loss) for idx, loss in enumerate(losses))
        else:
            train_loss_output += 'train loss: %.4f' % losses
        return train_loss_output + ']'
    
    def _get_missing_data(self,covariates, missing_rate):
        from sklearn.experimental import enable_iterative_imputer
        from sklearn.impute import IterativeImputer

        covariates_numpy  = covariates.numpy()

        if self.config['pertubation_fix']:
            RANDOM_SEED = self.config['seed']
            branch_rng = np.random.RandomState(RANDOM_SEED)
            random_probs = branch_rng.random(size=covariates_numpy.shape)
            covariates_numpy[random_probs <= missing_rate] = np.nan
            imputer = IterativeImputer(random_state=RANDOM_SEED, max_iter=5)
            covariates_numpy = imputer.fit_transform(covariates_numpy)
        else:
            random_probs = np.random.random(size=covariates_numpy.shape)
            covariates_numpy[random_probs <= missing_rate] = np.nan
            imputer = IterativeImputer(max_iter=5)
            covariates_numpy = imputer.fit_transform(covariates_numpy)

        return torch.tensor(covariates_numpy).to(self.device)


    def _get_measure_data(self, covariates, measure_rate):
        if self.config['pertubation_fix']:
            RANDOM_SEED = self.config['seed']
            branch_rng = np.random.RandomState(RANDOM_SEED)
            covariates_numpy =  branch_rng.normal(0, measure_rate, size=covariates.shape)
        else:
            covariates_numpy = np.random.normal(0, measure_rate, size=covariates.shape)
        return torch.tensor(covariates_numpy,dtype=torch.float32).to(self.device)

    def _get_hidden_data(self, covariates, hidden_rate):
        if self.config['pertubation_fix']:
            RANDOM_SEED = self.config['seed']
            branch_rng = np.random.RandomState(RANDOM_SEED)
  
            hidden_cov =  branch_rng.uniform(low=-3, high=3, size=(covariates.shape[0], hidden_rate))
            X = np.concatenate((covariates.numpy(),hidden_cov),axis=1)

            beta_for_T = branch_rng.binomial(1, 0.1, X.shape[1])
            prob_t = sigmoid(x=X, beta_for_T=beta_for_T, xi=1).squeeze()
            T = branch_rng.binomial(1, prob_t, X.shape[0])

            X_for_Y = generate_inner(X,2)
            beta_for_Y = branch_rng.binomial(1, 0.1, X_for_Y.shape[1]).reshape(-1, 1)
            mu0 = np.matmul(X_for_Y, beta_for_Y) / 10

            X_for_tau = []
            for i in range(0, X.shape[1]):
                for j in range(i, X.shape[1]):
                    X_for_tau.append(X[:, i] * X[:, j])

            X_for_tau = np.array(X_for_tau).T
            rho = 0.1
            beta_for_tau = np.array(branch_rng.binomial(1, rho, X_for_tau.shape[1]).reshape(-1, 1))
            tau_x = np.matmul(X_for_tau, beta_for_tau) / 10 + branch_rng.normal(0, 0.1, (X.shape[0],1))

        else:
            hidden_cov =  np.random.uniform(low=-3, high=3, size=(covariates.shape[0], hidden_rate))
            X = np.concatenate((covariates.numpy(),hidden_cov),axis=1)

            beta_for_T = np.random.binomial(1, 0.1, X.shape[1])
            prob_t = sigmoid(x=X, beta_for_T=beta_for_T, xi=1).squeeze()
            T = np.random.binomial(1, prob_t, X.shape[0])

            X_for_Y = generate_inner(X,2)
            beta_for_Y = np.random.binomial(1, 0.1, X_for_Y.shape[1]).reshape(-1, 1)
            mu0 = np.matmul(X_for_Y, beta_for_Y) / 10

            X_for_tau = []
            for i in range(0, X.shape[1]):
                for j in range(i, X.shape[1]):
                    X_for_tau.append(X[:, i] * X[:, j])

            X_for_tau = np.array(X_for_tau).T
            rho = 0.1
            beta_for_tau = np.array(np.random.binomial(1, rho, X_for_tau.shape[1]).reshape(-1, 1))
            tau_x = np.matmul(X_for_tau, beta_for_tau) / 10 + np.random.normal(0, 0.1, (X.shape[0],1))


        return torch.tensor(tau_x,dtype=torch.float32).to(self.device)

    def generate_cov(self, gene_type, N, P, set_type=None):
        if gene_type == 'diffusion':
            cov = diff.Euler_Maruyama_sampling(self.policy, self.config['n_steps'], 
                                                N, P, self.device) 
        if gene_type == 'random':
            cov = torch.randn(N, P)
        if gene_type == 'normal':
            cov = torch.normal(0, 1, size=(N, P))
        if gene_type == 'uniform':
            cov = torch.rand(N, P)
        return cov