import os
import time
from copy import deepcopy
import git
import torch
from torch.utils.data import DataLoader

from gflownet.models import GPS
from gflownet.trainer import Trainer
from gflownet.utils import create_logger, set_main_process_seed, ModelProxy

BASE_DIR = 'logs'

def make_log_dir(config):
    log_dir = os.path.join(BASE_DIR, config.log_dir)
    i = 1
    suffix = ''
    
    while True:
        new_log_dir = log_dir + suffix
        if not os.path.exists(new_log_dir):    
            try:
                os.mkdir(new_log_dir)
                break
            except FileExistsError:
                pass

        suffix = f'-{i:02d}'
        i += 1
        if i == 100:
            break
    
    print("New directory for experiments:", new_log_dir)

    config_path = os.path.join(new_log_dir, 'config.yaml')
    with open(config_path, 'w') as f:
        OmegaConf.save(config, f)

    return new_log_dir


def init_every(config):
    from gflownet.graphenv import GraphEnv, MolEnv
    from gflownet.algo import TrajectoryBalance, DetailedBalance, PPOAlgorithm
    from gflownet.sampler import TrajectoryDataset
    from gflownet.featurizer import GraphStateFeaturizer
    from gflownet.models import PolicyNet
    from gflownet.monitor import get_exact_eval_callback

    if config.device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(config.device)

    if config.env.name == 'graphenv':
        env = GraphEnv(**config.env.graphenv)
    elif config.env.name == 'molenv':
        env = MolEnv(**config.env.molenv)
    else:
        raise ValueError(f'{config.env.name} is not supported')


    featurizer = GraphStateFeaturizer(env = env, **config.featurizer)
    gnn = GPS(
        node_dim = featurizer.node_dim, 
        edge_dim = featurizer.edge_dim, 
        **config.model
    )
    
    def load_dataloader(dataset, worker_init_fn, config):
        if config.num_workers == 0:
            config.prefetch_factor = None
            config.persistent_workers = False
        return DataLoader(dataset, worker_init_fn=worker_init_fn, **config)

    if config.algo == 'tb':

        model = PolicyNet(env, gnn, out_graph=0, backward_policy=config.tb.algo.learn_backward)
        model.logZ = torch.nn.Parameter(torch.tensor(1.0))
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        algo = TrajectoryBalance(optimizer, **config.tb.algo)
        proxy = ModelProxy(model, config.tb.dataloader.num_workers)
        dataset = TrajectoryDataset(
            model = proxy.placeholder, 
            env = env,
            featurizer = featurizer,
            compute_unif_backward_prob = not config.tb.algo.learn_backward,
            compute_backward_action_idxs = config.tb.algo.learn_backward,
            scale_reward = config.scale_reward,
            **config.tb.dataset
        )
        dataloader = load_dataloader(dataset, proxy.worker_init_fn, config.tb.dataloader)
        callback = get_exact_eval_callback(env, model, featurizer, eval_every=config.eval_every)
        callbacks = {
            'on_batch_end': [callback],
            'on_train_end': [callback]
        }

    elif config.algo == 'db':
        model = PolicyNet(env, gnn, out_graph=1, backward_policy=False)

        if config.db.algo.sampling_tau is not None:
            sampling_model = deepcopy(model)
            sampling_model.to(device)
        else:
            sampling_model = None
        
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        algo = DetailedBalance(optimizer, **config.db.algo, sampling_model=sampling_model)
        proxy = ModelProxy(sampling_model, config.db.dataloader.num_workers)
        dataset = TrajectoryDataset(
            model = proxy.placeholder, 
            env = env,
            featurizer = featurizer,
            compute_automorphism = config.compute_automorphism,
            compute_unif_backward_prob = True,
            compute_backward_action_idxs = False,
            scale_reward = config.scale_reward,
            transition_dataset = True,
            **config.db.dataset
        )
        dataloader = load_dataloader(dataset, proxy.worker_init_fn, config.db.dataloader)
        callback = get_exact_eval_callback(env, model, featurizer, eval_every=config.eval_every)
        callbacks = {
            'on_batch_end': [callback],
            'on_train_end': [callback]
        }

    elif config.algo == 'ppo':
        model = PolicyNet(env, gnn, out_graph=1)
        optimizer = torch.optim.Adam(model.parameters())
        algo = PPOAlgorithm(optimizer, **config.ppo.algo)
        proxy = ModelProxy(model, config.ppo.dataloader.num_workers)
        dataset = TrajectoryDataset(
            model = proxy.placeholder, 
            env = env,
            featurizer = featurizer,
            scale_reward = config.scale_reward,
            **config.ppo.dataset
        )
        dataloader = DataLoader(
            dataset, 
            worker_init_fn=proxy.worker_init_fn,
            **config.ppo.dataloader
        )
        dataloader = load_dataloader(dataset, proxy.worker_init_fn, config.ppo.dataloader)
        callbacks = {}

    return env, featurizer, model, dataloader, algo, callbacks, device


def main(config):
    log_dir = make_log_dir(config)

    set_main_process_seed(config.seed)
    env, featurizer, model, dataloader, algo, callbacks, device = init_every(config)

        
    logger = create_logger(
        logfile=os.path.join(log_dir, 'train.log'), 
        streamHandle=True
    )
    logger.info(f'model parameters: {model.num_params()}')
    trainer = Trainer(
        model = model, 
        dataloader = dataloader, 
        algo = algo, 
        logger = logger
    )
    for on_event, callback_list in callbacks.items():
        for callback in callback_list:
            trainer.add_callback(on_event, callback)

    if config.save_at_valid:
        from gflownet.monitor import get_save_callback
        callback = get_save_callback(log_dir, save_every=config.eval_every)
        trainer.add_callback('on_batch_end', callback)

    try:
        githash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7]
    except git.InvalidGitRepositoryError:
        githash = "unknown"

    logger.info(f'githash: {githash}')
    trainer.run(config.training_steps, device=device, print_every=config.print_every)
    trainer.save(os.path.join(log_dir, 'model.pt'))


# You can run the script as follows:
# python main.py [path to config.yaml]

if __name__ == "__main__":
    import sys
    from omegaconf import OmegaConf

    if len(sys.argv) > 1:
        config_path = sys.argv[1]
    else:
        config_path = 'config/default.yaml'

    config = OmegaConf.load(config_path)

    main(config)

