# -*- encoding: utf-8 -*-
import numpy as np
import torch

class OutCausalEvaluator(object):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.decision_boundary = config['decision_boundary']
        self.error_theta = config['error_theta']
        self.device = config['device']
        self.targets = None
        self.logits = None
        self.treatments = None
        self.factual_outcome = None
        self.indicator_random = None

    def reset(self):
        self.targets = None
        self.logits = None
        self.treatments = None
        self.factual_outcome = None
        self.indicator_random = None
        self.treat_pred = None
        self.control_pred= None
    def collect(self,ground_truth,prediction,treatment,indicator_random,factual_outcome,treat_pred,control_pred):

        if self.targets is None:
            self.targets = ground_truth
            self.logits = prediction # f(x,1) - f(x,0)
            self.treatments = treatment # true treatments
            self.factual_outcome = factual_outcome
            self.indicator_random = indicator_random
            self.treat_pred = treat_pred
            self.control_pred = control_pred
        else:
            self.targets = torch.cat([self.targets,ground_truth],dim=0)
            self.logits = torch.cat([self.logits, prediction], dim=0)
            self.treatments = torch.cat([self.treatments, treatment], dim=0)
            self.factual_outcome = torch.cat([self.factual_outcome, factual_outcome], dim=0)
            self.indicator_random = torch.cat([self.indicator_random, indicator_random], dim=0)
            self.treat_pred = torch.cat([self.treat_pred, treat_pred], dim=0)
            self.control_pred = torch.cat([self.control_pred, control_pred], dim=0)

    def evaluate(self):

        self.targets = self.targets.squeeze(-1)
        self.decision_boundary = torch.mean(self.targets)
        if self.config['model'] in ['UITE_MMD','UITE_WASS']:
            self.logits = self.logits.squeeze(1)
            self.treat_pred = self.treat_pred.squeeze(1)
            self.control_pred = self.control_pred.squeeze(1)
        else:
            self.logits = self.logits.squeeze(-1)
            self.treat_pred = self.treat_pred.squeeze(-1)
            self.control_pred = self.control_pred.squeeze(-1)
        self.treatments = self.treatments.squeeze(-1)
        self.factual_outcome = self.factual_outcome.squeeze(-1)
        self.indicator_random = self.indicator_random.squeeze(-1)

        pehe = self.pehe()
        ate = self.ate()
        rpol = self.rpol(self.treatments,self.factual_outcome,self.logits) if self.config['dataset'] in ['Jobs'] else 0
        att = self.att() if self.config['dataset'] in ['Jobs'] else 0
        fn,fp,fpn = self.fnp()
        results = {
            'out_pehe': pehe,
            'out_ate': ate,
            'out_rpol': rpol,
            'out_fn': fn,
            'out_fp': fp,
            'out_fpn': fpn,
            'out_att': att
        }
        self.reset()
        return results
    def fnp(self):
        if self.config['dataset'] in ['Jobs']:
            ground_truth = self.factual_outcome
            preds = torch.zeros((len(ground_truth))).to(self.device)
            if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
                for index in range(len(preds)):
                    logit = self.treat_pred[index] if self.treatments[index] > 0 else self.control_pred[index]
                    self.decision_boundary = torch.mean(logit)
                    p = torch.sum((logit >= self.decision_boundary).to(torch.float32)) / len(logit)
                    if self.error_theta * (1 - p) < (1 - self.error_theta) * p:
                        preds[index] = 1
            else:
                preds[self.treatments > 0] = self.treat_pred[self.treatments > 0]
                preds[self.treatments < 1] = self.control_pred[self.treatments < 1]
        else:
            ground_truth = self.targets >= self.decision_boundary
            if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
                preds = torch.zeros((len(ground_truth))).to(self.device)
                for index in range(len(ground_truth)):
                    logit = self.logits[index]
                    self.decision_boundary = torch.mean(logit)
                    p = torch.sum((logit >= self.decision_boundary).to(torch.float32)) / len(logit)
                    if self.error_theta * (1-p) < (1-self.error_theta) * p:
                        preds[index] = 1
            else:
                self.decision_boundary = torch.mean(self.logits)
                preds = self.logits >= self.decision_boundary

        test_size = len(preds)
        fn,fp = 0,0

        for x,y in zip(ground_truth,preds):
            if x == 1 and y == 0: fn += 1
            if x == 0 and y == 1: fp += 1
        fn = fn / test_size
        fp = fp / test_size

        fpn = self.error_theta * fp + (1-self.error_theta) * fn
        return fn,fp,fpn

    def att(self):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.logits,dim=-1)
        else:
            logits = self.logits
        treat_att = torch.mean(self.factual_outcome[self.treatments > 0])
        control_random_index = (1-self.treatments) * self.indicator_random
        control_att = torch.mean(self.factual_outcome[control_random_index > 0])
        att = treat_att - control_att
        treat_ite = torch.mean(logits[self.treatments > 0])
        att = torch.abs(att-treat_ite).item()
        return att

    def ate(self):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.logits,dim=-1)
        else:
            logits = self.logits
        return torch.abs(torch.mean(logits) - torch.mean(self.targets)).item()


    def pehe(self,):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.logits,dim=-1)
        else:
            logits = self.logits
        return torch.sqrt(torch.mean((logits-self.targets)**2)).item()

    def rpol(self, t, yf, eff_pred):
        '''
        Args:
            t: treatment
            yf: factual results
            eff_pred: the difference between f(x,1) and f(x,0),.e.g,f(x,1)-f(x,0)

        Returns:

        '''
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            eff_pred = torch.mean(eff_pred,dim=-1)

        t,yf,eff_pred = t[self.indicator_random > 0], yf[self.indicator_random > 0], eff_pred[self.indicator_random > 0]
        if torch.any(torch.isnan(eff_pred)):
            return np.nan, np.nan

        policy = eff_pred > 0
        treat_overlap = (policy == t) * (t > 0)
        control_overlap = (policy == t) * (t < 1)

        if torch.sum(treat_overlap) == 0:
            treat_value = 0
        else:
            treat_value = torch.mean(yf[treat_overlap])

        if torch.sum(control_overlap) == 0:
            control_value = 0
        else:
            control_value = torch.mean(yf[control_overlap])
        pit = torch.mean(policy.to(torch.float32))
        policy_value = pit * treat_value + (1 - pit) * control_value
        policy_value = policy_value.item()
        return 1-policy_value


class InCausalEvaluator(object):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.decision_boundary = config['decision_boundary']
        self.error_theta = config['error_theta']
        self.device = config['device']
        self.targets = None
        self.logits = None
        self.treatments = None
        self.factual_outcome = None
        self.indicator_random = None

    def reset(self):
        self.targets = None
        self.logits = None
        self.treatments = None
        self.factual_outcome = None
        self.indicator_random = None
        self.treat_pred = None
        self.control_pred= None
    def collect(self,ground_truth,prediction,treatment,indicator_random,factual_outcome,treat_pred,control_pred):

        if self.targets is None:
            self.targets = ground_truth
            self.logits = prediction
            self.treatments = treatment
            self.factual_outcome = factual_outcome
            self.indicator_random = indicator_random
            self.treat_pred = treat_pred
            self.control_pred = control_pred
        else:
            self.targets = torch.cat([self.targets,ground_truth],dim=0)
            self.logits = torch.cat([self.logits, prediction], dim=0)
            self.treatments = torch.cat([self.treatments, treatment], dim=0)
            self.factual_outcome = torch.cat([self.factual_outcome, factual_outcome], dim=0)
            self.indicator_random = torch.cat([self.indicator_random, indicator_random], dim=0)
            self.treat_pred = torch.cat([self.treat_pred, treat_pred], dim=0)
            self.control_pred = torch.cat([self.control_pred, control_pred], dim=0)

    def evaluate(self):

        self.targets = self.targets.squeeze(-1)
        self.decision_boundary = torch.mean(self.targets)
        self.treatments = self.treatments.squeeze(-1)

        if self.config['model'] in ['UITE_MMD','UITE_WASS']:
            self.logits = self.logits.squeeze(1)
            self.treat_pred = self.treat_pred.squeeze(1)
            self.control_pred = self.control_pred.squeeze(1)
            self.in_logits = torch.zeros(len(self.logits),len(self.treat_pred[0])).to(self.device)
        else:
            self.logits = self.logits.squeeze(-1)
            self.treat_pred = self.treat_pred.squeeze(-1)
            self.control_pred = self.control_pred.squeeze(-1)
            self.in_logits = torch.zeros(len(self.logits)).to(self.device)

        self.in_logits[self.treatments > 0] = self.factual_outcome[self.treatments > 0] - self.control_pred[self.treatments > 0]
        self.in_logits[self.treatments < 1] = self.treat_pred[self.treatments < 1] - self.factual_outcome[self.treatments < 1]

        self.factual_outcome = self.factual_outcome.squeeze(-1)
        self.indicator_random = self.indicator_random.squeeze(-1)

        pehe = self.pehe()
        ate = self.ate()
        att = self.att() if self.config['dataset'] in ['Jobs'] else 0
        rpol = self.rpol(self.treatments,self.factual_outcome,self.in_logits) if self.config['dataset'] in ['Jobs'] else 0
        fn,fp,fpn = self.fnp()
        results = {
            'in_pehe': pehe,
            'in_ate': ate,
            'in_rpol': rpol,
            'in_fn': fn,
            'in_fp': fp,
            'in_fpn': fpn,
            'in_att': att
        }
        self.reset()
        return results
    def fnp(self):
        if self.config['dataset'] in ['Jobs']:
            ground_truth = self.factual_outcome
            preds = torch.zeros((len(ground_truth))).to(self.device)
            if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
                for index in range(len(preds)):
                    logit = self.treat_pred[index] if self.treatments[index] > 0 else self.control_pred[index]
                    p = torch.sum((logit >= self.decision_boundary).to(torch.float32)) / len(logit)
                    if self.error_theta * (1 - p) < (1 - self.error_theta) * p:
                        preds[index] = 1
            else:
                preds[self.treatments > 0] = self.treat_pred[self.treatments > 0]
                preds[self.treatments < 1] = self.control_pred[self.treatments < 1]
        else:
            ground_truth = self.targets >= self.decision_boundary
            if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
                preds = torch.zeros((len(ground_truth))).to(self.device)
                for index in range(len(ground_truth)):
                    logit = self.in_logits[index]
                    p = torch.sum((logit >= self.decision_boundary).to(torch.float32)) / len(logit)
                    if self.error_theta * (1-p) < (1-self.error_theta) * p:
                        preds[index] = 1
            else:
                self.decision_boundary = torch.mean(self.in_logits)
                preds = self.in_logits >= self.decision_boundary

        test_size = len(preds)
        fn,fp = 0,0
        for x,y in zip(ground_truth,preds):
            if x == 1 and y == 0: fn += 1
            if x == 0 and y == 1: fp += 1
        fn = fn / test_size
        fp = fp / test_size

        fpn = self.error_theta * fp + (1-self.error_theta) * fn
        return fn,fp,fpn

    def att(self):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.in_logits,dim=-1)
        else:
            logits = self.in_logits
        treat_att = torch.mean(self.factual_outcome[self.treatments > 0])
        control_random_index = (1 - self.treatments) * self.indicator_random
        control_att = torch.mean(self.factual_outcome[control_random_index > 0])
        att = treat_att - control_att
        treat_ite = torch.mean(logits[self.treatments > 0])
        att = torch.abs(att-treat_ite).item()
        return att

    def ate(self):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.in_logits,dim=-1)
        else:
            logits = self.in_logits
        return torch.abs(torch.mean(logits) - torch.mean(self.targets)).item()


    def pehe(self,):
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            logits = torch.mean(self.in_logits,dim=-1)
        else:
            logits = self.in_logits
        return torch.sqrt(torch.mean((logits-self.targets)**2)).item()

    def rpol(self, t, yf, eff_pred):
        '''
        Args:
            t: treatment
            yf: factual results
            eff_pred: the difference between f(x,1) and f(x,0),.e.g,f(x,1)-f(x,0)

        Returns:

        '''
        if self.config['model'] in ['UITE_MMD', 'UITE_WASS']:
            eff_pred = torch.mean(eff_pred,dim=-1)
        t,yf,eff_pred = t[self.indicator_random > 0], yf[self.indicator_random > 0], eff_pred[self.indicator_random > 0]
        if torch.any(torch.isnan(eff_pred)):
            return np.nan, np.nan

        policy = eff_pred > 0
        treat_overlap = (policy == t) * (t > 0)
        control_overlap = (policy == t) * (t < 1)

        if torch.sum(treat_overlap) == 0:
            treat_value = 0
        else:
            treat_value = torch.mean(yf[treat_overlap])

        if torch.sum(control_overlap) == 0:
            control_value = 0
        else:
            control_value = torch.mean(yf[control_overlap])
        pit = torch.mean(policy.to(torch.float32))
        policy_value = pit * treat_value + (1 - pit) * control_value
        policy_value = policy_value.item()
        return 1-policy_value