#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
from parameters_chexpert import para
import torch

from logistic_torch import Logistic_Partial_AUC

kvalue1 = 10000
kvalue2 =100000
train_size = 0.5
val_size = 0.25
dataname = "chexpert"
sseed = 2
# dataname = "a9a"

X = torch.from_numpy(np.load('./dataset/chexpert/CheXpert_train_hidden_features_all.npy'))
y = torch.from_numpy(np.load('./dataset/chexpert/CheXpert_train_labels_all.npy'))
y[y <= 0] = 0

Model_name = "Logistic"
# num_dis = para.num_dis
num_dis = 5

classifier = Logistic_Partial_AUC(lr_0=para.lr_0, lr_outer=para.lr_outer, mu=para.mu, num_iter=para.num_iter, T0=para.T0, batch=para.batch, c=para.c, k1_value = kvalue1, k2_value = kvalue2, seed =para.seed, dataname = para.dataname, Model_name=Model_name)
# classifier = Logistic_Partial_AUC(lr_0=5, lr_outer=2, mu=1000, num_iter=15, T0=50, batch=100, k1_value = kvalue1, k2_value = kvalue2, seed = sseed, dataname = dataname, Model_name=Model_name)

classifier.fit(X, y, X, y, X ,y)

data_pass_dc = classifier.data_pass
time_dc = classifier.time_list
loss_dc = classifier.loss_list
auc_list = classifier.pauc_list
w_1 = classifier.w_1.cpu()
w_2 = classifier.w_2.cpu()

# save w1,w2
w1 = w_1.numpy()
w2 = w_2.numpy()

df1 = pd.DataFrame(data_pass_dc, columns=['data_pass'])
df2 = pd.DataFrame(loss_dc, columns=['loss'])
df3 = pd.DataFrame(w1.T,columns=['w1'])
df4 = pd.DataFrame(w2.T,columns=['w2'])
df5 = pd.DataFrame(auc_list, columns=['pAUC'])
df6 = pd.DataFrame(time_dc, columns=['time'])

d1 = df3.join(df4)
d2 = d1.join(df1)
d3 = d2.join(df2)
d4 = d3.join(df5)
d5 = d4.join(df6)
# s = str(para.dataname)+"_lr0="+str(para.lr_0)+"_lr_out="+str(para.lr_outer)+"_mu="+str(para.mu)+"_T0="+str(para.T0)+"_numStages="+str(para.num_iter)+"_margin="+str(para.margin)+"_seed="+str(para.seed)+"_random.csv";
# d5.to_csv(s)
s = str(para.dataname)+"_"+str(num_dis)+"_"+str(Model_name)+"_lr0="+str(para.lr_0)+"_lr_out="+str(para.lr_outer)+"_mu="+str(para.mu)+"_T0="+str(para.T0)+"_numStages="+str(para.num_iter)+"_c="+str(para.c)+"_seed="+str(para.seed)+"_random_baseline.csv";
d5.to_csv(s)


