# -*- coding: utf-8 -*-
import os, sys
import ray
import csv
import time
import torch
import random
import argparse
import numpy as np
import matplotlib.pyplot as plt
from utils.Model import get_model
from utils.Dataset import init_dataset
from utils.Data_setup import init_dataset_new
from utils.utils import matrix_normal

from methods.WPM_HN import Client, all_train, all_get_p, all_grad_step, all_test, all_HN_update, consensus_distance, all_aggregation, all_set_model

parser = argparse.ArgumentParser()

parser.add_argument('--method', default='Megaminx', type=str)
parser.add_argument('--model', default='LR', choices=('LR','LeNet'))
parser.add_argument('--dataset', default='mnist',choices=('mnist','cifar10'))
parser.add_argument('--iid', default=0, type=int)
parser.add_argument('--agg_mode', default='HN', type=str, choices=('HN', 'Avg', 'DPSGD'))

parser.add_argument('--world-size', default=10, type=int, help='number of workers')
parser.add_argument('--device', default='cpu', choices=('cpu','gpu'))
parser.add_argument('--topology', default='random', choices=('random', 'reduce'))
parser.add_argument('--partion-method', default=0) # 0-> Split dataset by class, 1-> Split dataset by class and number
parser.add_argument('--test-method', default=0) # 0-> Use global testset, 1-> Use local testset

parser.add_argument('--p', default='15', type=str)
parser.add_argument('--data-overlap', type=float, default=0.1, choices=(0.1, 1, 10))
parser.add_argument('--alpha', type=float, default=0.5, choices=(0.5, 1, 10))

parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--batch-size', default=128, type=int)

parser.add_argument('--change-interval', default=1, type=int, help='change frequency of topoloies')
parser.add_argument('--rounds', default=300, type=int)

parser.add_argument('--iter-method', default='iteration', choices=('epoch','iteration'))
parser.add_argument('--density', default=0.5, type=float, help='density of topologies')

parser.add_argument("--hn_lr", type=float, default=0.01)
parser.add_argument("--embedding_dim", type=int, default=100)
parser.add_argument("--hidden_dim", type=int, default=100)

parser.add_argument('--stdout', default='stdout', help='stdout log dir for subprocess')
parser.add_argument('--data-dir', default='dataset', help='data store location')
parser.add_argument('--cache-dir', default='cache', help='cache location')
parser.add_argument('--seed', default=2022, type=int)

parser.add_argument('--deacy-step', default=10, type=int, help='when to deacy beta')
parser.add_argument('--gamma', default=0.9, type=float, help='deacy factor of gain')

args = parser.parse_args()

path = f'{args.stdout}/{args.model}-{args.dataset}/partion_{args.partion_method}-test_{args.test_method}/'
if not os.path.exists(path): os.makedirs(path) 

def set_random_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.device == 'gpu':
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True


if __name__ == '__main__':

    if args.model == 'LeNet':
        args.iter_method = 'epoch'  
    else:
        args.iter_method = 'iteration'        

    print(f'method: {args.method} | P:{args.p}')
    print(f'Model: {args.model} | Datasets: {args.dataset} | iid: {args.iid} | Round: {args.rounds}')
    print(f'World Size: {args.world_size} | Density: {args.density} | lr: {args.lr}')

    # Random Seeds
    set_random_seed(args)

    # NonIID Setting
    data_distribution = 'iid' if args.iid == 1 else f'noniid-{args.data_overlap}'
    
    # Time-varying topologies
    topology_info = f'{args.world_size}-{args.density}'
    
    # Init Log
    file = open(path + f'/{args.agg_mode}_{args.model}-{args.dataset}_{args.p}.csv', "w", newline='')
    writer = csv.writer(file)
    writer.writerow(['Round', 'Acc', 'Loss', 'Consensus Dis'])
    file.flush()
    

    # Init Ray
    ray.init(ignore_reinit_error=True)
    print('==> Ray successfully created ==>')

    # Init Model
    model = get_model(args)
    layer_names = list(model.get_weights().keys()) 
    print('== Model Initialization Done ==')

    # Init Datasets
    if args.partion_method:
        train_loader_list, test_loader_list, global_test_loader = init_dataset_new(args)
    else:
        train_loader_list, test_loader = init_dataset(args)
        test_loader_list = [test_loader for _ in range(args.world_size)]
        global_test_loader = test_loader       
    
    # Time-varying topologies
    topology_list = np.load(f'topology/topology_list_num{args.world_size}_den{args.density}.npy')
    weight_matrix = topology_list[0]
    
    # Init Clients
    clients = [Client.remote(i, args, model, train_loader_list[i], test_loader_list[i]) for i in
                range(args.world_size)] if args.test_method else \
        [Client.remote(i, args, model, train_loader_list[i], global_test_loader) for i in range(args.world_size)]
    print('==> Create Clients and Allocate Dataset ==>\n')

    begin_time = time.time()
    current_round = 0
    while current_round < args.rounds:

        # Time-varying topology
        if current_round % args.change_interval == 0:
            weight_matrix = matrix_normal(topology_list[current_round])
            
        # Get p from HN
        p_loc_list = all_get_p(clients)
        
        # Local Update And Mirror Map
        all_models, all_gradients = all_train(clients, p_loc_list)
        
        # Aggregation
        all_agged_list, all_alpha = all_aggregation(clients, all_models, weight_matrix)

        # Mirror Gradient Descent And Inverse Map
        all_grad_step(layer_names, clients, all_agged_list, all_gradients, p_loc_list)
        
        # Update HN
        if args.agg_mode == 'Megaminx':
            all_HN_update(clients)
        else: pass
        
        # Test
        acc, loss = all_test(clients, weight_matrix)
        print(f'Round[{current_round}]: Acc={acc}, Loss={loss}')
        
        # Log
        writer.writerow([current_round, acc, loss])
        file.flush()
        current_round += 1
            
    end_time = time.time()
    print(f'total_time: {end_time - begin_time:.2}s')    