import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from other import *


def build_equation(lib, coef, eq, round_eq):
    """Builds an equations.

    Builds an equation using the given coefficient and library terms and then
    appends the result to the given equation.

    Args:
        lib: A list of strings of each term in the SINDy library. Should 
            be what is returned from "equation_sindy_library" in
            model_utils.py.
        coef: The coefficients (numpy array of shape (library_dim,)) of
            each term in the library 
        eq: A string of the start of the equation to be created. For example,
            if eq = "dx = ", then appends the result to the right side of that
            string.
        round_eq: If True, rounds the coefficients to 2 significant figures.

    Returns:
        A string of the created equation.
    """
    for i in range(len(coef)):
        if coef[i] != 0:
            curr_coef = coef[i]
            if round_eq:
                rounded_coef = np.round(curr_coef, 2) 
                if rounded_coef == 0:
                    rounded_coef = np.format_float_scientific(curr_coef, 1)
                rounded_coef = str(rounded_coef)
            else:
                rounded_coef = str(curr_coef)
            if i == len(coef) - 1:
                eq += rounded_coef + lib[i]
            else:
                eq += rounded_coef + lib[i] + ' + '
    if eq[-2] == '+':
        eq = eq[:-3]
    return eq

def update_equation_list(equations, library, coefs, starts, x_dim, round_eq):
    """Updates the list of equations.

    Appends equations corresponding to the given library and coefficients to
    given the list of equations.

    Args:
        equations: The list of strings to append equations to.
        library: A list of strings of each term in the SINDy library. Should 
            be what is returned from "equation_sindy_library" in
            model_utils.py.
        coefs: The coefficients (numpy array of shape library_dim x x_dim) of
            each term in the library
        starts: A list (of length x_dim) of strings denoting the start of each
            each equation. The string at index i in starts should correspond to
            x_dim[i], where 0 <= i < x_dim.
        x_dim: The number (int) of spatial dimensions in the data.
        round_eq: If True, rounds the coefficients to 2 significant figures.

    Returns:
        None
    """
    for i in range(x_dim):
        equations.append(build_equation(library, coefs[:,i], starts[i], round_eq))

def get_equations(net, library, model_type, device, round_eq=True, seed=None):
    """Gets the equations learned by the network.

    Gets a list of the equations learned by the network. For HyperSINDy and
    ESINDy, gets both the mean and standard deviation over the coefficients.

    Args:
        net: The network (torch.nn.Module) to get the equations for.
        library: The SINDy library object (from src.utils.library_utils).
        model_type: The str name of the model ({"HyperSINDy", "ESINDy",
            "SINDy"}). Equivalent to the model_type arguments in parse_args
            from cmd_line.py
        device: The cpu or gpu device to get the equations with. To use cpu,
            device must be "cpu". To use, specify which gpu as an integer
            (i.e.: 0 or 1 or 2 or 3).
        round_eq: If True, rounds the coefficients to 2 significant figures.
            Default: True.
        seed: The seed to use for reproducible randomization through
            set_random_seed from other.py. The default is None.

    Returns:
        The equations as a list of strings. For HyperSINDy and ESINDy,
        returns a list in the format:
            ["MEAN",
                equation_1,
                equation_2,
                ...,
                equation_n,
                "STD",
                equation_1,
                equation_2,
                ...,
                equation_n]
        where n = x_dim. For SINDy, returns a list in the format:
            ["SINDy",
                equation_1,
                equation_2,
                ...,
                equation_n]
        where n = x_dim.
    """
    if seed is not None:
        set_random_seed(seed)
    starts = ["dx = ", "dy = ", "dz = "]
    if library.n > 3:
        starts = ['dx' + str(i + 1) + " = " for i in range(library.n)]
    equations = []
    if model_type == "HyperSINDy" or model_type == "ESINDy":
        mean_coeffs, std_coeffs = sindy_coeffs_stats(net.get_masked_coefficients(device=device))
        equations.append("MEAN")
        feature_names = library.get_feature_names()
        update_equation_list(equations, feature_names, mean_coeffs, starts, library.n, round_eq)
        equations.append("STD")
        feature_names = library.get_feature_names()
        update_equation_list(equations, feature_names, std_coeffs, starts, library.n, round_eq)
    return equations


def sindy_coeffs_stats(sindy_coeffs):
    """Calculates the coefficient statistics.

    Calculates the mean and standard deviation of the given sindy coefficients
    along the batch dimension.

    Args:
        sindy_coeffs: The sindy coefficients as a torch.Tensor of shape
            (batch_size x library_dim x x_dim).

    Returns:
        A tuple of (array_a, array_b) where array_a is a numpy array of
        the mean of the coefficients and tensor_b is a numpy array of the
        standard deviation.
    """
    coefs = sindy_coeffs.detach().cpu().numpy()
    return np.mean(coefs, axis=0), np.std(coefs, axis=0)

def log_equations(equations, board, model, epoch, x_dim):
    """Logs the equations into Tensorboard.

    Logs the given equations to the given Tensorboard.

    Args:
        equations: The equations as a list of strings to log (result of
            get_equations).
        board: The tensorboard to use.
        model: A string of the name of the model.
        epoch: An int for the current epoch in training.
        x_dim: An int for the spatial dimension.

    Returns:
        None.
    """
    if model == "HyperSINDy":
        if x_dim == 2:
            eq_mean = str(equations[1]) + "  \n" + str(equations[2])
            eq_std = str(equations[4]) + "  \n" + str(equations[5])
        elif x_dim == 3:
            eq_mean = str(equations[1]) + "  \n" + str(equations[2]) + "  \n" + str(equations[3])
            eq_std = str(equations[5]) + "  \n" + str(equations[6]) + "  \n" + str(equations[7])
        elif x_dim > 3:
            eq_mean = str(equations[1])
            ct = 2
            for i in range(1, x_dim):
                eq_mean += "  \n" + str(equations[ct])
                ct += 1
            eq_std = str(equations[ct + 1])
            ct += 2
            for i in range(1, x_dim):
                eq_std += "  \n" + str(equations[ct])
                ct += 1
        board.add_text(tag="Equations/mean", text_string=eq_mean, global_step=epoch, walltime=None)
        board.add_text(tag="Equations/std", text_string=eq_std, global_step=epoch, walltime=None)