import torch
from sklearn.linear_model import LogisticRegression

class Intervention():
    def __init__(self, args, CVAE):
        self.args = args
        self.CVAE = CVAE
        
        self.CVAE.eval()
        
    def do(self, data, s):
        with torch.no_grad():
            a = self.CVAE.gnn(s, data.edge_index)
            xa = torch.cat((data.x, a, data.z), dim=1)
            mu, logvar = self.CVAE.encoder(xa)
            z = self.CVAE.reparameterize(mu, logvar)
            z = torch.cat((z, a, data.z), dim=1)
            intervened_x = self.CVAE.decoder(z)
            """ v = torch.randn(data.x.shape[0], self.args.latent_dim, device = data.x.device)
            A = self.CVAE.gnn(s, data.edge_index)
            intervened_x = self.CVAE.decoder(torch.cat((v, A, data.z), dim=1)) """
            
        return intervened_x
    
    def cal_causal(self, intervened_x_pos, intervened_x_neg, classifier, verbose=False):
        causal_pos = classifier(intervened_x_pos)
        causal_neg = classifier(intervened_x_neg)
        
        causal_pos_label = torch.where(causal_pos > 0, 1.0, 0.0)
        causal_neg_label = torch.where(causal_neg > 0, 1.0, 0.0)
        
        if verbose:
            print(f'y+^hat|do s+: {torch.sum(causal_pos_label)} / s+: {causal_pos_label.shape[0]}, {(torch.sum(causal_pos_label) / causal_pos_label.shape[0]).item()}')
            print(f'y+^hat|do s-: {torch.sum(causal_neg_label)} / s-: {causal_neg_label.shape[0]}, {(torch.sum(causal_neg_label) / causal_neg_label.shape[0]).item()}')
            
        return (torch.sum(causal_pos_label) / causal_pos_label.shape[0]).item() - (torch.sum(causal_neg_label) / causal_neg_label.shape[0]).item()
    
    def cal_loss(self, intervened_x_pos, intervened_x_neg, classifier):
        causal_pos = classifier(intervened_x_pos)
        causal_neg = classifier(intervened_x_neg)
        _s = 1 - 2 * self.args.protect
        causal = torch.cat((causal_pos * _s, -causal_neg * _s), dim=1)
        def _ApproxHuber(z):
            mask = torch.where(z < 10, 1.0, 0.0)
            return torch.log(1 + torch.exp(z * mask)) + (z - 0.6931)* (1 - mask)
        
        surrogate = _ApproxHuber
        
        
        return torch.mean(torch.div(surrogate(causal), 0.5) - 1)
    
class InterventionByMLP(Intervention):
    def __init__(self, args, CVAE):
        super().__init__(args, CVAE)
        
    def do(self, data, s):
        with torch.no_grad():
            A = self.CVAE.gnn(s, data.edge_index)
            za = torch.cat((data.z, A), dim=1)
            intervened_x = self.CVAE.MLP(za)
            
        return intervened_x
    
class IIDIntervention():
    def __init__(self, args, data):
        self.args = args
        self.reg = LogisticRegression(max_iter=1000).fit(data.z.cpu().numpy(), data.s.squeeze(1).cpu().numpy())
        prob_sz = torch.tensor(self.reg.predict_proba(data.z.cpu())).to(data.x.device)
        label = torch.tensor(self.reg.predict(data.z.cpu()))
        
        prob_sz_pos = prob_sz[:, 1]
        prob_sz_neg = prob_sz[:, 0]
        
        prob_s_pos = torch.mean(data.s)
        prob_s_neg = 1 - prob_s_pos
        
        self.do_pos = prob_s_pos / prob_sz_pos
        self.do_neg = prob_s_neg / prob_sz_neg
        
        
        
    def do(self, s, y_pred):
        prob_yx = torch.sigmoid(y_pred.squeeze(1))
        if torch.mean(s) == 1:
            return torch.mean(self.do_pos * prob_yx) # do_pos_y
        
        elif torch.mean(s) == -1: 
            return torch.mean(self.do_neg * prob_yx) # do_neg_y
        
    def cal_causal(self, do_pos_y, do_neg_y, verbose=True):
        if verbose:
            print(f'y+^hat|do s+: {do_pos_y}')
            print(f'y+^hat|do s-: {do_neg_y}')
            
        return do_pos_y - do_neg_y
        
    def cal_loss(self, y_pred):
        _s = 1 - 2 * self.args.protect
        
        prob_yx = torch.sigmoid(y_pred.squeeze(1))
        
        return torch.mean(_s * (self.do_pos - self.do_neg) * prob_yx)
    
class InterventionForGNN(Intervention):
    def __init__(self, args, CVAE):
        super().__init__(args, CVAE)
      
    def cal_causal(self, intervened_x_pos, intervened_x_neg, edge_index,classifier, verbose=False):
        causal_pos = classifier(intervened_x_pos, edge_index)
        causal_neg = classifier(intervened_x_neg, edge_index)
        
        causal_pos_label = torch.where(causal_pos > 0, 1.0, 0.0)
        causal_neg_label = torch.where(causal_neg > 0, 1.0, 0.0)
        
        if verbose:
            print(f'y+^hat|do s+: {torch.sum(causal_pos_label)} / s+: {causal_pos_label.shape[0]}, {(torch.sum(causal_pos_label) / causal_pos_label.shape[0]).item()}')
            print(f'y+^hat|do s-: {torch.sum(causal_neg_label)} / s-: {causal_neg_label.shape[0]}, {(torch.sum(causal_neg_label) / causal_neg_label.shape[0]).item()}')
            
        return (torch.sum(causal_pos_label) / causal_pos_label.shape[0]).item() - (torch.sum(causal_neg_label) / causal_neg_label.shape[0]).item() 
        
        
        
    
        
    