# training
import sys
sys.path.append("..")

import argparse
import os
import torch
from copy import deepcopy
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader
import sys
import pickle
import threading
from tqdm import tqdm
import json
from utils.io import Tee, to_csv
from utils.eval import accuracy, accuracies, losses
from utils.aggregate import aggregate, aggregate_lr, zero_model, aggregate_momentum
from algs.individual_train import individual_train
from utils.concurrency import multithreads
from models.models import resnet18, CNN, CNN_FEMNIST, RNN_Shakespeare, RNN_StackOverflow
from utils.print import print_acc, round_list
from utils.save import save_acc_loss
from utils.stat import mean_std
from scipy.stats import gmean

root = '~/fairfl' # change this root directory to the repo folder

parser = argparse.ArgumentParser(description='training')
parser.add_argument('--device', type=str, default='5')
parser.add_argument('--data_dir', type=str, default='iid-12')
parser.add_argument('--dataset', type=str, default="MNIST")
parser.add_argument('--num_clients', type=int, default=12)
parser.add_argument('--batch_size', type=int, default=128)
#parser.add_argument('--optimizer', type=str, default='Adam')
parser.add_argument('--num_workers', type=int, default=0, help='for data loader')
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--num_local_epochs', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--base', type=float, default=5.0)
parser.add_argument('--save_epoch', type=int, default=10)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--trial', type=int, default=0)

args = parser.parse_args()
output_dir = os.path.join(root, 'results', args.dataset, args.data_dir, 'Prop_Fair', f'base_{args.base}', f'seed_{args.seed}', f'trial_{args.trial}')
args.data_dir = os.path.join(root, 'data', args.dataset, args.data_dir)
os.makedirs(output_dir, exist_ok=True)
print(args)
print('output_dir: ', output_dir)

os.environ['CUDA_VISIBLE_DEVICES'] = args.device
if torch.cuda.is_available():
    device = torch.device('cuda:0')  # use the first GPU
else:
    device = torch.device('cpu')
    

with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
    json.dump(vars(args), fp)

in_file = os.path.join(args.data_dir, 'in.pickle')
out_file = os.path.join(args.data_dir, 'out.pickle')

with open(in_file, 'rb') as f_in:
    in_data = pickle.load(f_in)
with open(out_file, 'rb') as f_out:
    out_data = pickle.load(f_out)

weights = np.array([len(in_data[i]) for i in range(args.num_clients)])
weights_test = np.array([len(out_data[i]) for i in range(args.num_clients)])

print('total train samples: {}'.format(np.sum(weights)))
print('total test samples: {}'.format(np.sum(weights_test)))
print('total samples: {}'.format(np.sum(weights)+np.sum(weights_test)))

print('samples: ', weights)
weights = list(weights / np.sum(weights))
    
# data loaders
train_loaders = [DataLoader(
    dataset=in_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]

test_loaders = [DataLoader(
    dataset=out_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]


if args.dataset == 'MNIST':
    models = [CNN() for _ in range(args.num_clients)]
elif args.dataset == 'CIFAR10':
    models = [resnet18(num_classes=10) for _ in range(args.num_clients)]
elif args.dataset == 'CIFAR100':
    models = [resnet18(num_classes=100) for _ in range(args.num_clients)]
elif args.dataset == 'CINIC10':
    models = [resnet18(num_classes=10)  for _ in range(args.num_clients)]
elif args.dataset == 'FEMNIST':
    models = [CNN_FEMNIST() for _ in range(args.num_clients)]      
elif args.dataset == 'Shakespeare':
    models = [RNN_Shakespeare()  for _ in range(args.num_clients)]
elif args.dataset == 'StackOverflow':
    models = [RNN_StackOverflow()  for _ in range(args.num_clients)]           
    
    
# loss functions, optimizer
loss_func = nn.CrossEntropyLoss()
json_file = os.path.join(output_dir, 'log.json')

optimizers = [optim.SGD(model.parameters(), lr = args.learning_rate, \
            momentum=0.0) for model in models]

# checkpoint
model_path = output_dir  + f'/model_last.pth'
if os.path.exists(model_path):
    start_epoch = torch.load(model_path)['epoch']
    for model in models:
        model.load_state_dict(torch.load(model_path)['state_dict'])
else:
    start_epoch = -1
    with open(json_file, 'w') as f:  # clear up the json file, start from the beginning
        f.write('')

if start_epoch + 1 == args.num_epochs:
    print('already finished')
    sys.exit()

base = args.base
mean_accs = []
for t in range(start_epoch + 1, args.num_epochs):
    base = args.base
    def log_loss(output, target, base=base):
        ce_loss = loss_func(output, target)
        base = torch.tensor(base).to(device)
        if base - ce_loss < 0.2:           
            # for the bad performing batches, we enforce a constant to avoid divergence
            #print('blow up')
            #return 0 * ce_loss
            return ce_loss/base
        else:
            return -torch.log(1 - ce_loss/base)
    for i in range(args.num_clients):
        # start training, one local epoch
        individual_train(train_loaders[i], log_loss, optimizers[i], models[i], test_loaders[i], \
                         device=device, client_id=i, epochs=args.num_local_epochs, \
                         output_dir=output_dir, show=False, save=False)

    aggregate(models, weights=weights)
    if t % args.save_epoch == 0 or t == args.num_epochs - 1:
        torch.save({'epoch': t, 'state_dict': models[0].state_dict()}, \
            output_dir  + f'/model_last.pth')
    accs = accuracies(models, test_loaders, device)
    losses_ = losses(models, train_loaders, loss_func, device)
    print(f'global epoch: {t}')
    mean, std = mean_std(accs)
    mean_accs.append(mean)
    print(f'losses: {round_list(losses_)}')
    save_acc_loss(json_file, t, accs, losses_)
     
mean, std = mean_std(accs)

print('mean: ', mean, 'std: ', std)
print(f'accs: {[round(i, 3) for i in accs]}')

acc_file = "propfair_{}_seed{}_trial{}_mean_acc.pkl".format(args.dataset, args.seed, args.trial)
with open(acc_file, 'wb') as f_out:
    pickle.dump(mean_accs, f_out)