# training
import sys
sys.path.append("..")

import argparse
import os
import torch
from copy import deepcopy
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader
import sys
import pickle
import threading
from tqdm import tqdm
import json
from utils.io import Tee, to_csv
from utils.eval import accuracy, accuracies, losses
from utils.aggregate import aggregate, aggregate_lr, zero_model, aggregate_momentum
from algs.individual_train import individual_train
from models.models import resnet18, CNN, CNN_FEMNIST, RNN_Shakespeare, RNN_StackOverflow
from utils.print import print_acc, round_list
from utils.save import save_acc_loss
from utils.project import project
from utils.stat import mean_std

root = '~/fairfl' # change this root directory to the repo folder

parser = argparse.ArgumentParser(description='training')
parser.add_argument('--device', type=str, default='7')
parser.add_argument('--data_dir', type=str, default='iid-12')
parser.add_argument('--dataset', type=str, default="MNIST")
parser.add_argument('--num_clients', type=int, default=12)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=0, help='for data loader')
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--num_local_epochs', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--step_size_lambda', type=float, default=0.1)
parser.add_argument('--save_epoch', type=int, default=10)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--trial', type=int, default=0)

args = parser.parse_args()
output_dir = os.path.join(root, 'results', args.dataset, args.data_dir, 'AFL', f'seed_{args.seed}', f'trial_{args.trial}')
args.data_dir = os.path.join(root, 'data', args.dataset, args.data_dir)
os.makedirs(output_dir, exist_ok=True)
print(args)
print('output_dir: ', output_dir)


os.environ['CUDA_VISIBLE_DEVICES'] = args.device
if torch.cuda.is_available():
    device = torch.device('cuda:0')  # use the first GPU
else:
    device = torch.device('cpu')
    

with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
    json.dump(vars(args), fp)

in_file = os.path.join(args.data_dir, 'in.pickle')
out_file = os.path.join(args.data_dir, 'out.pickle')

with open(in_file, 'rb') as f_in:
    in_data = pickle.load(f_in)
with open(out_file, 'rb') as f_out:
    out_data = pickle.load(f_out)

weights = np.array([len(in_data[i]) for i in range(args.num_clients)])
weights_test = np.array([len(out_data[i]) for i in range(args.num_clients)])

print('total train samples: {}'.format(np.sum(weights)))
print('total test samples: {}'.format(np.sum(weights_test)))
print('total samples: {}'.format(np.sum(weights)+np.sum(weights_test)))

print('samples: ', weights)
weights = list(weights / np.sum(weights))
    
# data loaders
train_loaders = [DataLoader(
    dataset=in_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]

test_loaders = [DataLoader(
    dataset=out_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]


if args.dataset == 'MNIST':
    models = [CNN() for _ in range(args.num_clients)]
elif args.dataset == 'CIFAR10':
    models = [resnet18(num_classes=10) for _ in range(args.num_clients)]
elif args.dataset == 'CIFAR100':
    models = [resnet18(num_classes=100) for _ in range(args.num_clients)]
elif args.dataset == 'CINIC10':
    models = [resnet18(num_classes=10)  for _ in range(args.num_clients)]
elif args.dataset == 'FEMNIST':
    models = [CNN_FEMNIST() for _ in range(args.num_clients)]      
elif args.dataset == 'Shakespeare':
    models = [RNN_Shakespeare()  for _ in range(args.num_clients)]
elif args.dataset == 'StackOverflow':
    models = [RNN_StackOverflow()  for _ in range(args.num_clients)]           
    
    
# loss functions, optimizer
loss_func = nn.CrossEntropyLoss()
json_file = os.path.join(output_dir, 'log.json')
with open(json_file, 'w') as f:
    f.write('')

lambda_ = weights
optimizers = [optim.SGD(model.parameters(), lr = args.learning_rate) for model in models]

# checkpoint
model_path = output_dir  + f'/model_last.pth'
if os.path.exists(model_path):
    start_epoch = torch.load(model_path)['epoch']
    for model in models:
        model.load_state_dict(torch.load(model_path)['state_dict'])
else:
    start_epoch = 0

mean_accs = []    
for t in range(start_epoch + 1, args.num_epochs):
    for i in range(args.num_clients):
        # start training, one local epoch
        individual_train(train_loaders[i], loss_func, optimizers[i], models[i], test_loaders[i], \
                         device=device, client_id=i, epochs=args.num_local_epochs, \
                         output_dir=output_dir, show=False, save=False)
        
    aggregate(models, weights=lambda_)
    if t % args.save_epoch == 0:
        torch.save({'epoch': t, 'state_dict': models[0].state_dict()}, \
            output_dir  + f'/model_last.pth')
    accs = accuracies(models, test_loaders, device)
    losses_ = losses(models, train_loaders, loss_func, device)
    lambda_ = project(np.array(lambda_) + args.step_size_lambda * np.array(losses_))
    print(f'global epoch: {t}')
    mean, std = mean_std(accs)
    mean_accs.append(mean)
    print(f'losses: {round_list(losses_)}')
    print(f'losses: {round_list(lambda_)}')
    save_acc_loss(json_file, t, accs, losses_)

mean, std = mean_std(accs)
print('mean: ', mean, 'std: ', std)
print(f'accs: {[round(i, 3) for i in accs]}')


acc_file = "afl_{}_seed{}_trial{}_mean_acc.pkl".format(args.dataset, args.seed, args.trial)
with open(acc_file, 'wb') as f_out:
    pickle.dump(mean_accs, f_out)