import torch
import random
import time
import json
import numpy as np

from functools import wraps
from inspect import signature


def export(dic, confs=None):
    """Export the function by adding an entry into 'dic'.

    The key is the function name.

    confs: dict
        Optional dictionary into which all keyword only arguments are
        put into as a config dict. Useful for sacred configs.
        Adds a 'type' config entry as well.
    """

    def export_decorator(f):
        """Add an entry to 'dic' which points to f."""
        if confs is not None:
            if "type" not in confs:
                confs["type"] = f.__name__
            confs[f.__name__] = {}
            sig = signature(f)

            for p in sig.parameters.values():
                if p.default != p.empty:
                    confs[f.__name__][p.name] = p.default

        return wraps(f)(dic.setdefault(f.__name__, f))

    return export_decorator


def get_loss(X, Z, type, addit_args=None):

    if type == "l2":
        return torch.linalg.norm(X - Z)
    elif type == "linf":
        return torch.linalg.norm(X - Z, 2)
    elif type == "l1":
        return torch.linalg.norm(X - Z, 1)
    elif type == "cut_thres":
        S = torch.linalg.eig(X - Z)[0].cpu().detach().numpy()
        return np.sum(np.maximum(np.abs(S) - addit_args[0], 0))
    elif type == "sq_one_l2":
        I = torch.eye(X.size()[0])
        return torch.linalg.norm(X @ X - I)
    else:
        raise ValueError("Not implemented!")


def get_loss_vec(x, z, type, addit_args=None):
    if type == "l2":
        return torch.linalg.norm(x - z) / torch.linalg.norm(z)
    elif type == "linf":
        return torch.max(np.abs(x - z))
    elif type == "l1":
        return torch.linalg.norm(x - z, 1)
    elif type == "cut_thres":
        return torch.sum(torch.maximum(torch.abs(x - z) - addit_args[0], 0))
    elif type == "sq_one_l2":
        return torch.linalg.norm(torch.multiply(x, x) - torch.ones(np.shape(x)))
    else:
        raise ValueError("Not implemented!")


def SVD_zeropower(A):

    print("\n====== Warming up GPU ======\n")
    for _ in range(5):
        _ = torch.ones((1000, 1000))

    torch.cuda.synchronize()
    start_time = time.time()

    U, S, Vt = torch.linalg.svd(A)
    Ans = U @ Vt

    torch.cuda.synchronize()
    end_time = time.time()

    return Ans, end_time - start_time


def PD_zeropower_chol(A):

    torch.cuda.synchronize()
    start_time = time.time()

    AtA = torch.matmul(A.T, A)

    # Step 2: Take the square root of A^T * A to get P
    P = torch.linalg.cholesky(
        AtA
    )  # This gives P such that P^2 = AtA (positive semi-definite)

    # Step 3: Compute the unitary matrix U
    U = torch.matmul(A, torch.linalg.inv(P))

    torch.cuda.synchronize()
    end_time = time.time()

    return U, end_time - start_time


def get_label(jsonfile):

    if "Label" in jsonfile.keys():
        return jsonfile["Label"]
    else:
        return jsonfile["Name"]


def load_alg(label, name, learnable, params, param_needed, algs_dir):

    try:
        with open(algs_dir, "r") as file:
            data = json.load(file)  # Load JSON as list
    except (FileNotFoundError, json.JSONDecodeError):
        data = []

    if not any(d.get("Label") == label for d in data):
        data.append(
            {
                "Label": label,
                "Name": name,
                "Learnable": learnable,
                "Params": params,
                "Param_needed": param_needed,
            }
        )
    else:
        print("Use different label!!!")

    with open(algs_dir, "w") as f:
        json.dump(data, f, indent=4)


def read_params(label, category, algs_dir):

    with open(algs_dir, "r") as file:
        data = json.load(file)

    for d in data:
        if d["Label"] == label:
            return d[category]

    print("Label doesn't exist!!!")
    return None


def modify_params(label, category, category_val_after, algs_dir):

    with open(algs_dir, "r") as file:
        data = json.load(file)

    for d in data:
        if d["Label"] == label:
            d[category] = category_val_after

    with open(algs_dir, "w") as f:
        json.dump(data, f, indent=4)


def modify_label(label, label_after, algs_dir):

    with open(algs_dir, "r") as file:
        data = json.load(file)

    for d in data:
        if d["Label"] == label:
            d["Label"] = label_after

    with open(algs_dir, "w") as f:
        json.dump(data, f, indent=4)


def get_labels_list(algs_dir):

    with open(algs_dir, "r") as file:
        data = json.load(file)

    labels = []
    for d in data:
        labels.append(d["Label"])

    print(labels)


def cubevec(a):
    return a * a * a


import networkx as nx
import matplotlib.pyplot as plt


def visualize_tree_with_labels(tree, node_labels, edge_labels, t):
    """Visualizes a tree from a BFS traversal and writes custom labels"""

    # Create a custom layout for a pyramid shape
    levels = {}
    queue = [(0, 0)]  # Start with the root at level 0

    # Assign levels to nodes
    while queue:
        node, level = queue.pop(0)
        if level not in levels:
            levels[level] = []
        if node not in levels[level]:  # Ensure no duplicates
            levels[level].append(node)

        # Add children to the next level
        neighbors = [
            n
            for n in tree.neighbors(node)
            if n not in [n for lvl in levels.values() for n in lvl]
        ]
        for child in neighbors:
            queue.append((child, level + 1))

    # Create positions for pyramid-like visualization
    pos = {}
    for level, nodes in levels.items():
        y = -level  # Each level goes one step down
        x_start = -len(nodes) / 2
        for i, node in enumerate(nodes):
            pos[node] = (x_start + i, y)

    # Draw the graph
    nx.draw(
        tree,
        pos,
        with_labels=False,
        node_size=2000,
        node_color="lightblue",
        font_size=5,
    )

    # Draw the custom node labels
    nx.draw_networkx_labels(
        tree, pos, labels=node_labels, font_size=5, font_color="black"
    )

    # Draw the custom edge labels
    nx.draw_networkx_edge_labels(
        tree, pos, edge_labels=edge_labels, font_size=5, font_color="red"
    )
    plt.savefig(str(t) + ".png")


def power_iteration(A, num_iters=10):
    """
    Returns the approximate largest eigenvalue (operator norm)
    """
    b = torch.randn(A.shape[1], 1)
    b /= torch.linalg.norm(b)  # Normalize initial vector

    for _ in range(num_iters):
        b = A @ b
        b /= torch.linalg.norm(b)  # Normalize in every iteration

    # Compute dominant eigenvalue using Rayleigh quotient
    eigenvalue = (b.T @ A @ b) / (b.T @ b)
    return eigenvalue  # Return scalar value


def rand_unif(d, eps):

    diag = torch.diag(torch.rand(d) * 2 - 1)
    diag = torch.where(torch.abs(diag) < eps, torch.sign(diag) * eps, diag)
    RM = torch.randn(d, d)
    # Perform QR decomposition
    Q, R = torch.linalg.qr(RM)
    # If det(Q) < 0, flip the sign of one column
    if torch.det(Q) < 0:
        Q[:, 0] = -Q[:, 0]

    A = Q @ diag @ Q.T
    return A


def rand_Wishart(c, d, eps):

    X = torch.randn(int(c * d), d)
    A = X.T @ X
    A = A / 3 / d + eps * torch.eye(d)

    return A


def rand_Wishart_unif(c, d, eps):

    X = torch.rand(int(c * d), d) * 2 - 1
    A = X.T @ X
    A = A / d + eps * torch.eye(d)

    return A


def rand_rectangular(c, d, eps):

    A = torch.randn(int(c * d), d)
    A = A / np.sqrt(4 * d)

    return A


# %%
def set_all_seeds(seed):
    """Set all seeds for reproducibility."""
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


# %%
def print_rng_states():
    print("Python random state", random.getstate())
    print("NumPy state:", np.random.get_state()[1])
    print("Torch state:", torch.get_rng_state())
    print("Torch CUDA state:", torch.cuda.get_rng_state_all())
