import torch
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from data import target_sln, reg_misclassification_detection, wga
from sklearn.metrics import accuracy_score
import argparse

parser = argparse.ArgumentParser(description = 'sets the noise level')
parser.add_argument('noise', metavar='noise', type = float, help = 'set the noise level')
args = parser.parse_args()


noise_level = args.noise
seed = 0 
dataset = 'cmnist'

np.random.seed(seed)
torch.manual_seed(seed)

def run_misclassify_upweight(val_data, test_data, c_id, lam, c):
    data = val_data.copy()
    detector = reg_misclassification_detection(data, test_data)
    data_ds = detector.run_detection_model(epochs=epochs_sel, lr=lr, weight_decay=w_decay_sel, c_sel=c_id, opt='AdamW', lr_scheduler = 'none')
    data = data.drop(data_ds.index)
    full_data = pd.concat([data,data_ds])
    weights = np.concatenate((np.ones(len(data)), lam*np.ones(len(data_ds))))
    model = LogisticRegression(penalty = 'l1', solver='liblinear',C=c).fit(full_data.drop(['true_target', 'target','group'],axis=1), full_data['target'], weights)
    return wga(model,data.drop(['true_target'], axis=1)), wga(model, test_data)
    

    

final_results = pd.DataFrame(columns=['dataset', 'noise', 'wga_mean', 'wga_std', 'c_id', 'lambda', 'c', 'iters', 'lr', 'exp'])

w_decay_sel = 1e-4
lr = 1e-5
epochs_sel = 6

C_VALUES = [0.007848]
C_SEL_VALUES = [33.598183]
LAMBDA_VALUES= [1, 3, 20, 30]


# The base path is the directory path of the embeddings (extracted from the base model) of the required datasets. 
# In the base path, the code expects the embeddings to be in a directory named after the datasets.
# The code expects the test and validation embeddings along with the test and validation target labels and domain 
# annotations (The code refers to the domain annotations as groups) in numpy file array format (.npy). For example, 
# the name of the celebA validation embeddings would be 'celebA_val_embeddings.npy' which is in the 'celebA' directory.

base_path = './'+dataset+'/'
X = np.load(base_path+dataset+'_val_embeddings.npy')
y = np.load(base_path+dataset+'_val_labels.npy')
group = np.load(base_path+dataset+'_val_groups.npy')
test_X = np.load(base_path+dataset+'_test_embeddings.npy')
test_y = np.load(base_path+dataset+'_test_labels.npy')
test_group = np.load(base_path+dataset+'_test_groups.npy')

original_val_data = pd.DataFrame(X)
original_val_data['target'] = y
original_val_data['group'] = group

final_test_data = pd.DataFrame(test_X)
final_test_data['target'] = test_y
final_test_data['group'] = test_group

print(dataset, noise_level, seed)




test_data = original_val_data.sample(frac=0.5,replace=False)

train_data = target_sln(original_val_data.drop(test_data.index).reset_index(drop=True),p=noise_level)

full_train_data = pd.concat([target_sln(test_data.reset_index(drop=True),p=noise_level), train_data],ignore_index=True)


results = pd.DataFrame(columns=['c_id', 'lambda', 'c', 'val_wga', 'test_wga', 'type'])


for c in C_VALUES:
    for lam in LAMBDA_VALUES:
        for c_sel in C_SEL_VALUES:
            rad_val, rad_test = run_misclassify_upweight(train_data, test_data, c_sel, lam, c)
            print(rad_test)
            results.loc[len(results)] = {'c_id':c_sel, 'lambda': lam, 'c': c, 'val_wga':rad_val,'test_wga':rad_test,'type':'RAD'}

rad_avg_param = results[results['type']=='RAD'].groupby(['c_id', 'lambda', 'c'])['test_wga'].mean().idxmax()

print(rad_avg_param)


rad_wga = np.zeros(10)

seeds = np.random.randint(200, size=(10)) 


for i, seed in enumerate(seeds):
    print(i)
    np.random.seed(seed)
    torch.manual_seed(seed)
    full_train_data = target_sln(original_val_data.reset_index(drop=True),p=noise_level)
    _,rad_wga[i] = run_misclassify_upweight(full_train_data, final_test_data, *rad_avg_param)
    print(rad_wga[i])


print("RAD (" + dataset + ")(" + str(noise_level) + "): ", rad_wga.mean(), rad_wga.std())
final_results.loc[len(final_results)] = {'dataset': dataset, 'noise': noise_level, 'wga_mean': rad_wga.mean(), 'wga_std': rad_wga.std(), 'c_id': rad_avg_param[0], 'lambda':rad_avg_param[1], 'iters': epochs_sel, 'lr':lr, 'c': rad_avg_param[2], 'exp': seed } 
results = results[0:0] 
            
path = 'results/vanilla_RAD/final_RAD_' + dataset + '_' + str(noise_level*100) + '.csv'
final_results.to_csv(path, mode='a', header=True)
