# training
import sys
sys.path.append("..")

import argparse
import os
import torch
from copy import deepcopy
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader
import pickle
import threading
from tqdm import tqdm
import json
from utils.io import Tee, to_csv
from utils.eval import accuracy, accuracies, losses
from utils.aggregate import aggregate, aggregate_lr, zero_model, aggregate_momentum
from algs.individual_train import individual_train
from utils.concurrency import multithreads
from models.models import resnet18, CNN, CNN_FEMNIST, RNN_Shakespeare, RNN_StackOverflow
from utils.print import print_acc, round_list
from utils.save import save_acc_loss
from utils.stat import mean_std


root = '..' 

parser = argparse.ArgumentParser(description='training')
parser.add_argument('--device', type=str, default='1')
parser.add_argument('--data_dir', type=str, default='iid-4')
parser.add_argument('--dataset', type=str, default='MNIST')
parser.add_argument('--num_clients', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=0, help='for data loader')
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--num_local_epochs', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.1)
parser.add_argument('--reg_lambda', type=float, default=0.1)
parser.add_argument('--save_epoch', type=int, default=5)
parser.add_argument('--seed', type=int, default=0)

args = parser.parse_args()
output_dir = os.path.join(root, 'results', args.dataset, args.data_dir, 'GiFair', f'seed_{args.seed}')
args.data_dir = os.path.join(root, 'data', args.dataset, args.data_dir)
os.makedirs(output_dir, exist_ok=True)
print(args)
print('output_dir: ', output_dir)

with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
    json.dump(vars(args), fp)

os.environ['CUDA_VISIBLE_DEVICES'] = args.device
if torch.cuda.is_available():
    device = torch.device('cuda:0')  # use the first GPU
else:
    device = torch.device('cpu')

in_file = os.path.join(args.data_dir, 'in.pickle')
out_file = os.path.join(args.data_dir, 'out.pickle')

with open(in_file, 'rb') as f_in:
    in_data = pickle.load(f_in)
with open(out_file, 'rb') as f_out:
    out_data = pickle.load(f_out)  

weights = np.array([len(in_data[i]) for i in range(args.num_clients)])
weights_test = np.array([len(out_data[i]) for i in range(args.num_clients)])

print('total train samples: {}'.format(np.sum(weights)))
print('total test samples: {}'.format(np.sum(weights_test)))
print('total samples: {}'.format(np.sum(weights)+np.sum(weights_test)))

print('samples: ', weights)
weights = list(weights / np.sum(weights))
if args.reg_lambda > np.min(weights)/(args.num_clients-1):
    print('lambda (the reg weight) needs to be smaller! (smaller than {})'.format(np.min(weights)/(args.num_clients-1)))
    exit()

# data loaders
train_loaders = [DataLoader(
    dataset=in_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]

test_loaders = [DataLoader(
    dataset=out_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]

if args.dataset == 'MNIST':
    models = [CNN() for _ in range(args.num_clients)]
elif args.dataset == 'CIFAR10':
    models = [resnet18(num_classes=10)  for _ in range(args.num_clients)]
elif args.dataset == 'CIFAR100':
    models = [resnet18(num_classes=100)  for _ in range(args.num_clients)]
elif args.dataset == 'CINIC10':
    models = [resnet18(num_classes=10)  for _ in range(args.num_clients)]
elif args.dataset == 'FEMNIST':
    models = [CNN_FEMNIST() for _ in range(args.num_clients)]      
elif args.dataset == 'Shakespeare':
    models = [RNN_Shakespeare()  for _ in range(args.num_clients)]
elif args.dataset == 'StackOverflow':
    models = [RNN_StackOverflow()  for _ in range(args.num_clients)]     

# loss functions, optimizer
loss_func = nn.CrossEntropyLoss()
#loss_func = nn.MSELoss()
optimizers = [optim.SGD(model.parameters(), lr = args.learning_rate, \
                        momentum=0.0) for model in models]

# checkpoint
model_path = output_dir  + f'/model_last.pth'
if os.path.exists(model_path):
    start_epoch = torch.load(model_path)['epoch']
    for model in models:
        model.load_state_dict(torch.load(model_path)['state_dict'])
else:
    start_epoch = 0

json_file = os.path.join(output_dir, 'log.json')
with open(json_file, 'w') as f:
    f.write('')

mean_accs = []
r = np.linspace(-(args.num_clients-1), (args.num_clients-1), num=args.num_clients)
GiFair_weights = np.zeros(args.num_clients)
local_losses = [0] * args.num_clients

for t in range(start_epoch + 1, args.num_epochs):
    GiFair_weights[np.argsort(local_losses)] = r
    local_losses = []
    for i in range(args.num_clients):
        _, local_loss = individual_train(train_loaders[i], loss_func, optimizers[i], models[i], test_loaders[i], \
                         device=device, client_id=i, epochs=args.num_local_epochs, \
                         output_dir=output_dir, show=False, save=False)
        local_losses.append(local_loss)
    agg_weights = [weights[i] + (args.reg_lambda*GiFair_weights[i]) for i in range(args.num_clients)]    
    aggregate(models, weights=agg_weights)
    accs = accuracies(models, test_loaders, device)
    losses_ = losses(models, train_loaders, loss_func, device)
    print(f'global epoch: {t}')
    mean, std = mean_std(accs)
    mean_accs.append(mean)
    print(f'losses: {round_list(losses_)}')
    save_acc_loss(json_file, t, accs, losses_)
    if t % args.save_epoch == 0:
        torch.save({'epoch': t, 'state_dict': models[0].state_dict()}, \
            output_dir  + f'/model_last.pth')
    
mean, std = mean_std(accs)
print('mean: ', mean, 'std: ', std)
print(f'accs: {[round(i, 3) for i in accs]}')

acc_file = "mean_acc.pkl".format(args.dataset, args.seed)
acc_file = os.path.join(output_dir, acc_file)

with open(acc_file, 'wb') as f_out:
    pickle.dump(mean_accs, f_out) 