import torch
import sys
import os

from find_high_accuracy_path_v2.find_parameters import ParameterMove, ParameterTrain, ParameterRebuildNorm, ParameterGeneral
from find_high_accuracy_path_v2.runtime_parameters import RuntimeParameters

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from py_src.ml_setup import MlSetup

model_name = 'mobilenet_v2'

def get_parameter_general(runtime_parameter: RuntimeParameters, ml_setup: MlSetup):
    output = ParameterGeneral()
    if ml_setup.model_name == model_name:
        output.max_tick = 500 * 24
        output.dataloader_worker = 8
        output.test_dataset_use_whole = True
    else:
        raise NotImplemented
    return output

def get_parameter_move(runtime_parameter: RuntimeParameters, ml_setup: MlSetup):
    output = ParameterMove()
    always_ignore_list = ["running_mean", "running_var", "num_batches_tracked"]
    phase_time = 500
    adoptive_step_size = 0.002
    ratio_step_size = 0.002
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 1:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.1.", "bottlenecks.2.", "bottlenecks.3.", "bottlenecks.4.", "bottlenecks.5.", "bottlenecks.6.", "bottlenecks.7.", "bottlenecks.8.", "bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 2:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.2.", "bottlenecks.3.", "bottlenecks.4.", "bottlenecks.5.", "bottlenecks.6.", "bottlenecks.7.", "bottlenecks.8.", "bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 3:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.3.", "bottlenecks.4.", "bottlenecks.5.", "bottlenecks.6.", "bottlenecks.7.", "bottlenecks.8.", "bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 4:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.4.", "bottlenecks.5.", "bottlenecks.6.", "bottlenecks.7.", "bottlenecks.8.", "bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 5:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.5.", "bottlenecks.6.", "bottlenecks.7.", "bottlenecks.8.", "bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 6:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.6.", "bottlenecks.7.", "bottlenecks.8.", "bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 7:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.7.", "bottlenecks.8.", "bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 8:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.8.", "bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 9:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.9.", "bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 10:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.10.", "bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 11:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.11.", "bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 12:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.12.", "bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 13:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.13.", "bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 14:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.14.", "bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 15:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.15.", "bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 16:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["bottlenecks.16.", "fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 17:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = ["conv1.weight", "bn1.weight", "bn1.bias"]
            output.layer_skip_move_keyword = ["fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 18:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = []
            output.layer_skip_move_keyword = ["fc"] + always_ignore_list
            output.merge_bias_with_weights = False
        elif runtime_parameter.current_tick == phase_time * 19:
            output.step_size = 0
            output.adoptive_step_size = adoptive_step_size
            output.ratio_step_size = ratio_step_size
            output.layer_skip_move = []
            output.layer_skip_move_keyword = [] + always_ignore_list
            output.merge_bias_with_weights = False
        else:
            return None
    else:
        raise NotImplemented
    return output


def get_parameter_train(runtime_parameter: RuntimeParameters, ml_setup: MlSetup):
    output = ParameterTrain()
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            output.train_for_max_rounds = 1000
            output.train_for_min_rounds = 100
            output.train_until_loss = 0.1
            output.pretrain_optimizer = True
            output.load_existing_optimizer = False
        else:
            return None
    else:
        raise NotImplemented
    return output

def get_optimizer_train(runtime_parameter: RuntimeParameters, ml_setup: MlSetup, model_parameter):
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            optimizer = torch.optim.SGD(model_parameter, lr=0.001)
        else:
            return None
    else:
        raise NotImplemented
    return optimizer

def get_parameter_rebuild_norm(runtime_parameter: RuntimeParameters, ml_setup: MlSetup):
    output = ParameterRebuildNorm()
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            output.rebuild_norm_for_max_rounds = 0
            output.rebuild_norm_for_min_rounds = 0
            output.rebuild_norm_until_loss = 0
            output.rebuild_norm_layer = []
            output.rebuild_norm_layer_keyword = ['bn']
        else:
            return None
    else:
        raise NotImplemented
    return output

def get_optimizer_rebuild_norm(runtime_parameter: RuntimeParameters, ml_setup: MlSetup, model_parameter):
    if ml_setup.model_name == model_name:
        if runtime_parameter.current_tick == 0:
            optimizer = torch.optim.SGD(model_parameter, lr=0.001, momentum=0.9, weight_decay=5e-4)
        else:
            return None
    else:
        raise NotImplemented
    return optimizer