import argparse
import torch
import numpy as np
import torch.nn.functional as F
from data_utils import get_dataset
from eval import evaluate_ged3
from utils import seed_everything, get_enc_cls_opt
from models import *
from torch.nn import BCEWithLogitsLoss


def contrastive_loss(z1, z2, temperature=0.1):
    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)
    
    sim_matrix = torch.mm(z1, z2.T) / temperature

    labels = torch.arange(z1.size(0)).to(z1.device)
    loss = F.cross_entropy(sim_matrix, labels)
    return loss


def run(data, args):
    acc, f1, auc_roc, parity, equality = np.zeros(args.runs), np.zeros(args.runs), np.zeros(args.runs), np.zeros(args.runs), np.zeros(args.runs)

    data = data.to(args.device)
    data_A = data.clone()
    data_A.x = torch.ones_like(data.x)
    S = data.sens.unsqueeze(dim=1)  

    encoder_A = Encoder_A(args.num_features, args.hidden, args.hidden).to(args.device)
    encoder_X = Encoder_X(args.num_features, args.hidden, args.hidden).to(args.device)
    encoder_D = Encoder_D(args.num_features, args.hidden, args.hidden).to(args.device)

    _, classifier, _, optimizer_c = get_enc_cls_opt(args)
    projector = Projector(3 * args.hidden, args.hidden).to(args.device) 
    optimizer_projector = torch.optim.Adam(projector.parameters(), lr=args.e_lr, weight_decay=args.e_wd)
    optimizer_encoders = torch.optim.Adam(
        list(encoder_A.parameters()) + list(encoder_X.parameters()) + list(encoder_D.parameters()),
        lr=args.e_lr,
        weight_decay=args.e_wd
    )

    diffusion_prop = APPNP(K=args.k, alpha=0.15).to(args.device)

    propensity_model = PropensityModel(args.num_features).to(args.device)
    optimizer_prop = torch.optim.Adam(propensity_model.parameters(), lr=0.01)
    criterion_prop = BCEWithLogitsLoss()

    # Pre-train propensity model (estimate P(S=1 | X))
    propensity_model.train()
    for _ in range(100):  
        optimizer_prop.zero_grad()
        prop_out = propensity_model(data.x)
        loss_prop = criterion_prop(prop_out, S.float())
        loss_prop.backward()
        optimizer_prop.step()

    # Compute IPW weights
    propensity_model.eval()
    with torch.no_grad():
        prop_scores = propensity_model(data.x).squeeze()  # P(S=1 | X)
    # IPW: for S=1, w = 1 / prop, for S=0, w = 1 / (1 - prop)
    w = torch.where(S.squeeze() == 1, 1.0 / prop_scores.clamp(min=1e-3), 1.0 / (1 - prop_scores).clamp(min=1e-3))
    w = w / w.mean() 
    w = w.clamp(min=0.1, max=10.0)  

    for count in range(args.runs):
        seed_everything(count + args.seed)
        encoder_A.reset_parameters()
        encoder_X.reset_parameters()
        encoder_D.reset_parameters()
        projector.reset_parameters()
        classifier.reset_parameters()
        best_val_tradeoff = 0
        for epoch in range(0, args.epochs):
            encoder_A.train()
            encoder_X.train()
            encoder_D.train()
            classifier.train()
            projector.train()

            optimizer_encoders.zero_grad()
            optimizer_c.zero_grad()
            if args.zero_grad:
                optimizer_projector.zero_grad()

            # IPW-weighted features
            X_weighted = data.x * w.unsqueeze(1)

            diffused_x = diffusion_prop(X_weighted, data.edge_index)
            # Encoder
            mu_a, log_std_a = encoder_A(data_A.x, data.edge_index)  
            mu_x, log_std_x = encoder_X(data.x)
            mu_d, log_std_d = encoder_D(diffused_x)

            z_A = mu_a + torch.randn_like(log_std_a) * torch.exp(log_std_a)
            z_X = mu_x + torch.randn_like(log_std_x) * torch.exp(log_std_x)
            z_D = mu_d + torch.randn_like(log_std_d) * torch.exp(log_std_d)

            z = torch.cat([z_A, z_X, z_D], dim=-1)
            z_p = projector(z)
            z_S = torch.cat([z_p, S], dim=-1)

            output = classifier(z_S)

            loss_con_ax = contrastive_loss(z_A[data.train_mask], z_X[data.train_mask])
            loss_con_ad = contrastive_loss(z_A[data.train_mask], z_D[data.train_mask])
            loss_con_xd = contrastive_loss(z_X[data.train_mask], z_D[data.train_mask])
            loss_con = (loss_con_ax + loss_con_ad + loss_con_xd) / 3 

            loss_c = F.binary_cross_entropy_with_logits(
                output[data.train_mask], data.y[data.train_mask].unsqueeze(1).to(args.device)
            )
            kl_a = encoder_A.kl_loss(mu_a[data.train_mask], log_std_a[data.train_mask])
            kl_x = encoder_X.kl_loss(mu_x[data.train_mask], log_std_x[data.train_mask])
            kl_d = encoder_D.kl_loss(mu_d[data.train_mask], log_std_d[data.train_mask])

            total_loss = loss_c + args.beta * (kl_a + kl_x + kl_d) + args.beta_c * loss_con
            total_loss.backward()
            optimizer_encoders.step()
            optimizer_c.step()
            optimizer_projector.step()

            accs, auc_rocs, F1s, tmp_parity, tmp_equality = evaluate_ged3(classifier, projector, encoder_A, encoder_X,
                                                                          encoder_D, data, diffused_x)

            if epoch % 10 == 0:
                print(
                    "RUN: {}/{}, Epoch: {:04}/{:04} | Val Acc: {:.4f}, Test Acc: {:.4f}, Test AUC: {:.4f}, Test F1: {:.4f}, Test SP: {:.4f}, Test EO: {:.4f}".format(
                        count + 1, args.runs, epoch, args.epochs, accs['val'], accs['test'], auc_rocs['test'],
                        F1s['test'], tmp_parity['test'], tmp_equality['test']
                    ))

            if (auc_rocs['val'] + F1s['val'] + accs['val'] - args.alpha * (
                    tmp_parity['val'] + tmp_equality['val']) > best_val_tradeoff):
                test_acc = accs['test']
                test_auc_roc = auc_rocs['test']
                test_f1 = F1s['test']
                test_parity, test_equality = tmp_parity['test'], tmp_equality['test']

                best_val_tradeoff = auc_rocs['val'] + F1s['val'] + \
                                    accs['val'] - args.alpha * (tmp_parity['val'] + tmp_equality['val'])

                print(
                    "\033[0;30;41m RUN: {}/{}, Epoch: {:04}/{:04} | Val Acc: {:.4f}, Test Acc: {:.4f}, Test AUC: {:.4f}, Test F1: {:.4f}, Test SP: {:.4f}, Test EO: {:.4f}\033[0m".format(
                        count + 1, args.runs, epoch, args.epochs, accs['val'], accs['test'], auc_rocs['test'],
                        F1s['test'], tmp_parity['test'], tmp_equality['test']
                    ))

        acc[count] = test_acc
        f1[count] = test_f1
        auc_roc[count] = test_auc_roc
        parity[count] = test_parity
        equality[count] = test_equality

    return acc, f1, auc_roc, parity, equality


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='german')
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--c_lr', type=float, default=0.01)
    parser.add_argument('--c_wd', type=float, default=0)
    parser.add_argument('--e_lr', type=float, default=0.01)
    parser.add_argument('--e_wd', type=float, default=0)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--hidden', type=int, default=16)
    parser.add_argument('--hidden2', type=int, default=17)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--encoder', type=str, default='GCN')
    parser.add_argument('--alpha', type=float, default=1)
    parser.add_argument('--gpu_num', type=int, default=0)
    parser.add_argument('--warmup', type=int, default=5)
    parser.add_argument('--eta', type=float, default=0.5)
    parser.add_argument('--gamma', type=float, default=0.5)
    parser.add_argument('--beta', type=float, default=1e-3)
    parser.add_argument('--beta_c', type=float, default=1e-3)
    parser.add_argument('--k', type=int, default=3)
    parser.add_argument('--zero_grad', action='store_true',
                        help='')

    args = parser.parse_args()
    args.device = torch.device('cuda:{}'.format(args.gpu_num) if torch.cuda.is_available() else 'cpu')
    data, args.sens_idx, args.x_min, args.x_max = get_dataset(args.dataset)
    args.num_features, args.num_classes = data.x.shape[1], 1

    acc, f1, auc_roc, parity, equality = run(data,args)
    print('======' + args.dataset + args.encoder + '======')
    print('auc_roc: {:.2f} +- {:.2f}'.format(np.mean(auc_roc) * 100, np.std(auc_roc) * 100))
    print('Acc: {:.2f} +- {:.2f}'.format(np.mean(acc) * 100, np.std(acc) * 100))
    print('f1: {:.2f} +- {:.2f}'.format(np.mean(f1) * 100, np.std(f1) * 100))
    print('parity: {:.2f} +- {:.2f}'.format(np.mean(parity) * 100, np.std(parity) * 100))
    print('equality: {:.2f} +- {:.2f}'.format(np.mean(equality) * 100, np.std(equality) * 100))