# training
import sys
sys.path.append("..")

import argparse
import os
import torch
from copy import deepcopy
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader
import pickle
import threading
from tqdm import tqdm
import json
from utils.io import Tee, to_csv
from utils.eval import accuracy, accuracies, losses
from utils.aggregate import aggregate, aggregate_lr, zero_model, aggregate_momentum, average_update, average_loss, global_delta, assign_models, global_update, global_delta_semi
from algs.individual_train_varred import individual_train
from utils.concurrency import multithreads
from models.models import resnet18, CNN, CNN_FEMNIST, RNN_Shakespeare, RNN_StackOverflow
from utils.print import print_acc, round_list
from utils.save import save_acc_loss
from utils.stat import mean_std


root = '..' 

parser = argparse.ArgumentParser(description='training')
parser.add_argument('--device', type=str, default='1')
parser.add_argument('--data_dir', type=str, default='iid-4')
parser.add_argument('--dataset', type=str, default='MNIST')
parser.add_argument('--num_clients', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=0, help='for data loader')
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--num_local_epochs', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.1)
parser.add_argument('--save_epoch', type=int, default=5)
parser.add_argument('--beta', type=float, default=0.1, help='the regularization weight for var reduction')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--version', type=str, default='fullvar', help='version of the variance reduction algorithm: fullvar/semivar ')

args = parser.parse_args()
if args.version == 'fullvar':
    output_dir = os.path.join(root, 'results', args.dataset, args.data_dir, 'VarRed', f'beta_{args.beta}', f'seed_{args.seed}')
if args.version == 'semivar':
    output_dir = os.path.join(root, 'results', args.dataset, args.data_dir, 'SemiVarRed', f'beta_{args.beta}', f'seed_{args.seed}')
    
args.data_dir = os.path.join(root, 'data', args.dataset, args.data_dir)
os.makedirs(output_dir, exist_ok=True)
print(args)
print('output_dir: ', output_dir)

with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
    json.dump(vars(args), fp)

os.environ['CUDA_VISIBLE_DEVICES'] = args.device
if torch.cuda.is_available():
    device = torch.device('cuda:0')  # use the first GPU
else:
    device = torch.device('cpu')

in_file = os.path.join(args.data_dir, 'in.pickle')
out_file = os.path.join(args.data_dir, 'out.pickle')

with open(in_file, 'rb') as f_in:
    in_data = pickle.load(f_in)
with open(out_file, 'rb') as f_out:
    out_data = pickle.load(f_out)  

weights = np.array([len(in_data[i]) for i in range(args.num_clients)])
weights_test = np.array([len(out_data[i]) for i in range(args.num_clients)])

print('total train samples: {}'.format(np.sum(weights)))
print('total test samples: {}'.format(np.sum(weights_test)))
print('total samples: {}'.format(np.sum(weights)+np.sum(weights_test)))

print('samples: ', weights)
weights = list(weights / np.sum(weights))

# data loaders
train_loaders = [DataLoader(
    dataset=in_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]

test_loaders = [DataLoader(
    dataset=out_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]

if args.dataset == 'MNIST':
    models = [CNN() for _ in range(args.num_clients)]
    model_global = CNN().to(device)
elif args.dataset == 'CIFAR10':
    models = [resnet18(num_classes=10)  for _ in range(args.num_clients)]
    model_global = resnet18(num_classes=10).to(device)
elif args.dataset == 'CIFAR100':
    models = [resnet18(num_classes=100)  for _ in range(args.num_clients)]
    model_global = resnet18(num_classes=100).to(device)
elif args.dataset == 'CINIC10':
    models = [resnet18(num_classes=10)  for _ in range(args.num_clients)]
    model_global = resnet18(num_classes=10).to(device)
elif args.dataset == 'FEMNIST':
    models = [CNN_FEMNIST() for _ in range(args.num_clients)]      
    model_global = CNN_FEMNIST().to(device)
elif args.dataset == 'Shakespeare':
    models = [RNN_Shakespeare()  for _ in range(args.num_clients)]
    model_global = RNN_Shakespeare().to(device)
elif args.dataset == 'StackOverflow':
    models = [RNN_StackOverflow()  for _ in range(args.num_clients)]
    model_global = RNN_StackOverflow().to(device)

# loss functions, optimizer
loss_func = nn.CrossEntropyLoss()
#loss_func = nn.MSELoss()
optimizers = [optim.SGD(model.parameters(), lr = args.learning_rate, \
                        momentum=0.0) for model in models]

# checkpoint
model_path = output_dir  + f'/model_last.pth'
if os.path.exists(model_path):
    start_epoch = torch.load(model_path)['epoch']
    for model in models:
        model.load_state_dict(torch.load(model_path)['state_dict'])
else:
    start_epoch = 0

json_file = os.path.join(output_dir, 'log.json')
with open(json_file, 'w') as f:
    f.write('')

mean_accs = []

for t in range(start_epoch + 1, args.num_epochs):
    local_deltas = []
    local_losses = []
    for i in range(args.num_clients):
        local_delta, local_loss = individual_train(train_loaders[i], loss_func, optimizers[i], models[i], test_loaders[i], \
                         device=device, client_id=i, epochs=args.num_local_epochs, \
                         output_dir=output_dir, show=False, save=False)
        local_deltas.append(local_delta)
        local_losses.append(local_loss)
    
    avg_delta = average_update(local_deltas, weights, model_global.state_dict().keys())
    avg_loss = average_loss(local_losses, weights)
    if args.version == 'fullvar':
        global_step = global_delta(args.beta, local_deltas, local_losses, avg_delta, avg_loss, weights, model_global.state_dict().keys())
    elif args.version == 'semivar':
        global_step = global_delta_semi(args.beta, local_deltas, local_losses, avg_delta, avg_loss, weights, model_global.state_dict().keys())
        
    global_update(model_global, global_step, model_global.state_dict().keys())
    assign_models(models, model_global)
    
    accs = accuracies(models, test_loaders, device)
    losses_ = losses(models, train_loaders, loss_func, device)
    print(f'global epoch: {t}')
    mean, std = mean_std(accs)
    print('mean acc: {}'.format(mean))
    mean_accs.append(mean)
    print(f'losses: {round_list(losses_)}')
    save_acc_loss(json_file, t, accs, losses_)
    if t % args.save_epoch == 0:
        torch.save({'epoch': t, 'state_dict': models[0].state_dict()}, \
            output_dir  + f'/model_last.pth')
    
mean, std = mean_std(accs)
print('mean: ', mean, 'std: ', std)
print(f'accs: {[round(i, 3) for i in accs]}')

acc_file = "mean_acc.pkl".format(args.dataset, args.seed)
acc_file = os.path.join(output_dir, acc_file)

with open(acc_file, 'wb') as f_out:
    pickle.dump(mean_accs, f_out)