import importlib
import os
import sys
import time

import torch
import torch.nn as nn

import datasets

from args import parse_args

from util.general_util import (
    parse_configs_file,
    setup_seed,
    FairMetrics,
    Logger
)
from models.cvae import cvae
from intervention import Intervention, IIDIntervention, InterventionForGNN
from riskdifference import LogisticRiskDiff
from torch_geometric.nn import GCN



def inference_with_causal_iid(model, data, test_mask, intervention):
      
    model.eval()
    
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        y_pred_all = model(data.x, data.edge_index)
        y_pred = y_pred_all[test_mask]
        
    
    do_pos_y = intervention.do(torch.ones_like(data.s), y_pred_all)
    do_neg_y = intervention.do(torch.zeros_like(data.s) - 1, y_pred_all)
    
    cd = intervention.cal_causal(do_pos_y, do_neg_y, verbose=False)
    print(f'iidcd:{cd}')
    

def train(epoch, model, optimizer, criterion, data, args):
    criterion, fair_criterion = criterion
    model.train()
    loss_val = 0.0
    
    y_pred = model(data.x, data.edge_index)
    ce_loss = criterion(y_pred[data.train_mask,:], data.y[data.train_mask].view_as(y_pred[data.train_mask]))
    fair_loss = fair_criterion(y_pred[data.train_mask,:], data.s[data.train_mask].view_as(y_pred[data.train_mask]))
    loss =  ce_loss + args.lf * fair_loss
    
    loss_val += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        
    print(f'Epoch:{epoch}, loss:{loss_val}')
    
def inference_with_causal(model, criterion, data, test_mask, intervention:Intervention, args):
    model.eval()
    metric = FairMetrics()
    loss_val = 0.0
    correct = 0.0
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        y_pred = model(data.x, data.edge_index)
        y_pred = y_pred[test_mask]
        y_pred_label = torch.where(y_pred > 0.0, 1.0, 0.0)
        correct += y_pred_label.eq(data.y[test_mask].view_as(y_pred_label)).cpu().sum()
        loss = criterion(y_pred, data.y[test_mask].view_as(y_pred))
        loss_val += loss.item()
        
        metric.update(y_pred_label, data.s[test_mask].view_as(y_pred), data.y[test_mask].view_as(y_pred))
        
    rd, rd1, rd0 = metric.count_stats(verbose=True)
    print(f'y+^hat | s+:{rd1}')
    print(f'y+^hat | s-:{rd0}')
    do_pos = intervention.do(data, torch.ones_like(data.s))
    do_neg = intervention.do(data, torch.zeros_like(data.s) - 1)
    print(f'loss:{loss_val}, acc:{correct / s.shape[0]}, rd:{abs(rd)}') 
    cd = intervention.cal_causal(do_pos, do_neg, data.edge_index, model, verbose=True)
    print(f'cd: {abs(cd):.4f}')
    return loss_val, correct / s.shape[0]

def inference(model, criterion, data, test_mask, args):
    model.eval()
    metric = FairMetrics()
    loss_val = 0.0
    correct = 0.0
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        y_pred = model(data.x, data.edge_index)
        y_pred = y_pred[test_mask]
        y_pred_label = torch.where(y_pred > 0.0, 1.0, 0.0)
        correct += y_pred_label.eq(data.y[test_mask].view_as(y_pred_label)).cpu().sum()
        loss = criterion(y_pred, data.y[test_mask].view_as(y_pred))
        loss_val += loss.item()
        
        metric.update(y_pred_label, data.s[test_mask].view_as(y_pred), data.y[test_mask].view_as(y_pred))
        
    rd, rd1, rd0 = metric.count_stats(verbose=True)
    print(f'loss:{loss_val}, acc:{correct / s.shape[0]}, rd: {abs(rd)}')   
    coef_p = 0.0 if args.lf == 0.0 else 0.5
    
    return loss_val, correct / s.shape[0] - coef_p * abs(rd)



def main():
    args = parse_args()
    if args.configs is not None:
        parse_configs_file(args)
        
    args.lf = 1.5
    args.lf = 0.375
    args.lf = 0.0
    
    # TODO: add logger
    result_main_dir = os.path.join('logs', f'seed{args.gtseed}_A{args.A}_layers{args.num_layers}_k{args.k}')
    result_sub_dir = os.path.join(result_main_dir, f'gcn_lf{args.lf:.4f}_seed{args.seed}.log')
    
    sys.stdout = Logger(result_sub_dir)
    print(args)
    
    setup_seed(args.seed)
    device = "cuda:1" if torch.cuda.is_available() else "cpu" 
        
    D = datasets.__dict__[args.dataset](args)   
    dataset = D.data_loaders(mapping_function = args.mapping_function)
    dataset = dataset.to(device)
    classifier = GCN(args.input_size, 16, 2, args.num_class)
    classifier = classifier.to(device)
    CVAE = cvae(args.sensitive_size, args.hidden_channels, args.num_layers, args.A, args.input_size, args.latent_dim, args.z_size).to(device)
    CVAE.load_state_dict(torch.load('weights/cvae_training.pt'))
    intervention = InterventionForGNN(args, CVAE)
    
    # auto grad 
    criterion = nn.BCEWithLogitsLoss()
    fair_criterion = LogisticRiskDiff(args.protect)
    optimizer = torch.optim.Adam(classifier.parameters(), lr= 0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
    
    best_acc = 0.0
    for epoch in range(1000):
        train(epoch, classifier, optimizer, (criterion, fair_criterion), dataset, args)
        loss, acc = inference(classifier, criterion, dataset, dataset.test_mask, args)
        if acc >= best_acc:
            torch.save(classifier.state_dict(), 'weights/normal.pt')
            best_acc = acc
            
        scheduler.step()
        
    classifier.load_state_dict(torch.load('weights/normal.pt'))
    print('Best validate model')
    print('Test on valid')
    loss, acc = inference_with_causal(classifier, criterion, dataset, dataset.test_mask, intervention, args)
    
    
    iidintervention = IIDIntervention(args, dataset)
    inference_with_causal_iid(classifier, dataset, dataset.test_mask, iidintervention)
    
    print(f'GT do s+: {torch.mean(dataset.do_pos_y)}')
    print(f'GT do s-: {torch.mean(dataset.do_neg_y)}')
    
    with torch.no_grad():
        classifier.eval()
        
        gt_pos_y = classifier(dataset.do_pos_x, dataset.edge_index)
        gt_pos_y_label = torch.where(gt_pos_y > 0,1.0,0.0)
        
        gt_neg_y = classifier(dataset.do_neg_x, dataset.edge_index)
        gt_neg_y_label = torch.where(gt_neg_y > 0,1.0,0.0)
        
    print(f'GT do s+ predict y: {torch.mean(gt_pos_y_label)}')
    print(f'GT do s- predict y: {torch.mean(gt_neg_y_label)}')
    
    
    
if __name__ == '__main__':
    main()