import os.path as osp
import torch
import torch_geometric.transforms as T
import numpy as np
from Utils.args import arg_parse
import os
from model.train import *

import copy
from torch_geometric.data import Data

import torch.nn.functional as F
from Utils.load_data import *
from pygod.utils import load_data

import warnings

warnings.filterwarnings("ignore")
import itertools
from model.encoder import *
from model.decoder import *
from model.model import *

def seed_torch(seed=2025):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) 
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch()

if __name__ == '__main__':
    args = arg_parse()
    DS = args.DS = 'disney'

    if DS in ['cora','email']:
        data, anomaly_flag = load_real_world(args)
    elif DS in ['disney','books','weibo','reddit','enron']:
        data = load_data(name=DS)
        anomaly_flag = data.y
    else:
        data, features, anomaly_flag, _, _ = load_dataset(DS)
    
    # get real adj matrix
    real_adj = get_real_adj(data)
    
    # parameters
    out_channels = data.num_features
    num_features = data.num_features
    epochs = args.epochs
    alpha = args.alpha
    beta = args.beta
    gamma = args.gamma # regularization coefficient for the diversity loss
    embedding_channels = args.embedding_channels
    hidden_channels = args.hidden_channels
    
    # move to GPU (if available)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    x = data.x
    edge_index = data.edge_index
    y = real_adj.cpu().numpy()
    log_interval = 1
    prev_loss = 1e12
    flag = True
    best_auc = -np.inf
    accumulated_rounds = 1
    
    repNum = args.repNum
    auclist = np.zeros([repNum, 1])
    f1list = np.zeros([repNum, 1])
    auprclist = np.zeros([repNum, 1])
    for rep in range(repNum):
        num_layer = args.num_layers
        innerproduct_decoder = InnerProductDecoder(embedding_channels, hidden_channels, out_channels, num_layers=2)
        model = GAE(
            encoder=Rewrite_GCNEncoder(y.shape[0], num_features, hidden_channels, embedding_channels, num_layers=num_layer, aggr='add'),
            decoder=innerproduct_decoder#mlp_decoder,
        )
        
        model = model.to(device)
        data = data.to(device)
        real_adj = torch.tensor(real_adj).to(device)
        
        # inizialize the optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay = 0.)
        

        if args.eval == False:
            # Training the model
            para_set = [0.001,0.01,0.1,1,10,100,1000]
            for b_1 in range(len(para_set)):
                alpha = para_set[b_1]
                auc_mean_list = []
                f1_mean_list = []
                auprc_mean_list = []
                for o in range(len(para_set)):
                    beta = para_set[o]
                    trainer = GOD_Trainer(args, model, optimizer, alpha, beta, gamma, device)
                    auc = trainer.train(data, real_adj, anomaly_flag, args.epochs)
        else:
            trainer = GOD_Trainer(args, model, optimizer, alpha, beta, gamma, device)
            model.load_state_dict(torch.load("./weights/" + args.DS + ".pth")['model'])
            auc, _, _, _ = trainer.test(data, real_adj, anomaly_flag)

        auclist[rep] = auc
    AUCmean_std = np.around([np.mean(auclist), np.std(auclist)], decimals=4)
    print("Testing Statistic Results:" + str(AUCmean_std))
    with open('./results/' + args.DS + '_result.txt', 'a') as f:
        f.write('AUC:' + str(AUCmean_std[0]*100) + '$\pm$' + str(AUCmean_std[1]*100) + '\n')



