import os
import argparse
import random
import copy

import torch
from pathlib import Path

from utilities import *
from strategies.selftrain import run_selftrain_GC
from strategies.fedavg import run_fedavg
from strategies.fedprox import run_fedprox
from strategies.GFCL import run_gcfl
from strategies.GFCLPlus import run_gcflplus
import warnings
import numpy as np
import pandas as pd
warnings.filterwarnings("ignore", category=UserWarning)


def process_selftrain(clients, server, local_epoch):
    print("Self-training ...")
    df = pd.DataFrame()
    allAccs = run_selftrain_GC(clients, server, local_epoch)
    for k, v in allAccs.items():
        df.loc[k, [f'train_acc', f'val_acc', f'test_acc']] = v
    print(df)
    if args.repeat is None:
        outfile = os.path.join(outpath, f'accuracy_selftrain_GC{suffix}.csv')
    else:
        outfile = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_selftrain_GC{suffix}.csv')
    df.to_csv(outfile)
    print(f"Wrote to file: {outfile}")


def process_fedavg(clients, server):
    print("\nDone setting up FedAvg devices.")

    print("Running FedAvg ...")
    frame, logs, ti = run_fedavg(clients, server, args.num_rounds, args.local_epoch, samp=None)
    if args.repeat is None:
        outfile = os.path.join(outpath, f'accuracy_fedavg_GC{suffix}.csv')
        outfile_r = os.path.join(outpath, f'accuracy_fedavg_GC{suffix}_r.csv')
        outfile_t = os.path.join(outpath, f'accutacy_fedavg_GC{suffix}_t.csv')


    else:
        outfile = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_fedavg_GC{suffix}.csv')
        outfile_r = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_fedavg_GC{suffix}_r.csv')
        outfile_t = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_fedavg_GC{suffix}_t.csv')

    frame.to_csv(outfile)
    logs.to_csv(outfile_r)
    ti.to_csv(outfile_t)
    print(f"Wrote to file: {outfile}")


def process_fedprox(clients, server, mu):
    print("\nDone setting up FedProx devices.")

    print("Running FedProx ...")
    frame,logs, ti = run_fedprox(clients, server, args.num_rounds, args.local_epoch, mu, samp=None)
    if args.repeat is None:
        outfile = os.path.join(outpath, f'accuracy_fedprox_mu{mu}_GC{suffix}.csv')
        outfile_r = os.path.join(outpath, f'accuracy_fedprox_mu{mu}_GC{suffix}_r.csv')
        outfile_t = os.path.join(outpath, f'accuracy_fedprox_mu{mu}_GC{suffix}_t.csv')


    else:
        outfile = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_fedprox_mu{mu}_GC{suffix}.csv')
        outfile_r = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_fedprox_mu{mu}_GC{suffix}_r.csv')
        outfile_t = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_fedprox_mu{mu}_GC{suffix}_t.csv')


    frame.to_csv(outfile)
    logs.to_csv(outfile_r)
    ti.to_csv(outfile_t)
    print(f"Wrote to file: {outfile}")


def process_gcfl(clients, server):
    print("\nDone setting up GCFL devices.")
    print("Running GCFL ...")

    if args.repeat is None:
        outfile = os.path.join(outpath, f'accuracy_gcfl_GC{suffix}.csv')
        outfile_r = os.path.join(outpath, f'accuracy_gcfl_GC{suffix}_r.csv')
        outfile_t = os.path.join(outpath, f'accuracy_gcfl_GC{suffix}_t.csv')

    else:
        outfile = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_gcfl_GC{suffix}.csv')
        outfile_r = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_gcfl_GC{suffix}_r.csv')
        outfile_t = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_gcfl_GC{suffix}_t.csv')

    frame, logs,ti = run_gcfl(clients, server, args.num_rounds, args.local_epoch, EPS_1, EPS_2)
    frame.to_csv(outfile)
    logs.to_csv(outfile_r)
    ti.to_csv(outfile_t)
    print(f"Wrote to file: {outfile}")


def process_gcflplus(clients, server):
    print("\nDone setting up GCFL devices.")
    print("Running GCFL plus ...")

    if args.repeat is None:
        outfile = os.path.join(outpath, f'accuracy_gcflplus_GC{suffix}.csv')
        outfile_r = os.path.join(outpath, f'accuracy_gcflplus_GC{suffix}_r.csv')
        outfile_t = os.path.join(outpath, f'accuracy_gcflplus_GC{suffix}_t.csv')

    else:
        outfile = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_gcflplus_GC{suffix}.csv')
        outfile_r = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_gcflplus_GC{suffix}_r.csv')
        outfile_t = os.path.join(outpath, "repeats", f'{args.repeat}_accuracy_gcflplus_GC{suffix}_t.csv')

    frame,logs,ti = run_gcflplus(clients, server, args.num_rounds, args.local_epoch, EPS_1, EPS_2, args.seq_length, args.standardize)
    frame.to_csv(outfile)
    logs.to_csv(outfile_r)
    ti.to_csv(outfile_t)

    print(f"Wrote to file: {outfile}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu',
                        help='CPU / GPU device.')
    parser.add_argument('--num_repeat', type=int, default=5,
                        help='number of repeating rounds to simulate;')
    parser.add_argument('--num_rounds', type=int, default=200,
                        help='number of rounds to simulate;')
    parser.add_argument('--local_epoch', type=int, default=1,
                        help='number of local epochs;')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate for inner solver;')
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='Weight decay (L2 loss on parameters).')
    parser.add_argument('--nlayer', type=int, default=3,
                        help='Number of GINconv layers')
    parser.add_argument('--hidden', type=int, default=64,
                        help='Number of hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5,
                        help='Dropout rate (1 - keep probability).')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size for node classification.')
    parser.add_argument('--seed', help='seed for randomness;',
                        type=int, default=123)

    parser.add_argument('--datapath', type=str, default='./data',
                        help='The input path of data.')
    parser.add_argument('--outbase', type=str, default='./outputs',
                        help='The base path for outputting.')
    parser.add_argument('--repeat', help='index of repeating;',
                        type=int, default=None)
    parser.add_argument('--data_group', help='specify the group of datasets',
                        type=str, default='PROTEINS')

    parser.add_argument('--convert_x', help='whether to convert original node features to one-hot degree features',
                        type=bool, default=False)
    parser.add_argument('--overlap', help='whether clients have overlapped data',
                        type=bool, default=False)
    parser.add_argument('--standardize', help='whether to standardize the distance matrix',
                        type=bool, default=False)
    parser.add_argument('--seq_length', help='the length of the gradient norm sequence',
                        type=int, default=10)
    parser.add_argument('--epsilon1', help='the threshold epsilon1 for GCFL',
                        type=float, default=0.01)
    parser.add_argument('--epsilon2', help='the threshold epsilon2 for GCFL',
                        type=float, default=0.1)
    parser.add_argument('--cr', help='Coarsening', type=str, default='False')
    parser.add_argument('--cr_ratio', help='cr_ratio', type=float, default=1)

    parser.add_argument('--dp', help='DP', type=str, default='False')
    parser.add_argument('--priv_budget', help='priv_budget', type=float, default=0)
    parser.add_argument('--strategy', help='strategy', type=str, default='SDMC')
    parser.add_argument('--num_clients', help='number of clients',
                        type=int, default=10)
    try:
        args = parser.parse_args()
    except IOError as msg:
        parser.error(str(msg))

    seed_dataSplit = 124

    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    # args.device='cpu'

    EPS_1 = args.epsilon1
    EPS_2 = args.epsilon2

    # TODO: change the data input path and output path
    outbase = os.path.join(args.outbase, f'seqLen{args.seq_length}')
    if args.cr=='True':
        outbase = os.path.join(outbase,'Coarsen')
        outbase = os.path.join(outbase,f'{args.cr_ratio}')
    elif args.dp=='True':
        outbase = os.path.join(outbase,'DP')
        outbase = os.path.join(outbase,f'{args.priv_budget}')

    else:
        outbase = os.path.join(outbase,'Standard')
    if args.strategy=='MDMC': #Multi Data Multi Client
        if args.overlap and args.standardize:
            outpath = os.path.join(outbase, f"standardizedDTW/multiDS-overlap")
        elif args.overlap:
            outpath = os.path.join(outbase, f"multiDS-overlap")
        elif args.standardize:
            outpath = os.path.join(outbase, f"standardizedDTW/multiDS-nonOverlap")
        else:
            outpath = os.path.join(outbase, f"multiDS-nonOverlap")
        outpath = os.path.join(outpath, args.data_group, f'eps_{EPS_1}_{EPS_2}')
        Path(outpath).mkdir(parents=True, exist_ok=True)
        print(f"Output Path: {outpath}")

        # preparing data
        if not args.convert_x:
            """ using original features """
            suffix = ""
            print("Preparing data (original features) ...")
        else:
            """ using node degree features """
            suffix = "_degrs"
            print("Preparing data (one-hot degree features) ...")

        if args.repeat is not None:
            Path(os.path.join(outpath, 'repeats')).mkdir(parents=True, exist_ok=True)

        splitedData, df_stats = prepareData_multiDS(args.datapath, args.data_group, args.batch_size, convert_x=args.convert_x, seed=seed_dataSplit, cr=args.cr, cr_ratio=args.cr_ratio)
        print("Done")
    elif args.strategy=='SDMC': #Single Data Multi Client
        if args.overlap and args.standardize:
            outpath = os.path.join(outbase, f"standardizedDTW/oneDS-overlap")
        elif args.overlap:
            outpath = os.path.join(outbase, f"oneDS-overlap")
        elif args.standardize:
            outpath = os.path.join(outbase, f"standardizedDTW/oneDS-nonOverlap")
        else:
            outpath = os.path.join(outbase, f"oneDS-nonOverlap")
        outpath = os.path.join(outpath, f'{args.data_group}-{args.num_clients}clients', f'eps_{EPS_1}_{EPS_2}')
        Path(outpath).mkdir(parents=True, exist_ok=True)
        print(f"Output Path: {outpath}")

        """ distributed one dataset to multiple clients """

        if not args.convert_x:
            """ using original features """
            suffix = ""
            print("Preparing data (original features) ...")
        else:
            """ using node degree features """
            suffix = "_degrs"
            print("Preparing data (one-hot degree features) ...")

        if args.repeat is not None:
            Path(os.path.join(outpath, 'repeats')).mkdir(parents=True, exist_ok=True)

        splitedData, df_stats = prepareData_oneDS(args.datapath, args.data_group, num_client=args.num_clients, batchSize=args.batch_size,
                                                        convert_x=args.convert_x, seed=seed_dataSplit, overlap=args.overlap, cr=args.cr, cr_ratio=args.cr_ratio)
    # save statistics of data on clients
    if args.repeat is None:
        outf = os.path.join(outpath, f'stats_trainData{suffix}.csv')
    else:
        outf = os.path.join(outpath, "repeats", f'{args.repeat}_stats_trainData{suffix}.csv')
    df_stats.to_csv(outf)
    print(f"Wrote to {outf}")

    init_clients, init_server, init_idx_clients = setup_devices(splitedData, args)
    print("\nDone setting up devices.")

    process_selftrain(clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), local_epoch=100)
    process_fedavg(clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server))
    process_fedprox(clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server), mu=0.01)
    process_gcfl(clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server))
    process_gcflplus(clients=copy.deepcopy(init_clients), server=copy.deepcopy(init_server))
