import os
import matplotlib
import scipy.stats

matplotlib.use('AGG')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from betty.engine import Engine
from betty.configs import Config, EngineConfig
from betty.problems import ImplicitProblem
import numpy as np

from betty.engine import Engine
from betty.configs import Config, EngineConfig
from betty.problems import ImplicitProblem
from networks import DiscreteTreatmentChildNet, DiscreteTreatmentParentNet, ContinuousTreatmentChildNet, ContinuousTreatmentParentNet
from metrics import compute_pehe_score, compute_policy_risk, compute_ate
from utils import PolyakAdam, ReduceLROnPlateau, save_model, load_model, load_dict, save_dict, InfiniteIterator
from torch.utils.data.sampler import WeightedRandomSampler


def _make_balanced_sampler(labels):
    class_counts = np.bincount(labels)
    class_weights = 1. / class_counts
    weights = class_weights[labels]
    return WeightedRandomSampler(weights, len(weights))



def pin_ball_loss(eps, tau):
    loss = (eps>=0).float() * tau * eps + (eps<0).float()*(tau-1)*eps
    return loss


class EarlyStopEngine(Engine):
    def __init__(self, problems, config=None, dependencies=None, env=None, early_stop=10):
        super().__init__(problems, config, dependencies, env)
        self.min_criterion = 1e99
        self.patience = 0
        self.early_stop = early_stop

    def run(self):
        """
		Execute multilevel optimization by running gradient descent for leaf problems.
		"""
        self.train()
        for it in range(1, self.train_iters + 1):
            self.global_step += 1
            self.train_step()
            min_criterion = self.problems[-1].min_criterion
            if min_criterion < self.min_criterion:
                self.min_criterion = min_criterion
                self.patience = 0
            else:
                self.patience += 1
            if self.patience > self.early_stop:
                break


class QuantileRegressor(ImplicitProblem):
    def training_step(self, batch):
        covariate, treatment, targets = batch
        covariate, treatment, targets = covariate.to(self.device), treatment.to(self.device), targets.to(self.device)
        # we have to minimize the pin-ball loss using training samples for every tau in testing loader
        # In our case, training and testing loaders are same.
        # But sometimes we may have different loaders, e.g., we may only have one interested sample
        test_covariate, test_treatment, test_targets = next(self.test_iter)
        tau = self.outer(test_covariate.to(self.device), test_treatment.to(self.device), test_targets.to(self.device))
        # during training, we use all training for every tau
        flat_covariate = covariate.repeat(len(tau), 1)
        flat_treatment = treatment.repeat(len(tau), 1)
        flat_targets = targets.repeat(len(tau), 1)
        flat_taus = tau.repeat_interleave(len(covariate), dim=0).unsqueeze(-1)
        assert len(flat_covariate) == len(flat_taus)
        outs = self.module(flat_covariate, flat_treatment, flat_taus)
        eps = flat_targets - outs
        loss = (pin_ball_loss(eps, flat_taus))
        loss = (loss).mean()
        self.num_step += 1
        if self.num_step % self.log_iter == 0:
            print('[%d] loss: %.4f' % (self.num_step, loss.item()))
        return loss

    @torch.no_grad()
    def forward_loader(self, loader):
        others = []
        yf_hat = []
        ycf_hat = []
        ts = []
        ys = []
        taus = []
        self.module.eval()
        self.outer.module.eval()
        for batch in loader:
            covariate, treatment, targets = batch[:3]
            covariate, treatment, targets = covariate.to(self.device), treatment.to(self.device), targets.to(self.device)
            if len(batch)>3:
                others.append(torch.cat(batch[3:], dim=1))
            tau = self.outer.forward(covariate, treatment, targets).unsqueeze(-1)
            yf_hat_valid = self.module(covariate, treatment, tau).detach()
            ycf_hat_valid = self.module(covariate, self.cf_bound - treatment, tau).detach()
            yf_hat.append(yf_hat_valid)
            ycf_hat.append(ycf_hat_valid)
            ts.append(treatment)
            ys.append(targets)
            taus.append(tau)
        yf_hat = torch.cat(yf_hat, 0)
        ycf_hat = torch.cat(ycf_hat, 0)
        ts = torch.cat(ts, 0)
        ys = torch.cat(ys, 0)
        taus = torch.cat(taus, 0)
        if len(others) >0:
            others = torch.cat(others, 0).cpu()
        result_dict = {'yf_hat': yf_hat.cpu(), 'ycf_hat': ycf_hat.cpu(), 'ts': ts.cpu(), 'ys': ys.cpu(), 'others': others, 'tau': taus,
                       }
        self.module.train()
        self.outer.module.train()
        return result_dict

    @torch.no_grad()
    def compute_metric(self, yf_hat, ycf_hat, ts, ys, others, val_metric):
        if val_metric == 'pehe' and len(others) != 0:
            mu0, mu1 = others.chunk(2, dim=1)
            pehe = compute_pehe_score(yf_hat, ycf_hat, ts, mu0, mu1)
            return pehe
        elif val_metric == 'ate' and len(others) != 0:
            mu0, mu1 = others.chunk(2, dim=1)
            ate = compute_ate(yf_hat, ycf_hat, ts, mu0, mu1)
            return ate
        elif val_metric == 'rpol' and len(others) != 0:
            e = others
            rpol = compute_policy_risk(yf_hat, ycf_hat, ts, e, ys)
            return rpol
        elif val_metric == 'mae':
            error = torch.mean(torch.abs(yf_hat-ys))
            return error
        elif val_metric == 'mse':
            error = torch.mean((yf_hat-ys)**2)
            return error
        elif val_metric == 'rmse':
            error = torch.mean((yf_hat-ys)**2)
            return torch.sqrt(error)
        elif val_metric == 'cf_mae':
            ycf = others
            assert ycf_hat.size() == ycf.size()
            error = torch.mean(torch.abs(ycf_hat-ycf))
            return error
        elif val_metric == 'cf_mse':
            ycf = others
            error = torch.mean((ycf_hat-ycf)**2)
            return error
        elif val_metric == 'cf_rmse':
            ycf = others
            error = torch.mean((ycf_hat-ycf)**2)
            return torch.sqrt(error)
        elif val_metric == 'time':
            return -self.num_step

    def optimizer_step(self, *args, **kwargs):
        if self.is_implemented("custom_optimizer_step"):
            assert (
                not self._is_default_fp16()
            ), "[!] FP16 training is not supported for custom optimizer step."
            if self.gradient_clipping > 0.0:
                self.clip_grad()
            self.custom_optimizer_step(*args, **kwargs)
        else:
            if self._is_default_fp16():
                if self.gradient_clipping > 0.0:
                    self.scaler.unscale_(self.optimizer)
                    self.clip_grad()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                if self.gradient_clipping > 0.0:
                    self.clip_grad()
                self.optimizer.step()


        if self.num_step % 1 == 0:
            # using EMA models
            self.optimizer.swap()
            result_dict = self.forward_loader(self.val_loader)
            cur_ema_metric = self.compute_metric(result_dict['yf_hat'], result_dict['ycf_hat'], result_dict['ts'], result_dict['ys'], result_dict['others'],
                                             self.val_metric)

            test_result_dict = self.forward_loader(self.true_test_loader)
            test_cur_ema_metric = self.compute_metric(test_result_dict['yf_hat'], test_result_dict['ycf_hat'],
                                                  test_result_dict['ts'], test_result_dict['ys'], test_result_dict['others'],
                                                  self.val_metric)
            #print('>>>> %s: %.4f test: %.4f <<<<' % (self.val_metric, cur_metric, test_cur_metric))
            self.optimizer.swap()


            result_dict = self.forward_loader(self.val_loader)
            cur_metric = self.compute_metric(result_dict['yf_hat'], result_dict['ycf_hat'], result_dict['ts'], result_dict['ys'], result_dict['others'],
                                             self.val_metric)

            test_result_dict = self.forward_loader(self.true_test_loader)
            test_cur_metric = self.compute_metric(test_result_dict['yf_hat'], test_result_dict['ycf_hat'],
                                                  test_result_dict['ts'], test_result_dict['ys'], test_result_dict['others'],
                                                  self.val_metric)

            cur_min_metric = min(cur_metric, cur_ema_metric)
            if cur_min_metric<=self.min_criterion:
                self.min_criterion = cur_min_metric
                save_model(self.outer.module, self.module, self.outer.optimizer, self.optimizer, self.checkpoint_path)()


class QuantileEstimator(ImplicitProblem):
    def training_step(self, batch):
        covariate, treatment, targets = batch[:3]
        covariate, treatment, targets = covariate.to(self.device), treatment.to(self.device), targets.to(self.device)
        taus = self.module.forward(covariate, treatment, targets).unsqueeze(-1)
        outs = self.inner.module.forward(covariate, treatment, taus)
        if self.loss_type == 'l1':
            loss = (torch.abs(outs - targets))
        elif self.loss_type == 'mse':
            loss = ((outs-targets)**2)
        elif self.loss_type == 'log':
            outs = 0.995 / (1.0 + torch.exp(-outs)) + 0.0025
            loss = targets * torch.log(outs) + (1.0 - targets) * torch.log(1.0 - outs)
        loss = torch.mean(loss)
        self.loss = loss.item()
        self.num_step += 1
        return loss




class ITETrainer(nn.Module):
    def __init__(self, dataset_dict, config, device, run_id):
        super().__init__()
        self.device = device
        self.run_id = run_id
        self.setup_child(dataset_dict, config)
        self.setup_parent(dataset_dict, config)
        self.setup_bilevel(config)

    def setup_child(self, dataset_dict, config):
        child_config = config.model.child
        child_problem_config = Config(type="darts", unroll_steps=config.inner_iters)
        self.val_metric = config.val_metric
        # prepare data
        train_dataset = torch.utils.data.TensorDataset(dataset_dict['x_train'], dataset_dict['t_train'], dataset_dict['y_f_train'])
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_train'])),
                                                   shuffle=True,
                                                   drop_last=True)

        tensors = [dataset_dict['x_test'], dataset_dict['t_test'], dataset_dict['y_f_test']]
        if 'mu0_test' in dataset_dict and 'mu1_test' in dataset_dict and self.val_metric in ['pehe', 'ate']:
            tensors += [dataset_dict['mu0_test'], dataset_dict['mu1_test']]
        elif 'e_test' in dataset_dict and self.val_metric == 'rpol':
            tensors += [dataset_dict['e_test']]

        true_test_dataset = torch.utils.data.TensorDataset(*tensors
                                                           )
        true_test_loader = torch.utils.data.DataLoader(true_test_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_test'])), shuffle=False,
                                                       drop_last=False)

        tensors = [dataset_dict['x_valid'], dataset_dict['t_valid'], dataset_dict['y_f_valid']]
        self.val_metric = config.val_metric
        if 'mu0_valid' in dataset_dict and 'mu1_valid' in dataset_dict and self.val_metric in ['pehe', 'ate']:
            tensors += [dataset_dict['mu0_valid'], dataset_dict['mu1_valid']]
        elif 'e_valid' in dataset_dict and self.val_metric == 'rpol':
            tensors += [dataset_dict['e_valid']]

        val_dataset = torch.utils.data.TensorDataset(*tensors
                                                     )
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_valid'])), shuffle=False,
                                                 drop_last=False)



        if config.data.discrete_treatment:
            print('>>>> Using Discrete Child Treatment <<<<')
            self.child_net = DiscreteTreatmentChildNet(child_config, config.data)
        else:
            print('>>>> Using Continuous Child Treatment <<<<')
            self.child_net = ContinuousTreatmentChildNet(child_config, config.data)

        self.child_optim = PolyakAdam(self.child_net.parameters(), lr=child_config.lr, weight_decay=child_config.weight_decay, amsgrad=False, polyak=child_config.polyak)
        self.child_problem = QuantileRegressor(name='inner', module=self.child_net,
                                               optimizer=self.child_optim, train_data_loader=train_loader,
                                               config=child_problem_config)
        self.child_problem.num_step = 0
        self.child_problem.log_iter = child_config.log_iter
        self.child_problem.device = self.device
        self.child_problem.init_lr = child_config.lr
        self.child_problem.true_test_loader = true_test_loader
        self.child_problem.val_loader = val_loader
        os.makedirs(config.checkpoint_dir, exist_ok=True)
        self.checkpoint_dir  = config.checkpoint_dir
        self.child_problem.checkpoint_path = os.path.join(config.checkpoint_dir, 'exp%04d_run%03d_best_model.pth' % (config.exp_num, self.run_id))
        self.exp_num = config.exp_num
        self.child_problem.cf_bound = config.data.cf_bound
        self.child_problem.val_metric = self.val_metric
        self.child_problem.min_criterion = 1e99
        self.child_problem.train_iters = config.train_iters
        self.child_problem.patience = 0

    def setup_parent(self, dataset_dict, config):
        """
        In ITE, we use training samples as samples in the higher-level problem
        But we use validation samples to perform model selection.
        """

        parent_config = config.model.parent
        parent_config.batch_size = int(parent_config.batch_size)
        parent_problem_config = Config(log_step=parent_config.log_iter, first_order=True, retain_graph=True)



        tensors = [dataset_dict['x_train'], dataset_dict['t_train'], dataset_dict['y_f_train']]
        test_dataset = torch.utils.data.TensorDataset(*tensors
                                                     )

        if config.data.discrete_treatment:
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=min(parent_config.batch_size, len(dataset_dict['x_train'])),
                                                  shuffle=False,
                                                 drop_last=True,
                                                  sampler=_make_balanced_sampler(dataset_dict['t_train'].long().numpy().reshape(-1)))
        else:
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=min(int(config.model.child.batch_size),
                                                                                       len(dataset_dict['x_train'])),
                                                          shuffle=True,
                                                          drop_last=True)

        tmp_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=min(int(config.model.child.batch_size), len(dataset_dict['x_train'])), shuffle=True,
                                                  drop_last=True)
        self.child_problem.test_iter = InfiniteIterator(tmp_test_loader)
        print('Parent>>>> ', len(test_dataset), len(test_dataset))


        if config.data.discrete_treatment:
            print('>>>> Using Discrete Treatment <<<<')
            self.parent_net = DiscreteTreatmentParentNet(parent_config, config.data)
        else:
            print('>>>> Using Continuous Treatment <<<<')
            self.parent_net = ContinuousTreatmentParentNet(parent_config, config.data)
        self.parent_optim = PolyakAdam(self.parent_net.parameters(), lr=parent_config.lr,
                                       weight_decay=parent_config.weight_decay, amsgrad=False, polyak=parent_config.polyak)

        self.parent_problem = QuantileEstimator(name='outer', module=self.parent_net,
                                        optimizer=self.parent_optim, train_data_loader=test_loader,
                                        config=parent_problem_config)
        self.parent_problem.num_step = 0
        self.parent_problem.loss_type = parent_config.loss_type
        self.parent_problem.device = self.device



    def setup_bilevel(self, config):

        engine_config = EngineConfig(train_iters=config.train_iters * config.inner_iters, logger_type="none")

        problems = [self.parent_problem, self.child_problem]
        u2l = {self.parent_problem: [self.child_problem]}
        l2u = {self.child_problem: [self.parent_problem]}
        dependencies = {"l2u": l2u, "u2l": u2l}

        self.engine = EarlyStopEngine(config=engine_config, problems=problems, dependencies=dependencies,
                                      early_stop=config.early_stop)

    def run(self):
        self.engine.run()

    def evaluate_and_record(self, config, dataset_dict, metrics, run_id):
        tensors = [dataset_dict['x_train'], dataset_dict['t_train'], dataset_dict['y_f_train']]
        if 'mu0_train' in dataset_dict and 'mu1_train' in dataset_dict and ('pehe' in metrics or 'ate' in metrics):
            tensors += [dataset_dict['mu0_train'], dataset_dict['mu1_train']]
        elif 'e_train' in dataset_dict and self.val_metric == 'rpol':
            tensors += [dataset_dict['e_train']]
        true_train_dataset = torch.utils.data.TensorDataset(*tensors
                                                           )
        true_train_loader = torch.utils.data.DataLoader(true_train_dataset, batch_size=1, shuffle=False)


        tensors = [dataset_dict['x_test'], dataset_dict['t_test'], dataset_dict['y_f_test']]
        if 'mu0_test' in dataset_dict and 'mu1_test' in dataset_dict and ('pehe' in metrics or 'ate' in metrics):
            tensors += [dataset_dict['mu0_test'], dataset_dict['mu1_test']]
        elif 'e_test' in dataset_dict and self.val_metric == 'rpol':
            tensors += [dataset_dict['e_test']]

        true_test_dataset = torch.utils.data.TensorDataset(*tensors
                                                           )
        true_test_loader = torch.utils.data.DataLoader(true_test_dataset, batch_size=1, shuffle=False)
        load_model(self.parent_problem.module, self.child_problem.module, self.parent_problem.optimizer, self.child_problem.optimizer,
                                     self.child_problem.checkpoint_path)()

        self.child_problem.optimizer.swap()
        ema_result_dict = self.child_problem.forward_loader(true_test_loader)
        ema_train_result_dict = self.child_problem.forward_loader(true_train_loader)
        ema_metric_dict = {}
        ema_train_metric_dict = {}
        for met in metrics:
            ema_metric_dict[met] = self.child_problem.compute_metric(ema_result_dict['yf_hat'], ema_result_dict['ycf_hat'],
                                                                  ema_result_dict['ts'], ema_result_dict['ys'], ema_result_dict['others'], met)
        for met in metrics:
            ema_train_metric_dict[met] = self.child_problem.compute_metric(ema_train_result_dict['yf_hat'], ema_train_result_dict['ycf_hat'],
                                                                 ema_train_result_dict['ts'], ema_train_result_dict['ys'], ema_train_result_dict['others'], met)

        ema_val_result_dict = self.child_problem.forward_loader(self.child_problem.val_loader)
        ema_metric_dict['selection_criterion'] = self.child_problem.compute_metric(ema_val_result_dict['yf_hat'], ema_val_result_dict['ycf_hat'],
                                                              ema_val_result_dict['ts'], ema_val_result_dict['ys'],
                                                              ema_val_result_dict['others'], self.val_metric)
        self.child_problem.optimizer.swap()

        result_dict = self.child_problem.forward_loader(true_test_loader)
        train_result_dict = self.child_problem.forward_loader(true_train_loader)
        metric_dict = {}
        train_metric_dict = {}
        for met in metrics:
            metric_dict[met] = self.child_problem.compute_metric(result_dict['yf_hat'], result_dict['ycf_hat'],
                                                                 result_dict['ts'], result_dict['ys'], result_dict['others'], met)
        for met in metrics:
            train_metric_dict[met] = self.child_problem.compute_metric(train_result_dict['yf_hat'], train_result_dict['ycf_hat'],
                                                                       train_result_dict['ts'], train_result_dict['ys'], train_result_dict['others'], met)

        val_result_dict = self.child_problem.forward_loader(self.child_problem.val_loader)
        metric_dict['selection_criterion'] = self.child_problem.compute_metric(val_result_dict['yf_hat'], val_result_dict['ycf_hat'],
                                                                               val_result_dict['ts'], val_result_dict['ys'],
                                                                               val_result_dict['others'],
                                                                               self.val_metric)

        if metric_dict['selection_criterion']>=ema_metric_dict['selection_criterion']:
            result_dict = ema_result_dict
            train_metric_dict = ema_train_metric_dict
            train_result_dict = ema_train_result_dict
            metric_dict = ema_metric_dict

        for k in train_metric_dict:
            result_dict['In-%s' % k] = train_metric_dict[k]
        for k in train_result_dict:
            result_dict['In-%s' % k] = train_result_dict[k]
        result_dict.update(metric_dict)
        result_dict['config'] = config
        print(metric_dict, train_metric_dict)
        save_dict(result_dict, os.path.join(self.checkpoint_dir, 'exp%04d_run%03d_test_result.pkl' % (self.exp_num, run_id)))

class CFRTrainer(ITETrainer):
    def __init__(self, dataset_dict, config, device, run_id):
        super().__init__(dataset_dict, config, device, run_id)

    def setup_child(self, dataset_dict, config):
        child_config = config.model.child
        child_problem_config = Config(type="darts", unroll_steps=config.inner_iters)
        self.val_metric = config.val_metric
        # prepare data
        train_dataset = torch.utils.data.TensorDataset(dataset_dict['x_train'], dataset_dict['t_train'], dataset_dict['y_f_train'])
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_train'])),
                                                   shuffle=True,
                                                   drop_last=True)

        tensors = [dataset_dict['x_test'], dataset_dict['t_test'], dataset_dict['y_f_test'], dataset_dict['y_cf_test']]

        true_test_dataset = torch.utils.data.TensorDataset(*tensors
                                                           )
        true_test_loader = torch.utils.data.DataLoader(true_test_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_train'])), shuffle=False,
                                                       drop_last=False)

        print('Child>>>> ', len(train_dataset), len(true_test_dataset))
        tensors = [dataset_dict['x_train'], dataset_dict['t_train'], dataset_dict['y_f_train'], dataset_dict['y_cf_train']]
        self.val_metric = config.val_metric
        val_dataset = torch.utils.data.TensorDataset(*tensors
                                                     )
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_train'])), shuffle=False,
                                                 drop_last=False)

        print('>>>>> CFR Dataset', len(train_dataset), len(val_dataset))
        if config.data.discrete_treatment:
            print('>>>> Using Discrete Child Treatment <<<<')
            self.child_net = DiscreteTreatmentChildNet(child_config, config.data)
        else:
            print('>>>> Using Continuous Child Treatment <<<<')
            self.child_net = ContinuousTreatmentChildNet(child_config, config.data)

        self.child_optim = PolyakAdam(self.child_net.parameters(), lr=child_config.lr, weight_decay=child_config.weight_decay, amsgrad=False, polyak=child_config.polyak)
        self.child_problem = QuantileRegressor(name='inner', module=self.child_net,
                                               optimizer=self.child_optim, train_data_loader=train_loader,
                                               config=child_problem_config)
        self.child_problem.num_step = 0
        self.child_problem.log_iter = child_config.log_iter
        self.child_problem.device = self.device
        self.child_problem.init_lr = child_config.lr
        self.child_problem.true_test_loader = true_test_loader
        self.child_problem.val_loader = val_loader
        os.makedirs(config.checkpoint_dir, exist_ok=True)
        self.checkpoint_dir  = config.checkpoint_dir
        self.child_problem.checkpoint_path = os.path.join(config.checkpoint_dir, 'exp%04d_run%03d_best_model.pth' % (config.exp_num, self.run_id))
        self.exp_num = config.exp_num
        self.child_problem.cf_bound = config.data.cf_bound
        self.child_problem.val_metric = self.val_metric
        self.child_problem.min_criterion = 1e99
        self.child_problem.train_iters = config.train_iters
        self.child_problem.patience = 0

    def evaluate_and_record(self, config, dataset_dict, metrics, run_id):
        tensors = [dataset_dict['x_train'], dataset_dict['t_train'], dataset_dict['y_f_train'], dataset_dict['y_cf_train']]
        true_train_dataset = torch.utils.data.TensorDataset(*tensors
                                                            )
        true_train_loader = torch.utils.data.DataLoader(true_train_dataset, batch_size=1, shuffle=False)


        tensors = [dataset_dict['x_test'], dataset_dict['t_test'], dataset_dict['y_f_test'], dataset_dict['y_cf_test']]
        true_test_dataset = torch.utils.data.TensorDataset(*tensors
                                                            )
        true_test_loader = torch.utils.data.DataLoader(true_test_dataset, batch_size=1, shuffle=False)
        load_model(self.parent_problem.module, self.child_problem.module, self.parent_problem.optimizer, self.child_problem.optimizer,
                   self.child_problem.checkpoint_path)()

        self.child_problem.optimizer.swap()
        ema_result_dict = self.child_problem.forward_loader(true_test_loader)
        ema_train_result_dict = self.child_problem.forward_loader(true_train_loader)
        ema_metric_dict = {}
        ema_train_metric_dict = {}
        for met in metrics:
            ema_metric_dict[met] = self.child_problem.compute_metric(ema_result_dict['yf_hat'], ema_result_dict['ycf_hat'],
                                                                     ema_result_dict['ts'], ema_result_dict['ys'], ema_result_dict['others'], met)
        for met in metrics:
            ema_train_metric_dict[met] = self.child_problem.compute_metric(ema_train_result_dict['yf_hat'], ema_train_result_dict['ycf_hat'],
                                                                           ema_train_result_dict['ts'], ema_train_result_dict['ys'], ema_train_result_dict['others'], met)

        ema_val_result_dict = self.child_problem.forward_loader(self.child_problem.val_loader)
        ema_metric_dict['selection_criterion'] = self.child_problem.compute_metric(ema_val_result_dict['yf_hat'], ema_val_result_dict['ycf_hat'],
                                                                                   ema_val_result_dict['ts'], ema_val_result_dict['ys'],
                                                                                   ema_val_result_dict['others'], self.val_metric)
        self.child_problem.optimizer.swap()

        result_dict = self.child_problem.forward_loader(true_test_loader)
        train_result_dict = self.child_problem.forward_loader(true_train_loader)
        metric_dict = {}
        train_metric_dict = {}
        for met in metrics:
            metric_dict[met] = self.child_problem.compute_metric(result_dict['yf_hat'], result_dict['ycf_hat'],
                                                                 result_dict['ts'], result_dict['ys'], result_dict['others'], met)
        for met in metrics:
            train_metric_dict[met] = self.child_problem.compute_metric(train_result_dict['yf_hat'], train_result_dict['ycf_hat'],
                                                                       train_result_dict['ts'], train_result_dict['ys'], train_result_dict['others'], met)

        val_result_dict = self.child_problem.forward_loader(self.child_problem.val_loader)
        metric_dict['selection_criterion'] = self.child_problem.compute_metric(val_result_dict['yf_hat'], val_result_dict['ycf_hat'],
                                                                               val_result_dict['ts'], val_result_dict['ys'],
                                                                               val_result_dict['others'],
                                                                               self.val_metric)

        if metric_dict['selection_criterion']>=ema_metric_dict['selection_criterion']:
            result_dict = ema_result_dict
            train_metric_dict = ema_train_metric_dict
            train_result_dict = ema_train_result_dict
            metric_dict = ema_metric_dict

        for k in train_metric_dict:
            result_dict['In-%s' % k] = train_metric_dict[k]
        for k in train_result_dict:
            result_dict['In-%s' % k] = train_result_dict[k]
        result_dict.update(metric_dict)
        result_dict['config'] = config
        print(metric_dict, train_metric_dict)
        save_dict(result_dict, os.path.join(self.checkpoint_dir, 'exp%04d_run%03d_test_result.pkl' % (self.exp_num, run_id)))

class Trainer(ITETrainer):
    def __init__(self, dataset_dict, config, device, run_id):
        super().__init__(dataset_dict, config, device, run_id)

    def setup_child(self, dataset_dict, config):
        child_config = config.model.child
        child_problem_config = Config(type="darts", unroll_steps=config.inner_iters)
        self.val_metric = config.val_metric
        # prepare data
        train_dataset = torch.utils.data.TensorDataset(dataset_dict['x_train'], dataset_dict['t_train'], dataset_dict['y_f_train'])
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_train'])),
                                                   shuffle=True,
                                                   drop_last=True)

        tensors = [dataset_dict['x_test'], dataset_dict['t_test'], dataset_dict['y_f_test']]

        true_test_dataset = torch.utils.data.TensorDataset(*tensors
                                                           )
        true_test_loader = torch.utils.data.DataLoader(true_test_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_train'])), shuffle=False,
                                                       drop_last=False)

        tensors = [dataset_dict['x_test'], dataset_dict['t_test'], dataset_dict['y_f_test']]
        self.val_metric = config.val_metric
        val_dataset = torch.utils.data.TensorDataset(*tensors
                                                     )
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=min(int(child_config.batch_size), len(dataset_dict['x_train'])), shuffle=False,
                                                 drop_last=False)


        if config.data.discrete_treatment:
            print('>>>> Using Discrete Child Treatment <<<<')
            self.child_net = DiscreteTreatmentChildNet(child_config, config.data)
        else:
            print('>>>> Using Continuous Child Treatment <<<<')
            self.child_net = ContinuousTreatmentChildNet(child_config, config.data)

        self.child_optim = PolyakAdam(self.child_net.parameters(), lr=child_config.lr, weight_decay=child_config.weight_decay, amsgrad=False, polyak=child_config.polyak)
        self.child_problem = QuantileRegressor(name='inner', module=self.child_net,
                                               optimizer=self.child_optim, train_data_loader=train_loader,
                                               config=child_problem_config)
        self.child_problem.num_step = 0
        self.child_problem.log_iter = child_config.log_iter
        self.child_problem.device = self.device
        self.child_problem.init_lr = child_config.lr
        self.child_problem.true_test_loader = true_test_loader
        self.child_problem.val_loader = val_loader
        os.makedirs(config.checkpoint_dir, exist_ok=True)
        self.checkpoint_dir  = config.checkpoint_dir
        self.child_problem.checkpoint_path = os.path.join(config.checkpoint_dir, 'exp%04d_run%03d_best_model.pth' % (config.exp_num, self.run_id))
        self.exp_num = config.exp_num
        self.child_problem.cf_bound = config.data.cf_bound
        self.child_problem.val_metric = self.val_metric
        self.child_problem.min_criterion = 1e99
        self.child_problem.train_iters = config.train_iters
        self.child_problem.patience = 0

    def setup_parent(self, dataset_dict, config):
        parent_config = config.model.parent
        parent_config.batch_size = int(parent_config.batch_size)
        parent_problem_config = Config(log_step=parent_config.log_iter, first_order=True, retain_graph=True)



        tensors = [dataset_dict['x_test'], dataset_dict['t_test'], dataset_dict['y_f_test']]
        test_dataset = torch.utils.data.TensorDataset(*tensors
                                                      )

        if config.data.discrete_treatment:
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=min(parent_config.batch_size, len(dataset_dict['x_test'])),
                                                      shuffle=False,
                                                      drop_last=True,
                                                      sampler=_make_balanced_sampler(dataset_dict['t_test'].long().numpy().reshape(-1)))
        else:
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=min(int(config.model.child.batch_size),
                                                                                   len(dataset_dict['x_test'])),
                                                      shuffle=True,
                                                      drop_last=True)

        tmp_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=min(int(config.model.child.batch_size), len(dataset_dict['x_test'])), shuffle=True,
                                                      drop_last=True)
        self.child_problem.test_iter = InfiniteIterator(tmp_test_loader)


        if config.data.discrete_treatment:
            print('>>>> Using Discrete Treatment <<<<')
            self.parent_net = DiscreteTreatmentParentNet(parent_config, config.data)
        else:
            print('>>>> Using Continuous Treatment <<<<')
            self.parent_net = ContinuousTreatmentParentNet(parent_config, config.data)
        self.parent_optim = PolyakAdam(self.parent_net.parameters(), lr=parent_config.lr, weight_decay=parent_config.weight_decay, amsgrad=False, polyak=parent_config.polyak)
        self.parent_problem = QuantileEstimator(name='outer', module=self.parent_net,
                                                optimizer=self.parent_optim, train_data_loader=test_loader,
                                                config=parent_problem_config)
        self.parent_problem.num_step = 0
        self.parent_problem.loss_type = parent_config.loss_type
        self.parent_problem.device = self.device

    def evaluate_and_record(self, config, dataset_dict, metrics, run_id):
        tensors = [dataset_dict['x_train'], dataset_dict['t_train'], dataset_dict['y_f_train'], dataset_dict['y_cf_train']]
        true_train_dataset = torch.utils.data.TensorDataset(*tensors
                                                            )
        true_train_loader = torch.utils.data.DataLoader(true_train_dataset, batch_size=1, shuffle=False)


        tensors = [dataset_dict['x_test'], dataset_dict['t_test'], dataset_dict['y_f_test'], dataset_dict['y_cf_test']]
        true_test_dataset = torch.utils.data.TensorDataset(*tensors
                                                           )
        true_test_loader = torch.utils.data.DataLoader(true_test_dataset, batch_size=1, shuffle=False)
        load_model(self.parent_problem.module, self.child_problem.module, self.parent_problem.optimizer, self.child_problem.optimizer,
                   self.child_problem.checkpoint_path)()

        self.child_problem.optimizer.swap()
        ema_result_dict = self.child_problem.forward_loader(true_test_loader)
        ema_train_result_dict = self.child_problem.forward_loader(true_train_loader)
        ema_metric_dict = {}
        ema_train_metric_dict = {}
        for met in metrics:
            ema_metric_dict[met] = self.child_problem.compute_metric(ema_result_dict['yf_hat'], ema_result_dict['ycf_hat'],
                                                                     ema_result_dict['ts'], ema_result_dict['ys'], ema_result_dict['others'], met)
        for met in metrics:
            ema_train_metric_dict[met] = self.child_problem.compute_metric(ema_train_result_dict['yf_hat'], ema_train_result_dict['ycf_hat'],
                                                                           ema_train_result_dict['ts'], ema_train_result_dict['ys'], ema_train_result_dict['others'], met)

        ema_val_result_dict = self.child_problem.forward_loader(self.child_problem.val_loader)
        ema_metric_dict['selection_criterion'] = self.child_problem.compute_metric(ema_val_result_dict['yf_hat'], ema_val_result_dict['ycf_hat'],
                                                                                   ema_val_result_dict['ts'], ema_val_result_dict['ys'],
                                                                                   ema_val_result_dict['others'], self.val_metric)
        self.child_problem.optimizer.swap()

        result_dict = self.child_problem.forward_loader(true_test_loader)
        train_result_dict = self.child_problem.forward_loader(true_train_loader)
        metric_dict = {}
        train_metric_dict = {}
        for met in metrics:
            metric_dict[met] = self.child_problem.compute_metric(result_dict['yf_hat'], result_dict['ycf_hat'],
                                                                 result_dict['ts'], result_dict['ys'], result_dict['others'], met)
        for met in metrics:
            train_metric_dict[met] = self.child_problem.compute_metric(train_result_dict['yf_hat'], train_result_dict['ycf_hat'],
                                                                       train_result_dict['ts'], train_result_dict['ys'], train_result_dict['others'], met)

        val_result_dict = self.child_problem.forward_loader(self.child_problem.val_loader)
        metric_dict['selection_criterion'] = self.child_problem.compute_metric(val_result_dict['yf_hat'], val_result_dict['ycf_hat'],
                                                                               val_result_dict['ts'], val_result_dict['ys'],
                                                                               val_result_dict['others'],
                                                                               self.val_metric)

        if metric_dict['selection_criterion']>=ema_metric_dict['selection_criterion']:
            result_dict = ema_result_dict
            train_metric_dict = ema_train_metric_dict
            train_result_dict = ema_train_result_dict
            metric_dict = ema_metric_dict

        for k in train_metric_dict:
            result_dict['In-%s' % k] = train_metric_dict[k]
        for k in train_result_dict:
            result_dict['In-%s' % k] = train_result_dict[k]
        result_dict.update(metric_dict)
        result_dict['config'] = config
        print(metric_dict, train_metric_dict)
        save_dict(result_dict, os.path.join(self.checkpoint_dir, 'exp%04d_run%03d_test_result.pkl' % (self.exp_num, run_id)))

class ToyTrainer(Trainer):
    def __init__(self, dataset_dict, config, device, run_id):
        super().__init__(dataset_dict, config, device, run_id)

    @torch.no_grad()
    def evaluate_and_record(self, config, dataset_dict, metrics, run_id):
        load_model(self.parent_problem.module, self.child_problem.module, self.parent_problem.optimizer, self.child_problem.optimizer,
                   self.child_problem.checkpoint_path)()

        y_obs = dataset_dict['y_f_test']
        x = dataset_dict['x_val'].to(self.device)
        t = dataset_dict['t_val'].to(self.device)
        # the ground truth noise and generation function
        e = dataset_dict['e_val'].to(self.device)
        y_cf = dataset_dict['y_f_val'].to(self.device)

        # observations
        x_obs = dataset_dict['x_test'].to(self.device)
        t_obs = dataset_dict['t_test'].to(self.device)
        y_obs = dataset_dict['y_f_test'].to(self.device)

        tau = self.parent_problem.module(x_obs, t_obs, y_obs).unsqueeze(-1)

        y_cf_hat = self.child_problem.module(x, t, tau.repeat(len(x),1)).detach()
        self.child_problem.optimizer.swap()
        ema_y_cf_hat = self.child_problem.module(x, t, tau.repeat(len(x),1)).detach()
        self.child_problem.optimizer.swap()



        from scipy.stats import norm
        print('True Tau: %.4f , Estimated Tau: %.4f' % ( norm.cdf(dataset_dict['e_test'].numpy()), tau.item()))

        import seaborn as sns
        import pandas as pd
        from sklearn.linear_model import QuantileRegressor
        df = pd.DataFrame()
        train_data = torch.cat([dataset_dict['x_train'], dataset_dict['t_train']], 1).cpu().numpy()
        print(train_data)
        train_y = dataset_dict['y_f_train'].cpu().numpy()
        fig, ax = plt.subplots()
        ax.scatter(t.cpu(), y_cf.cpu(), label='Truth', marker='s')
        ax.scatter(t.cpu(), y_cf_hat.cpu(), label='Ours', marker='*')
        for i in [0.1, 0.5, 0.9]:
            rf_model = QuantileRegressor(quantile=i, alpha=0, solver='highs')
            rf_model.fit(train_data, train_y)
            rf_pred = rf_model.predict(torch.cat([x, t], 1).cpu().numpy())
            ax.scatter(t.cpu(), rf_pred, label='Linear Quantile: %.1f' % i)

        name = config.dataset.capitalize()
        fn = (os.path.join(self.checkpoint_dir, name+'_scatter.pdf'))
        plt.legend(fontsize=16)
        plt.xlabel('Treatment', fontdict={'size':22})
        plt.ylabel('Outcome', fontdict={'size':22})
        plt.title(r'$E^2, cos(E)$', fontdict={'size': 22})
        plt.grid()
        plt.savefig(fn, bbox_inches='tight')
        plt.close()

        np.savetxt(os.path.join(self.checkpoint_dir, 'tau.txt'), tau.cpu().numpy())


