# this is a wrapper prior that samples hyperparameters which are set to be ConfigSpace parameters
# from fblearner.flow.projects.ae.benchmarks.pfn.ConfigSpace import hyperparameters as CSH
# import fblearner.flow.projects.ae.benchmarks.pfn.ConfigSpace as CS
from copy import deepcopy

import torch

from .prior import Batch


def list_all_hps_in_nested(config):
    if isinstance(config, CSH.Hyperparameter):
        return [config]
    elif isinstance(config, dict):
        result = []
        for k, v in config.items():
            result += list_all_hps_in_nested(v)
        return result
    else:
        return []


def create_configspace_from_hierarchical(config):
    cs = CS.ConfigurationSpace()
    for hp in list_all_hps_in_nested(config):
        cs.add_hyperparameter(hp)
    return cs


def fill_in_configsample(config, configsample):
    # config is our dict that defines config distribution
    # configsample is a CS.Configuration
    hierarchical_configsample = deepcopy(config)
    for k, v in config.items():
        if isinstance(v, CSH.Hyperparameter):
            hierarchical_configsample[k] = configsample[v.name]
        elif isinstance(v, dict):
            hierarchical_configsample[k] = fill_in_configsample(v, configsample)
    return hierarchical_configsample


def sample_configspace_hyperparameters(hyperparameters):
    cs = create_configspace_from_hierarchical(hyperparameters)
    cs_sample = cs.sample_configuration()
    return fill_in_configsample(hyperparameters, cs_sample)


def get_batch_old(batch_size, *args, hyperparameters, get_batch, **kwargs):
    num_models = min(
        hyperparameters.get("num_hyperparameter_samples_per_batch", 1), batch_size
    )
    if num_models == -1:
        num_models = batch_size
    assert batch_size % num_models == 0, "batch_size must be a multiple of num_models"
    cs = create_configspace_from_hierarchical(hyperparameters)
    sub_batches = []
    for i in range(num_models):
        cs_sample = cs.sample_configuration()
        hyperparameters_sample = fill_in_configsample(hyperparameters, cs_sample)
        sub_batch = get_batch(
            batch_size // num_models,
            *args,
            hyperparameters=hyperparameters_sample,
            **kwargs
        )
        sub_batches.append(sub_batch)

    # concat x, y, target (and maybe style)
    # assert 3 <= len(sub_batch) <= 4
    # return tuple(torch.cat([sb[i] for sb in sub_batches], dim=(0 if i == 3 else 1)) for i in range(len(sub_batch)))
    assert all(
        not b.other_filled_attributes(set_of_attributes=("x", "y", "target_y"))
        for b in sub_batches
    )
    return Batch(
        x=torch.cat([b.x for b in sub_batches], dim=1),
        y=torch.cat([b.y for b in sub_batches], dim=1),
        target_y=torch.cat([b.target_y for b in sub_batches], dim=1),
    )


def get_batch(batch_size, *args, hyperparameters, get_batch, **kwargs):
    num_models = min(
        hyperparameters.get("num_hyperparameter_samples_per_batch", 1), batch_size
    )
    if num_models == -1:
        num_models = batch_size
    assert batch_size % num_models == 0, "batch_size must be a multiple of num_models"
    cs = create_configspace_from_hierarchical(hyperparameters)
    sub_batches = []
    for i in range(num_models):
        cs_sample = cs.sample_configuration()
        hyperparameters_sample = fill_in_configsample(hyperparameters, cs_sample)
        sub_batch = get_batch(
            batch_size // num_models,
            *args,
            hyperparameters=hyperparameters_sample,
            **kwargs
        )
        sub_batches.append(sub_batch)

    # concat x, y, target (and maybe style)
    # assert 3 <= len(sub_batch) <= 4
    # return tuple(torch.cat([sb[i] for sb in sub_batches], dim=(0 if i == 3 else 1)) for i in range(len(sub_batch)))
    assert all(
        not b.other_filled_attributes(set_of_attributes=("x", "y", "target_y"))
        for b in sub_batches
    )
    return Batch(
        x=torch.cat([b.x for b in sub_batches], dim=1),
        y=torch.cat([b.y for b in sub_batches], dim=1),
        target_y=torch.cat([b.target_y for b in sub_batches], dim=1),
    )
