from clu import metrics
import jax
import flax
from jax._src.basearray import Array as ndarray
import numpy as np
from scipy.special import softmax
import jax.numpy as jnp
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax                           # Common loss functions and optimizers

@struct.dataclass
class Metrics(metrics.Collection):
    # accuracy: metrics.Accuracy
    loss: metrics.Average.from_output('loss')
    TPR_top_level: metrics.Average.from_output('tpr_top_level')
    TPR_mid_level: metrics.Average.from_output('tpr_mid_level')
    TPR_bottom_level: metrics.Average.from_output('tpr_bottom_level')
    TNR_top_level: metrics.Average.from_output('tnr_top_level')
    TNR_mid_level: metrics.Average.from_output('tnr_mid_level')
    TNR_bottom_level: metrics.Average.from_output('tnr_bottom_level')
    TPR_top_level_cont: metrics.Average.from_output('tpr_top_level_cont')
    TPR_mid_level_cont: metrics.Average.from_output('tpr_mid_level_cont')
    TPR_bottom_level_cont: metrics.Average.from_output('tpr_bottom_level_cont')
    TNR_top_level_cont: metrics.Average.from_output('tnr_top_level_cont')
    TNR_mid_level_cont: metrics.Average.from_output('tnr_mid_level_cont')
    TNR_bottom_level_cont: metrics.Average.from_output('tnr_bottom_level_cont')
    mean_distance_OCS: metrics.Average.from_output('mean_distance_OCS')

      
class TrainState(train_state.TrainState):
    metrics: Metrics

def create_train_state(module, rng, learning_rate, momentum=None, input_size=(1,8)):
    """Creates an initial `TrainState`."""
    params = module.init(rng, jnp.ones(input_size))['params'] # initialize parameters by passing a template image
    tx = optax.sgd(learning_rate, momentum=momentum)
    # tx = optax.adam(learning_rate=learning_rate)
    return TrainState.create(
        apply_fn=module.apply, params=params, tx=tx,
        metrics=Metrics.empty())


# stuff for celeba 
@struct.dataclass
class Metrics_celeba(metrics.Collection):
    # accuracy: metrics.Accuracy
    loss: metrics.Average.from_output('loss')
    mean_distance_OCS: metrics.Average.from_output('mean_distance_OCS')

class TrainStateCeleba(train_state.TrainState):
    metrics: Metrics_celeba

def create_train_state_celeba(module, rng, learning_rate, momentum=None, input_size=(1,8), optimizer="sgd"):
    """Creates an initial `TrainState`."""
    params = module.init(rng, jnp.ones(input_size))['params'] # initialize parameters by passing a template image
    if optimizer == "sgd":
        tx = optax.sgd(learning_rate, momentum=momentum)
    elif optimizer == "adam":
        tx = optax.adam(learning_rate=learning_rate)
    # tx = optax.adam(learning_rate=learning_rate)
    return TrainStateCeleba.create(
        apply_fn=module.apply, params=params, tx=tx,
        metrics=Metrics_celeba.empty())