from util_real import * 
from Optimizer import *
import pandas as pd
from Competitor import *
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--network_typ', type=str, default='circle')
parser.add_argument('--attack_typ', type=str, default='featureatt')
args = parser.parse_args()

## Set hyperparameters
n_workers = 50
# network_typs = ['circle','er']
q_degrees = [0.06,0.6] if args.network_typ == 'circle' else [0.2,0.6]
byz_ratios = [0.15,0.25,0.35]
# attack_typs = ['labelatt','featureatt']
network_typs = [args.network_typ]
attack_typs = [args.attack_typ]
random_state = 2025
data_name = 'mnist'

## Other parameters
device = 'cuda'
n_classes = 10
save2hd = False  # Whether to save Byzantine data to disk
output_dir = f'output_{data_name}'  # Output directory for accuracy and other metrics
os.makedirs(output_dir, exist_ok=True)
model_type = 'logistic' if data_name == 'cifar' else 'lenet5'  # lenet5
print_freq = 1

## Load data
data = np.load(f'{data_name}_train.npz')
X_train, y_train = data['images'], data['labels']
data = np.load(f'{data_name}_test.npz')
X_test, y_test = data['images'], data['labels']
X_test = torch.from_numpy(X_test).to(device)
y_test = torch.from_numpy(y_test).to(device)
input_dim = np.prod(X_train.shape[1:])

for byz_ratio in byz_ratios:
    for ii, attack_typ in enumerate(attack_typs):
        data_path = f'data_{data_name}_{byz_ratio}_{attack_typ}'

        ## Generate corrupted data
        Xs_all, ys_all, X_oracle_tensor, y_oracle_tensor, byz_labels = attack_and_save(
            save2hd=save2hd,
            n_workers=n_workers,
            byz_ratio=byz_ratio,
            attack_typ=attack_typ,
            random_state=random_state,
            X_train=X_train,
            y_train=y_train,
            data_path=data_path,
            data_name=data_name,
            device=device
        )

        for network_typ in network_typs:
            for q_degree in q_degrees:
                print(f'\n byz_raio:{byz_ratio} attack_typ:{attack_typ} network_typ:{network_typ} q_degree:{q_degree}')

                W = generate_network(n_workers, typ=network_typ, q=q_degree, seed=random_state)
                neighbors = get_neighbors(W, include_diag=False)

                ## GT Median
                gt_med = GradientTrack(
                    neighbors,
                    lr_constant=0.01,
                    model_type=model_type,
                    n_workers=n_workers,
                    input_dim=512,
                    device=device,
                    random_state=random_state,
                    pretrained=False,
                    custom_init=True,
                    agg_method='median',
                    byz_ratio=byz_ratio
                )

                gt_med.epochs = 100
                gt_med.fit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                gt_med.epochs = 500
                gt_med.set_learning_rates([0.002]*n_workers)
                gt_med.lr_constant = 0.002
                gt_med.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                gt_med.epochs = 1000
                gt_med.set_learning_rates([0.0005]*n_workers)
                gt_med.lr_constant = 0.0005
                gt_med.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                gt_med.epochs = 3400
                gt_med.set_learning_rates([0.0001]*n_workers)
                gt_med.lr_constant = 0.0001
                gt_med.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                gt_med.epochs = 4000
                gt_med.set_learning_rates([0.00001]*n_workers)
                gt_med.lr_constant = 0.00001
                gt_med.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                param_dfl = gt_med.get_parameters()
                torch.save(param_dfl, f'{output_dir}/param_gtmed_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}.pth')
                acc, loss = gt_med.save_history(save_path=f'{output_dir}/metric_gtmed_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}')

                ## GT Trimmed
                gt_trimmed = GradientTrack(
                    neighbors,
                    lr_constant=0.01,
                    model_type=model_type,
                    input_dim=512,
                    n_workers=n_workers,
                    device=device,
                    random_state=random_state,
                    pretrained=False,
                    custom_init=True,
                    agg_method='trimmed_mean',
                    byz_ratio=byz_ratio
                )

                gt_trimmed.epochs = 100
                gt_trimmed.fit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                gt_trimmed.epochs = 900
                gt_trimmed.set_learning_rates([0.001]*n_workers)
                gt_trimmed.lr_constant = 0.001
                gt_trimmed.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                gt_trimmed.epochs = 500
                gt_trimmed.set_learning_rates([0.0005]*n_workers)
                gt_trimmed.lr_constant = 0.0005
                gt_trimmed.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                gt_trimmed.epochs = 1000
                gt_trimmed.set_learning_rates([0.0002]*n_workers)
                gt_trimmed.lr_constant = 0.0002
                gt_trimmed.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                gt_trimmed.epochs = 6500
                gt_trimmed.set_learning_rates([0.0001]*n_workers)
                gt_trimmed.lr_constant = 0.0001
                gt_trimmed.refit(Xs_all, ys_all, X_test, y_test, print_freq=print_freq)

                param_dfl = gt_trimmed.get_parameters()
                torch.save(param_dfl, f'{output_dir}/param_gttrim_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}.pth')
                acc, loss = gt_trimmed.save_history(save_path=f'{output_dir}/metric_gttrim_{byz_ratio}_{attack_typ}_{network_typ}_{q_degree}')


