from util_real import * 
from Optimizer import *
import pandas as pd
from Competitor import *
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--network_typ', type=str, default='circle')
parser.add_argument('--attack_typ', type=str, default='featureatt')
args = parser.parse_args()

def fit_oracle_with_strategy(opt, X, y, data_name='mnist'):
    if data_name == 'mnist':
        opt.set_learning_rates(0.1)
        opt.fit(X, y, epochs=1000)
        opt.set_learning_rates(0.01)
        opt.fit(X, y, epochs=3000)
        opt.set_learning_rates(0.001)
        opt.fit(X, y, epochs=1000)
    elif data_name == 'cifar':
        opt.set_learning_rates(0.5)
        opt.fit(X, y, epochs=5000)
        opt.set_learning_rates(0.2)
        opt.fit(X, y, epochs=3000)
        opt.set_learning_rates(0.01)
        opt.fit(X, y, epochs=1000)

## Set hyperparameters
n_workers = 50
#network_typs = ['circle','er']
q_degrees = [0.06,0.6] if args.network_typ == 'circle' else [0.2,0.6]
byz_ratios = [0.15,0.25,0.35]
#attack_typs = ['labelatt','featureatt']
network_typs = [args.network_typ]
attack_typs = [args.attack_typ]
#attack_typs = ['labelatt']
random_state = 2025
data_name = 'mnist'

## Other configurations
device = 'cuda'
n_classes = 10
save2hd = False # Whether to save Byzantine data to disk
output_dir = f'output_{data_name}' # Path to save accuracy and other output
os.makedirs(output_dir, exist_ok=True)
model_type = 'logistic' if data_name == 'cifar' else 'lenet5' # lenet5 for MNIST
print_freq = 1

## Load data
data = np.load(f'/database/share/shuyuan/RDFL/data/{data_name}_train.npz')
X_train, y_train = data['images'], data['labels']
data = np.load(f'/database/share/shuyuan/RDFL/data/{data_name}_test.npz')
X_test, y_test = data['images'], data['labels']
X_test = torch.from_numpy(X_test).to(device)
y_test = torch.from_numpy(y_test).to(device)
input_dim = np.prod(X_train.shape[1:])

for byz_ratio in byz_ratios:
    for ii, attack_typ in enumerate(attack_typs):
        data_path = f'data_{data_name}_{byz_ratio}_{attack_typ}'
        
        ## Generate poisoned data
        Xs_all, ys_all, X_oracle_tensor, y_oracle_tensor, _ = attack_and_save(
            save2hd=save2hd,
            n_workers=n_workers, byz_ratio=byz_ratio, attack_typ=attack_typ, 
            random_state=random_state, X_train=X_train, y_train=y_train, 
            data_path=data_path, data_name=data_name, device=device)

        ## Oracle estimation
        if ii == 0:
            opt = Optimizer(model_type=model_type,num_classes=n_classes,
                lr=0.1,device=device,random_state=random_state,pretrained=False,custom_init=True,
                           input_dim=input_dim)
            fit_oracle_with_strategy(opt, X_oracle_tensor, y_oracle_tensor,
                                     data_name=data_name)
            acc_single,loss_single = evaluate(X_test, y_test,opt,device)
            df = pd.DataFrame([acc_single,loss_single],['acc','loss']).T
            df.to_csv(f'{output_dir}/metric_single_{byz_ratio}.csv', index=False)
        for network_typ in network_typs:
            for q_degree in q_degrees:
                print(f'\n byz_ratio:{byz_ratio} attack_typ:{attack_typ} network_typ:{network_typ} q_degree:{q_degree}')

                W = generate_network(n_workers, typ=network_typ, q=q_degree, seed=random_state)
                neighbors = get_neighbors(W, include_diag=False)

                ## Initial estimation
                adfl = DFLOptimizer(
                    neighbors, lr_constant=1, model_type=model_type,
                    n_workers=n_workers, device=device,
                    random_state=random_state, pretrained=False, custom_init=True,
                    input_dim=input_dim
                )

                adfl.epochs = 100
                adfl.fit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)
                adfl.epochs = 4900
                adfl.set_learning_rates([0.05] * n_workers)
                adfl.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)
                adfl.epochs = 3000
                adfl.set_learning_rates([0.01] * n_workers)
                adfl.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)
                adfl.epochs = 1000
                adfl.set_learning_rates([0.005] * n_workers)
                adfl.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                param_dfl = adfl.get_parameters()
                torch.save(param_dfl, f'{output_dir}/param_init_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}.pth')
                acc, loss = adfl.save_history(save_path=f'{output_dir}/metric_init_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}')

                ## aDFL estimation
                file_path = f'{output_dir}/bestparam_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}.csv'
                best_param = pd.read_csv(file_path)
                cn = best_param['best_cn'][0]
                best_weights = np.array(best_param['best_weights'])

                adfl._initialize_models()
                adfl._initialize_history()
                adp_lr = 1 * best_weights
                adfl.set_learning_rates(adp_lr)
                adfl.epochs = 100
                adfl.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)
                adp_lr = 0.05 * best_weights
                adfl.set_learning_rates(adp_lr)
                adfl.epochs = 4900
                adfl.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)
                adfl.cn = cn
                adfl.lr_constant = 0.01
                adfl._compute_adaptive_lr(Xs_all, ys_all)
                adfl.epochs = 3000
                adfl.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)
                adfl.lr_constant = 0.005
                adfl._compute_adaptive_lr(Xs_all, ys_all)
                adfl.epochs = 1000
                adfl.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                param_dfl = adfl.get_parameters()
                torch.save(param_dfl, f'{output_dir}/param_adfl_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}.pth')
                acc, loss = adfl.save_history(save_path=f'{output_dir}/metric_adfl_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}')
