from util_real import * 
from Optimizer import *
import pandas as pd
from matplotlib import pyplot as plt
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--network_typ', type=str, default='circle')
parser.add_argument('--attack_typ', type=str, default='featureatt')
args = parser.parse_args()

## Set hyperparameters
n_workers = 50
# network_typs = ['circle','er']
# attack_typs = ['labelatt','featureatt']
q_degrees = [0.06,0.6] if args.network_typ == 'circle' else [0.2,0.6]
# q_degrees = [0.1]
byz_ratios = [0.15,0.25,0.35]
#byz_ratios = [0.45]
network_typs = [args.network_typ]
attack_typs = [args.attack_typ]
random_state = 2025
data_name = 'mnist'

## Other parameters
device = 'cuda'
n_classes = 10
lr_constant = 0.1
epochs = 5000
save2hd = False  # Whether to save Byzantine data to disk
output_dir = f'output_{data_name}'  # Directory to store metrics and other outputs
os.makedirs(output_dir, exist_ok=True)

## Grid search setup
if data_name == 'mnist':
    N = 60000; n = int(N / n_workers)
    lr_value = 0.01
elif data_name == 'cifar':
    N = 50000; n = int(N / n_workers)
    lr_value = 0.2

start = 0.5; end = 2; n_grids = 20
param_grid = {
    'cn': np.logspace(np.log10(np.max([np.log(N) * start, 1])), np.log10(end * np.sqrt(n)), n_grids)}
if data_name == 'mnist':
    param_grid['cn'] = param_grid['cn'][0:4]
elif data_name == 'cifar':
    param_grid['cn'] = param_grid['cn'][2:18]

## Load data
data = np.load(f'{data_name}_train.npz')
X_train, y_train = data['images'], data['labels']
data = np.load(f'{data_name}_test.npz')
X_test, y_test = data['images'], data['labels']
X_test = torch.from_numpy(X_test).to(device)
y_test = torch.from_numpy(y_test).to(device)
X_test = X_test.view(X_test.size(0), -1)

for byz_ratio in byz_ratios:
    for ii, attack_typ in enumerate(attack_typs):
        data_path = f'data_{data_name}_{byz_ratio}_{attack_typ}'

        ## Generate Byzantine (malicious) data
        Xs_all, ys_all, _, _, byz_labels = attack_and_save(save2hd=save2hd,
                n_workers=n_workers, byz_ratio=byz_ratio, attack_typ=attack_typ, 
                random_state=random_state, X_train=X_train, y_train=y_train, 
                data_path=data_path, data_name=data_name, device=device)
        Xs_all = [x.view(x.size(0), -1) for x in Xs_all]
        X_train_all, y_train_all, X_valid_all, y_valid_all = split_bucket_indices(Xs_all, ys_all)

        for network_typ in network_typs:
            for q_degree in q_degrees:
                print(f'\n byz_raio:{byz_ratio} attack_typ:{attack_typ} network_typ:{network_typ} q_degree:{q_degree}')

                W = generate_network(n_workers, typ=network_typ, q=q_degree, seed=random_state)
                neighbors = get_neighbors(W, include_diag=False)

                adfl = DFLOptimizer(neighbors, lr_constant=lr_constant,
                                    model_type='logistic',   
                                    input_dim=np.prod(X_train.shape[1:]), n_workers=n_workers,
                                    epochs=epochs, device=device, random_state=random_state,
                                    pretrained=False, custom_init=True)
                adfl.fit(Xs_all, ys_all)
                adfl.set_learning_rates([lr_value] * n_workers)
                adfl.refit(Xs_all, ys_all)
                init_parameters = adfl.get_parameters()
                adfl.lr_constant = lr_value
                losses, best_params, best_weights, adfl = grid_search(X_train_all, y_train_all,
                                                                      X_valid_all, y_valid_all,
                                                                      adfl, init_parameters, param_grid)
                df = pd.DataFrame({
                    'best_weights': best_weights,
                    'byz_labels': byz_labels,
                    'best_cn': [best_params['cn']] * len(best_weights)
                })

                # Save to CSV
                file_path = f'{output_dir}/bestparam_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}.csv'
                df.to_csv(file_path, index=False)
                print(f"\nData successfully saved to {file_path}")


