import sys
import subprocess as sp
from itertools import product, chain
import numpy as np
from typing import List, Dict

def dict_product(dicts):
    return (dict(zip(dicts, x)) for x in product(*dicts.values()))

# Default Arguments
# -------------------------------------------------------------------------------
use_wandb = 0
wandb_project_name = 'name'
wandb_entity = 'entity'
wandb_job_type = 'job_type'

seed = 7777

dataset_config_path = 'configs/synthetic.yaml'
sample_modes = 'sine,tanh,gaussian,relu'

num_epochs = 16000
batch_size = 64
validation_batch_size = 20
num_vis = 8

validate_every = 2000
log_interval = 200
log_image_interval = 1000
validation_repeat = 10

dim_hidden = 128
num_layers = 5
ff_dim = dim_hidden

inr_type = 'composer'
sigma = 0.0

meta_sgd_lr_max = 5
outer_lr = 1e-4

inr_ckpt_idx = 0
model_type = 'multimodal'

use_meta_test_set = True
grad_encoder_mm_attn_type = 'spatial'
context_encoder_pos_embed_type = 'concat'

loss_weight_mode = 'uncertainty'
logvar_lr_ratio = 1

meta_target = 'query'
grad_encoder_type = 'transformer'
grad_encoder_pos_embed_type = 'learned'
context_pooler_pos_embed_type = 'learned'
latent_spatial_shapes = '8'

Ra, Rb, Ma, Mb = 100000, 1, 1, 1
Rmin, Rmax = 0.1, 1

Rrange_lists = [0.01, 0.02, 0.05, 0.1]
Rrange_lists = '-'.join([str(Rrange) for Rrange in Rrange_lists])

run_script = 'main.py'

base_cmd = ' '.join([
    'python {}'.format(run_script),
    '--use_wandb {}'.format(use_wandb),
    '--wandb_project_name {}'.format(wandb_project_name),
    '--wandb_entity {}'.format(wandb_entity),
    '--wandb_job_type {}'.format(wandb_job_type),
    '--dataset_config_path {}'.format(dataset_config_path),
    '--model_type {}'.format(model_type),
    '--dim_hidden {}'.format(dim_hidden),
    '--num_layers {}'.format(num_layers),
    '--batch_size {}'.format(batch_size),
    '--validation_batch_size {}'.format(validation_batch_size),
    '--validate_every {}'.format(validate_every),
    '--log_interval {}'.format(log_interval),
    '--log_image_interval {}'.format(log_image_interval),
    '--num_vis {}'.format(num_vis),
    '--ff_dim {}'.format(ff_dim),
    '--num_epochs {}'.format(num_epochs),
    '--validation_repeat {}'.format(validation_repeat),
    '--sample_modes {}'.format(sample_modes),
    '--meta_sgd_lr_max {}'.format(meta_sgd_lr_max),
    '--outer_lr {}'.format(outer_lr),
    '--inr_ckpt_idx {}'.format(inr_ckpt_idx),
    '--grad_encoder_mm_attn_type {}'.format(grad_encoder_mm_attn_type),
    '--context_encoder_pos_embed_type {}'.format(context_encoder_pos_embed_type),
    '--context_pooler_pos_embed_type {}'.format(context_pooler_pos_embed_type),
    '--grad_encoder_pos_embed_type {}'.format(grad_encoder_pos_embed_type),
    '--meta_target {}'.format(meta_target),
    '--grad_encoder_type {}'.format(grad_encoder_type),
    '--Rrange_lists {}'.format(Rrange_lists),
    '--use_meta_test_set {}'.format(use_meta_test_set),
    '--seed {}'.format(seed),
    '--loss_weight_mode {}'.format(loss_weight_mode),
    '--logvar_lr_ratio {}'.format(logvar_lr_ratio),
    '--inr_type {}'.format(inr_type),
    '--sigma {}'.format(sigma),
    '--latent_spatial_shapes {}'.format(latent_spatial_shapes),

    '--Rmin {}'.format(Rmin),
    '--Rmax {}'.format(Rmax),
    '--Ra {}'.format(Ra),
    '--Rb {}'.format(Rb),
    '--Ma {}'.format(Ma),
    '--Mb {}'.format(Mb),
])

search_dicts = [
    dict(
        wandb_tags = ['Composer-CAVIA'],
        modes = [sample_modes], encoder_capa = ['tiny'],
        inner_steps = [3],
        use_meta_sgd = ['False'], meta_sgd_lr_init = [1.00], meta_sgd_lr_max = [5.0],

        # GE
        grad_encoder_depth = [
            [0, 0],
        ],
        grad_encoder_use_alfa = ['False'],
        grad_encoder_use_fuser = ['False'],

        # CE
        context_encoder_depth = [
            [0, 0],
        ],
    ),

    dict(
        wandb_tags = ['Composer-MetaSGD'],
        modes = [sample_modes], encoder_capa = ['tiny'],
        inner_steps = [3],
        use_meta_sgd = ['True'], meta_sgd_lr_init = [1.00], meta_sgd_lr_max = [5.0],

        # GE
        grad_encoder_depth = [
            [0, 0],
        ],
        grad_encoder_use_alfa = ['False'],
        grad_encoder_use_fuser = ['False'],

        # CE
        context_encoder_depth = [
            [0, 0],
        ],
    ),

    dict(
        wandb_tags = ['Composer-ALFA'],
        modes = [sample_modes], encoder_capa = ['tiny'],
        inner_steps = [3],
        use_meta_sgd = ['True'], meta_sgd_lr_init = [1.00], meta_sgd_lr_max = [5.0],

        # GE
        grad_encoder_depth = [
            [0, 0],
        ],
        grad_encoder_use_alfa = ['True'],
        grad_encoder_use_fuser = ['False'],

        # CE
        context_encoder_depth = [
            [0, 0],
        ],
    ),

    dict(
        wandb_tags = ['Composer-Encoder'],
        modes = [sample_modes], encoder_capa = ['tiny'],
        inner_steps = [0],
        use_meta_sgd = ['False'], meta_sgd_lr_init = [1.00], meta_sgd_lr_max = [5.0],

        # GE
        grad_encoder_depth = [
            [0, 0],
        ],
        grad_encoder_use_alfa = ['False'],
        grad_encoder_use_fuser = ['False'],

        # CE
        context_encoder_depth = [
            [1, 1],
        ],
    ),
    dict(
        wandb_tags = ['Composer-SFT'],
        modes = [sample_modes], encoder_capa = ['tiny'],
        inner_steps = [3],
        use_meta_sgd = ['False'], meta_sgd_lr_init = [1.00], meta_sgd_lr_max = [5.0],

        # GE
        grad_encoder_depth = [
            [1, 1],
        ],
        grad_encoder_use_alfa = ['False'],
        grad_encoder_use_fuser = ['True'],

        # CE
        context_encoder_depth = [
            [0, 0],
        ],
    ),

]


params_dict_list = list(chain(*[dict_product(v) for v in search_dicts]))

def generate_cmd(params_dict):
    for k, v in params_dict.items():
        globals()[k] = v

    context_encoder_um_depth, context_encoder_mm_depth = 0, 0
    grad_encoder_um_depth, grad_encoder_mm_depth = 0, 0
    if sum(context_encoder_depth) > 0 and sum(grad_encoder_depth) > 0:
        raise("Both context and grad encoder are used")
    elif sum(context_encoder_depth) > 0:
        context_encoder_um_depth, context_encoder_mm_depth = context_encoder_depth
    elif sum(grad_encoder_depth) > 0:
        grad_encoder_um_depth, grad_encoder_mm_depth = grad_encoder_depth

    outer_steps = str(inner_steps)
    validation_inner_steps = str(inner_steps)

    if encoder_capa == 'tiny':
        encoder_dim, encoder_heads = 192, 3
    elif encoder_capa == 'small':
        encoder_dim, encoder_heads = 384, 6
    elif encoder_capa == 'base':
        encoder_dim, encoder_heads = 768, 12

    context_encoder_dim = grad_encoder_dim = encoder_dim
    context_encoder_heads = grad_encoder_heads = encoder_heads

    exp_name = '_'.join([
        'tag:{}'.format(wandb_tags),
    ])

    cmd = ' '.join([
        base_cmd,
        '--name {}'.format(exp_name),

        '--context_encoder_dim {}'.format(context_encoder_dim),
        '--grad_encoder_dim {}'.format(grad_encoder_dim),
        '--context_encoder_heads {}'.format(context_encoder_heads),
        '--grad_encoder_heads {}'.format(grad_encoder_heads),
        '--context_encoder_um_depth {}'.format(context_encoder_um_depth),
        '--context_encoder_mm_depth {}'.format(context_encoder_mm_depth),
        '--grad_encoder_um_depth {}'.format(grad_encoder_um_depth),
        '--grad_encoder_mm_depth {}'.format(grad_encoder_mm_depth),

        '--wandb_tags {}'.format(wandb_tags),
        '--modes {}'.format(modes),

        '--meta_sgd_lr_init {}'.format(meta_sgd_lr_init),
        '--use_meta_sgd {}'.format(use_meta_sgd),
        '--meta_sgd_lr_max {}'.format(meta_sgd_lr_max),

        '--inner_steps {}'.format(inner_steps),
        '--outer_steps {}'.format(outer_steps),
        '--validation_inner_steps {}'.format(validation_inner_steps),
        '--grad_encoder_use_alfa {}'.format(grad_encoder_use_alfa),
        '--grad_encoder_use_fuser {}'.format(grad_encoder_use_fuser),
    ])

    return cmd, exp_name

cmd, exp_name = generate_cmd(params_dict_list[int(sys.argv[1])-1])
print(exp_name)
sp.run(cmd.split())
