import time
import torch
from datasets import Data
from nodes import Node
from args import args_parser
from utils import *
import numpy as np
from tensorboardX import SummaryWriter
import os
import torch.nn as nn
import copy
import torch.optim as optim
import torch.nn.functional as F
import math
from pyhessian import hessian
from server_funct import *
import wandb
from client_funct import *


if __name__ == '__main__':

    args = args_parser()
    setup_seed(args.random_seed)

    setting_name = args.exp_name
    save_path_root = './'
    output_path = './'
    wandb.init(
        config = args,
        project = 'FedAWO',
        name = setting_name , notes = args.exp_name
    )

    os.environ['CUDA_VISIBLE_DEVICES'] = args.device

    data = Data(args)

    sample_size = []
    for i in range(args.node_num): 
        sample_size.append(len(data.train_loader[i]))
    size_weights = [i/sum(sample_size) for i in sample_size]
    print('size-based weights',size_weights)

    # initialize the central node
    # num_id equals to -1 stands for central node
    central_node = Node(-1, data.test_loader[0], data.test_set, args)

    # initialize the client nodes
    client_nodes = {}
    for i in range(args.node_num): 
        client_nodes[i] = Node(i, data.train_loader[i], data.train_set, args) 

    final_test_acc_recorder = RunningAverage()
    weight_recorder = []
    test_acc_recorder = []

    for rounds in range(args.T):
        print('===============Stage 1 The {:d}-th round==============='.format(rounds + 1))
        print(setting_name)
        lr_scheduler(rounds, client_nodes, args)

        client_nodes, train_loss = Client_update(args, client_nodes, central_node)
        avg_client_acc, select_list = Client_validate(args, client_nodes)
        print('fedawo, averaged clients acc is ', avg_client_acc)
        

        agg_weights, client_params = receive_client_models(args, client_nodes, select_list, size_weights)

        gamma, optmized_weights = FedAWO_optimization(args, agg_weights, client_params, central_node)

        # combined with other methods
        if 'feddf' in args.server_method:
            central_node = fedawo_generate_global_model(gamma, optmized_weights, client_params, central_node)
            central_node = feddf(args, central_node, client_nodes, select_list)
        elif 'feddyn' in args.server_method:
            agg_weights = [gamma*w for w in optmized_weights]
            central_node = feddyn(args, central_node, agg_weights, client_nodes, select_list)
        else:
            central_node = fedawo_generate_global_model(gamma, optmized_weights, client_params, central_node)

        acc = validate(args, central_node, which_dataset = 'local')
        print('gamma ', gamma)
        print('optmized_weights', optmized_weights)
        print('fedawo, global model test acc is ', acc)
        test_acc_recorder.append(acc)
        weight_recorder.append((gamma, optmized_weights))

        if rounds >= args.T - 10:
            final_test_acc_recorder.update(acc)

        try:
            wandb.log({'trainloss': train_loss}, step = rounds)
            wandb.log({'testacc': acc}, step = rounds)
            wandb.log({'gamma': gamma}, step = rounds)
        except:
            pass
    
    try:
        wandb.log({'final_testacc': final_test_acc_recorder.value()})
    except:
        pass
    select_list_recorder = []
    recorder = {'data_proportion':data.proportion, 'size_weights': size_weights, 
                'select_list_recorder':select_list_recorder, 'weight_recorder':weight_recorder, 
                'test_acc_recorder':test_acc_recorder}
    torch.save(recorder, os.path.join(save_path_root, output_path, setting_name+'_recorder.pth'))