import fff
import yaml
from jinja2.nodes import Break
from yaml import safe_load


def load_fff_model_and_nll(args, fff_directory="../../FFF/", version="version_0"):
    model_name = f"/lightning_logs/{args.dataset}/{version}/checkpoints/last.ckpt"
    model = fff.FreeFormFlow.load_from_checkpoint(fff_directory + model_name)
    return model


def fff_sample_and_logprob(model, n_samples=1000, samples=None):
    if samples is None:
        samples = model.sample_all(sample_shape=torch.Size([n_samples]))
    log_prob = model.exact_log_prob(samples)
    nll = log_prob.log_prob
    return samples, nll


def define_fff_config(args, fff_directory, model="fff"):
    assert model in ["fff", "fif"], "model must be either 'fff' or 'fif'"
    model_dict = {"fff": "fff.FreeFormFlow", "fif": "fff.FreeFormInjectiveFlow"}
    config_dict = {
        # training info
        "model": model_dict[model],
        "max_epochs": 2000,  # to adjust
        "batch_size": 64,  # args.n_samples_dataset, # to adjust
        "loss_weights": {"nll": 1, "noisy_reconstruction": 100},  # to adjust  # to adjust
        "lr_scheduler": "onecyclelr",
        "optimizer": {"name": "adam", "lr": 0.0002, "weight_decay": 0.0001},
        # 'check_val_every_n_epoch': 50,
        # "model_checkpoint":{
        #       "monitor": "auto",
        #       "save_last": True,
        #       "every_n_epochs": 50},
        # "skip_val_nll": 1,
        # model info
        "models": [
            # {'name': 'fff.model.FullyConnectedNetwork',
            #   'latent_dim': int # for injective flows with lower latent dimension
            # 'layer_spec': [256, 256]},
            {
                "name": "fff.model.ResNet",
                "layers_spec": [
                    [128, 128],
                    [128, 128],
                    [128, 128],
                    [128, 128],
                    [128, 128],
                    [128, 128],
                ],
            },
        ],
        # dataset info
        "data_set": {
            "name": args.dataset,
            "args": {
                "n_dim": args.x_dim,
                "n_samples_dataset": args.n_samples_dataset,
                "n_samples_dataset_test": args.n_samples_dataset_test,
                "n_mog": args.n_mog,
                "mog_sigma": args.mog_sigma,
                "means": None,
                "bounds": args.bounds,
                "epsilon": args.epsilon,
                "device": args.device,
                "seed": args.seed,
                "dataset": args.dataset,
                "script_path": str(args.script_path),
            },
        },
    }
    # if selected model is inejctive, specify latent dimension
    if model == "fif":
        config_dict["models"][0]["latent_dim"] = 1  # specify latent dimension
    # Write the configuration to a file
    config_filename = fff_directory / "configs" / model / f"margflow-{args.dataset}.yaml"
    with open(str(config_filename), "w") as yamlfile:
        # data = yaml.dump(config_dict, yamlfile, default_flow_style=False, width=80, sort_keys=False)
        data = yaml.safe_dump(config_dict, yamlfile, sort_keys=False)
        print(f"Wrote config file for free form flow on {args.dataset} dataset")
    # with open(config_filename, 'r') as yamlfile:
    #     safe_load(yamlfile)
