import os
from functools import partial

import numpy as np
import optax
from flax import nnx
from jax import numpy as jnp, random

from .quadratures import *


def get_data(filepath):
    data = jnp.load(filepath)
    return data["x"], data["x_grid"], data["y"], data["y_grid"]


####################################################################################################
### utils for jnp functions / nn modules ###########################################################
####################################################################################################


def create_lifted_module(
    module: nnx.Module, lift_size: int, *, rngs: nnx.Rngs
) -> nnx.Module:
    backups = nnx.split_rngs(
        rngs,
        splits=lift_size,
    )
    lifted_module = nnx.vmap(lambda x: module(rngs=x))(rngs)
    nnx.restore_rngs(backups)
    return lifted_module


### init function for flax trainable params
def perturbed_ones(key, shape, dtype=jnp.float32):
    key, _ = random.split(key)
    samples = 0.01 * random.normal(key, shape=shape) + 1
    return samples.astype(dtype)


### adds an operation to a function, use as add_op(fn, op)
def add_op(fn, op):
    def new_fn(*args, **kwargs):
        return op(fn(*args, **kwargs))

    return new_fn


### not too worried about perfection for inv_softplus, as we're just using it to help initialize kernel params
inverse_softplus_perturbed_ones = add_op(
    perturbed_ones, op=lambda x: jnp.log(jnp.expm1(x))
)


######################################################################################################
### optimizer stuff ##################################################################################
######################################################################################################


### this is a bit manual, but alas
def get_kernel_layer_paths(config):
    kernel_layer_paths = []
    kernel_layer_paths.append((jax.tree_util.DictKey(key="interpolation_module"),))
    for i in range(config.model.depth + 1):
        kernel_layer_paths.append(
            (
                jax.tree_util.DictKey(key="integration_modules"),
                jax.tree_util.DictKey(key=i),
            )
        )
    return kernel_layer_paths


def exponential_decay(num_steps, init_value, decay_rate=0.01):
    return optax.exponential_decay(
        init_value,
        transition_steps=num_steps,
        decay_rate=decay_rate,
        transition_begin=0,
        staircase=False,
        end_value=None,
    )


### just add a 'cycle' with a constant learning rate to be end_value
def cosine_annealing_plus_constant(
    total_steps,
    init_value=3e-5,
    warmup_frac=0.3,
    peak_value=5e-4,
    end_value=1e-5,
    num_cosine_cycles=5,
    gamma=0.7,
    num_constant_cycles=1,
):
    decay_steps = total_steps / (num_cosine_cycles + num_constant_cycles)
    schedules = []
    boundaries = []
    boundary = 0
    for cycle in range(num_cosine_cycles):
        schedule = optax.warmup_cosine_decay_schedule(
            init_value=init_value,
            warmup_steps=decay_steps * warmup_frac,
            peak_value=peak_value,
            decay_steps=decay_steps,
            end_value=end_value,
            exponent=2,
        )
        boundary = decay_steps + boundary
        boundaries.append(boundary)
        init_value = end_value
        peak_value = peak_value * gamma
        schedules.append(schedule)

    schedule = optax.constant_schedule(init_value)
    schedules.append(schedule)
    return optax.join_schedules(schedules=schedules, boundaries=boundaries)


def cosine_annealing(
    total_steps,
    init_value=1e-4,
    warmup_frac=0.3,
    peak_value=3e-4,
    end_value=1e-4,
    num_cycles=6,
    gamma=0.9,
):
    decay_steps = total_steps / num_cycles
    schedules = []
    boundaries = []
    boundary = 0
    for cycle in range(num_cycles):
        schedule = optax.warmup_cosine_decay_schedule(
            init_value=init_value,
            warmup_steps=decay_steps * warmup_frac,
            peak_value=peak_value,
            decay_steps=decay_steps,
            end_value=end_value,
            exponent=2,
        )
        boundary = decay_steps + boundary
        boundaries.append(boundary)
        init_value = end_value
        peak_value = peak_value * gamma
        schedules.append(schedule)
    return optax.join_schedules(schedules=schedules, boundaries=boundaries)


######################################################################################################
###### normalization, pointwise gaussian #############################################################
######################################################################################################


class UnitGaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = jnp.mean(x, axis=0)
        self.std = jnp.std(x, axis=0)
        self.eps = eps

    @partial(jax.jit, static_argnums=(0,))
    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    @partial(jax.jit, static_argnums=(0,))
    def decode(self, x):
        std = self.std + self.eps  # n
        mean = self.mean
        # x is in shape of batch*n or T*batch*n
        x = (x * std) + mean
        return x


######################################################################################################
###### domain helpers ################################################################################
######################################################################################################


### returns triangle mesh (as vertex coordinates) according to first breaking up a unit_square with a given
### number of 1d subdivisions into smaller squares, (i.e. 1 subdivision, 1 square; 2 subdivisions, 4 squares)
### and then making 2 triangles from each of those smaller squares
def unit_square_triangular_mesh(subdivisions):
    size = 1 / subdivisions
    triangles = []
    for i in range(subdivisions):
        for j in range(subdivisions):
            # Calculate the vertices of the smaller square
            bottom_left = [i * size, j * size]
            bottom_right = [(i + 1) * size, j * size]
            top_right = [(i + 1) * size, (j + 1) * size]
            top_left = [i * size, (j + 1) * size]
            triangle1 = [bottom_left, bottom_right, top_right]
            triangle2 = [bottom_left, top_right, top_left]
            triangles.append(triangle1)
            triangles.append(triangle2)
    triangles = np.array(triangles)
    return triangles


def get_triangular_notch_quadrature_rule(
    n,
    quadrature_fn,
):
    from scipy.spatial import Delaunay

    tri_notch_vertices = [
        [0, 0],
        [1, 0],
        [0.49, 0],
        [0.49, 0.4],
        [0.51, 0],
        [0.51, 0.4],
        [0.5, np.sqrt(3) / 2],
    ]
    tri_notch_vertices = np.array(tri_notch_vertices)
    mesh = tri_notch_vertices[Delaunay(tri_notch_vertices).simplices[:5]]
    t, w = triangle_mesh_quad_rule(n, quadrature_fn, mesh)
    return t, w


####################################################################################################
### filter out stuff we don't need to log to w and b ################################################
####################################################################################################


def get_tracking_hypers(config):
    hypers = {}
    hypers.update(config.model)
    hypers.update(config.optim)
    hypers.update(config.kernels)

    hypers["alpha"] = config.training.alpha
    hypers["dataset"] = config.dataset

    hypers["batch_size"] = config.training.batch_size
    hypers["freeze_epochs_per_layer"] = config.training.freeze_epochs_per_layer
    hypers["freeze_train"] = config.training.freeze_train
    hypers["normalize"] = config.data.normalize

    return hypers


####################################################################################################
### general utilities ##############################################################################
####################################################################################################


def calculate_model_size(params):
    # Flatten the parameter tree to get all leaf nodes (parameter arrays)
    param_leaves = jax.tree_util.tree_leaves(params)
    total_size = 0
    for param in param_leaves:
        size = np.prod(param.shape) * param.dtype.itemsize
        total_size += size
    return total_size


def create_path(path, verbose=True):
    try:
        if not os.path.exists(path):
            os.makedirs(path, exist_ok=True)
            if verbose:
                print("Directory '%s' created successfully" % (path))
    except OSError:
        print("Directory '%s' can not be created" % (path))


import logging


def get_logger(logpath, displaying=True, saving=True, debug=False, append=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        if append:
            info_file_handler = logging.FileHandler(logpath, mode="a")
        else:
            info_file_handler = logging.FileHandler(logpath, mode="w+")
        #
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)

    return logger
