import importlib
import os
import sys
import time

import torch
import torch.nn as nn

import datasets

from args import parse_args

from util.general_util import (
    parse_configs_file,
    setup_seed,
    FairMetrics,
    Logger
)
from models.cvae import cvae
from intervention import Intervention, IIDIntervention
from riskdifference import LogisticRiskDiff
from baselines.FairGNN import FairGNN



class InterventionForFairGNN(Intervention):
    def __init__(self, args, CVAE):
        super().__init__(args, CVAE)
      
    def cal_causal(self, intervened_x_pos, intervened_x_neg, edge_index,classifier, verbose=False):
        causal_pos, _ = classifier(edge_index, intervened_x_pos)
        causal_neg, _ = classifier(edge_index, intervened_x_neg)
        
        causal_pos_label = torch.where(causal_pos > 0, 1.0, 0.0)
        causal_neg_label = torch.where(causal_neg > 0, 1.0, 0.0)
        
        if verbose:
            print(f'y+^hat|do s+: {torch.sum(causal_pos_label)} / s+: {causal_pos_label.shape[0]}, {(torch.sum(causal_pos_label) / causal_pos_label.shape[0]).item()}')
            print(f'y+^hat|do s-: {torch.sum(causal_neg_label)} / s-: {causal_neg_label.shape[0]}, {(torch.sum(causal_neg_label) / causal_neg_label.shape[0]).item()}')
            
        return (torch.sum(causal_pos_label) / causal_pos_label.shape[0]).item() - (torch.sum(causal_neg_label) / causal_neg_label.shape[0]).item() 


def inference_with_causal_iid(model, data, test_mask, intervention):
      
    model.eval()
    
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        y_pred_all, _ = model(data.edge_index, data.x)
        
    
    do_pos_y = intervention.do(torch.ones_like(data.s), y_pred_all)
    do_neg_y = intervention.do(torch.zeros_like(data.s) - 1, y_pred_all)
    
    cd = intervention.cal_causal(do_pos_y, do_neg_y, verbose=False)
    print(f'iidcd:{cd}')
    

def eval(model, intervention, data):

    do_pos = intervention.do(data, torch.ones_like(data.s))
    do_neg = intervention.do(data, torch.zeros_like(data.s) - 1)
    intervention.cal_causal(do_pos, do_neg, data.edge_index, model, verbose=True)



def main():
    args = parse_args()
    if args.configs is not None:
        parse_configs_file(args)
        
    args.lf = 3
    args.lf = 0.9
    alpha = 600
    beta= 0.5
    
   
    result_main_dir = os.path.join('logs_syn_credit', f'seed{args.gtseed}_A{args.A}_layers{args.num_layers}_k{args.k}')
    result_sub_dir = os.path.join(result_main_dir, f'fairgnn_a{alpha}_seed{args.seed}.log')
    sys.stdout = Logger(result_sub_dir)
    print(args)
    
    setup_seed(args.seed)
    device = "cuda:0" if torch.cuda.is_available() else "cpu" 
        
    D = datasets.__dict__[args.dataset](args)   
    dataset = D.data_loaders(mapping_function = args.mapping_function)
    dataset = dataset.to(device)
    print(dataset)
    model = FairGNN(args.input_size, epoch=1000, subgraph_size=dataset.x.shape[0], acc=0.9, alpha=alpha, beta=beta)
    model.fit(dataset.edge_index, dataset.x, dataset.y, dataset.train_mask, dataset.test_mask, dataset.test_mask, dataset.s.squeeze(1), dataset.train_mask)
    print(model.predict(dataset.test_mask))
    
    CVAE = cvae(args.sensitive_size, args.hidden_channels, args.num_layers, args.A, args.input_size, args.latent_dim, args.z_size).to(device)
    CVAE.load_state_dict(torch.load('weights/cvae_training.pt'))
    intervention = InterventionForFairGNN(args, CVAE)

    eval(model, intervention, dataset)
    
    iidintervention = IIDIntervention(args, dataset)
    inference_with_causal_iid(model, dataset, dataset.test_mask, iidintervention)
    
    with torch.no_grad():
        model.eval()
        
        gt_pos_y, _ = model(dataset.edge_index, dataset.do_pos_x)
        gt_pos_y_label = torch.where(gt_pos_y > 0,1.0,0.0)
        
        gt_neg_y, _ = model(dataset.edge_index, dataset.do_neg_x)
        gt_neg_y_label = torch.where(gt_neg_y > 0,1.0,0.0)
        
    print(f'GT do s+ predict y: {torch.mean(gt_pos_y_label)}')
    print(f'GT do s- predict y: {torch.mean(gt_neg_y_label)}')
    print(f'cd:{abs(torch.mean(gt_pos_y_label) - torch.mean(gt_neg_y_label))}')
    
    
    
if __name__ == '__main__':
    main()