# Code modified from https://github.com/Nicolas-Pinon/uad_ocsvm_guided_repr_learning/blob/main/README.md
# Require more computation ressources than CLOE method, please use GPU acceleration to run this code
from og import OCSVMguidedAutoencoder
import torch
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import copy
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score
import time



path = "datasets"

NUM_WORKER = 0
NUM_EPOCH = 200
LEARNING_RATE = 1e-3
if torch.cuda.is_available():
        print('Using Cuda')
        device = 'cuda:0'
else : 
        device = 'cpu'
RANDOM_SEED = 49


for file in os.listdir(path):
    data_name = file[:-4]
    print(data_name)
    time_begin = time.time()
    data = np.load(os.path.join(path, file), allow_pickle=True)
    X, y = data['X'], data['y']
    x = torch.from_numpy(X).to(device)

    if x[y==0].shape[0]<5000:
        test_size = 0.1
    else:
        test_size = 1- 5000/x[y==0].shape[0]
    X_train_valid, X_test = train_test_split(x[y==0].to(dtype=torch.float64).cpu().numpy(), test_size=test_size, random_state=RANDOM_SEED)
    X_train, X_valid= train_test_split(X_train_valid, test_size=0.2, random_state=RANDOM_SEED)

    if data_name == "15_Hepatitis":
        BATCH_SIZE_train = 20
    else : 
        BATCH_SIZE_train = 100
    BATCH_SIZE_valid = x.shape[0]

    train_loader = DataLoader(X_train.to(dtype=torch.float32), batch_size = BATCH_SIZE_train, num_workers = NUM_WORKER, drop_last=True)
    og_model = OCSVMguidedAutoencoder(BATCH_SIZE_train, BATCH_SIZE_valid, X_train.shape[1]).to(device)
    optimizer = optim.SGD(og_model.parameters(), lr=LEARNING_RATE)
    best_loss = float('inf')

    for epoch in range(NUM_EPOCH) :
        overall_loss = 0
        for x_train in train_loader:
            total_loss, mse, ocsvm_obj = og_model.compute_loss(x_train) 
            total_loss.backward()  # by SGD
            optimizer.step()
            overall_loss += total_loss.item()
        
        print(f'Epoch: {epoch}, train loss: {overall_loss}')

        if overall_loss < best_loss:
            best_loss = overall_loss
            epoch_best_loss = epoch
            best_model_wts = copy.deepcopy(og_model.state_dict())

    og_model.load_state_dict(best_model_wts)
    og_model.eval()
    if  x.shape[0]%2 != 0 :
        x = x[:-1]
        y_true=y[x.shape[0]//2:-1]
    else :
        y_true=y[x.shape[0]//2:]
    x_hat, z = og_model.forward(x.to(device, dtype=torch.float32))
    alpha, K_sv = og_model.solve_ocsvm(z, training=False)
    ocsvm_obj = og_model.ocsvm_objective(alpha, z, K_sv)
    score = [0 if ocsvm_obj[0][i].item()>0 else 1 for i in range (len(ocsvm_obj[0]))]
    aucroc = roc_auc_score(y_true=y_true, y_score=score)
    aucap = average_precision_score(y_true=y_true, y_score=score, pos_label=1)


    print(f'AUC ROC for Christoffel score: {aucroc}')
    print(f'AP AUC for Christoffel score: {aucap}')
    print(f'time: {time.time()-time_begin}')
    result_path = f"CLOE/results/{args.data_name}/{args.seed}/Og/"
    os.makedirs(result_path, exist_ok=True)
    np.save(
            result_path + "result.npy",
            {
                "AUC ROC": aucroc,
                "AP AUC": aucap,
                "F1 Score": 0,
            },
        )



