import copy
import functools
import logging
from typing import Tuple, Type

from jaxpruner import algorithms
from jaxpruner import base_updater
from jaxpruner import sparsity_distributions
from jaxpruner import sparsity_schedules
from jaxpruner import sparsity_types
import ml_collections
import optax


ALGORITHM_REGISTRY = {
    'no_prune': base_updater.NoPruning,
    'magnitude': algorithms.MagnitudePruning,
    'random': algorithms.RandomPruning,
    'saliency': algorithms.SaliencyPruning,
    'magnitude_ste': algorithms.SteMagnitudePruning,
    'random_ste': algorithms.SteRandomPruning,
    'global_magnitude': algorithms.GlobalMagnitudePruning,
    'global_saliency': algorithms.GlobalSaliencyPruning,
    'static_sparse': algorithms.StaticRandomSparse,
    'rigl': algorithms.RigL,
    'set': algorithms.SET,
}
ALGORITHMS = tuple(ALGORITHM_REGISTRY.keys())


def all_algorithm_names():
  """Returns all algorithm names registered."""
  return tuple(ALGORITHM_REGISTRY.keys())


def register_algorithm(algorithm_name, algorithm):
  """Registers a new algorithm."""
  ALGORITHM_REGISTRY[algorithm_name] = algorithm


def create_updater_from_config(sparsity_config):
  """Gets a sparsity updater based on the given sparsity config.

  Example:
    sparsity_config = ml_collections.ConfigDict()
    # Required
    sparsity_config.algorithm = 'magnitude'
    sparsity_config.dist_type = 'erk'
    sparsity_config.sparsity = 0.8

    # Optional
    sparsity_config.update_freq = 10
    sparsity_config.update_end_step = 1000
    sparsity_config.update_start_step = 200
    sparsity_config.schedule_power = 1.0  # polynomial schedule power
    sparsity_config.sparsity_type = 'nm_2,4'

    updater = create_updater_from_config(sparsity_config)

  Fields:
    - algorithm: str, one of all_algorithm_names().
    - dist_type: 'erk' or 'uniform'.
    - update_freq: int, passed to Periodic/Polynomial schedule.
    - update_end_step: int.
    - update_start_step: if None → NoUpdate; if equal to end → OneShot;
      else → Polynomial (uses schedule_power if provided).
    - schedule_power: float (optional). Power/exponent for PolynomialSchedule.
    - sparsity: float or jaxpruner.SparsityType.
    - sparsity_type: 'nm_2,4', 'block_4,4', 'channel_<axis>', etc.
    - Any extra fields are passed to the algorithm ctor.
  """
  logging.info('Creating  updater for %s', sparsity_config.algorithm)
  if sparsity_config.algorithm == 'no_prune':
    return base_updater.NoPruning()

  config = copy.deepcopy(sparsity_config).unlock()

  # Distribution fn
  if config.dist_type == 'uniform':
    config.sparsity_distribution_fn = sparsity_distributions.uniform
  elif config.dist_type == 'erk':
    config.sparsity_distribution_fn = sparsity_distributions.erk
  else:
    raise ValueError(
        f'dist_type: {config.dist_type} is not supported. Use `erk` or `uniform`'
    )
  del config.dist_type

  if config.get('filter_fn', None):
    if not config.algorithm.startswith('global_'):
      new_fn = functools.partial(
          config.sparsity_distribution_fn, filter_fn=config.filter_fn
      )
      config.sparsity_distribution_fn = new_fn
      del config.filter_fn

  if config.get('custom_sparsity_map', None):
    if not config.algorithm.startswith('global_'):
      new_fn = functools.partial(
          config.sparsity_distribution_fn,
          custom_sparsity_map=config.custom_sparsity_map,
      )
      config.sparsity_distribution_fn = new_fn
      del config.custom_sparsity_map

  if config.algorithm.startswith('global_'):
    # Distribution function is not used.
    del config.sparsity_distribution_fn
  else:
    kwargs = {'sparsity': config.sparsity}
    del config.sparsity
    if config.get('filter_fn', None):
      kwargs['filter_fn'] = config.filter_fn
      del config.filter_fn
    config.sparsity_distribution_fn = functools.partial(
        config.sparsity_distribution_fn, **kwargs
    )

  # Sparsity type parsing
  if config.get('sparsity_type', None):
    s_type = config.sparsity_type
    if isinstance(s_type, str) and s_type.startswith('nm'):
      # example: nm_2,4
      n, m = s_type.split('_')[1].strip().split(',')
      del config.sparsity_type
      config.sparsity_type = sparsity_types.NByM(int(n), int(m))
    elif isinstance(s_type, str) and (s_type.startswith('block')):
      # example: block_4,4
      n, m = s_type.split('_')[1].strip().split(',')
      del config.sparsity_type
      config.sparsity_type = sparsity_types.Block(block_shape=(int(n), int(m)))
    elif isinstance(s_type, str) and (s_type.startswith('channel')):
      axis = int(s_type.split('_')[1])
      del config.sparsity_type
      config.sparsity_type = sparsity_types.Channel(axis=axis)
    else:
      raise ValueError(f'Sparsity type {s_type} is not supported.')

  # Algorithm
  if config.algorithm in ALGORITHM_REGISTRY:
    updater_type = ALGORITHM_REGISTRY[config.algorithm]
    if config.algorithm in ('rigl', 'set'):
      config.drop_fraction_fn = optax.cosine_decay_schedule(
          config.get('drop_fraction', 0.1), config.update_end_step
      )
    del config.algorithm
  else:
    raise ValueError(
        f"Sparsity algorithm not supported. Choose from {all_algorithm_names()}"
    )

  # Scheduler (with optional polynomial power)
  schedule_power = config.get('schedule_power', None)
  if config.get('update_start_step', None) is None:
    config.scheduler = sparsity_schedules.NoUpdateSchedule()
  elif config.update_end_step == config.update_start_step:
    config.scheduler = sparsity_schedules.OneShotSchedule(
        target_step=config.update_end_step
    )
  else:
    # Try to pass power if supported by the installed PolynomialSchedule.
    try:
      config.scheduler = sparsity_schedules.PolynomialSchedule(
          update_freq=config.update_freq,
          update_start_step=config.update_start_step,
          update_end_step=config.update_end_step,
          power=schedule_power if schedule_power is not None else 1,
      )
    except TypeError:
      # Fallback for older versions without `power` kwarg.
      config.scheduler = sparsity_schedules.PolynomialSchedule(
          update_freq=config.update_freq,
          update_start_step=config.update_start_step,
          update_end_step=config.update_end_step,
          power=schedule_power if schedule_power is not None else 1
      )

  # Clean up consumed fields
  for field_name in (
      'update_freq',
      'update_start_step',
      'update_end_step',
      'drop_fraction',
      'schedule_power',
  ):
    if hasattr(config, field_name):
      delattr(config, field_name)

  updater = updater_type(**config)
  return updater
