import importlib
import os
import sys
import time

import torch
import torch.nn as nn
from torchvision.ops import MLP

import datasets

from args import parse_args

from util.general_util import (
    parse_configs_file,
    setup_seed,
    FairMetrics,
    Logger
)
from models.cvae import cvae
from intervention import Intervention, IIDIntervention
from torch_geometric.nn import GCN



def train_with_causal_iid(epoch, model, optimizer, criterion, data, intervention:IIDIntervention,args):
    model.train()
    loss_val = 0.0
    
    y_pred = model(data.x, data.edge_index)
    
    ce_loss = criterion(y_pred[data.train_mask,:], data.y[data.train_mask].view_as(y_pred[data.train_mask]))
    fair_loss = intervention.cal_loss(y_pred)

    loss =  ce_loss + args.lf * fair_loss
    
    
    loss_val += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        
    print(f'Epoch:{epoch}, loss:{loss_val}')
    
def inference_with_causal_iid(model, criterion, data, test_mask, intervention, args):
    
    intervention, iidintervention = intervention    
    model.eval()
    metric = FairMetrics()
    loss_val = 0.0
    correct = 0.0
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        y_pred_all = model(data.x, data.edge_index)
        y_pred = y_pred_all[test_mask]
        y_pred_label = torch.where(y_pred > 0.0, 1.0, 0.0)
        correct += y_pred_label.eq(data.y[test_mask].view_as(y_pred_label)).cpu().sum()
        loss = criterion(y_pred, data.y[test_mask].view_as(y_pred))
        loss_val += loss.item()
        
        metric.update(y_pred_label, data.s[test_mask].view_as(y_pred), data.y[test_mask].view_as(y_pred))
        
    rd, rd1, rd0 = metric.count_stats(verbose=True)
    print(f'y+^hat | s+:{rd1}')
    print(f'y+^hat | s-:{rd0}')
    
    do_pos_y = iidintervention.do(torch.ones_like(data.s), y_pred_all)
    do_neg_y = iidintervention.do(torch.zeros_like(data.s) - 1, y_pred_all)
    
    do_pos = intervention.do(data, torch.ones_like(data.s))
    do_neg = intervention.do(data, torch.zeros_like(data.s) - 1)
    noniidcd = intervention.cal_causal(do_pos, do_neg, model, verbose=True)
    

    print(f'loss:{loss_val}, acc:{correct / s.shape[0]}, rd:{abs(rd)}') 
    print(f'noniid:{noniidcd}')
    cd = iidintervention.cal_causal(do_pos_y, do_neg_y)
    print(f'iidcd:{cd}')
    coef_p = 0.0 if args.lf == 0.0 else 0.5
    
    return loss_val, correct / s.shape[0] - coef_p * abs(cd)



class LinearClassfier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class GNNWithMLP(nn.Module):
    def __init__(self, args, gnn):
        super().__init__()
        self.gnn = gnn
        self.MLP = MLP(args.A + args.input_size, [args.input_size])
        
    def forward(self, data):
        A = self.gnn(2 * data.s - 1, data.edge_index)
        za = torch.cat((data.z, A), dim=1)
        x = self.MLP(za)
        
        return x

def main():
    args = parse_args()
    if args.configs is not None:
        parse_configs_file(args)
        
    args.lf = 15
    args.lf = 9
    

    result_main_dir = os.path.join('logs_german', f'seed{args.gtseed}_A{args.A}_layers{args.num_layers}_k{args.k}')
    result_sub_dir = os.path.join(result_main_dir, f'iid_lf{args.lf:.4f}_seed{args.seed}.log')
    sys.stdout = Logger(result_sub_dir)
    print(args)
    
    setup_seed(args.seed)
    device = "cuda:1" if torch.cuda.is_available() else "cpu" 
        
    D = datasets.__dict__[args.dataset](args)   
    dataset = D.data_loaders(mapping_function = args.mapping_function)
    dataset = dataset.to(device)
    classifier = GCN(args.input_size, 16, args.num_class)
    classifier = classifier.to(device)
    CVAE = cvae(args.sensitive_size, args.hidden_channels, args.num_layers, args.A, args.input_size, args.latent_dim, args.z_size).to(device)
    CVAE.load_state_dict(torch.load('weights/cvae_training.pt'))
   
    iidintervention = IIDIntervention(args, dataset)
    intervention = Intervention(args, CVAE)
    
    
    
    # auto grad 
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr= 0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
    
    best_acc = 0.0
    for epoch in range(100):
        train_with_causal_iid(epoch, classifier, optimizer, criterion, dataset, iidintervention,args)
        loss, acc = inference_with_causal_iid(classifier, criterion, dataset, dataset.test_mask, (intervention, iidintervention),args)
        if acc >= best_acc:
            torch.save(classifier.state_dict(), 'weights/normal.pt')
            best_acc = acc
            
        scheduler.step()
        
    classifier.load_state_dict(torch.load('weights/normal.pt'))
    print('Best validate model')
    print('Test on valid')
    loss, acc = inference_with_causal_iid(classifier, criterion, dataset, dataset.test_mask, (intervention, iidintervention), args)
    
    print(f'GT do s+: {torch.mean(dataset.do_pos_y)}')
    print(f'GT do s-: {torch.mean(dataset.do_neg_y)}')
    
    with torch.no_grad():
        classifier.eval()
        
        gt_pos_y = classifier(dataset.do_pos_x)
        gt_pos_y_label = torch.where(gt_pos_y > 0,1.0,0.0)
        
        gt_neg_y = classifier(dataset.do_neg_x)
        gt_neg_y_label = torch.where(gt_neg_y > 0,1.0,0.0)
        
    print(f'GT do s+ predict y: {torch.mean(gt_pos_y_label)}')
    print(f'GT do s- predict y: {torch.mean(gt_neg_y_label)}')
        
        
    
    
    
    
if __name__ == '__main__':
    main()