from functools import partial

experiment_registry = {}


def register_exp(exp_name):
  def decorator(exp_fn):
    experiment_registry[exp_name] = exp_fn
    return exp_fn

  return decorator


def waterbirds(
    context_class_size: int,
    selected_settings: list[str],
    model: str = 'incontext_learner_gptj_80m',
    encoding_extractor: str = 'dinov2_vitb14',
    reverse_task: bool = False,
    modified: bool = False,
    ):
  """Experiments for waterbirds."""

  assert context_class_size in [64, 256]
  n_repeats = 5

  settings = {
      # group-agnostic: truly ERM-like
      'group-agnostic': {
          'context_group_proportions': '[0.45, 0.05, 0.05, 0.45]',
          'query_group_proportions': '[0.45, 0.05, 0.05, 0.45]',
          'spurious_setting': 'wb_erm',
          'use_context_as_intermediate_queries': True,
          },
      # V1 behavior: for demonstrating that it is important to have 
      # intermediate queries follow the final query distribution
      'v1': {
          'context_group_proportions': '[0.45, 0.05, 0.05, 0.45]',
          'query_group_proportions': '[0.25, 0.25, 0.25, 0.25]',
          'spurious_setting': 'wb_erm',
          'use_context_as_intermediate_queries': True,
          },
      # V2 behavior: to show that robustness is not learned in context.
      'v2_erm': {
          'context_group_proportions': '[0.45, 0.05, 0.05, 0.45]',
          'query_group_proportions': '[0.25, 0.25, 0.25, 0.25]',
          'spurious_setting': 'wb_erm',
          'use_context_as_intermediate_queries': False,
      },
      'v2_dro': {
          'context_group_proportions': '[0.45, 0.05, 0.05, 0.45]',
          'query_group_proportions': '[0.25, 0.25, 0.25, 0.25]',
          'spurious_setting': 'wb_dro',
          'use_context_as_intermediate_queries': False,
      },
  }

  settings = {k: v for k, v in settings.items() if k in selected_settings}

  sp_token_generation_mode = 'opposite'
  datamodule = 'waterbirds_emb_contexts'
  optimizer = 'adam'
  batch_size = 32
  if context_class_size == 256:
    batch_size = 24

  augmentations = [
      {'rotate_encodings': False, 'randomly_swap_labels': False,
       'permute_input_dim': False},
      {'rotate_encodings': False, 'randomly_swap_labels': False,
       'permute_input_dim': True},
  ]

  n_rotation_matrices = 10000

  ask_context_probs = ['null', 0.25]

  lrs = [0.00003, 0.00006, 0.0001]

  val_sets = '[train, train_val, train_test, val]'
  if context_class_size >= 256:
    val_sets = '[train, train_val, train_test]'

  precision = '16-mixed'

  input_layer_norm = False

  args_per_job = []

  for setting_name, setting_dict in settings.items():
    for aug_idx, aug in enumerate(augmentations):
      for ask_context_prob in ask_context_probs:
        for lr in lrs:
          for seed in range(n_repeats):

            train_len = 2000000

            val_check_interval = 1000
            if aug_idx == 0:
              train_len = 200000  # 10 times less
              val_check_interval = 250  # 4 times less

            context_group_proportions = setting_dict['context_group_proportions']  # pylint: disable=line-too-long
            query_group_proportions = setting_dict['query_group_proportions']
            spurious_setting = setting_dict['spurious_setting']
            use_context_as_intermediate_queries = setting_dict['use_context_as_intermediate_queries']  # pylint: disable=line-too-long

            rotate_encodings = aug['rotate_encodings']
            randomly_swap_labels = aug['randomly_swap_labels']
            permute_input_dim = aug['permute_input_dim']

            args = [
                f'spurious_setting={spurious_setting}',
                f'sp_token_generation_mode={sp_token_generation_mode}',
                f'use_context_as_intermediate_queries={use_context_as_intermediate_queries}',  # pylint: disable=line-too-long
                f'datamodule={datamodule}',
                f'datamodule.reverse_task={reverse_task}',
                f'datamodule.modified={modified}',
                f'datamodule.context_group_proportions={context_group_proportions}',  # pylint: disable=line-too-long
                f'datamodule.query_group_proportions={query_group_proportions}',
                f'datamodule.context_class_size={context_class_size}',
                f'datamodule.train_len={train_len}',
                f'datamodule.batch_size={batch_size}',
                f'datamodule.rotate_encodings={rotate_encodings}',
                f'datamodule.n_rotation_matrices={n_rotation_matrices}',
                f'datamodule.randomly_swap_labels={randomly_swap_labels}',
                f'datamodule.permute_input_dim={permute_input_dim}',
                f'datamodule.ask_context_prob={ask_context_prob}',
                f'datamodule.val_sets={val_sets}',
                f'encoding_extractor={encoding_extractor}',
                f'model={model}',
                f'model.input_layer_norm={input_layer_norm}',
                f'optimizer={optimizer}',
                f'optimizer.lr={lr}',
                f'trainer.val_check_interval={val_check_interval}',
                f'trainer.precision={precision}',
                f'seed={seed}',
            ]

            args_per_job.append(args)
            print('HYDRA_FULL_ERROR=1 python run.py ', ' '.join(args))

  return args_per_job


def inaturalist(
    context_class_size: int,
    selected_settings: list[str],
    model: str = 'incontext_learner_gptj_80m',
    encoding_extractor: str = 'dinov2_vitb14',
    ):
  """Experiments for iNaturalist 2017."""

  assert context_class_size in [64, 200]
  n_repeats = 5

  settings = {
      'no_spurious': {
          'context_minority_group_proportion': 0.0,  # doesn't matter
          'query_minority_group_proportion': 0.0,  # doesn't matter
          'spurious_setting': 'inat_no_spurious',
          'use_context_as_intermediate_queries': True,
      },
      # group-agnostic: truly ERM-like
      'group-agnostic': {
          'context_minority_group_proportion': 0.1,
          'query_minority_group_proportion': 0.1,
          'spurious_setting': 'inat_sum_erm',
          'use_context_as_intermediate_queries': True,
      },
      # V2 behavior
      'v2_erm': {
          'context_minority_group_proportion': 0.1,
          'query_minority_group_proportion': 0.5,
          'spurious_setting': 'inat_sum_erm',
          'use_context_as_intermediate_queries': False,
      },
      'v2_dro': {
          'context_minority_group_proportion': 0.1,
          'query_minority_group_proportion': 0.5,
          'spurious_setting': 'inat_sum_dro',
          'use_context_as_intermediate_queries': False,
      },
      # Swap group-agnostic
      'swap-group-agnostic': {
          'context_minority_group_proportion': 0.0,
          'query_minority_group_proportion': 0.0,
          'spurious_setting': 'swap_erm',
          'use_context_as_intermediate_queries': True,
          'swapping_minority_proportion_context': 0.1,
          'swapping_minority_proportion_query': 0.1,
          'points_to_swap_range': '[0, 200]',
      },
      # Swap V2 behavior
      'swap_v2_erm': {
          'context_minority_group_proportion': 0.0,
          'query_minority_group_proportion': 0.0,
          'spurious_setting': 'swap_erm',
          'use_context_as_intermediate_queries': False,
          'swapping_minority_proportion_context': 0.1,
          'swapping_minority_proportion_query': 0.5,
          'points_to_swap_range': '[0, 200]',
      },
      'swap_v2_dro': {
          'context_minority_group_proportion': 0.0,
          'query_minority_group_proportion': 0.0,
          'spurious_setting': 'swap_dro',
          'use_context_as_intermediate_queries': False,
          'swapping_minority_proportion_context': 0.1,
          'swapping_minority_proportion_query': 0.5,
          'points_to_swap_range': '[0, 200]',
      },
  }

  settings = {k: v for k, v in settings.items() if k in selected_settings}

  sp_token_generation_mode = 'opposite'
  datamodule = 'inaturalist_emb_contexts'
  optimizer = 'adam'
  batch_size = 32

  augmentations = [
      {'rotate_encodings': False, 'permute_input_dim': False},
      {'rotate_encodings': False, 'permute_input_dim': True},
  ]

  n_rotation_matrices = 10000

  ask_context_probs = ['null', 0.25]

  lrs = [0.00003, 0.00006, 0.0001]

  val_sets = '[inner, inner_outer, outer]'

  precision = '16-mixed'

  input_layer_norm = True  # helpful for evaluating on Waterbirds

  args_per_job = []

  for setting_name, setting_dict in settings.items():
    for aug_idx, aug in enumerate(augmentations):
      for ask_context_prob in ask_context_probs:
        for lr in lrs:
          for seed in range(n_repeats):

            train_len = 4000000  # 2x compared to waterbirds
            val_check_interval = 1000

            context_minority_group_proportion = setting_dict['context_minority_group_proportion']  # pylint: disable=line-too-long
            query_minority_group_proportion = setting_dict['query_minority_group_proportion']  # pylint: disable=line-too-long
            spurious_setting = setting_dict['spurious_setting']
            use_context_as_intermediate_queries = setting_dict['use_context_as_intermediate_queries']  # pylint: disable=line-too-long

            rotate_encodings = aug['rotate_encodings']
            permute_input_dim = aug['permute_input_dim']

            args = [
                f'spurious_setting={spurious_setting}',
                f'sp_token_generation_mode={sp_token_generation_mode}',
                f'use_context_as_intermediate_queries={use_context_as_intermediate_queries}',  # pylint: disable=line-too-long
                f'datamodule={datamodule}',
                f'datamodule.context_minority_group_proportion={context_minority_group_proportion}',  # pylint: disable=line-too-long
                f'datamodule.query_minority_group_proportion={query_minority_group_proportion}',  # pylint: disable=line-too-long
                f'datamodule.context_class_size={context_class_size}',
                f'datamodule.train_len={train_len}',
                f'datamodule.batch_size={batch_size}',
                f'datamodule.rotate_encodings={rotate_encodings}',
                f'datamodule.n_rotation_matrices={n_rotation_matrices}',
                f'datamodule.permute_input_dim={permute_input_dim}',
                f'datamodule.ask_context_prob={ask_context_prob}',
                f'datamodule.val_sets={val_sets}',
                f'encoding_extractor={encoding_extractor}',
                f'model={model}',
                f'model.input_layer_norm={input_layer_norm}',
                f'optimizer={optimizer}',
                f'optimizer.lr={lr}',
                f'trainer.val_check_interval={val_check_interval}',
                f'trainer.precision={precision}',
                f'seed={seed}',
            ]

            if setting_name.find('swap') != -1:
              swapping_minority_proportion_context = setting_dict['swapping_minority_proportion_context']  # pylint: disable=line-too-long
              swapping_minority_proportion_query = setting_dict['swapping_minority_proportion_query']  # pylint: disable=line-too-long
              points_to_swap_range = setting_dict['points_to_swap_range']
              args.extend([
                  f'datamodule.swapping_minority_proportion_context={swapping_minority_proportion_context}',  # pylint: disable=line-too-long
                  f'datamodule.swapping_minority_proportion_query={swapping_minority_proportion_query}',  # pylint: disable=line-too-long
                  f'datamodule.points_to_swap_range={points_to_swap_range}',
              ])

            args_per_job.append(args)
            print('HYDRA_FULL_ERROR=1 python run.py ', ' '.join(args))

  return args_per_job


## Waterbirds experiments
experiment_registry['final_waterbirds_64'] = partial(
    waterbirds, context_class_size=64,
    selected_settings=['group-agnostic', 'v1', 'v2_erm', 'v2_dro'])

experiment_registry['final_waterbirds_256'] = partial(
    waterbirds, context_class_size=256,
    selected_settings=['group-agnostic', 'v2_erm', 'v2_dro'])

experiment_registry['final_mod_waterbirds_64'] = partial(
    waterbirds, context_class_size=64,
    selected_settings=['group-agnostic', 'v1', 'v2_erm', 'v2_dro'],
    modified=True)

experiment_registry['final_mod_waterbirds_256'] = partial(
    waterbirds, context_class_size=256,
    selected_settings=['group-agnostic', 'v2_erm', 'v2_dro'],
    modified=True)


## iNaturalist experiments
experiment_registry['final_inat_no_spurious_64'] = partial(
    inaturalist, context_class_size=64,
    selected_settings=['no_spurious'])

experiment_registry['final_inat_no_spurious_200'] = partial(
    inaturalist, context_class_size=200,
    selected_settings=['no_spurious'])

experiment_registry['final_inat_swap_64'] = partial(
    inaturalist, context_class_size=64,
    selected_settings=['swap-group-agnostic', 'swap_v2_erm', 'swap_v2_dro'])

experiment_registry['final_inat_swap_200'] = partial(
    inaturalist, context_class_size=200,
    selected_settings=['swap-group-agnostic', 'swap_v2_erm', 'swap_v2_dro'])
