from dataset import process_dataset, get_dataset
from models import *
from utils import *
from runner import *
from config import args
import numpy as np
from tqdm import tqdm

if __name__ == '__main__':
    # seed_everything(args.seed)
    data = get_dataset(args, args.inid)
    target_data = get_dataset(args, args.outid)
    print("\n\n\n********************process source data********************")
    process_dataset(args, data)
    print("\n\n\n********************process target data********************")
    process_dataset(args, target_data)
    data = data.to(args.device)
    target_data = target_data.to(args.device)
    encoder = Encoder(args, args.pre_train_encoder).to(args.device)
    optimizer = torch.optim.Adam(params = encoder.parameters(), lr = args.lr, weight_decay = args.lr2_reg)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.6)
    criterion = nn.BCEWithLogitsLoss()

    s0_mask = data.train_mask & (data.sens_labels == 0)
    s1_mask = data.train_mask & (data.sens_labels == 1)
    y0_mask = data.train_mask & (data.y == 0)
    y1_mask = data.train_mask & (data.y == 1)
    y0s0_mask = (data.y==0) & (data.sens_labels == 0)
    y0s1_mask = (data.y==0) & (data.sens_labels == 1)
    y1s0_mask = (data.y==1) & (data.sens_labels == 0)
    y1s1_mask = (data.y==1) & (data.sens_labels == 1)
    perf = 0.0
    for epoch in tqdm(range(0, args.pre_train_epochs), desc='Pretrain'):
        mprint(f"\n=======epoch: {epoch}=======")
        encoder.train()
        optimizer.zero_grad()
        embeddings, logits = encoder(data.x, data.edge_index)
        # loss = criterion(logits[s0_mask].view(-1), data.y[s0_mask].float()) + criterion(logits[s1_mask].view(-1), data.y[s1_mask].float())
        # loss = criterion(logits[y0_mask].view(-1), data.y[y0_mask].float()) + criterion(logits[y1_mask].view(-1), data.y[y1_mask].float())

        # loss = criterion(logits[data.train_mask].view(-1), data.y[data.train_mask].float())

        loss = criterion(logits[data.train_mask & y0s0_mask].view(-1), data.y[data.train_mask & y0s0_mask].float()) + \
            criterion(logits[data.train_mask & y0s1_mask].view(-1), data.y[data.train_mask & y0s1_mask].float()) + \
            criterion(logits[data.train_mask & y1s0_mask].view(-1), data.y[data.train_mask & y1s0_mask].float()) + \
            criterion(logits[data.train_mask & y1s1_mask].view(-1), data.y[data.train_mask & y1s1_mask].float())
        loss.backward()
        optimizer.step()
        scheduler.step()

        if epoch % 50 == 0:
            accs, auc_rocs, parity, equality = evaluate_per_class(args, data, encoder)
            tar_accs, tar_auc_rocs, tar_parity, tar_equality = evaluate_per_class(args, target_data, encoder)
            # print(f"epoch {epoch}: acc: {accs}; auc: {auc_rocs}; parity: {parity}; equality: {equality}")

            if perf < accs["val"] + accs["test"] + auc_rocs["val"] + auc_rocs["test"]:
                perf = accs["val"] + accs["test"] + auc_rocs["val"] + auc_rocs["test"]
                print(f"source : epoch {epoch}: acc: {accs}; auc: {auc_rocs}; parity: {parity}; equality: {equality}")
                print(f"target : epoch {epoch}: acc: {tar_accs}; auc: {tar_auc_rocs}; parity: {tar_parity}; equality: {tar_equality}")
                pre_train_path = args.pre_train_path.format(args.pre_train_encoder, args.dataset, args.inid)
                torch.save({
                    'encoder': encoder.state_dict(),
                    }, pre_train_path)
                mprint('save successful')
