from typing import Any
from collections import OrderedDict
import random
import numpy as np
import jax
import jax.numpy as jnp
from flax.training import train_state
from flax.core.frozen_dict import freeze
import optax
from t5x import optimizers


def update_flax_module():
    import flax.linen as nn
    from flax.linen.partitioning import AxisMetadata, _param_with_axes_sow_reduce_fn
    from jax.experimental import pjit
    import t5x.examples.t5.layers as layers

    old_param = nn.Module.param
    def param(self, name, init_fn, *init_args, module=None):
        if module is None:
            module = nn.module._context.module_stack[-1]  # pylint: disable=protected-access
            assert module is not None

        module_param = old_param(module, name, init_fn, *init_args)
        axes = tuple([f'r{i}' for i in range(len(module_param.shape))])
        # apply logical axis constraint immediately
        module_param = nn.partitioning.with_sharding_constraint(module_param,
                                                 pjit.PartitionSpec(*axes))
        # record logical axis constraint for global axis metadata
        module.sow(
            'params_axes', f'{name}_axes', AxisMetadata(axes),
            reduce_fn=_param_with_axes_sow_reduce_fn)
        return module_param

    def param_with_axes(name, init_fn, *init_args, axes=None, module=None):
        if module is None:
            module = nn.module._context.module_stack[-1]  # pylint: disable=protected-access
            assert module is not None

        module_param = old_param(module, name, init_fn, *init_args)

        if axes is not None:
            # apply logical axis constraint immediately
            module_param = nn.partitioning.with_sharding_constraint(module_param,
                                                     pjit.PartitionSpec(*axes))
            # record logical axis constraint for global axis metadata
            module.sow(
                'params_axes', f'{name}_axes', AxisMetadata(axes),
                reduce_fn=_param_with_axes_sow_reduce_fn)
        return module_param

    nn.Module.param = param
    layers.param_with_axes = param_with_axes
    nn.partitioning.param_with_axes = param_with_axes


class TrainState(train_state.TrainState):
    model_state: Any

def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)


def get_first_device(x):
    x = jax.tree_map(lambda a: a[0], x)
    return jax.device_get(x)


def get_all_devices(x):
    x = jax.tree_map(lambda a: jnp.reshape(a, (-1,) + a.shape[2:]), x)
    return jax.device_get(x)
 

def print_model_size(params, name=''):
    model_params_size = jax.tree_map(lambda x: x.size, params)
    total_params_size = sum(jax.tree_flatten(model_params_size)[0])
    print('model parameter count:', total_params_size)


def generate_weight_decay_mask(params, parent=None):
    if not isinstance(params, dict):
        return parent not in ['bias', 'embedding', 'abs_embedding', 'd_0', 'd_1', 'd_2']
    
    return {k: generate_weight_decay_mask(v, k) for k, v in params.items()} 


def get_learning_rate_fn(config):
    if config.lr_schedule == 'cosine':
        learning_rate_fn = optax.warmup_cosine_decay_schedule(
            init_value=0.,
            peak_value=config.lr,
            warmup_steps=config.warmup_steps,
            decay_steps=config.total_steps - config.warmup_steps
        )
    elif config.lr_schedule == 'constant':
        learning_rate_fn = optax.join_schedules([
            optax.linear_schedule(
                init_value=0.,
                end_value=config.lr,
                transition_steps=config.warmup_steps
            ),
            optax.constant_schedule(config.lr)
        ], [config.warmup_steps])
    else:
        raise ValueError(f'Unknown schedule: {config.lr_schedule}')
    
    return learning_rate_fn

    
def get_optimizer(config, param_shapes, use_optax=False):
    learning_rate_fn = get_learning_rate_fn(config)
    if config.clip_grad_norm:
        if use_optax:
            tx = optax.chain(
                optax.clip_by_global_norm(config.clip_grad_norm),
                optax.adamw(learning_rate=learning_rate_fn, eps=1e-4, weight_decay=config.weight_decay),
            )
        else:
            tx = optimizers.chain([
                optax.clip_by_global_norm(config.clip_grad_norm),
                optax.adamw(learning_rate=learning_rate_fn, eps=1e-4, weight_decay=config.weight_decay),
            ])
    else:
        adamw = optax.adamw if use_optax else optimizers.adamw
        tx = adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.95, 
                   weight_decay=config.weight_decay)
    return tx, learning_rate_fn


def init_model_state(rng_key, model, sample, config):
    variables = model.init(
        rngs={k: rng_key for k in ['params', *config.rng_keys]},
        **{k: sample[k] for k in config.batch_keys}
    ).unfreeze()
    params = freeze(variables.pop('params'))
    model_state = variables
    print_model_size(params)

    tx, learning_rate_fn = get_optimizer(config, params, use_optax=True)

    return TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
        model_state=model_state
    ), learning_rate_fn

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, total_iters, meter_names, prefix=""):
        self.iter_fmtstr = self._get_iter_fmtstr(total_iters)
        self.meters = OrderedDict({mn: AverageMeter(mn, ':6.3f') 
                                   for mn in meter_names})
        self.prefix = prefix
    
    def update(self, n=1, **kwargs):
        for k, v in kwargs.items():
            self.meters[k].update(v, n=n)

    def display(self, iteration):
        entries = [self.prefix + self.iter_fmtstr.format(iteration)]
        entries += [str(meter) for meter in self.meters.values()]
        print('\t'.join(entries))

    def _get_iter_fmtstr(self, total_iters):
        num_digits = len(str(total_iters // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(total_iters) + ']' 
