import os
import pandas as pd
import torch
import torch.optim as optim
from torch.nn.utils.clip_grad import clip_grad_norm_
import numpy as np
import matplotlib.pyplot as plt
import random
from time import time
from logging import getLogger
import torch.nn as nn

from causally.start import TabDDPMdiff
from causally.utils.arguments import Torch_models
from causally.utils.utils import ensure_dir, get_local_time, early_stopping
from causally.trainer.AbstractTrainer import AbstractTrainer
import causally.start.diffusion as diff
import causally.start.process_edited as pce

import copy
from causally.start.TabDDPMdiff import loss_fn

class CARD_trainer(AbstractTrainer):

    def __init__(self, config, model, policy, ds, real_df, writer): 
        super(CARD_trainer, self).__init__(config, model, policy)
        self.ds = ds
        self.encoder = ds[4]
        self.decoder = ds[0]
        self.parser = ds[5]
        self.latent_features = ds[1]
        self.writer = writer
        self.real_df = real_df 
        self.logger = getLogger(config['logfilename'])
        self.learner = config['optimizer']
        self.learning_rate = config['learning_rate']

        self.epochs = config['epochs']
        self.eval_step = min(config['eval_step'], self.epochs)
        self.stopping_step = config['stopping_step']

        self.clip_grad_norm = config['clip_grad_norm']
        self.valid_metric_bigger = config['valid_metric_bigger']

        self.test_batch_size = config['eval_batch_size']
        self.device = config['device']

        self.checkpoint_dir = config['checkpoint_dir']
        print(self.checkpoint_dir)
        ensure_dir(self.checkpoint_dir)
        self.saved_name = '{}-{}-{}-{}-{}-{}'.format('RL',self.config['v_rl_algo'],self.config['start_order'],
                                                self.config['model'],self.config['dataset'],self.config['exp_info']) # , get_local_time()
        saved_model_file = self.saved_name+'.pth'
        self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)

        self.start_epoch = 0
        self.cur_step = 0
        self.best_valid_score = -1000000

        self.train_loss_dict = dict() # key: epoch, value: loss
        if self.config['model'] in Torch_models:
            self.optimizer = self._build_optimizer()

        self.df_learner = self.config['df_optimizer']
        self.df_learning_rate = self.config['df_learning_rate']
        self.df_optimizer = self._build_optimizer_df()

        self.writer_global_step = 0  
        self.writer_global_step_out_infer = 0  
        self.writer_global_step_out_policy = 0  
        self.writer_global_step_single_RL = 0  
        self.writer_global_step_single_SL = 0  
        self.pretrain_policy = copy.deepcopy(self.policy)

    def _build_optimizer(self):
        if self.learner.lower() == 'adam':
            optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'sgd':
            optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'adagrad':
            optimizer = optim.Adagrad(self.model.parameters(), lr=self.learning_rate)
        elif self.learner.lower() == 'rmsprop':
            optimizer = optim.RMSprop(self.model.parameters(), lr=self.learning_rate)
        else:
            self.logger.warning('Received unrecognized optimizer, set default Adam optimizer')
            optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer
    
    def _build_optimizer_df(self):
        if self.df_learner.lower() == 'adam':
            optimizer = optim.AdamW(self.policy.parameters(), lr=self.df_learning_rate)     
        elif self.df_learner.lower() == 'sgd':
            optimizer = optim.SGD(self.policy.parameters(), lr=self.df_learning_rate)
        elif self.df_learner.lower() == 'adagrad':
            optimizer = optim.Adagrad(self.policy.parameters(), lr=self.df_learning_rate)
        elif self.df_learner.lower() == 'rmsprop':
            optimizer = optim.RMSprop(self.policy.parameters(), lr=self.df_learning_rate)
        else:
            self.logger.warning('Received unrecognized optimizer, set default Adam optimizer')
            optimizer = optim.Adam(self.policy.parameters(), lr=self.df_learning_rate)
        return optimizer

    def _train_in_batch(self,train_data):
        total_loss = None
        total_policy_loss = None
        total_reward = None

        for batch_idx, (x,t,y,w) in enumerate(train_data):
            x,t,y,w = self.to_device(x,t,y,w)
            loss = self._inference_training_for_batch(x,t,y,w)
            reward, policy_loss = self._policy_training_for_batch(x,t,y,w, loss)

            total_loss = loss if total_loss is None else total_loss + loss
            total_reward = reward if total_reward is None else total_reward + reward
            total_policy_loss = policy_loss if total_policy_loss is None else total_policy_loss + policy_loss

        return total_loss, total_policy_loss, total_reward

    def _train_out_batch(self, train_data, epoch_idx):
        total_loss = None
        total_policy_loss = None
        total_reward = None

        for batch_idx, (x,t,y,w) in enumerate(train_data):
            x,t,y,w = self.to_device(x,t,y,w)
            loss = self._inference_training_for_batch(x,t,y,w)
            total_loss = loss if total_loss is None else total_loss + loss
 
        if epoch_idx % self.config['policy_K_for_train'] == 0: 
            for batch_idx, (x,t,y,w) in enumerate(train_data):
                x,t,y,w = self.to_device(x,t,y,w)
                reward, policy_loss = self._policy_training_for_batch(x,t,y,w,loss)
                total_reward = reward if total_reward is None else total_reward + reward
                total_policy_loss = policy_loss if total_policy_loss is None else total_policy_loss + policy_loss

        return total_loss, total_policy_loss, total_reward

    def _inference_training_for_batch(self, x,t,y,w):
        self.model.train()
        loss_func = self.model.calculate_loss 
        covariates = diff.Euler_Maruyama_sampling( self.policy, self.config['n_steps'], x.shape[0], x.shape[1], self.device) 

        u = torch.cat([x,covariates],dim=-1)
        self.optimizer.zero_grad()
        loss = loss_func(u,t,y,w)
        self._check_nan(loss)
        loss.backward()
        if self.clip_grad_norm:
            clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
        self.optimizer.step()
        return loss.item() 
        
    def _policy_training_for_batch(self, x, t, y, w, inference_loss):
        self.policy.train()
        if self.config['v_rl_algo'] == 'pg':
            reward, policy_loss = self._policy_gradient(x, t, y, w, inference_loss)

        return reward, policy_loss


    def _policy_gradient(self, x,t,y,w, inference_loss):
        covariates, x_seqs, log_probs, entropys, kl_divs = diff.Euler_Maruyama_sampling_for_RL(self.policy, self.pretrain_policy, self.config['n_steps'], x.shape[0], x.shape[1], self.device)

        reward = [] 
        for x_step in x_seqs:
            loss_reward = self.model.get_reward(torch.cat([x, x_step], dim=-1), t, y, w).reshape(-1)
            all_reward = loss_reward 
            reward.append(all_reward)

        returns = []
        gamma = 0.99
        discounted_reward = 0
        for r in reversed(reward):
            discounted_reward = r.squeeze() + gamma * discounted_reward
            returns.insert(0, discounted_reward)

        returns = torch.stack(returns).to(self.device)
        returns_diffusion_step_mean = returns.mean(dim=1).reshape(-1,1)
        returns_diffusion_step_std = returns.std(dim=1).reshape(-1,1)
        returns_normalized = (returns - returns_diffusion_step_mean) / (returns_diffusion_step_std + 1e-8)

        returns = []
        gamma = 0.99
        discounted_reward = 0
        for r in reversed(reward):
            discounted_reward = r.squeeze() + gamma * discounted_reward
            returns.insert(0, discounted_reward)

        returns = torch.stack(returns).to(self.device)
        returns_diffusion_step_mean = returns.mean(dim=1).reshape(-1,1)
        returns_diffusion_step_std = returns.std(dim=1).reshape(-1,1)
        returns_normalized = (returns - returns_diffusion_step_mean) / (returns_diffusion_step_std + 1e-8)

 
        batch_idx = random.choices(range(self.latent_features.shape[0]), k=x.shape[0])  ## Choose random indices 
        batch_X = self.latent_features[batch_idx,:]  
        loss_values = loss_fn(self.policy, batch_X, self.config['n_steps'], self.config['autodiff_eps'])
        mse_loss = torch.mean(loss_values)

        log_probs = torch.stack(log_probs).mean(dim=-1)
        policy_loss = -(log_probs * returns_normalized).mean()  
        sum_loss = policy_loss - self.config['pretrain_ratio'] * mse_loss 
        self.df_optimizer.zero_grad()
        self._check_nan(sum_loss)
        sum_loss.backward()  
        if self.clip_grad_norm:
            clip_grad_norm_(self.policy.parameters(), self.clip_grad_norm)
        self.df_optimizer.step()

        return torch.stack(reward).sum().item(), policy_loss.item()

    def _train_epoch(self, train_data, epoch_idx, loss_func=None):
        self.policy.train()
        self.model.train()
        n_samples = train_data.get_X_size()[0]
        # loss_func = loss_func or self.model.calculate_loss
        if self.config['v_batch_type'] == 'in_batch':
            total_loss, total_policy_loss, total_reward = self._train_in_batch(train_data)
        elif self.config['v_batch_type'] == 'out_batch':
            total_loss, total_policy_loss, total_reward = self._train_out_batch(train_data, epoch_idx)

        avg_inference_loss = total_loss / n_samples
        if self.config['is_record_tensorboard']:
            self.writer.add_scalar('order_{}/Epoch/Inference Loss'.format(self.config['start_order']), avg_inference_loss, epoch_idx)
        return avg_inference_loss

    def _save_checkpoint(self, epoch):

        state = {
            'config': self.config,
            'epoch': epoch,
            'cur_step': self.cur_step,
            'best_valid_score': self.best_valid_score,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'model': self.model,
            'policy_state_dict': self.policy.state_dict(),
            'policy_optimizer': self.df_optimizer.state_dict(),
            'policy_model': self.policy
        }
        torch.save(state, self.saved_model_file)


    def fit(self, train_data, valid_data=None, my_val_data = None, test_treated_data=None,test_control_data=None,
            verbose=True, saved=True):

        train_loss_list = []
        valid_loss_list = []

        for epoch_idx in range(self.start_epoch, self.epochs):
            # train
            training_start_time = time()
            train_loss = self._train_epoch(train_data, epoch_idx)
            self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
            training_end_time = time()
            train_loss_output = \
                self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
            if verbose:
                self.logger.info(train_loss_output)

            # eval
            if self.eval_step <= 0 or not valid_data:
                if saved:
                    self._save_checkpoint(epoch_idx)
                    update_output = 'Saving current: %s' % self.saved_model_file
                    if verbose:
                        self.logger.info(update_output)
                continue
            if (epoch_idx + 1) % self.eval_step == 0:
                if self.config['model'] == 'XNet' or self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                    update_flag = True
                    if epoch_idx > self.config['no_validation_epoch']:
                        stop_flag = True
                    else:
                        stop_flag = False
                else:
                    valid_start_time = time()
                    ate, pehe, yf_score = self._valid_epoch(valid_data, my_val_data)
                    if self.config['valid_metric_criterion'] == 'factual':
                        valid_score = -yf_score # -yf_score #
                    elif self.config['valid_metric_criterion'] == 'pehe':
                        valid_score = -pehe
                    self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
                        valid_score, self.best_valid_score, self.cur_step,
                        max_step=self.stopping_step, bigger=self.valid_metric_bigger)
                    valid_end_time = time()
                    valid_score_output = "epoch %d evaluating [time: %.2fs, valid_score: %f]" % \
                                        (epoch_idx, valid_end_time - valid_start_time, -valid_score)

                    train_loss_list.append(self.train_loss_dict[epoch_idx])
                    valid_loss_list.append(-valid_score)
                
                    if verbose:
                        self.logger.info(valid_score_output)

                if update_flag:
                    if saved:
                        self._save_checkpoint(epoch_idx)
                        update_output = 'Saving current best: %s' % self.saved_model_file
                        if verbose:
                            self.logger.info(update_output)

                if stop_flag:
                    stop_output = 'Finished training, best eval result in epoch %d' % \
                                  (epoch_idx - self.cur_step * self.eval_step)
                    if verbose:
                        self.logger.info(stop_output)
                    break
            
        return -self.best_valid_score


    @torch.no_grad()
    def _valid_epoch(self, val_data, my_val_data):
        self.policy.eval()
        self.model.eval()

        treated_data, control_data = val_data['treated'], val_data['control']
        n_samples = treated_data.get_X_size()[0]
        
        criterion = nn.MSELoss(reduction='sum')
        pehe2, true_mean_ate, pred_mean_ate = None, None, None

        for batch_treat, batch_control, batch_my_val in zip(treated_data, control_data, my_val_data):
            ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)
            treat_covariates = batch_treat[0].to(self.device)
            control_covariates = batch_control[0].to(self.device)

            true_treatment = batch_my_val[1].to(self.device)
            true_yf = batch_my_val[2].to(self.device)

            cov = self.generate_cov(gene_type='diffusion',N=treat_covariates.shape[0], P=treat_covariates.shape[1])
            treat_covariates = torch.cat([treat_covariates, cov], dim=-1) 
            control_covariates = torch.cat([control_covariates, cov], dim=-1)

            preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                    - self.model(control_covariates, batch_control[1].to(self.device))

            pred_mu0 = self.model(control_covariates, batch_control[1].to(self.device))
            pred_mu1 = self.model(treat_covariates, batch_treat[1].to(self.device))
            pred_yf = torch.where(true_treatment == 1, pred_mu1, pred_mu0)
            yf_score = criterion(pred_yf, true_yf)

            hat_ite = criterion(preds, ground_truth)
            pehe2 = hat_ite if pehe2 is None else pehe2 + hat_ite

            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate + hat_ate

            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

        pehe2 = pehe2 / n_samples
        pred_mean_ate = pred_mean_ate / n_samples
        true_mean_ate = true_mean_ate / n_samples
        yf_score = yf_score / n_samples

        ate = torch.abs(pred_mean_ate - true_mean_ate).item()
        pehe = np.sqrt(pehe2.item())

        return ate, pehe, yf_score

    @torch.no_grad()
    def evaluate(self, treated_data,control_data, load_best_model=True, model_file=None):
        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file, weights_only=False)
            self.model.load_state_dict(checkpoint['state_dict'])
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        self.model.eval()
        self.policy.eval()

        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2,true_mean_ate,pred_mean_ate = None,None,None

        for batch_treat,batch_control in zip(treated_data, control_data):
            ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)

            treat_covariates = batch_treat[0].to(self.device)     
            control_covariates = batch_control[0].to(self.device) 
            cov = self.generate_cov(gene_type='diffusion',N=treat_covariates.shape[0], P=treat_covariates.shape[1])
            treat_covariates = torch.cat([treat_covariates, cov], dim=-1)
            control_covariates = torch.cat([control_covariates, cov], dim=-1)

            if self.config['model'] == 'XNet':
                y1, y0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                preds = ps * y0 + (1 - ps) * y1
            elif self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                preds = self.model(treat_covariates, batch_treat[1].to(self.device))
            else:
                preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                        - self.model(control_covariates, batch_control[1].to(self.device))

            hat_ite = criterion(preds.reshape(-1),ground_truth.reshape(-1))
            pehe2 = hat_ite if pehe2 is None else pehe2+hat_ite

            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate+hat_ate
            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate


        pehe2 = pehe2 / n_samples
        pred_mean_ate = pred_mean_ate / n_samples
        true_mean_ate = true_mean_ate / n_samples

        ate = torch.abs(pred_mean_ate-true_mean_ate).item()
        pehe = np.sqrt(pehe2.item())

        return {'pehe': pehe, 'ate': ate}

    @torch.no_grad()
    def adversarial_measure_evaluate(self, treated_data, control_data, measure_ratio, load_best_model=True, model_file=None):
        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file,weights_only=False)#,weights_only=True)
            self.model.load_state_dict(checkpoint['state_dict'])
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        self.model.eval()
        self.policy.eval()

        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2,true_mean_ate,pred_mean_ate = None,None,None

        res = {}
        measure = measure_ratio
        for mea in measure:
            for batch_treat,batch_control in zip(treated_data, control_data):
                ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)

                covariates_perturbed = self._get_measure_data(batch_treat[0], mea)
                treat_covariates = batch_treat[0].to(self.device)  + covariates_perturbed
                control_covariates = batch_control[0].to(self.device) + covariates_perturbed

                cov = self.generate_cov(gene_type='diffusion',N=treat_covariates.shape[0], P=treat_covariates.shape[1])
                treat_covariates = torch.cat([treat_covariates, cov], dim=-1)
                control_covariates = torch.cat([control_covariates, cov], dim=-1)

                if self.config['model'] == 'XNet':
                    y1, y0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                    preds = ps * y0 + (1 - ps) * y1
                elif self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device))
                else:
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                            - self.model(control_covariates, batch_control[1].to(self.device))

                hat_ite = criterion(preds.reshape(-1),ground_truth.reshape(-1))
                pehe2 = hat_ite if pehe2 is None else pehe2+hat_ite

            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate+hat_ate
            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

            pehe2 = pehe2 / n_samples
            pred_mean_ate = pred_mean_ate / n_samples
            true_mean_ate = true_mean_ate / n_samples

            ate = torch.abs(pred_mean_ate-true_mean_ate).item()
            pehe = np.sqrt(pehe2.item())
    
            res['ate_{}'.format(mea)] = ate
            res['pehe_{}'.format(mea)] = pehe

        return res

    @torch.no_grad()
    def adversarial_missing_evaluate(self, treated_data, control_data, missing_ratio, load_best_model=True, model_file=None):
        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file,weights_only=False)#,weights_only=True)
            self.model.load_state_dict(checkpoint['state_dict'])
            self.policy.load_state_dict(checkpoint['policy_state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        self.model.eval()
        self.policy.eval()

        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2,true_mean_ate,pred_mean_ate = None,None,None

        res = {}
        missing = missing_ratio

        for mea in missing:
            for batch_treat,batch_control in zip(treated_data, control_data):
                ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)

                covariates = self._get_missing_data(batch_treat[0], mea)
                treat_covariates = covariates.to(self.device)  
                control_covariates = covariates.to(self.device) 

                cov = self.generate_cov(gene_type='diffusion',N=treat_covariates.shape[0], P=treat_covariates.shape[1])
                treat_covariates = torch.cat([treat_covariates, cov], dim=-1)
                control_covariates = torch.cat([control_covariates, cov], dim=-1)

                if self.config['model'] == 'XNet':
                    y1, y0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                    preds = ps * y0 + (1 - ps) * y1
                elif self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device))
                else:
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                            - self.model(control_covariates, batch_control[1].to(self.device))

                hat_ite = criterion(preds.reshape(-1),ground_truth.reshape(-1))
                pehe2 = hat_ite if pehe2 is None else pehe2+hat_ite


            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate+hat_ate
            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

            pehe2 = pehe2 / n_samples
            pred_mean_ate = pred_mean_ate / n_samples
            true_mean_ate = true_mean_ate / n_samples

            ate = torch.abs(pred_mean_ate-true_mean_ate).item()
            pehe = np.sqrt(pehe2.item())

            res['ate_{}'.format(mea)] = ate
            res['pehe_{}'.format(mea)] = pehe

        return res

    @torch.no_grad()
    def adversarial_hidden_evaluate(self, treated_data, control_data, hidden_ratio, load_best_model=True, model_file=None):
        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file,weights_only=False)#,weights_only=True)
            self.model.load_state_dict(checkpoint['state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)
        self.model.eval()
        self.policy.eval()

        n_samples = treated_data.get_X_size()[0]

        criterion = nn.MSELoss(reduction='sum')
        pehe2,true_mean_ate,pred_mean_ate = None,None,None

        res = {}
        hidden = hidden_ratio
        for mea in hidden:
            for batch_treat,batch_control in zip(treated_data, control_data):
                ground_truth = batch_treat[2].to(self.device) - batch_control[2].to(self.device)

                # hidden groud_truth 
                ground_truth = self._get_hidden_data(batch_treat[0], mea)

                treat_covariates = batch_treat[0].to(self.device) 
                control_covariates = batch_control[0].to(self.device) 

                cov = self.generate_cov(gene_type='diffusion',N=treat_covariates.shape[0], P=treat_covariates.shape[1])
                treat_covariates = torch.cat([treat_covariates, cov], dim=-1)
                control_covariates = torch.cat([control_covariates, cov], dim=-1)

                if self.config['model'] == 'XNet':
                    y1, y0, ps = self.model(treat_covariates, batch_treat[1].to(self.device))
                    preds = ps * y0 + (1 - ps) * y1
                elif self.config['model'] == 'RNet' or self.config['model'] == 'DRNet':
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device))
                else:
                    preds = self.model(treat_covariates, batch_treat[1].to(self.device)) \
                            - self.model(control_covariates, batch_control[1].to(self.device))

                hat_ite = criterion(preds.reshape(-1),ground_truth.reshape(-1))
                pehe2 = hat_ite if pehe2 is None else pehe2+hat_ite


            hat_ate = sum(preds)
            pred_mean_ate = hat_ate if pred_mean_ate is None else pred_mean_ate+hat_ate
            true_ate = sum(ground_truth)
            true_mean_ate = true_ate if true_mean_ate is None else true_mean_ate + true_ate

            pehe2 = pehe2 / n_samples
            pred_mean_ate = pred_mean_ate / n_samples
            true_mean_ate = true_mean_ate / n_samples


            ate = torch.abs(pred_mean_ate-true_mean_ate).item()
            pehe = np.sqrt(pehe2.item())
            res['ate_{}'.format(mea)] = ate
            res['pehe_{}'.format(mea)] = pehe

        return res
