import os
from parser_1 import Parser
from datetime import datetime

from misc.utils import *
from modules.multiprocs import ParentProcess
from data.loader import get_data
import global_var as gvr

import emoji
import wandb
from pathlib import Path


import os

def main(args):

    args = set_config(args)

    if args.model == 'fedavg':    
        from models.fedavg.server import Server
        from models.fedavg.client import Client
    elif args.model == 'lochyp':    
        from models.lochyp.server import Server
        from models.lochyp.client import Client
    elif args.model == 'loceuc':    
        from models.loceuc.server import Server
        from models.loceuc.client import Client
    elif args.model == 'flatland':    
        from models.flatland.server import Server
        from models.flatland.client import Client
    else:
        print('incorrect model was given: {}'.format(args.model))
        os._exit(0)

    pp = ParentProcess(args, Server, Client)
    pp.start()

def set_wandb(args):

    run_dir = run_dir = Path("../results") / "FedHyp" / args.mode / f'{args.dataset}_{args.n_clients}'
    if not run_dir.exists():
        os.makedirs(str(run_dir))

    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="FlatLand",
        # track hyperparameters and run metadata
        config={
            'dataset':args.dataset,
            'mode': args.mode,
            'model': args.model,
            'n-clients': args.n_clients,
            'n-rnds': args.n_rnds,
            'n-eps': args.n_eps,
            'n_dims': args.n_dims,
            'lr': args.lr,
            'optimizer': args.optimizer,
            'classifier': args.classifier,
            'rescale': args.rescale,
            'loc-l2': args.loc_l2
            },
        name=f"{args.dataset}_c{args.n_clients}_{args.model}_seed{args.seed}",
        group=args.mode,
        dir=str(run_dir)        
    )

    

def set_config(args):

    print("============", emoji.emojize(":cherry_blossom: :cherry_blossom:"), args.dataset, emoji.emojize(":cherry_blossom: :cherry_blossom:"), "=============")

    args.base_lr = 1e-3
    args.min_lr = 1e-3
    args.momentum_opt = 0.9
    args.weight_decay = 1e-6
    args.warmup_epochs = 10
    args.base_momentum = 0.99
    args.final_momentum = 1.0

    args.grad_k = True 
    
    
    args.n_clients = 10 if args.n_clients == None else args.n_clients
    args.base_lr = 0.01 if args.lr == None else args.lr

    now = datetime.now().strftime("%Y%m%d_%H%M%S")
    trial = f'{args.dataset}_{args.mode}/clients_{args.n_clients}/{now}_{args.model}'

    args.data_path = os.path.join(os.getcwd(), 'datasets')

    args.base_path = os.path.join(os.getcwd(), 'logs')
    args.checkpt_path = f'{args.base_path}/checkpoints/{trial}'
    args.log_path = f'{args.base_path}/logs/{trial}'
    if args.csv:
        csv_dir = f'{args.base_path}/csvs/{args.dataset}_{args.mode}/clients_{args.n_clients}/'
        if not os.path.exists(csv_dir):
            os.makedirs(csv_dir)
        args.csv_path = f'{args.base_path}/csvs/{trial}.csv'
        
        if args.fname == None:
            args.bs_dir = f'{args.base_path}/summary/{args.dataset}/{args.n_clients}'
        else:
            args.bs_dir = f'{args.base_path}/summary/{args.fname}/{args.dataset}/{args.n_clients}'
        if not os.path.exists(args.bs_dir):
            os.makedirs(args.bs_dir)
            
        if args.summary:
            if args.model in gvr.HYP_METHODS:
                if args.fname == None:
                    args.summary_path =f'{args.bs_dir}/{args.dataset}_{args.n_clients}_{args.model}_{args.n_eps}_summary_{args.loc_l2}_{args.lr}_bc{args.bc}.csv'
                else:
                    args.summary_path =f'{args.bs_dir}/{args.dataset}_{args.n_clients}_{args.model}_{args.n_eps}_summary_{args.fname}_{args.loc_l2}_{args.lr}_bc{args.bc}.csv'
            else:
                if args.fname == None:
                    args.summary_path =f'{args.bs_dir}/{args.dataset}_{args.n_clients}_{args.model}_{args.n_eps}_summary_lr{args.lr}.csv'
                else:
                    args.summary_path =f'{args.bs_dir}/{args.dataset}_{args.n_clients}_{args.model}_{args.n_eps}_summary_{args.fname}_lr{args.lr}.csv'

    if args.debug == True:
        args.checkpt_path = f'{args.base_path}/debug/checkpoints/{trial}'
        args.log_path = f'{args.base_path}/debug/logs/{trial}'

    y = set() # Obtain the n_clss

    if args.model in gvr.HYP_METHODS:
        from GraphRicciCurvature.FormanRicci import FormanRicci
        
        
        if args.dataset == 'ogbn-arxiv' and args.n_clients == 10:
            for i in range(args.n_clients):
                pre_partition = get_data(args, client_id=i)
                y.update(set(pre_partition[0].y.data.numpy()))
            frc_curs_avg = [123.86939696871977, 121.95510033219003, 345.26774114043275,
                546.5181477213508, 202.7338562166614, 46.012028283759896,
                71.34583824685535, 15.774210931485758, 389.92735010201284, 36.454768039174894]
        else:
            frc_curs_avg = []
            for i in range(args.n_clients):
                pre_partition = get_data(args, client_id=i)
                y.update(set(pre_partition[0].y.data.numpy()))

                # compute the Forman-Ricci curvature of the given grapsssh G
                G = convert_to_networkx(pre_partition)

                edge_list = G.edges()
                frc = FormanRicci(G)
                frc.compute_ricci_curvature()

                frc_curs = []
                for e in edge_list:
                    frc_curs.append(frc.G[e[0]][e[1]]["formanCurvature"])
                frc_curs_avg.append(sum(frc_curs) / len(frc_curs))
                print(emoji.emojize(":seedling:"),  f"Forman Average for client_{i}:", frc_curs_avg[i])

        import numpy as np
        from misc.utils_hyp import norm_normalization, norm_standardization

        frc_curs_avg = np.array(frc_curs_avg)
        frc_list = abs(frc_curs_avg)
        norm_frc_list = norm_normalization(frc_list) + 0.7

        args.norm_frc_list = norm_frc_list.tolist()


    else:
        for i in range(args.n_clients):
            pre_partition = get_data(args, client_id=i)
            y.update(set(pre_partition[0].y.data.numpy()))


    args.n_feat = len(pre_partition[0].x[0])
    args.n_clss = len(y)
    print(emoji.emojize(':herb: :herb:'), f' #Features: {args.n_feat}, #Classes: {args.n_clss}')

    return args

if __name__ == '__main__':
    args = Parser().parse()
    if args.wandb:
        set_wandb(args)
        
    main(args)
    # wandb.log({"acc": acc, "loss": loss})
        
    # [optional] finish the wandb run, necessary in notebooks
    if args.wandb:
        wandb.finish()










