import torch 
import torch.nn as nn

from datetime import datetime 
import json 
import pathlib 
import numpy as np 

from functools import partial
from itertools import chain 
import sys 

sys.path.append('../scripts') 

import argparse 

parser = argparse.ArgumentParser(description='mvp') 

parser.add_argument('--domain', type=str, help='the domain', default='grid', choices=['grid', 'multisets', 'sequences', 'phylogenetics']) 
parser.add_argument('--seed', type=int, help='seed for rng', default=42) 
parser.add_argument('--num_clients', type=int, help='number of clients for partitioning inference', default=4) 
parser.add_argument('--device', type=str, help='cpu or gpu', default='cuda')

# Domain-specific attributes 
hyperparameters = {
    'grid': {
        'width': 16, 
        'height': 16, 
    }, 
    'multisets': {
        'size': 16, 
        'warehouse_size': 8, 
    }, 
    'sequences': {
        'max_size': 8, 
        'vocab_size': 6, 
    }, 
    'phylogenetics': {
        'num_leaves': 6, 
        'vocab_size': 4  
    }
}

for domain in hyperparameters:
    for feature in hyperparameters[domain]:
        parser.add_argument(f'--{domain}_{feature}', type=type(hyperparameters[domain][feature]), help=f'{feature} for {domain}', 
                    default=hyperparameters[domain][feature]) 

# Training  
parser.add_argument('--epochs', type=int, help='number of epochs for training', default=512) 
parser.add_argument('--lr', type=float, help='learning rate for CB and DB', default=3e-4) 
parser.add_argument('--batch_size_train', type=int, help='batch size for training', default=512) 
parser.add_argument('--hidden_dim', type=int, help='size of the hidden layer for the policy', default=128)
parser.add_argument('--emb_dim', type=int, help='the size of the embedding dim for the dictionary-based methods', default=16) 

# Evaluation 
parser.add_argument('--batch_size_eval', type=int, help='batch size for inference', default=512)
parser.add_argument('--num_batches_eval', type=int, help='number of batches to sample for inference', default=int(5e2)) 

# Assessment of different criteria 
parser.add_argument('--epochs_per_step', type=int, help='number of epochs between consecutive inferences', default=42) 
parser.add_argument('--criterion', type=str, help='criterion used to train the models', default='tb', choices=['db', 'tb', 'cb'])   
 
# Parse 
args = parser.parse_args() 

# Import each component of the model accordingly to the chosen domain 
match args.domain: 
    case 'grid': 
        from grids.envs import Grid2D, LogRewardDist 
        from grids.flows import ForwardFlow, BackwardFlow 
        from grids.variational import VariationalApproximationGrid, VariationalProductGrid, \
                            create_var, create_var_prod, samples_from_state 
        from grids.utils import create_gfn, create_env, create_state_flow, unique_smp
        create_env_args = {
            'width': args.grid_width, 
            'height': args.grid_height 
        }
        flow_args = {
            'hidden_dim': args.hidden_dim, 
            'device': args.device 
        }
    case 'multisets': 
        from sets.envs import Set, LogRewardLinear 
        from sets.flows import ForwardFlow, BackwardFlow 
        from sets.variational import VariationalApproximationMultiset, VariationalProductMultiset, \
                            create_var, create_var_prod, samples_from_state 
        from sets.utils import create_gfn, create_state_flow, create_env, unique_smp 
        create_env_args = {
            'set_size': args.multisets_size, 
            'warehouse_size': args.multisets_warehouse_size 
        }
        flow_args = {
            'hidden_dim': args.hidden_dim, 
            'emb_dim': args.emb_dim, 
            'device': args.device 
        } 
    case 'sequences': 
        from seqvarlen.envs import Sequence, LogRewardLinear 
        from seqvarlen.flows import ForwardFlow, BackwardFlow 
        from seqvarlen.variational import VariationalApproximationSequence, VariationalProductSequence, \
                            create_var, create_var_prod, samples_from_state  
        from seqvarlen.utils import create_gfn, create_state_flow, create_env, unique_smp  
        create_env_args = {
            'max_size': args.sequences_max_size, 
            'vocab_size': args.sequences_vocab_size 
        }
        flow_args = {
            'hidden_dim': args.hidden_dim, 
            'emb_dim': args.emb_dim, 
            'device': args.device 
        }
    case 'phylogenetics': 
        from phylogenetics.envs import Trees, LikelihoodReward 
        from phylogenetics.flows import ForwardFlow, BackwardFlow 
        # there is no variational approximation to distributed inference over trees 
        from phylogenetics.utils import create_gfn, create_state_flow, create_env
        create_env_args = {
            'num_leaves': args.phylogenetics_num_leaves, 
            'vocab_size': args.phylogenetics_vocab_size 
        }
        flow_args = {
            'hidden_dim': args.hidden_dim, 
            'warmup': 64, # for sensibly tempering the target distribution  
            'device': args.device
        }
        unique_smp = None 
    case _: 
        raise ValueError 

from gflownet import GFlowNet, GFlowNetEnsemble 
from utils import multiclient, federated_gflownets 

torch.set_default_device(args.device) 
torch.set_default_dtype(torch.float64) 
