""" Config class for search/augment """
import argparse
import os
import genotypes as gt
from functools import partial
import torch


def get_parser(name):
    """ make default formatted parser """
    parser = argparse.ArgumentParser(name, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # print default value always
    parser.add_argument = partial(parser.add_argument, help=' ')
    return parser


def parse_gpus(gpus):
    if gpus == 'all':
        return list(range(torch.cuda.device_count()))
    else:
        return [int(s) for s in gpus.split(',')]


class BaseConfig(argparse.Namespace):
    def print_params(self, prtf=print):
        prtf("")
        prtf("Parameters:")
        for attr, value in sorted(vars(self).items()):
            prtf("{}={}".format(attr.upper(), value))
        prtf("")

    def as_markdown(self):
        """ Return configs as markdown format """
        text = "|name|value|  \n|-|-|  \n"
        for attr, value in sorted(vars(self).items()):
            text += "|{}|{}|  \n".format(attr, value)

        return text


class SearchConfig(BaseConfig):
    def build_parser(self):
        parser = get_parser("Search config")
        parser.add_argument('--name', default="cifar10", required=False)
        # parser.add_argument('--dataset', default="cifar10", required=False, help='CIFAR10 / MNIST / FashionMNIST')
        # parser.add_argument('--data_path', default="/cache/data/CIFAR10", type=str)
        parser.add_argument('--batch_size', type=int, default=4, help='batch size')
        parser.add_argument('--w_lr', type=float, default=0.025, help='lr for weights')
        parser.add_argument('--w_lr_min', type=float, default=0.001, help='minimum lr for weights')
        parser.add_argument('--w_momentum', type=float, default=0.9, help='momentum for weights')
        parser.add_argument('--w_weight_decay', type=float, default=3e-4,
                            help='weight decay for weights')
        parser.add_argument('--w_grad_clip', type=float, default=5.,
                            help='gradient clipping for weights')
        parser.add_argument('--print_freq', type=int, default=50, help='print frequency')
        # parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
        #                     '`all` indicates use all gpus.')
        # parser.add_argument('--epochs', type=int, default=50, help='# of training epochs')
        parser.add_argument('--init_channels', type=int, default=16)
        parser.add_argument('--layers', type=int, default=8, help='# of layers')
        parser.add_argument('--seed', type=int, default=2, help='random seed')
        parser.add_argument('--workers', type=int, default=4, help='# of workers')
        parser.add_argument('--alpha_lr', type=float, default=3e-4, help='lr for alpha')
        parser.add_argument('--alpha_weight_decay', type=float, default=1e-3,
                            help='weight decay for alpha')

        parser.add_argument('--dataset', default="cifar10", required=False, help='CIFAR10 / MNIST / FashionMNIST')
        # parser.add_argument('--data_path', default="/cache/data/CIFAR10", type=str)
        parser.add_argument("--log_dir", default="/cache/exps/boss/log")
        parser.add_argument("--shared_dir", default="/cache/exps/boss/shared")
        parser.add_argument("--code_dir", default="/home/ma-user/work/pytorch_boss_darts")
        parser.add_argument("--device", type=str, help="cuda or cpu", default="cuda:0")
        parser.add_argument("--epochs", type=int, help="epochs of iteration", default=2)  # 正常设置为50
        parser.add_argument("--prev_best_acc", default=0.0, type=float)
        parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
                            '`all` indicates use all gpus.')
        # parser.add_argument("--id", type=str, help="unique id of worker", default=None)

        return parser

    def __init__(self):
        parser = self.build_parser()
        args = parser.parse_args()
        super().__init__(**vars(args))

        if self.device == "cpu":
            self.gpus = ["cpu"]
        else:
            self.gpus = parse_gpus(self.device[-1])   # 取"cude:0"中的0
        if self.dataset == "cifar10":
            # self.data_path = "/cache/data/CIFAR10"
            self.data_path = "/Users/liyujun/Programs/data/CIFAR10"
        elif self.dataset == "cifar100":
            self.data_path = "/cache/data/CIFAR100"
        else:
            raise RuntimeError(f"wrong self.dataset {self.dataset}")


class AugmentConfig(BaseConfig):
    def build_parser(self):
        parser = get_parser("Augment config")
        parser.add_argument('--name', default="cifar10", required=False)
        # parser.add_argument('--dataset', required=True, help='CIFAR10 / MNIST / FashionMNIST')
        parser.add_argument('--batch_size', type=int, default=4, help='batch size')
        parser.add_argument('--lr', type=float, default=0.025, help='lr for weights')
        parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
        parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
        parser.add_argument('--grad_clip', type=float, default=5.,
                            help='gradient clipping for weights')
        parser.add_argument('--print_freq', type=int, default=200, help='print frequency')
        parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
                            '`all` indicates use all gpus.')
        # parser.add_argument('--epochs', type=int, default=600, help='# of training epochs')
        parser.add_argument('--init_channels', type=int, default=36)
        parser.add_argument('--layers', type=int, default=20, help='# of layers')
        parser.add_argument('--seed', type=int, default=2, help='random seed')
        parser.add_argument('--workers', type=int, default=4, help='# of workers')
        parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight')
        parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
        parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path prob')

        parser.add_argument('--genotype', required=True, help='Cell genotype')

        parser.add_argument('--dataset', default="cifar10")
        parser.add_argument('--data_path', default="/cache/data/CIFAR10", type=str)
        parser.add_argument('--epochs', type=int, default=600, help='# of training epochs')
        parser.add_argument("--log_dir", default="/cache/exps/boss/log")
        # parser.add_argument("--bucket_log_dir", default="s3://bucket-010/yujun/cache/exps/boss/log")
        parser.add_argument("--code_dir", default="/home/ma-user/work/pytorch_boss_darts")

        return parser

    def __init__(self):
        parser = self.build_parser()
        args = parser.parse_args()
        super().__init__(**vars(args))

        # self.data_path = './data/'
        # self.path = os.path.join('augments', self.name)
        # self.genotype = gt.from_str(self.genotype)
        # self.gpus = parse_gpus(self.gpus)

        self.genotype = gt.from_str(self.genotype)
        self.gpus = parse_gpus(self.gpus)

        if self.dataset == "cifar10":
            self.data_path = "/cache/data/CIFAR10"
        elif self.dataset == "cifar100":
            self.data_path = "/cache/data/CIFAR100"
        elif self.dataset == "imagenet":
            self.data_path = "/cache/data/ILSVRC/Data/CLS-LOC"
        else:
            raise RuntimeError(f"wrong self.dataset {self.dataset}")
