#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
from parameters_chexpert import para
import torch

from hinge_torch import Hinge_Partial_AUC
from logistic_torch_val import Logistic_Partial_AUC

alpha = 0.05
beta = 0.5
train_size = 0.9
val_size = 0.1
dataname = "chexpert"
sseed = 2
# dataname = "a9a"
torch.manual_seed(para.seed)
# torch.manual_seed(sseed)

X = torch.from_numpy(np.load('./dataset/chexpert/CheXpert_train_hidden_features_all.npy'))
y = torch.from_numpy(np.load('./dataset/chexpert/CheXpert_train_labels_all.npy'))
y[y <= 0] = 0

# split data into training and valuation
n = X.shape[0]
n_train = int(n * train_size)
n_val = n - n_train
index = torch.randperm(n)
index_train = index[0:n_train]
index_val = index[n_train:n]
X_train = X[index_train,:]
y_train = y[index_train,:]
X_val = X[index_val,:]
y_val = y[index_val,:]

Model_name = "Logistic"
num_dis = para.num_dis
# num_dis = 1

classifier = Logistic_Partial_AUC(lr_0=para.lr_0, lr_outer=para.lr_outer, mu=para.mu, num_iter=para.num_iter, T0=para.T0, batch=para.batch, c=para.c, alpha=alpha, beta=beta, seed =para.seed, dataname = para.dataname, Model_name=Model_name)

classifier.fit(X_train, y_train[:,num_dis-1], X_val, y_val[:,num_dis-1], X ,y[:,num_dis-1])

# w = classifier.best_w.cpu()
# ww = w.numpy()
auc_list = classifier.pauc_list()

s = str(para.dataname)+"_"+str(num_dis)+"_"+str(Model_name)+"_lr0="+str(para.lr_0)+"_lr_out="+str(para.lr_outer)+"_mu="+str(para.mu)+"_T0="+str(para.T0)+"_numStages="+str(para.num_iter)+"_c="+str(para.c)+"_seed="+str(para.seed)+".csv";
df = pd.DataFrame(auc_list,columns=(['pauc']))
df.to_csv(s)
