import torch
import sys
import os

from find_high_accuracy_path_v2.find_parameters import ParameterMove, ParameterTrain, ParameterRebuildNorm, ParameterGeneral
from find_high_accuracy_path_v2.runtime_parameters import RuntimeParameters

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from py_src.ml_setup import MlSetup

model_name = 'densenet_cifar'

def get_parameter_general(runtime_parameter: RuntimeParameters, ml_setup: MlSetup):
    output = ParameterGeneral()
    if ml_setup.model_name == model_name:
        output.max_tick = 200*25
        output.dataloader_worker = 8
        output.test_dataset_use_whole = True
    else:
        raise NotImplemented
    return output

def get_parameter_move(runtime_parameter: RuntimeParameters, ml_setup: MlSetup):
    output = ParameterMove()
    test_weights_keyword = ['running_mean', 'running_var', 'num_batches_tracked']
    phase_time = 200
    adoptive_step_size = 0.002
    ratio_step_size = 0.004
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            # conv1
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense', 'trans', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 1:
            # dense1.0-dense1.2
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense1.3', 'dense1.4', 'dense1.5' ,'dense2', 'dense3', 'dense4', 'trans', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 2:
            # dense1,3-dense1,5
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense2', 'dense3', 'dense4', 'trans', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 3:
            # trans1
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense2', 'dense3', 'dense4', 'trans2', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 4:
            # dense2.0-dense2.2
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense2.3','dense2.4','dense2.5','dense2.6','dense2.7','dense2.8','dense2.9','dense2.10','dense2.11', 'dense3', 'dense4', 'trans2', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 5:
            # dense2.3-dense2.5
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense2.6','dense2.7','dense2.8','dense2.9','dense2.10','dense2.11', 'dense3', 'dense4', 'trans2', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 6:
            # dense2.6-dense2.8
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense2.9','dense2.10','dense2.11', 'dense3', 'dense4', 'trans2', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 7:
            # dense2.9-dense2.11
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense3', 'dense4', 'trans2', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 8:
            # trans2
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense3', 'dense4', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 9:
            # dense3.0-dense3.3
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense3.4','dense3.5','dense3.6','dense3.7',
                                              'dense3.8','dense3.9','dense3.10','dense3.11', 
                                              'dense3.12','dense3.13','dense3.14','dense3.15',
                                              'dense3.16','dense3.17','dense3.18','dense3.19',
                                              'dense3.20','dense3.21','dense3.22','dense3.23',
                                              'dense4', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 10:
            # dense3.4-dense3.7
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense3.8','dense3.9','dense3.10','dense3.11', 
                                              'dense3.12','dense3.13','dense3.14','dense3.15',
                                              'dense3.16','dense3.17','dense3.18','dense3.19',
                                              'dense3.20','dense3.21','dense3.22','dense3.23',
                                              'dense4', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 11:
            # dense3.8-dense3.11
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense3.12','dense3.13','dense3.14','dense3.15',
                                              'dense3.16','dense3.17','dense3.18','dense3.19',
                                              'dense3.20','dense3.21','dense3.22','dense3.23',
                                              'dense4', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 12:
            # dense3.12-dense3.15
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense3.16','dense3.17','dense3.18','dense3.19',
                                              'dense3.20','dense3.21','dense3.22','dense3.23',
                                              'dense4', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 13:
            # dense3.16-dense3.19
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense3.20','dense3.21','dense3.22','dense3.23',
                                              'dense4', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 14:
            # dense3.20-dense3.23
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense4', 'trans3', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 15:
            # trans3
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense4', 'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 16:
            # dense4.0-dense4.3
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense4.4','dense4.5','dense4.6','dense4.7',
                                              'dense4.8','dense4.9','dense4.10','dense4.11',
                                              'dense4.12','dense4.13','dense4.14','dense4.15',
                                               'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 17:
            # dense4.4-dense4.7
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = []
            output.layer_skip_move_keyword = ['dense4.8','dense4.9','dense4.10','dense4.11',
                                              'dense4.12','dense4.13','dense4.14','dense4.15',
                                               'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 18:
            # dense4.8-dense4.11
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['dense4.12','dense4.13','dense4.14','dense4.15',
                                               'linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 19:
            # dense4.12-dense4.15
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ['bn.weight','bn.bias']
            output.layer_skip_move_keyword = ['linear'] + test_weights_keyword
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 20:
            # bn, linear
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = []
            output.layer_skip_move_keyword = [] + test_weights_keyword
            output.merge_bias_with_weights = False
        else:
            return None
    else:
        raise NotImplemented
    return output


def get_parameter_train(runtime_parameter: RuntimeParameters, ml_setup: MlSetup):
    output = ParameterTrain()
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            output.train_for_max_rounds = 1000
            output.train_for_min_rounds = 100
            output.train_until_loss = 0.005
            output.pretrain_optimizer = False
            output.load_existing_optimizer = False
        else:
            return None
    else:
        raise NotImplemented
    return output

def get_optimizer_train(runtime_parameter: RuntimeParameters, ml_setup: MlSetup, model_parameter):
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            # optimizer = torch.optim.SGD(model_parameter, lr=0.001)
            base_lr = 0.001
            optimizer = torch.optim.SGD(
                [{'params': param, 'lr': base_lr} for param in model_parameter]
            )
        else:
            return None
    else:
        raise NotImplemented
    return optimizer

def get_parameter_rebuild_norm(runtime_parameter: RuntimeParameters, ml_setup: MlSetup):
    output = ParameterRebuildNorm()
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            output.rebuild_norm_for_max_rounds = 0
            output.rebuild_norm_for_min_rounds = 0
            output.rebuild_norm_until_loss = 0
            output.rebuild_norm_layer = []
            output.rebuild_norm_layer_keyword = []
        else:
            return None
    else:
        raise NotImplemented
    return output

def get_optimizer_rebuild_norm(runtime_parameter: RuntimeParameters, ml_setup: MlSetup, model_parameter):
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            # optimizer = torch.optim.SGD(model_parameter, lr=0.001, momentum=0.9, weight_decay=5e-4)
            base_lr = 0.001
            optimizer = torch.optim.SGD(
                [{'params': param, 'lr': base_lr} for param in model_parameter],
                momentum=0.9, weight_decay=5e-4
            )
        else:
            return None
    else:
        raise NotImplemented
    return optimizer

