# -*- encoding: utf-8 -*-
import numpy as np
import torch

class CausalEvaluator(object):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.error_theta = config['error_theta']
        self.device = config['device']
        self.theta_rate = config['theta_rate']
        self.targets = None
        self.logits = None
        self.treatments = None
        self.factual_outcome = None
        self.indicator_random = None
    def reset(self):
        self.targets = None
        self.logits = None
        self.treatments = None
        self.factual_outcome = None
        self.indicator_random = None
        self.treat_pred = None
        self.control_pred= None
    def collect(self,ground_truth,prediction,treatment,indicator_random,factual_outcome,treat_pred,control_pred):

        if self.targets is None:
            self.targets = ground_truth
            self.logits = prediction # f(x,1) - f(x,0)
            self.treatments = treatment # true treatments
            self.factual_outcome = factual_outcome
            self.indicator_random = indicator_random
            self.treat_pred = treat_pred
            self.control_pred = control_pred
        else:
            self.targets = torch.cat([self.targets,ground_truth],dim=0)
            self.logits = torch.cat([self.logits, prediction], dim=0)
            self.treatments = torch.cat([self.treatments, treatment], dim=0)
            self.factual_outcome = torch.cat([self.factual_outcome, factual_outcome], dim=0)
            self.indicator_random = torch.cat([self.indicator_random, indicator_random], dim=0)
            self.treat_pred = torch.cat([self.treat_pred, treat_pred], dim=0)
            self.control_pred = torch.cat([self.control_pred, control_pred], dim=0)



    def evaluate(self):


        self.targets = self.targets.squeeze(-1)
        # self.error_theta = torch.mean(self.targets)
        if self.config['model'] in ['UITE_MMD','UITE_WASS']:
            self.logits = self.logits.squeeze(1)
            self.treat_pred = self.treat_pred.squeeze(1)
            self.control_pred = self.control_pred.squeeze(1)
        else:
            self.logits = self.logits.squeeze(-1)
            self.treat_pred = self.treat_pred.squeeze(-1)
            self.control_pred = self.control_pred.squeeze(-1)
        self.treatments = self.treatments.squeeze(-1)
        self.factual_outcome = self.factual_outcome.squeeze(-1)
        self.indicator_random = self.indicator_random.squeeze(-1)

        pehe = self.pehe()
        ate = self.ate()
        rpol = self.rpol(self.treatments,self.factual_outcome,self.logits) if self.config['dataset'] in ['Jobs'] else 0
        att = self.att() if self.config['dataset'] in ['Jobs'] else 0
        fn,fp,fpn = self.fnp()
        results = {
            'out_pehe': pehe,
            'out_ate': ate,
            'out_rpol': rpol,
            'out_fn': fn,
            'out_fp': fp,
            'out_fpn': fpn,
            'out_att': att
        }
        self.reset()
        return results
    def fnp(self):
        if self.config['dataset'] in ['Jobs']:
            ground_truth = self.factual_outcome
            preds = torch.zeros((len(ground_truth))).to(self.device)
            if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
                for index in range(len(preds)):
                    logit = self.treat_pred[index] if self.treatments[index] > 0 else self.control_pred[index]
                    # print('{}-{}:{}'.format(self.config['model'], self.config['dataset'], torch.mean(logit)))
                    p = torch.sum((logit >= torch.mean(logit)).to(torch.float32)) / len(logit)

                    if self.theta_rate * (1 - p) < (1 - self.theta_rate) * p:
                        preds[index] = 1
            else:
                preds[self.treatments > 0] = self.treat_pred[self.treatments > 0]
                preds[self.treatments < 1] = self.control_pred[self.treatments < 1]
                self.error_theta = torch.mean(preds)
                preds = (preds >= self.error_theta).float()

        else:
            self.error_theta = torch.mean(self.targets)
            ground_truth = (self.targets >= self.error_theta).float()
            if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
                preds = torch.zeros((len(ground_truth))).to(self.device)
                for index in range(len(ground_truth)):
                    logit = self.logits[index]
                    # print('{}-{}:{}'.format(self.config['model'],self.config['dataset'],torch.mean(logit)))
                    p = torch.sum((logit >= torch.mean(logit)).to(torch.float32)) / len(logit)
                    # print(self.logits)

                    if self.theta_rate * (1-p) < (1-self.theta_rate) * p:
                        preds[index] = 1
            else:
                # self.error_theta = torch.mean(self.logits)
                self.error_theta = torch.mean(self.logits)
                preds = (self.logits >= self.error_theta).float()

        test_size = len(preds)
        fn,fp = 0,0

        for x,y in zip(ground_truth,preds):
            if x == 1. and y == 0: fn += 1.
            if x == 0. and y == 1: fp += 1.

        # print(self.logits)
        fn = fn / test_size
        fp = fp / test_size
        fpn = self.theta_rate * fp + (1-self.theta_rate) * fn
        return fn,fp,fpn

    def att(self):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.logits,dim=-1)
        else:
            logits = self.logits
        treat_att = torch.mean(self.factual_outcome[self.treatments > 0])
        control_random_index = (1-self.treatments) * self.indicator_random
        control_att = torch.mean(self.factual_outcome[control_random_index > 0])
        att = treat_att - control_att
        treat_ite = torch.mean(logits[self.treatments > 0])
        att = torch.abs(att-treat_ite).item()
        return att

    def ate(self):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.logits,dim=-1)
        else:
            logits = self.logits
        return torch.abs(torch.mean(logits) - torch.mean(self.targets)).item()


    def pehe(self,):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.logits,dim=-1)
        else:
            logits = self.logits
        return torch.sqrt(torch.mean((logits-self.targets)**2)).item()

    def rpol(self, t, yf, eff_pred):
        '''
        Args:
            t: treatment
            yf: factual results
            eff_pred: the difference between f(x,1) and f(x,0),.e.g,f(x,1)-f(x,0)

        Returns:

        '''
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            eff_pred = torch.mean(eff_pred,dim=-1)

        t,yf,eff_pred = t[self.indicator_random > 0], yf[self.indicator_random > 0], eff_pred[self.indicator_random > 0]
        if torch.any(torch.isnan(eff_pred)):
            return np.nan, np.nan

        policy = eff_pred > 0
        treat_overlap = (policy == t) * (t > 0)
        control_overlap = (policy == t) * (t < 1)

        if torch.sum(treat_overlap) == 0:
            treat_value = 0
        else:
            treat_value = torch.mean(yf[treat_overlap])

        if torch.sum(control_overlap) == 0:
            control_value = 0
        else:
            control_value = torch.mean(yf[control_overlap])
        pit = torch.mean(policy.to(torch.float32))
        policy_value = pit * treat_value + (1 - pit) * control_value
        policy_value = policy_value.item()
        return 1-policy_value

class SKlearnEvaluator(object):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.targets = None
        self.logits = None
        self.treatments = None
        self.factual_outcome = None
        self.indicator_random = None
        self.error_theta = config['error_theta']
        self.theta_rate = config['theta_rate']
    def reset(self):
        self.targets = None
        self.logits = None
        self.treatments = None
        self.factual_outcome = None
        self.indicator_random = None

    def collect(self,ground_truth,prediction,treatment,indicator_random,factual_outcome):

        self.targets = ground_truth
        self.logits = prediction # f(x,1) - f(x,0)
        self.treatments = treatment # true treatments
        self.factual_outcome = factual_outcome
        self.indicator_random = indicator_random



    def evaluate(self):

        pehe = self.pehe()
        ate = self.ate()
        rpol = self.rpol(self.treatments,self.factual_outcome,self.logits) if self.config['dataset'] in ['Jobs'] else 0
        att = self.att() if self.config['dataset'] in ['Jobs'] else 0
        fn, fp, fpn = self.fnp()
        results = {
            'out_pehe': pehe,
            'out_ate': ate,
            'out_rpol': rpol,
            'out_fn': fn,
            'out_fp': fp,
            'out_fpn': fpn,
            'out_att': att
        }
        self.reset()
        return results

    def fnp(self):
        if self.config['dataset'] in ['Jobs']:

            self.error_theta = np.mean(self.factual_outcome)
            ground_truth = (self.targets >= self.error_theta).astype(float)

        else:
            self.error_theta = np.mean(self.targets)
            ground_truth = (self.targets >= self.error_theta).astype(float)

        self.error_theta = np.mean(self.logits)
        preds = (self.logits >= self.error_theta).astype(float)

        test_size = len(preds)
        fn,fp = 0,0

        for x,y in zip(ground_truth,preds):
            if x == 1. and y == 0: fn += 1.
            if x == 0. and y == 1: fp += 1.
        # print(self.logits)
        fn = fn / test_size
        fp = fp / test_size
        fpn = self.theta_rate * fp + (1-self.theta_rate) * fn
        return fn,fp,fpn

    def att(self):
        treat_att = np.mean(self.factual_outcome[self.treatments > 0])
        control_random_index = (1-self.treatments) * self.indicator_random
        control_att = np.mean(self.factual_outcome[control_random_index > 0])
        att = treat_att - control_att
        treat_ite = np.mean(self.logits[self.treatments > 0])
        att = np.abs(att-treat_ite)
        return att

    def ate(self):


        return np.abs(np.mean(self.logits) - np.mean(self.targets))


    def pehe(self,):


        return np.sqrt(np.mean((self.logits-self.targets)**2))

    def rpol(self, t, yf, eff_pred):
        '''
        Args:
            t: treatment
            yf: factual results
            eff_pred: the difference between f(x,1) and f(x,0),.e.g,f(x,1)-f(x,0)

        Returns:

        '''

        t,yf,eff_pred = t[self.indicator_random > 0], yf[self.indicator_random > 0], eff_pred[self.indicator_random > 0]
        if np.any(np.isnan(eff_pred)):
            return np.nan, np.nan

        policy = eff_pred > 0
        treat_overlap = (policy == t) * (t > 0)
        control_overlap = (policy == t) * (t < 1)

        if np.sum(treat_overlap) == 0:
            treat_value = 0
        else:
            treat_value = np.mean(yf[treat_overlap])

        if np.sum(control_overlap) == 0:
            control_value = 0
        else:
            control_value = np.mean(yf[control_overlap])
        pit = np.mean(policy.astype(float))
        policy_value = pit * treat_value + (1 - pit) * control_value
        return 1-policy_value