import importlib
import os
import sys
import time

import torch
import torch.nn as nn

import datasets

from args import parse_args

from util.general_util import (
    parse_configs_file,
    setup_seed,
    FairMetrics,
    Logger
)
from models.cvae import cvae
from intervention import Intervention, IIDIntervention
from riskdifference import LogisticRiskDiff
from baselines.GEAR import GEAR


def inference_with_causal_iid(model, data, test_mask, intervention):
      
    model.eval()
    
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        hidden = model(torch.cat((data.s, data.x), dim=1), data.edge_index)
        y_pred_all = model.classifier(hidden)
        
    
    do_pos_y = intervention.do(torch.ones_like(data.s), y_pred_all)
    do_neg_y = intervention.do(torch.zeros_like(data.s) - 1, y_pred_all)
    
    cd = intervention.cal_causal(do_pos_y, do_neg_y, verbose=False)
    print(f'iidcd:{cd}')


class InterventionForGEAR(Intervention):
    def __init__(self, args, CVAE):
        super().__init__(args, CVAE)
      
    def cal_causal(self, intervened_x_pos, intervened_x_neg, edge_index,model, verbose=False):
        
        z_pos = model(intervened_x_pos, edge_index)
        z_neg = model(intervened_x_neg, edge_index)
        
        causal_pos = model.classifier(z_pos)
        causal_neg = model.classifier(z_neg)
        
        causal_pos_label = torch.where(causal_pos > 0, 1.0, 0.0)
        causal_neg_label = torch.where(causal_neg > 0, 1.0, 0.0)
        
        if verbose:
            print(f'y+^hat|do s+: {torch.sum(causal_pos_label)} / s+: {causal_pos_label.shape[0]}, {(torch.sum(causal_pos_label) / causal_pos_label.shape[0]).item()}')
            print(f'y+^hat|do s-: {torch.sum(causal_neg_label)} / s-: {causal_neg_label.shape[0]}, {(torch.sum(causal_neg_label) / causal_neg_label.shape[0]).item()}')
            
        return (torch.sum(causal_pos_label) / causal_pos_label.shape[0]).item() - (torch.sum(causal_neg_label) / causal_neg_label.shape[0]).item() 



    
def inference_with_causal(model, criterion, data, test_mask, intervention:Intervention, args):
    model.eval()
    metric = FairMetrics()
    loss_val = 0.0
    correct = 0.0
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        y_pred = model(data.x)
        y_pred = y_pred[test_mask]
        y_pred_label = torch.where(y_pred > 0.0, 1.0, 0.0)
        correct += y_pred_label.eq(data.y[test_mask].view_as(y_pred_label)).cpu().sum()
        loss = criterion(y_pred, data.y[test_mask].view_as(y_pred))
        loss_val += loss.item()
        
        metric.update(y_pred_label, data.s[test_mask].view_as(y_pred), data.y[test_mask].view_as(y_pred))
        
    rd, rd1, rd0 = metric.count_stats(verbose=True)
    print(f'y+^hat | s+:{rd1}')
    print(f'y+^hat | s-:{rd0}')
    do_pos = intervention.do(data, torch.ones_like(data.s))
    do_neg = intervention.do(data, torch.zeros_like(data.s) - 1)
    print(f'loss:{loss_val}, acc:{correct / s.shape[0]}, rd:{abs(rd)}') 
    intervention.cal_causal(do_pos, do_neg, model, verbose=True)
    
    return loss_val, correct / s.shape[0]

def eval(model, intervention, data):
    do_pos = intervention.do(data, torch.ones_like(data.s))
    do_neg = intervention.do(data, torch.zeros_like(data.s) - 1)
    do_pos_x = torch.cat((torch.ones_like(data.s), do_pos, data.z), dim=1)
    do_pos_x = torch.cat((torch.ones_like(data.s), do_pos), dim=1)
    do_neg_x = torch.cat((torch.zeros_like(data.s), do_neg, data.z), dim=1)
    do_neg_x = torch.cat((torch.zeros_like(data.s), do_neg), dim=1)
    intervention.cal_causal(do_pos_x, do_neg_x, data.edge_index, model, verbose=True)



def main():
    args = parse_args()
    if args.configs is not None:
        parse_configs_file(args)
        

    result_main_dir = os.path.join('logs_german', f'seed{args.gtseed}_A{args.A}_layers{args.num_layers}_k{args.k}')
    result_sub_dir = os.path.join(result_main_dir, f'gear_lf{args.lf:.4f}_seed{args.seed}.log')
    sys.stdout = Logger(result_sub_dir)
    
    setup_seed(args.seed)
    device = "cuda:0" if torch.cuda.is_available() else "cpu" 
        
    D = datasets.__dict__[args.dataset](args)   
    dataset = D.data_loaders(mapping_function = args.mapping_function)
   
   
    
    
    adj = torch.zeros(dataset.x.shape[0], dataset.x.shape[0])
    adj[dataset.edge_index[0, :], dataset.edge_index[1, :]] = 1
    adj = adj.to_sparse()

    x = torch.cat((dataset.s, dataset.x, dataset.z), dim=1)
    x = torch.cat((dataset.s, dataset.x), dim=1)
    
    
    
    model = GEAR(adj, x, dataset.y, torch.nonzero(dataset.train_mask).squeeze(1), torch.nonzero(dataset.test_mask).squeeze(1), torch.nonzero(dataset.test_mask).squeeze(1), dataset.s.squeeze(1), 0, hidden_size=16, encoder_hidden_size=16)
 
    model.fit(batch_size=x.shape[0],  epochs=200, sim_coeff=args.lf)
    model.predict(batch_size=x.shape[0])
   
    
    CVAE = cvae(args.sensitive_size, args.hidden_channels, args.num_layers, args.A, args.input_size, args.latent_dim, args.z_size).to(device)
    CVAE.load_state_dict(torch.load('weights/cvae_training.pt'))
    CVAE.to(device)
    
    intervention = InterventionForGEAR(args, CVAE)
    dataset = dataset.to(device)
    eval(model, intervention, dataset)
    
    iidintervention = IIDIntervention(args, dataset)
    inference_with_causal_iid(model, dataset, dataset.test_mask, iidintervention)
    
    with torch.no_grad():
        model.eval()
        do_pos_x = torch.cat((torch.ones_like(dataset.s), dataset.do_pos_x), dim=1)
        do_neg_x = torch.cat((torch.zeros_like(dataset.s), dataset.do_neg_x), dim=1)
        z_pos = model(do_pos_x, dataset.edge_index)
        z_neg = model(do_neg_x, dataset.edge_index)
        
        gt_pos_y = model.classifier(z_pos)
        gt_pos_y_label = torch.where(gt_pos_y > 0,1.0,0.0)
        
        gt_neg_y = model.classifier(z_neg)
        gt_neg_y_label = torch.where(gt_neg_y > 0,1.0,0.0)
        
    print(f'GT do s+ predict y: {torch.mean(gt_pos_y_label)}')
    print(f'GT do s- predict y: {torch.mean(gt_neg_y_label)}')
    print(f'cd:{abs(torch.mean(gt_pos_y_label) - torch.mean(gt_neg_y_label))}')
    
    
    
if __name__ == '__main__':
    main()
