from jax.tree_util import tree_flatten
from flax.training.train_state import TrainState
import jax.numpy as jnp
import os
import jax

def schedule_to_lr(iteration, epoch_schedule, total_iters, log_param=False):
    if log_param:
        epoch_schedule = jnp.exp(epoch_schedule)
    
    # Find the index of the current epoch
    iters_per_epoch = total_iters // (len(epoch_schedule) - 1)
    current_epoch = iteration // iters_per_epoch
    epoch_progress = (iteration - current_epoch * iters_per_epoch) / iters_per_epoch
    return epoch_schedule[current_epoch] * epoch_progress + epoch_schedule[current_epoch + 1] * (1 - epoch_progress)

# def set_seed(use_tf32):
#     # if not use_tf32:
#     # jax.config.update('jax_default_matmul_precision', 'float32')

#     jax.config.update('jax_enable_x64', True)

#     os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
#     os.environ['TF_DETERMINISTIC_OPS'] = '1'
#     os.environ['XLA_FLAGS'] = '--xla_gpu_deterministic_ops=true'
#     pass

def set_dtype(dty, determinism):
    if determinism:
        os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        os.environ['XLA_FLAGS'] = '--xla_gpu_deterministic_ops=true'

    print(f'>> Setting dtype={dty}, determinism={determinism}')

    if dty == 'float64':
        jax.config.update('jax_enable_x64', True)
        raise ValueError('float64 not supported')
    elif dty == 'float32':
        jax.config.update('jax_default_matmul_precision', 'float32')
    elif dty == 'tf32':
        jax.config.update('jax_default_matmul_precision', 'tensorfloat32')
        return
    else:
        raise ValueError(f'Unknown dtype {dty}')

def vec_params(state: TrainState):
    # tree_flatten returns values, structure
    flat_params = tree_flatten(state.params)[0]
    return jnp.concatenate([jnp.ravel(param) for param in flat_params])

# def cos_distance(vec1, vec2):
#     assert vec1.dtype == vec2.dtype
#     assert vec1.dtype == jnp.float64
#     # Compute the cosine similarity
#     dot_product = jnp.dot(vec1, vec2)
#     norm1 = jnp.linalg.norm(vec1)
#     norm2 = jnp.linalg.norm(vec2)
#     cosine_sim = dot_product / (norm1 * norm2)
#     return 1 - cosine_sim

# def param_diff(state1, state2):
#     params1 = vec_params(state1)
#     params2 = vec_params(state2)
#     return params1 - params2