from email.policy import default
from test_tube import HyperOptArgumentParser
from test_tube.hpc import SlurmCluster
import json
import hashlib

from train_lightning import train_and_eval

if __name__ == "__main__":
    parser = HyperOptArgumentParser(strategy='grid_search')
    # run specific
    parser.add_argument('--device', type=int, default=0)
    parser.opt_list('--run_number', type=int, default=0, tunable=True, options=[1])

    #MLP
    parser.opt_list('--hidden_dimension', type=int, default=8, tunable=True, options=[32])
    parser.opt_list('--hidden_state_factor', type=float, default=2, tunable=True, options=[2])
    parser.opt_list('--mlp_depth', type=float, default=2, tunable=True, options=[2])
    parser.opt_list('--dropout', type=float, default=0.1, tunable=True, options=[0.0])
    parser.opt_list('--normalization', type=str, default='None', tunable=True, options=['LayerNorm'])
    parser.opt_list('--activation', type=str, default='ReLU', tunable=True, options=['ReLU'])


    # GNN Architechture
    parser.opt_list('--conv', type=str, default='none', tunable=True, options=['gru'])
    parser.opt_list('--aggregation', type=str, default='add', tunable=True, options=['add'])
    parser.opt_list('--skip_previous', type=bool, default=False, tunable=True, options=[False])
    parser.opt_list('--skip_input', type=bool, default=False, tunable=True, options=[False])
    parser.opt_list('--in_channels', type=int, default=2, tunable=True, options=[1])
    parser.opt_list('--out_dim', type=int, default=2, tunable=True, options=[2])


    # Flood / Echo params
    parser.opt_list('--num_rounds', type=int, default=0, tunable=True, options=[2])
    parser.opt_list('--num_k', type=int, default=1, tunable=True, options=[2])


    # optimizer / training
    parser.opt_list('--lr', type=float, default=0.0004, tunable=True, options=[0.0002])
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--randomize_between_epochs', type=bool, default=True)
    parser.opt_list('--batch_size', type=int, default=0, tunable=True, options=[32])
    parser.opt_list('--weight_decay', type=int, default=0, tunable=True, options=[0])
    parser.opt_list('--scheduler', type=str, default='None', tunable=True, options=['Plateau'])


    parser.opt_list('--dataset', type=str, default='rome', tunable=True, options=['rome'])

    parser.add_argument('--store_models', type=bool, default=True)
    parser.add_argument('--model_name', type=str, default=None)
    parser.add_argument('--use_tensorboard', type=bool, default=False)
    parser.add_argument('--use_wandb', type=bool, default=False)
    parser.add_argument('--slurm', action='store_true', default=True)
    parser.add_argument('--verbose', action='store_true', default=False)

    parser.add_argument('--config', type=str, default=None)
    parser.add_argument('--wandb_project_name', type=str, default="refactored")

    parser.opt_list('--experiment_run', type=str, default='debug', tunable=True)
    parser.opt_list('--start_mode', type=str, default='none', tunable=True)
    parser.opt_list('--start_selection', type=str, default='none', tunable=True)
    parser.opt_list('--model', type=str, default='none', tunable=True)


    parser.opt_list('--test_graph_size', type=int, default=0, tunable=True, options=[32])
    parser.opt_list('--load_model', type=str, default='none', tunable=True)
    parser.add_argument('--inference_mode', type=bool, default=False)

    args = parser.parse_args()


    override_run = 0
    override_wandb = False
    override_experiment_run = 'debug'
    override_start_mode = 'none'
    override_num_rounds = 0
    override_model = 'none'
    override_conv = 'none'
    override_start_selection = 'none'

    override_test_graph_size = 0
    override_load_model = 'none'
    override_inference_mode = False

    override_batch_size = 0

    if args.batch_size != 0:
        override_batch_size = args.batch_size

    if args.test_graph_size != 0:
        override_test_graph_size = args.test_graph_size

    if args.load_model != 'none':
        override_load_model = args.load_model

    if args.inference_mode:
        override_inference_mode = args.inference_mode 

    if args.start_selection != 'none':
        override_start_selection = args.start_selection

    if args.conv != 'none':
        override_conv = args.conv

    if args.model != 'none':
        override_model = args.model

    if args.num_rounds != 0:
        override_num_rounds = args.num_rounds

    if args.start_mode != 'none':
        override_start_mode = args.start_mode

    if args.run_number != 0:
        override_run = args.run_number

    if args.use_wandb:
        override_wandb = True

    if args.experiment_run != 'debug':
        override_experiment_run = args.experiment_run



    if args.config is not None:
        with open(args.config, 'r') as f:
            conf = json.load(f)
            for key in conf.keys():
                setattr(args, key, conf[key])
                #setattr(args, i, getattr(conf, i)) # same as "args.i = conf.i" but doesnt struggle with referencing to attributes with variables
    
    if override_batch_size != 0:
        args.batch_size = override_batch_size

    if override_start_selection != 'none':
        args.start_selection = override_start_selection

    if override_conv != 'none':
        args.conv = override_conv

    if override_num_rounds != 0:
        args.num_rounds = override_num_rounds
    
    if override_model != 'none':    
        args.model = override_model

    if override_start_mode != 'none':
        args.start_mode = override_start_mode

    if override_run != 0:
        args.run_number = override_run

    if override_wandb:
        args.use_wandb = override_wandb
    
    if override_experiment_run != 'debug':
        args.experiment_run = override_experiment_run

    if override_test_graph_size != 0:
        args.test_graph_size = override_test_graph_size

    if override_load_model != 'none':
        args.load_model = override_load_model

    if override_inference_mode:
        args.inference_mode = override_inference_mode


    if args.model_name is None:
        hash_object = hashlib.sha256(json.dumps(args.__getstate__()).encode())
        model_hash = hash_object.hexdigest()
        args.model_name = model_hash


    train_and_eval(args)
