import importlib
import os
import sys
import time

import torch
import torch.nn as nn
from torchvision.ops import MLP

import datasets

from args import parse_args

from util.general_util import (
    parse_configs_file,
    setup_seed,
    FairMetrics,
    Logger
)
from models.cvae import cvae
from intervention import Intervention, IIDIntervention
from causaldifference import LogisticCausalkDiff

def inference_with_causal_iid(model, data, test_mask, intervention):
      
    model.eval()
    
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        y_pred_all = model(data.x)
        y_pred = y_pred_all[test_mask]
        
    
    do_pos_y = intervention.do(torch.ones_like(data.s), y_pred_all)
    do_neg_y = intervention.do(torch.zeros_like(data.s) - 1, y_pred_all)
    
    cd = intervention.cal_causal(do_pos_y, do_neg_y, verbose=False)
    print(f'iidcd:{cd}')


def train_with_causal(epoch, model, optimizer, criterion, data, intervention:Intervention,args):
    model.train()
    loss_val = 0.0
    
    y_pred = model(data.x)
    
    do_pos = intervention.do(data, torch.ones_like(data.s))[data.train_mask,:]
    do_neg = intervention.do(data, torch.zeros_like(data.s) - 1)[data.train_mask,:]
    
    ce_loss = criterion(y_pred[data.train_mask,:], data.y[data.train_mask].view_as(y_pred[data.train_mask]))
    fair_loss = intervention.cal_loss(do_pos, do_neg, model)

    loss =  ce_loss + args.lf * fair_loss
    
    loss_val += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        
    print(f'Epoch:{epoch}, loss:{loss_val}')
    
def inference_with_causal(model, criterion, data, test_mask, intervention:Intervention, args):
    model.eval()
    metric = FairMetrics()
    loss_val = 0.0
    correct = 0.0
    s = data.s[test_mask, :].squeeze(1)
    with torch.no_grad():
        y_pred = model(data.x)
        y_pred = y_pred[test_mask]
        y_pred_label = torch.where(y_pred > 0.0, 1.0, 0.0)
        correct += y_pred_label.eq(data.y[test_mask].view_as(y_pred_label)).cpu().sum()
        loss = criterion(y_pred, data.y[test_mask].view_as(y_pred))
        loss_val += loss.item()
        
        metric.update(y_pred_label, data.s[test_mask].view_as(y_pred), data.y[test_mask].view_as(y_pred))
        
    rd, rd1, rd0 = metric.count_stats(verbose=True)
    print(f'y+^hat | s+:{rd1}')
    print(f'y+^hat | s-:{rd0}')
    
    do_pos = intervention.do(data, torch.ones_like(data.s))
    do_neg = intervention.do(data, torch.zeros_like(data.s) - 1)
    print(f'loss:{loss_val}, acc:{correct / s.shape[0]}, rd:{abs(rd)}') 
    cd = intervention.cal_causal(do_pos, do_neg, model, verbose=True)
    print(f'cd: {cd:.4f}')
    coef_p = 0.0 if args.lf == 0.0 else 0.5
    
    return loss_val, correct / s.shape[0] - coef_p * abs(cd)



class LinearClassfier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class GNNWithMLP(nn.Module):
    def __init__(self, args, gnn):
        super().__init__()
        self.gnn = gnn
        self.MLP = MLP(args.A + args.input_size, [args.input_size])
        
    def forward(self, data):
        A = self.gnn(2 * data.s - 1, data.edge_index)
        za = torch.cat((data.z, A), dim=1)
        x = self.MLP(za)
        
        return x

def main():
    args = parse_args()
    if args.configs is not None:
        parse_configs_file(args)

    args.lf = 0.00
    
    result_main_dir = os.path.join('logs_syn_credit', f'seed{args.gtseed}_A{args.A}_layers{args.num_layers}_k{args.k}')
    result_sub_dir = os.path.join(result_main_dir, f'cd_lf{args.lf:.4f}_seed{args.seed}.log')
    sys.stdout = Logger(result_sub_dir)
    print(args)
    
    setup_seed(args.seed)
    device = "cuda:1" if torch.cuda.is_available() else "cpu" 
        
    D = datasets.__dict__[args.dataset](args)   
    dataset = D.data_loaders(mapping_function = args.mapping_function)
    dataset = dataset.to(device)
    classifier = LinearClassfier(args.input_size, 16, args.num_class)
    classifier = classifier.to(device)
    CVAE = cvae(args.sensitive_size, args.hidden_channels, args.num_layers, args.A, args.input_size, args.latent_dim,args.z_size).to(device)
    CVAE.load_state_dict(torch.load('weights/cvae_training.pt'))
 
    intervention = Intervention(args, CVAE)
 
    
    
    
    # auto grad 
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr= 0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
    
    best_acc = 0.0
    for epoch in range(100):
        train_with_causal(epoch, classifier, optimizer, criterion, dataset, intervention,args)
        loss, acc = inference_with_causal(classifier, criterion, dataset, dataset.test_mask, intervention,args)
        if acc >= best_acc:
            torch.save(classifier.state_dict(), 'weights/normal.pt')
            best_acc = acc
            
        scheduler.step()
        
    classifier.load_state_dict(torch.load('weights/normal.pt'))
    print('Best validate model')
    print('Test on valid')
    loss, acc = inference_with_causal(classifier, criterion, dataset, dataset.test_mask, intervention, args)
    
    iidintervention = IIDIntervention(args, dataset)
    inference_with_causal_iid(classifier, dataset, dataset.test_mask, iidintervention)
    
    print(f'GT do s+: {torch.mean(dataset.do_pos_y)}')
    print(f'GT do s-: {torch.mean(dataset.do_neg_y)}')
    
    with torch.no_grad():
        classifier.eval()
        
        gt_pos_y = classifier(dataset.do_pos_x)
        gt_pos_y_label = torch.where(gt_pos_y > 0,1.0,0.0)
        
        gt_neg_y = classifier(dataset.do_neg_x)
        gt_neg_y_label = torch.where(gt_neg_y > 0,1.0,0.0)
        
    print(f'GT do s+ predict y: {torch.mean(gt_pos_y_label)}')
    print(f'GT do s- predict y: {torch.mean(gt_neg_y_label)}')
        
        
    
    
    
    
if __name__ == '__main__':
    main()