import os
import time

import numpy as np
import sympy
import yaml


def set_path(experiment=0, id="test", t_id="test"):
    path = os.path.join("results", experiment, id, t_id)
    if os.path.exists(path):
        raise ValueError(
            f"There already exists a directory. Please remove '{path}' or change path.",
        )
    else:
        os.makedirs(path)
        print(f" * Set path {path}.")
    return path


def report(report, spent, name):
    spent = format(spent, ".10f")
    print(f" *** {name} time: {spent} seconds.")
    report[name] = float(spent)
    pass


def report_verificate(report, spent, name):
    report["verification"] += spent
    spent = format(spent, ".10f")
    print(f"   {name} time: {spent} seconds.")
    pass


def straggle(gap=1.0):
    t = time.time()
    while time.time() < t + gap:
        a = np.random.rand(200, 200)
        b = np.matmul(a, a)
    pass


def save_logs(vars, path):
    with open(os.path.join(path, "log.yaml"), "w") as f:
        yaml.dump(vars, f, indent=4)
    pass


def get_n_required(**kwargs):
    m, n = kwargs["m"], kwargs["n"]
    n_required = m * n

    return n_required


#
# codebook
#
def cheby_poly(input, order=2, **kwargs):  # orderth chebyshev
    if order == 1:
        return input
    elif order == 2:
        return 2 * pow(input, 2) - 1
    elif order == 3:
        return 4 * pow(input, 3) - 3 * input
    elif order == 4:
        return 8 * pow(input, 4) - 8 * pow(input, 2) + 1
    elif order == 5:
        return 16 * pow(input, 5) - 20 * pow(input, 3) + 5 * input
    elif order == 6:
        return 32 * pow(input, 6) - 48 * pow(input, 4) + 18 * pow(input, 2) - 1
    elif order == 7:
        return 64 * pow(input, 7) - 112 * pow(input, 5) + 56 * pow(input, 3) - 7 * input
    elif order == 8:
        return 128 * pow(input, 8) - 256 * pow(input, 6) + 160 * pow(input, 4) - 32 * pow(input, 2) + 1
    elif order == 9:
        return 256 * pow(input, 9) - 576 * pow(input, 7) + 432 * pow(input, 5) - 120 * pow(input, 3) + 9 * input
    elif order == 10:
        return 512 * pow(input, 10) - 1280 * pow(input, 8) + 1120 * pow(input, 6) - 400 * pow(input, 4) + 50 * pow(input, 2) - 1
    else:
        raise ValueError(f"not supported order {order} for chebyshev polynomial")


def codebook(N, m):
    n_consts = N // m

    consts = (np.arange(n_consts) + 1) / (n_consts + 5)
    codebook = np.zeros([n_consts, m])

    x = sympy.Symbol("x", real=True)

    for i, const in enumerate(consts):
        equation = cheby_poly(x, order=m) - const
        codebook[i] = list(sympy.solveset(equation, x, domain=sympy.S.Reals))

    return codebook


#
# matrix related
#
def get_shape(dim1, dim2, dim3, **kwargs):
    m, n = kwargs["m"], kwargs["n"]

    A_enc_shape = [dim1 // m, dim2]
    B_enc_shape = [dim2, dim3 // n]
    C_rec_shape = [dim1 // m, dim3 // n]

    return A_enc_shape, B_enc_shape, C_rec_shape


def make_inputs(dim1, dim2, dim3):
    A = np.random.rand(dim1, dim2)
    B = np.random.rand(dim2, dim3)

    print(f" * Matrices of shape {A.shape} and {B.shape} are created")

    return A, B


def split_matrix(A, B, **kwargs):
    m, n = kwargs["m"], kwargs["n"]
    # Ap: m length list of dim1_split * dim2
    # Bp: n length list of dim2 * dim3_split
    # dim1_split = int(dim1 / m), dim3_split = int(dim3 / n)
    Ap = np.split(A, m)
    Bp = np.hsplit(B, n)

    Ap = np.stack(Ap)
    Bp = np.stack(Bp)

    return Ap, Bp


def compute(a, b):
    assert a.shape[-1] == b.shape[0], f"Matrix shape unmatched between {a.shape[-1]} and {b.shape[0]}"
    c = np.matmul(a, b)
    return c
