import os
import sys
import time
import wandb
import argparse
from utils import set_random_seed, get_logger

shape = {
    'CIFAR10': [3, 32, 32],
    'CIFAR100': [3, 32, 32],
    'SVHN': [3, 32, 32],
}
classes = {
    'CIFAR10': 10,
    'CIFAR100': 100,
    'SVHN': 10,
}

def over_write_args_from_file(args, yml):
    """
    overwrite arguments acocrding to config file
    """
    import ruamel.yaml as yaml
    if yml == '':
        return
    with open(yml, 'r', encoding='utf-8') as f:
        dic = yaml.load(f.read(), Loader=yaml.Loader)
        for k in dic:
            setattr(args, k, dic[k])

def str2bool(v):
    """
    str to bool
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

from algorithm import SemiFL, FedLabel, FL2, TrainAloneOS
from algorithm import IOMatch, BDMatch, SSB, OPMatch, FedOpenMatch

name2algo = {
    'semifl': SemiFL,
    'fedlabel': FedLabel,
    "fl2": FL2,
    "iomatch": IOMatch,
    "bdmatch": BDMatch,
    "ssb": SSB,
    "fedopenmatch": FedOpenMatch,
    "alone_os": TrainAloneOS,
    "openmatch": OPMatch,
    
}


def get_config():
    parser = argparse.ArgumentParser()
    '''
    Saving & loading of the model.
    '''
    parser.add_argument('--pj_name', type=str, default='open-set ssfl')
    parser.add_argument('--dsp', type=str, default='', help='Describe the intention of this experiment')
    parser.add_argument('--log_level', type=str, default='INFO')
    parser.add_argument('--save_dir', type=str, default='./results')
    parser.add_argument('--load_path', type=str, default='')
    '''
    Optimizer configurations
    '''
    parser.add_argument('--optim', type=str, default='SGD')
    parser.add_argument('--lr', type=float, default=3e-2)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    '''
    Backbone Net Configurations
    '''
    parser.add_argument('--net', type=str, default='resnet18')

    '''
    Data Configurations
    '''
    ## standard setting configurations
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--num_labels', type=int, default=10, help='number of labeled samples per class')
    parser.add_argument("--num_seen_class", type=int, default=-1, help="number of seen classes, for open-set SSL")
    parser.add_argument("--close_train", type=str2bool, default=False, help="close unlabeled set or not")
    parser.add_argument("--close_test", type=str2bool, default=False, help="close test set or not")

    '''
    FL config:
    '''
    parser.add_argument('--num_clients', type=int, default=100)
    parser.add_argument('--join_ratio', type=float, default=0.1, help='ratio of clients to join in each round')
    parser.add_argument('--split_type', type=str, default='dir_0.1', help='type of heterogeneity')
    parser.add_argument('--local_steps', type=int, default=5, help='number of local steps')
    parser.add_argument('--s_local_steps', type=int, default=5, help='number of server local steps')
    parser.add_argument('--agg', type=str, default='uniform', help='aggregation method: average, weighted_average, lsa')

    # system config：
    parser.add_argument('--seed', type=int, default=100)
    parser.add_argument('--visible_gpu', type=str, default='0')
    '''
    Algorithms Configurations
    '''  
    parser.add_argument('--algorithm', type=str, default='fedavg', help='ssl algorithm')
    parser.add_argument('--warmup_epochs', type=int, default=40, help='number of warmup epochs')
    parser.add_argument('--threshold', type=float, default=0.95, help='threshold for pseudo label')
    parser.add_argument('--clip_grad', type=float, default=1.0, help='clip_grad')
    parser.add_argument('--global_rounds', type=int, default=600)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--c_batch_size', type=int, default=32)
    # SemiFL
    parser.add_argument('--sBN', type=str2bool, default=True, help='use sBN or not')
    parser.add_argument('--mixup', type=str2bool, default=True, help='use mixup augmentation or not')

    parser.add_argument('--fix', type=str2bool, default=True, help='freeze global fc or not')
    parser.add_argument('--aux', type=str2bool, default=True, help='using auxilary fc or not')
    parser.add_argument('--shrink', type=str2bool, default=True, help='use shrink loss or not')
    parser.add_argument('--linear', type=str2bool, default=False, help='use FC layer as auxiliary head')

    #FedLabel
    parser.add_argument('--lamda', type=float, default=1, help='trade-off parameter of kl_div')

    # FL^2
    parser.add_argument("--sam", type=str, default='asam', help="option of sharpness-aware minimization: sam, asam")
    parser.add_argument("--rho", type=float, default=0.1, help="magnitude of pertubation in SAM, 0.1 for cifar10/svhn, 1 for cifar-1oo")

    # IOMatch
    parser.add_argument('--mb_loss_ratio', type=float, default=1.0, help='loss weight for ova classifier')
    parser.add_argument('--op_loss_ratio', type=float, default=1.0, help='loss weight for open set classifier')
    parser.add_argument('--open_threshold', type=float, default=0.5, help='threshold for open set classifier')

    # BDMatch
    parser.add_argument('--ema', type=float, default=0.999, help='ema decay rate')
    parser.add_argument('--p_cutoff_pos', type=float, default=0.99, help='positive threshold')
    parser.add_argument('--p_cutoff_neg', type=float, default=0.01, help='negative threshold')
    
    # FedOpenMatch
    parser.add_argument('--cr_weight', type=float, default=1)
    parser.add_argument('--tau', type=float, default=0.5, help='weight of logits adjustment')
    
    # aloneos abalation
    parser.add_argument("--train_ova", type=str2bool, default=False, help="train ova classifier or not")
    parser.add_argument("--mlp_head", type=str2bool, default=False, help="use mlp head for inlier classifier or not")
    
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.visible_gpu
    
    args.dataset = args.dataset.upper()
    args.data_shape = shape[args.dataset]
    args.num_classes = classes[args.dataset]
    if args.num_seen_class == -1:
        args.num_seen_class = args.num_classes
        
    if args.num_labels == -1 or args.num_labels * args.num_seen_class > 1000:
        args.batch_size = 128
    else:
        args.batch_size = 10

    args.exp_tag = f'{args.algorithm}_{args.dataset}'
    if args.num_seen_class != -1:
        args.exp_tag += f'_{args.num_seen_class}'
    if args.algorithm.startswith('alone'):
        args.exp_tag += f'_{args.num_labels}'
    else:
        args.exp_tag += f'_{args.num_labels}_{args.num_clients}'
    args.exp_tag += f'_seed={args.seed}'

    args.save_dir  = os.path.join(args.save_dir, args.algorithm, args.dataset, f'{args.num_seen_class}_{args.num_labels}_{args.split_type}_{args.seed}_{time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())}') 
    while os.path.exists(args.save_dir):
        args.save_dir = os.path.join(args.save_dir, args.algorithm, args.dataset, f'{args.num_seen_class}_{args.num_labels}_{args.split_type}_{args.seed}_{time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())}') 
    os.makedirs(args.save_dir, exist_ok=True)

    return args

def init_wandb(args):
    project = args.pj_name
    name = args.exp_tag

    run = wandb.init(name=name, 
                    config=args.__dict__, 
                    project=project,
                    mode="offline"
            )
    return run


def main():
    cfg = get_config()
    set_random_seed(cfg.seed)
    cfg.printer = get_logger('printer', cfg.save_dir, cfg.log_level)
    cfg.logger = init_wandb(cfg)
    algorithm = name2algo[cfg.algorithm](cfg)
    cfg.printer.info('configurations:')
    argv = sys.argv[1:]
    for i in range(len(sys.argv) // 2):
        cfg.printer.info(f'{argv[2*i][2:]}: {argv[2*i+1]}')
    cfg.printer.info('-' * 50)
    start = time.time()
    algorithm.run()
    end = time.time()
    cfg.printer.info(f'-----------------time cost:{(end-start)/3600:.1f}h-----------------')

if __name__ == '__main__':
    main()    