from yacs.config import CfgNode as CN
from .utils import log_msg


def dump_cfg(cfg, show: bool=False):
    dump = CN()
    dump.EXPERIMENT = cfg.EXPERIMENT
    dump.DATASET = cfg.DATASET
    dump.DISTILLER = cfg.DISTILLER
    dump.SOLVER = cfg.SOLVER
    dump.LOG = cfg.LOG
    
    if cfg.DISTILLER.TYPE.startswith('AMD_'):
        dump.AMD = cfg.AMD
        distiller_type = cfg.DISTILLER.TYPE[4:]
    elif cfg.DISTILLER.TYPE.startswith('VITKD'):
        distiller_type = 'VITKD'
    else:
        distiller_type = cfg.DISTILLER.TYPE
    if distiller_type in cfg:
        dump.update({distiller_type: cfg.get(distiller_type)})
    
    if (distiller_type == 'VITKD') and cfg.VITKD.REF_AMD:
        dump.update({'AMD': cfg.get('AMD')})
    
    if show:
        print(log_msg("CONFIG:\n{}".format(dump.dump()), "INFO"))
    return dump


CFG = CN()

# Experiment
CFG.EXPERIMENT = CN()
CFG.EXPERIMENT.PROJECT = "distill"
CFG.EXPERIMENT.NAME = ""
CFG.EXPERIMENT.TAG = "default"
CFG.EXPERIMENT.DDP = False
CFG.EXPERIMENT.AMP = False
CFG.EXPERIMENT.TRACE_LOSS = False

# Dataset
CFG.DATASET = CN()
CFG.DATASET.TYPE = "cifar100"
CFG.DATASET.SUBSET = False
CFG.DATASET.NUM_WORKERS = 2
CFG.DATASET.INPUT_SIZE=[224, 224]
CFG.DATASET.TEST = CN()
CFG.DATASET.TEST.BATCH_SIZE = 64

# Distiller
CFG.DISTILLER = CN()
CFG.DISTILLER.TYPE = "NONE"  # Vanilla as default
CFG.DISTILLER.TEACHER = "ResNet50"
CFG.DISTILLER.STUDENT = "resnet32"

# Solver
CFG.SOLVER = CN()
CFG.SOLVER.TRAINER = "base"
CFG.SOLVER.BATCH_SIZE = 64
CFG.SOLVER.EPOCHS = 240
CFG.SOLVER.LR = 0.05
CFG.SOLVER.WEIGHT_DECAY = 0.0001
CFG.SOLVER.TYPE = "SGD"
CFG.SOLVER.GRAD_CLIP = 0.0
CFG.SOLVER.SGD = CN()
CFG.SOLVER.SGD.MOMENTUM = 0.9
CFG.SOLVER.ADAM = CN()
CFG.SOLVER.ADAM.BETAS = [0.9, 0.999]
CFG.SOLVER.ADAM.EPSILON = 1.0E-8
CFG.SOLVER.SCHEDULE = CN()
CFG.SOLVER.SCHEDULE.TYPE = "MULTISTEP"
CFG.SOLVER.SCHEDULE.MULTISTEP = CN()
CFG.SOLVER.SCHEDULE.MULTISTEP.STAGES = [150, 180, 210]
CFG.SOLVER.SCHEDULE.MULTISTEP.RATE = 0.1
CFG.SOLVER.SCHEDULE.COSINE = CN()
CFG.SOLVER.SCHEDULE.COSINE.WARMUP = 5
CFG.SOLVER.SCHEDULE.COSINE.RATE = 1.0E-4

# Log
CFG.LOG = CN()
CFG.LOG.TENSORBOARD_FREQ = 500
CFG.LOG.SAVE_CHECKPOINT_FREQ = 40
CFG.LOG.PREFIX = "./output"
CFG.LOG.WANDB = False

# Distillation Methods

"""Artifact Manipulating Distillation"""
CFG.AMD = CN()
CFG.AMD.M_LAYERS = [5]  # manipulating layers
CFG.AMD.ALIGN_TYPE = 'mse'  # 'cosine', 'mse', 'both'
CFG.AMD.INPUT_SIZE = [224, 224]
CFG.AMD.AF = CN()
CFG.AMD.AF.ENABLE = True
CFG.AMD.AF.ARTIFACT_NORM = False
CFG.AMD.AF.CRITERIA = CN()
CFG.AMD.AF.CRITERIA.TYPE = 'zscore'  # 'zscore', 'gaussian_std'
CFG.AMD.AF.CRITERIA.THRES = 5.5
CFG.AMD.AF.RECON = CN()
CFG.AMD.AF.RECON.TYPE = 'recon_mha'  #'recon_mha'
CFG.AMD.LOSS = CN()
CFG.AMD.LOSS.ALIGN_WEIGHT = 1.0
CFG.AMD.LOSS.RECON_WEIGHT = 1.0
CFG.AMD.LOSS.FEAT_WEIGHT = 100.0
CFG.AMD.LOSS.OUTLIER_WEIGHT = 0.01
CFG.AMD.LOSS.INFO_WEIGHT = 1.0
CFG.AMD.LOSS.ORTHO_WEIGHT = 1.0
CFG.AMD.LOSS.NULLSPACE_WEIGHT = 1.0
CFG.AMD.LOSS.DETACH_REFINER = False
CFG.AMD.SNER = CN()
CFG.AMD.SNER.RANK = 16
CFG.AMD.SNER.NULL_THRES = 1.0E-3
CFG.AMD.SNER.OUTLIER_Q = 0.95
CFG.AMD.SNER.METHOD = 'sner' # 'SNER', 'RANDOM'


# ViTKD CFG
CFG.VITKD = CN()
CFG.VITKD.REF_AMD = True
CFG.VITKD.M_LAYERS = [5]
CFG.VITKD.LAYERS = [3]
CFG.VITKD.MASKING_RATIO = 0.5
CFG.VITKD.HPARAMS = CN()
CFG.VITKD.HPARAMS.ALPHA = 0.00003
CFG.VITKD.HPARAMS.BETA = 0.000003

# KD CFG
CFG.KD = CN()
CFG.KD.TEMPERATURE = 4
CFG.KD.LOSS = CN()
CFG.KD.LOSS.CE_WEIGHT = 0.1
CFG.KD.LOSS.KD_WEIGHT = 0.9

# AT CFG
CFG.AT = CN()
CFG.AT.P = 2
CFG.AT.LOSS = CN()
CFG.AT.LOSS.CE_WEIGHT = 1.0
CFG.AT.LOSS.FEAT_WEIGHT = 1000.0

# RKD CFG
CFG.RKD = CN()
CFG.RKD.DISTANCE_WEIGHT = 25
CFG.RKD.ANGLE_WEIGHT = 50
CFG.RKD.LOSS = CN()
CFG.RKD.LOSS.CE_WEIGHT = 1.0
CFG.RKD.LOSS.FEAT_WEIGHT = 1.0
CFG.RKD.PDIST = CN()
CFG.RKD.PDIST.EPSILON = 1e-12
CFG.RKD.PDIST.SQUARED = False

# FITNET CFG
CFG.FITNET = CN()
CFG.FITNET.HINT_LAYER = 5  # (0, 1, 2, 3, 4) fit 
CFG.FITNET.INPUT_SIZE = [32, 32]
CFG.FITNET.LOSS = CN()
CFG.FITNET.LOSS.CE_WEIGHT = 1.0
CFG.FITNET.LOSS.FEAT_WEIGHT = 10.0
CFG.FITNET.AF = CN()
CFG.FITNET.AF.ENABLE = False
CFG.FITNET.AF.CRITERIA = CN()
CFG.FITNET.AF.CRITERIA.TYPE = 'zscore'
CFG.FITNET.AF.CRITERIA.THRES = 5.5

# FITNET MULTILAYER (FitViT) CFG
CFG.FITVIT = CN()
CFG.FITVIT.M_LAYERS = [5]

# KDSVD CFG
CFG.KDSVD = CN()
CFG.KDSVD.K = 1
CFG.KDSVD.LOSS = CN()
CFG.KDSVD.LOSS.CE_WEIGHT = 1.0
CFG.KDSVD.LOSS.FEAT_WEIGHT = 1.0

# OFD CFG
CFG.OFD = CN()
CFG.OFD.LOSS = CN()
CFG.OFD.LOSS.CE_WEIGHT = 1.0
CFG.OFD.LOSS.FEAT_WEIGHT = 0.001
CFG.OFD.CONNECTOR = CN()
CFG.OFD.CONNECTOR.KERNEL_SIZE = 1

# NST CFG
CFG.NST = CN()
CFG.NST.LOSS = CN()
CFG.NST.LOSS.CE_WEIGHT = 1.0
CFG.NST.LOSS.FEAT_WEIGHT = 50.0

# PKT CFG
CFG.PKT = CN()
CFG.PKT.LOSS = CN()
CFG.PKT.LOSS.CE_WEIGHT = 1.0
CFG.PKT.LOSS.FEAT_WEIGHT = 30000.0

# SP CFG
CFG.SP = CN()
CFG.SP.LOSS = CN()
CFG.SP.LOSS.CE_WEIGHT = 1.0
CFG.SP.LOSS.FEAT_WEIGHT = 3000.0

# VID CFG
CFG.VID = CN()
CFG.VID.LOSS = CN()
CFG.VID.LOSS.CE_WEIGHT = 1.0
CFG.VID.LOSS.FEAT_WEIGHT = 1.0
CFG.VID.EPS = 1e-5
CFG.VID.INIT_PRED_VAR = 5.0
CFG.VID.INPUT_SIZE = [32, 32]

# CRD CFG
CFG.CRD = CN()
CFG.CRD.MODE = "exact"  # ("exact", "relax")
CFG.CRD.FEAT = CN()
CFG.CRD.FEAT.DIM = 128
CFG.CRD.FEAT.STUDENT_DIM = 256
CFG.CRD.FEAT.TEACHER_DIM = 256
CFG.CRD.LOSS = CN()
CFG.CRD.LOSS.CE_WEIGHT = 1.0
CFG.CRD.LOSS.FEAT_WEIGHT = 0.8
CFG.CRD.NCE = CN()
CFG.CRD.NCE.K = 16384
CFG.CRD.NCE.MOMENTUM = 0.5
CFG.CRD.NCE.TEMPERATURE = 0.07

# ReviewKD CFG
CFG.REVIEWKD = CN()
CFG.REVIEWKD.CE_WEIGHT = 1.0
CFG.REVIEWKD.REVIEWKD_WEIGHT = 1.0
CFG.REVIEWKD.WARMUP_EPOCHS = 20
CFG.REVIEWKD.SHAPES = [1, 8, 16, 32]
CFG.REVIEWKD.OUT_SHAPES = [1, 8, 16, 32]
CFG.REVIEWKD.IN_CHANNELS = [64, 128, 256, 256]
CFG.REVIEWKD.OUT_CHANNELS = [64, 128, 256, 256]
CFG.REVIEWKD.MAX_MID_CHANNEL = 512
CFG.REVIEWKD.STU_PREACT = False

# DKD(Decoupled Knowledge Distillation) CFG
CFG.DKD = CN()
CFG.DKD.CE_WEIGHT = 1.0
CFG.DKD.ALPHA = 1.0
CFG.DKD.BETA = 8.0
CFG.DKD.T = 4.0
CFG.DKD.WARMUP = 20


# DOT CFG
CFG.SOLVER.DOT = CN()
CFG.SOLVER.DOT.DELTA = 0.075
