from utils import cartesian_product
from models import *

shared_params = {
    'time_dim': [1],
    'lr': [3e-4],
    'wd': [1e-7],
    'gnn_act': ['tanh'],
    'sampler_size': [5],
    'num_gnn_layers': [5, 3, 1],
}

def get_TGN_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t):
    confs = shared_params
    confs['embedding_dim'] = [19, 9]
    

    for params in cartesian_product(confs):
        yield {
            'model_params':{
                'num_nodes': num_nodes,
                'edge_dim': edge_dim,
                'node_dim': node_dim,
                'memory_dim': params['embedding_dim'],
                'time_dim': params['time_dim'],
                'gnn_hidden_dim': [params['embedding_dim']] * params['num_gnn_layers'],
                'gnn_act': params['gnn_act'],
                'readout_hidden': max(1, params['embedding_dim'] // 2),
                'readout_out': out_dim,
                'mean_delta_t': mean_delta_t,
                'std_delta_t': std_delta_t,
                'init_time': init_time
            },
            'optim_params':{
                'lr': params['lr'],
                'wd': params['wd']
            },
            'sampler': {'size': params['sampler_size']},
        }


def get_DyRep_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t):
    confs = {
        'embedding_dim': [53, 26],
        'lr': shared_params['lr'],
        'wd': shared_params['wd'],
        'non_linearity': shared_params['gnn_act'],
        'sampler_size': shared_params['sampler_size'],
        'conv_type': shared_params['conv_type']
    }

    for params in cartesian_product(confs):
        yield {
            'model_params':{
                'num_nodes': num_nodes,
                'edge_dim': edge_dim,
                'node_dim': node_dim,
                'memory_dim': params['embedding_dim'],
                'non_linearity': params['non_linearity'],
                'readout_hidden': max(1, params['embedding_dim'] // 2),
                'readout_out': out_dim,
                'mean_delta_t': mean_delta_t,
                'std_delta_t': std_delta_t,
                'init_time': init_time
            },
            'optim_params':{
                'lr': params['lr'],
                'wd': params['wd']
            },
            'sampler': {'size': params['sampler_size']},
        }


def get_JODIE_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t):
    confs = {
        'embedding_dim': [69, 34],
        'time_dim':shared_params['time_dim'],
        'lr': shared_params['lr'],
        'wd': shared_params['wd'],
        'non_linearity': shared_params['gnn_act'],
        'sampler_size': shared_params['sampler_size']
    }

    for params in cartesian_product(confs):
        yield {
            'model_params':{
                'num_nodes': num_nodes,
                'edge_dim': edge_dim,
                'node_dim': node_dim,
                'memory_dim': params['embedding_dim'],
                'time_dim': params['time_dim'],
                'non_linearity': params['non_linearity'],
                'readout_hidden': max(1, params['embedding_dim'] // 2),
                'readout_out': out_dim,
                'mean_delta_t': mean_delta_t,
                'std_delta_t': std_delta_t,
                'init_time': init_time
            },
            'optim_params':{
                'lr': params['lr'],
                'wd': params['wd']
            },
            'sampler': {'size': params['sampler_size']},
        }


def get_TGAT_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t):
    confs = shared_params
    confs['embedding_dim'] = [24, 12] 

    for params in cartesian_product(confs):
        yield {
            'model_params':{
                'num_nodes': num_nodes,
                'edge_dim': edge_dim,
                'node_dim': node_dim,
                'time_dim': params['time_dim'],
                'gnn_hidden_dim': [params['embedding_dim']] * params['num_gnn_layers'],
                'gnn_act': params['gnn_act'],
                'readout_hidden': max(1, params['embedding_dim'] // 2),
                'readout_out': out_dim,
                'mean_delta_t': mean_delta_t,
                'std_delta_t': std_delta_t,
                'init_time': init_time
            },
            'optim_params':{
                'lr': params['lr'],
                'wd': params['wd']
            },
            'sampler': {'size': params['sampler_size']},
        }


def get_CTAN_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t):
    confs = {
        'embedding_dim': [53, 26],
        'time_dim': shared_params['time_dim'],
        'lr': shared_params['lr'],
        'wd': shared_params['wd'],
        'gnn_act': shared_params['gnn_act'],
        'num_iters': shared_params['num_gnn_layers'],
        'epsilon': [0.5, 0.1, 0.01, 0.001],
        'gamma': [0.5, 0.1, 0.01, 0.001],
        'sampler_size': shared_params['sampler_size'],
        'final_act': [None, 'tanh'],
        'conv_type': shared_params['conv_type']
    }

    for params in cartesian_product(confs):
        yield {
            'model_params':{
                'num_nodes': num_nodes,
                'edge_dim': edge_dim,
                'node_dim': node_dim,
                'time_dim': params['time_dim'],
                'memory_dim': params['embedding_dim'],
                'gnn_act': params['gnn_act'],
                'num_iters': params['num_iters'],
                'epsilon': params['epsilon'],
                'gamma': params['gamma'],
                'readout_hidden': max(1, params['embedding_dim'] // 2),
                'readout_out': out_dim,
                'mean_delta_t': mean_delta_t,
                'std_delta_t': std_delta_t,
                'init_time': init_time,
                'final_act': params['final_act']
            },
            'optim_params':{
                'lr': params['lr'],
                'wd': params['wd']
            },
            'sampler': {'size': params['sampler_size']},
        }


_tgn_fun = lambda num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t: get_TGN_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t)
_dyrep_fun = lambda num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t: get_DyRep_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t)
_jodie_fun = lambda num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t: get_JODIE_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t)
_tgat_fun = lambda num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t: get_TGAT_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t)
_ctan_fun = lambda num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t: get_CTAN_conf(num_nodes, edge_dim, node_dim, out_dim, init_time, mean_delta_t, std_delta_t)

tgn, dyrep, jodie, tgat = 'TGN', 'DyRep', 'JODIE', 'TGAT'
adgn = 'CTAN'

MODEL_CONFS = {
    tgn: (TGN, _tgn_fun),
    dyrep: (DyRep, _dyrep_fun),
    jodie: (JODIE, _jodie_fun),
    tgat: (TGAT, _tgat_fun),
    adgn: (CTAN, _ctan_fun)
}
