import argparse
import os

from utils.func import *
from utils.logger import *


class Parser:
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.argument()

    def add(self,name,type,help='None',default=None):
        if default == None:
            self.parser.add_argument('--'+name, type=type, help=help, required=True)
        else:
            self.parser.add_argument('--'+name, type=type, help=help, default=default)

    def argument(self):
        # Global options (Without default value)
        self.add('gpu',        str, 'GPU IDXs')
        self.add('run',        str, 'Run ID (ex: run0)')
        self.add('job',        str, 'Choices = {noise, bias, final} Example: train_noise_bias_final')
        self.add('data',       str, 'Dataset, Choices = [CM, WM, CT, BAR, IN, CL]')
        self.add('alg',        str, 'Algset = {vanilla, ours, repair, rebias, aflite, rubi, mixin, lff}')
        
        # Network options
        self.add('arch',        str,   'architecture List = {conv0, conv1, resnet18, resnet34, resnet50, resnet101, resnet152}', 'conv0')
        self.add('save',        str, 'Model save or not', 'True')

        # Data options
        self.add('noise', float, 'Portion of clean label', 0.90)
        self.add('bias',  float, 'Portion of major label', 0.995)
        
        # Ours Algorithm parameters (alpha, beta)
        self.add('alpha', float, 'Balance parameter between Magnitude and Directional score', 1.0)
        

        # Training options (With default value)
        self.add('log',        str, 'Debug level, Options = [debug, info, warning, error]', 'info')
        self.add('seed',       int, 'Seed', 123)
        self.add('denoise',    str, 'Do denoise?', 'True')
        

    def parse(self):
        args, unknown = self.parser.parse_known_args()
        if len(unknown) != 0:
            raise SystemExit('Unknwon arguments: {}'.format(unknown))

        args = self.gpu_setting(args)
        args = self.data_specific_args(args)
        args = self.task_specific_args(args)
        args = self.gen_path(args)
        args = self.logging(args)
        
        return args

    def gpu_setting(self,args):
        os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
        os.environ['CUDA_VISIBLE_DEVICES']=args.gpu
        args.device = t.device('cuda' if t.cuda.is_available() else 'cpu')

        return args


    def logging(self,args):
        args.log = logger(args)
        args.log.info('+'+'='*94+'+')
        for k,v in args.__dict__.items():
            args.log.info('||\t%15s\t | \t %60s \t ||'%(k,v))
        args.log.info('+'+'='*94+'+')
        return args

    def gen_path(self,args):
        args.out_path = f'./result/{args.data}-nr_{args.noise}-br_{args.bias}/'
        args.log_path = f'{args.out_path}{args.alg}/{args.run}/'
        args.output_path = f'{args.out_path}{args.alg}/{args.run}/outputs/'
        args.ckpt_path = f'{args.out_path}{args.alg}/{args.run}/ckpt/'
        
        gen_path(args.out_path)
        gen_path(args.log_path)
        gen_path(args.output_path)
        gen_path(args.ckpt_path)
            
        return args
        
    def task_specific_args(self,args):
        return args    


    def data_specific_args(self,args):
        
        if args.data == 'CM':
            args.img_size = 28
            args.img_dim = 3
            args.num_labels = 10
            args.lr = 0.02
            args.weight_decay = 0.001
            args.momentum = 0.9
            args.batch = 128
            args.epoch = 100
            args.arch = 'conv0'
            args.d_option = 0.0001
            args.lr_decay_step = 40
            args.lr_decay = 0.1
        
        elif args.data == 'WM':
            args.img_size = 56
            args.img_dim = 1
            args.num_labels = 10
            args.lr = 0.02
            args.weight_decay = 0.001
            args.momentum = 0.9
            args.batch = 128
            args.epoch = 100
            args.arch = 'conv0'
            args.d_option = 8
            args.lr_decay_step = 40
            args.lr_decay = 0.1
            
        return args

