from absl import flags, logging

flags.DEFINE_integer("cross_batch_layer", None, "Tmp cross batch layer")
flags.DEFINE_multi_integer("cross_batch_layers", [], "Tmp cross batch layer")
flags.DEFINE_integer("cross_batch", None, "Tmp cross batch range")
flags.DEFINE_integer("cross_batch_step_inc", 0, "How much to add to cross batch step")
flags.DEFINE_integer(
    "cross_batch_range_inc", 0, "How much to add to cross batch visible range"
)
flags.DEFINE_boolean(
    "local_positionals_in_cross_batch", False, "Tmp cross batch positionals"
)
flags.DEFINE_boolean(
    "encode_other_as_first_pos_cross_batch",
    False,
    "Whether to use positional encoding of first element for all keys outside of the local context",
)
flags.DEFINE_boolean("legacy_cross_batch", False, "Whether to use old cross_batch")
flags.DEFINE_boolean(
    "no_local_cross_batch", False, "Whether to use local_context in cross_batch"
)

flags.DEFINE_integer("eval_xl_cache_windows", 0, "")
flags.DEFINE_integer("eval_xl_top_k", 0, "")
flags.DEFINE_boolean("eval_xl_clear", False, "")

FLAGS = flags.FLAGS


class CrossBatchConfig:
    def __init__(
        self,
        cross_batch,
        cross_batch_step_inc,
        cross_batch_range_inc,
        cross_batch_layer,
        local_positionals_in_cross_batch,
        encode_other_as_first_pos_cross_batch,
        legacy_cross_batch,
        no_local_cross_batch,
    ) -> None:
        self.cross_batch = cross_batch
        self.cross_batch_step_inc = cross_batch_step_inc
        self.cross_batch_range_inc = cross_batch_range_inc
        self.cross_batch_layer = cross_batch_layer
        self.local_positionals_in_cross_batch = local_positionals_in_cross_batch
        self.encode_other_as_first_pos_cross_batch = (
            encode_other_as_first_pos_cross_batch
        )
        self.legacy_cross_batch = legacy_cross_batch
        self.no_local_cross_batch = no_local_cross_batch

    def get_updated_from_dict(self, d: dict):
        allowed_keys = ["cross_batch", "cross_batch_step_inc", "cross_batch_range_inc"]
        for k in d:
            assert k in allowed_keys

        return CrossBatchConfig(
            cross_batch=d.get("cross_batch", self.cross_batch),
            cross_batch_step_inc=d.get(
                "cross_batch_step_inc", self.cross_batch_step_inc
            ),
            cross_batch_range_inc=d.get(
                "cross_batch_range_inc", self.cross_batch_range_inc
            ),
            cross_batch_layer=self.cross_batch_layer,
            local_positionals_in_cross_batch=self.local_positionals_in_cross_batch,
            encode_other_as_first_pos_cross_batch=self.encode_other_as_first_pos_cross_batch,
            legacy_cross_batch=self.legacy_cross_batch,
            no_local_cross_batch=self.no_local_cross_batch,
        )


def get_specified_cross_batch_config():
    assert FLAGS.cross_batch_layer is None or FLAGS.cross_batch_layers == []

    if FLAGS.cross_batch_layer is not None:
        cb_layers = [FLAGS.cross_batch_layer]
    else:
        cb_layers = FLAGS.cross_batch_layers

    return CrossBatchConfig(
        cross_batch=FLAGS.cross_batch,
        cross_batch_step_inc=FLAGS.cross_batch_step_inc,
        cross_batch_range_inc=FLAGS.cross_batch_range_inc,
        cross_batch_layer=cb_layers,
        local_positionals_in_cross_batch=FLAGS.local_positionals_in_cross_batch,
        encode_other_as_first_pos_cross_batch=FLAGS.encode_other_as_first_pos_cross_batch,
        legacy_cross_batch=FLAGS.legacy_cross_batch,
        no_local_cross_batch=FLAGS.no_local_cross_batch,
    )


def get_empty_cross_batch_config():
    return CrossBatchConfig(
        cross_batch=None,
        cross_batch_step_inc=0,
        cross_batch_range_inc=0,
        cross_batch_layer=[],
        local_positionals_in_cross_batch=False,
        encode_other_as_first_pos_cross_batch=False,
        legacy_cross_batch=False,
        no_local_cross_batch=False,
    )


class CommonCfgs:
    def __init__(self, dataset_packing, xl_cache_windows, xl_top_k, cb_cfg, mode):
        print(
            f"Creating CommonCfgs with dataset_packing={dataset_packing}, xl_cache_windows={xl_cache_windows}, mode={mode}"
        )
        self.dataset_packing = dataset_packing
        self.xl_cache_windows = xl_cache_windows
        self.xl_top_k = xl_top_k
        self.cb_cfg: CrossBatchConfig = cb_cfg
        self.mode = mode

    def get_without_cross_batch(self):
        return CommonCfgs(
            dataset_packing=self.dataset_packing,
            xl_cache_windows=self.xl_cache_windows,
            xl_top_k=self.xl_top_k,
            cb_cfg=get_empty_cross_batch_config(),
            mode=self.mode,
        )

    def get_updated_from_dict(self, d: dict):
        logging.info(f"Updating common cfg with mode {self.mode} using {d}")
        allowed_keys = ["dataset_packing", "xl_cache_windows", "xl_top_k", "cb_cfg"]
        for k in d:
            assert k in allowed_keys

        return CommonCfgs(
            dataset_packing=d.get("dataset_packing", self.dataset_packing),
            xl_cache_windows=d.get("xl_cache_windows", self.xl_cache_windows),
            xl_top_k=d.get("xl_top_k", self.xl_top_k),
            cb_cfg=self.cb_cfg.get_updated_from_dict(d.get("cb_cfg", {})),
            mode=self.mode,
        )
