
import importlib
import json
import os 
import subprocess
from helpers import load_defence_params_json
import defended_model as dfnd


def load_codec_cfg(cfg_path):
    """
    Loads additional parameters for codec from json config, as well as name of file with codec class.

    Args:
        cfg_path (str): path to json config
    
    Returns:
        tuple(str, dict): filename of module with codec class, dictionary with params for the codec.
    """
    codec_additional_params = {}
    with open(cfg_path) as json_file:
        config = json.load(json_file)
        module = config['module']
        for k in config.keys():
            if k != 'module':
                codec_additional_params[k] = config[k]
    return module, codec_additional_params


def setup_codec(run_cfg):
    """Sets up the codec model based on the provided run configuration.

    This function reads the codec's configuration file, dynamically imports the
    correct module, and instantiates the codec model(s). For JPEGAI codecs,
    it also sets up a 'main_codec'. The instantiated models are added to the
    run configuration dictionary.

    Args:
        run_cfg (dict): The run configuration dictionary.

    Returns:
        dict: The updated run configuration dictionary with 'undef_model' and
              'undef_main_codec' (if applicable) added.
    """

    is_jpegai = False
    if 'jpegai' in run_cfg["codec"]:
        is_jpegai = True
    run_cfg['is_jpegai'] = is_jpegai

    config_path = 'config.json'
    try:
        module, codec_additional_params = load_codec_cfg(config_path)
    except: 
        config_path = 'src/config.json'
        module, codec_additional_params = load_codec_cfg(config_path)
    
    # with open(config_path) as json_file:
    #     config = json.load(json_file)
    #     module = config['module']
    #     for k in config.keys():
    #         if k != 'module':
    #             codec_additional_params[k] = config[k]
    
    module_path = f'src.{module}'
    if is_jpegai:
        module_path = f'{module}'
    
    module = importlib.import_module(module_path)
    model = module.CodecModel(run_cfg["device"], **codec_additional_params)
    main_codec = None
    if is_jpegai:
        main_codec = module.CodecModel(run_cfg["device"], is_main=True, **codec_additional_params)
    run_cfg['undef_model'] = model
    run_cfg['undef_main_codec'] = main_codec


    # 255 for JPEGAI, 1 for others
    input_range = 1
    if hasattr(run_cfg['undef_model'], 'input_range'):
        input_range = run_cfg['undef_model'].input_range
    output_range = 1
    if hasattr(run_cfg['undef_model'], 'output_range'):
        output_range = run_cfg['undef_model'].output_range
    output_cspace = 'rgb' # YCbCr for JPEGAI
    if hasattr(run_cfg['undef_model'], 'output_cspace'):
        output_cspace = run_cfg['undef_model'].output_cspace
    print(f'INPUT RANGE: {input_range}, OUTPUT RANGE: {output_range}, OUTPUT COLOR SPACE: {output_cspace}')
    run_cfg['input_range'] = input_range
    run_cfg['output_range'] = output_range
    run_cfg['output_cspace'] = output_cspace

    return run_cfg

def setup_defence(run_cfg):
    """Sets up and applies a defence mechanism to the codec models.

    This function runs any necessary setup scripts for the defence, imports the
    defence module, loads defence parameters from a presets file, and wraps the
    existing codec models with the defence. The defended models are added to the
    run configuration dictionary.

    Args:
        run_cfg (dict): The run configuration dictionary, which must already
                        contain the undefended codec models.

    Returns:
        dict: The updated run configuration dictionary with 'def_model',
              'def_main_codec' (if applicable), and 'defence_name' added.
    """
    if "setup.sh" in os.listdir('defence'):
        subprocess.run('bash defence/setup.sh', shell=True, check=True)
    dfnce = importlib.import_module(f'defence.defence')

    if os.path.exists('defence/defence_presets.json'):
        defence_args = load_defence_params_json(run_cfg["defence_preset"])
        print(f'Defence args:{defence_args}')
        defended_model = dfnd.CodecModel(run_cfg['undef_model'], dfnce.Defense(**defence_args), run_cfg["device"])
        defended_main_codec = None
        if run_cfg['is_jpegai']:
            defended_main_codec = dfnd.CodecModel(run_cfg['undef_main_codec'], dfnce.Defense(**defence_args), run_cfg["device"])
    else:
        print(f'No defence args')
        defended_model = dfnd.CodecModel(run_cfg['undef_model'], dfnce.Defense(), run_cfg["device"])
        defended_main_codec = None
        if run_cfg['is_jpegai']:
            defended_main_codec = dfnd.CodecModel(run_cfg['undef_main_codec'], dfnce.Defense(), run_cfg["device"])
    defence_name = defended_model.defence.defence_name

    defended_model.eval()

    run_cfg['def_model'] = defended_model
    run_cfg['def_main_codec'] = defended_main_codec
    run_cfg['defence_name'] = defence_name
    return run_cfg


def setup_files(run_cfg):
    """Performs initial file and environment setup for the run.

    This function sets an environment variable for the loss function name,
    writes it to a file for potential use by other scripts, and creates
    the necessary dump directory if it doesn't already exist.

    Args:
        run_cfg (dict): The run configuration dictionary containing keys like
                        'loss_name' and 'dump_path'.
    """
    os.environ['LOSS_NAME'] = str(run_cfg["loss_name"])
    with open('loss_f.txt', 'w') as t:
        txt = f'LOSS_NAME="{run_cfg["loss_name"]}"'
        t.write(txt)
    if not os.path.exists(run_cfg["dump_path"]):
        os.makedirs(run_cfg["dump_path"])


def setup_attack_presets(run_cfg):
    """Sets up list_of_presets based on 'attack_preset', 'only_default_preset' and 'run_all_presets' fields of run_cfg

    Args:
        run_cfg (dict): The run configuration dictionary containing keys 'attack_preset', 'only_default_preset' and 'run_all_presets'.

    Returns: 
       list: a list of preset indices to run.
    """
    list_of_presets = [run_cfg["attack_preset"]]
    if run_cfg["only_default_preset"]:
        print("[Warning] Ignoring attack preset and run-all-presets arguments, running attack only with default params.")
        list_of_presets = [-1]
    elif run_cfg["run_all_presets"]:
        print("[Warning] Ignoring attack preset argument, running all presets from 0 to 3.")
        list_of_presets = [x for x in range(0,3)]
    return list_of_presets