
import time
import torch
from datasets import Data
from nodes import Node
from args import args_parser
from utils import *
import numpy as np
from tensorboardX import SummaryWriter
import os
import torch.nn as nn
import copy
import torch.optim as optim
import torch.nn.functional as F
import math
from pyhessian import hessian
from server_funct import *
import wandb
from client_funct import *

if __name__ == '__main__':

    args = args_parser()
    setup_seed(args.random_seed)

    setting_name = args.exp_name
    save_path_root = './'
    output_path = './'

    wandb.init(
        config = args,
        project = 'FedAWO',
        name = setting_name , tags = args.exp_name
    )

    # set GPU device
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device

    data = Data(args)

    sample_size = []
    for i in range(args.node_num): 
        sample_size.append(len(data.train_loader[i]))
    size_weights = [i/sum(sample_size) for i in sample_size]
    print('size-based weights',size_weights)

    # initialize the central node
    # num_id equals to -1 stands for central node
    central_node = Node(-1, data.test_loader[0], data.test_set, args)

    # initialize the client nodes
    client_nodes = {}
    for i in range(args.node_num): 
        client_nodes[i] = Node(i, data.train_loader[i], data.train_set, args) 
    
    final_test_acc_recorder = RunningAverage()
    test_acc_recorder = []
    for rounds in range(args.T):
        print('===============Stage 1 The {:d}-th round==============='.format(rounds + 1))
        print(setting_name)
        lr_scheduler(rounds, client_nodes, args)

        client_nodes, train_loss = Client_update(args, client_nodes, central_node)
        avg_client_acc, select_list = Client_validate(args, client_nodes)
        print(args.server_method + args.client_method + ', averaged clients acc is ', avg_client_acc)
        

        # select_list = [idx for idx in range(len(client_nodes))]
        central_node = Server_update(args, central_node, client_nodes, select_list, size_weights)
        acc = validate(args, central_node, which_dataset = 'local')
        print(args.server_method + args.client_method + ', global model test acc is ', acc)
        test_acc_recorder.append(acc)

        if rounds >= args.T - 10:
            final_test_acc_recorder.update(acc)

        try:
            wandb.log({'trainloss': train_loss}, step = rounds)
            wandb.log({'testacc': acc}, step = rounds)
        except:
            pass
    
    try:
        wandb.log({'final_testacc': final_test_acc_recorder.value()})
    except:
        pass
    torch.save(test_acc_recorder, os.path.join(save_path_root, output_path, setting_name+'_recorder.pth'))
