import numpy as np
import time
from pgpr_type import *
from pgpr_chol import pgpr_chol  # Fine, but maybe clarify its type in docstrings


# Utility functions
def SQR(a):
    return a * a

def MAX(a, b):
    return max(a, b)

def MIN(a, b):
    return min(a, b)

def SIGN(a, b):
    return abs(a) if b >= 0 else -abs(a)

def SWAP(a, b):
    return b, a

# File I/O functions
def get_lines(file_path):
    """
    Count the number of lines in a file.
    """
    with open(file_path, 'r') as file:
        return sum(1 for _ in file)

def save_data(file_path, m_data):
    """
    Save a matrix to a file.
    """
    with open(file_path, 'w') as file:
        for i in range(m_data.nrows()):
            file.write(f"{m_data[i][-1]:.4f} " + " ".join(f"{j}:{m_data[i][j]:.6f}" for j in range(m_data.ncols() - 1)) + "\n")

def load_data(file_path, m_data):
    """
    Load data from a file into a matrix.
    Each line contains an output variable followed by indexed input variables.
    The input variables are stored in the matrix row, followed by the output variable.
    
    :param file_path: Path to the file.
    :param m_data: Matrix to store the data.
    """
    with open(file_path, 'r') as file:
        lines = file.readlines()

    num_rows = len(lines)
    num_cols = len(lines[0].split()) - 1  # Exclude the output variable
    m_data.resize(num_rows, num_cols + 1)  # +1 for the output variable

    for i, line in enumerate(lines):
        parts = line.split()
        output = float(parts[0])  # First value is the output variable
        inputs = [float(part.split(':')[1]) for part in parts[1:]]  # Extract input values
        m_data[i][:num_cols] = inputs  # Store input variables
        m_data[i][num_cols] = output  # Store output variable

def load_data_gs(file_path, m_data):
    """
    Load data from a file into a square matrix.
    Each line contains a single element of the matrix.
    
    :param file_path: Path to the file.
    :param m_data: Matrix to store the data.
    """
    with open(file_path, 'r') as file:
        lines = file.readlines()

    n = int(len(lines) ** 0.5)  # Determine the size of the square matrix
    if n * n != len(lines):
        raise ValueError("The number of lines in the file is not square.")

    m_data.resize(n, n)

    for i, line in enumerate(lines):
        value = float(line.strip())
        row, col = divmod(i, n)  # Determine the row and column indices
        m_data[row][col] = value

def load_data_gs_vector(file_path, v_data):
    """
    Load data from a file into a vector.
    Each line contains a single element of the vector.

    :param file_path: Path to the file.
    :param v_data: Vector to store the data.
    """
    with open(file_path, 'r') as file:
        lines = file.readlines()

    n = len(lines)  # Determine the size of the vector
    v_data.resize(n)

    for i, line in enumerate(lines):
        value = float(line.strip())
        v_data[i] = value

def load_vector(file_path, v_data):
    """
    Load data from a file into a vector.
    """
    with open(file_path, 'r') as file:
        lines = file.readlines()
    v_data.resize(len(lines))
    for i, line in enumerate(lines):
        v_data[i] = float(line.strip())

def save_hyper(file_path, h, d):
    """
    Save hyperparameters to a file.
    
    :param file_path: Path to the file.
    :param h: A vector containing hyperparameters.
    :param d: Dimension of the length-scale vector.
    """
    nos = h[0]  # Noise parameter
    lsc = h[1:d + 1]  # Length-scale vector
    sig = h[d + 1]  # Signal variance
    h_mu = h[h.size() - 1]  # Mean of the hyperparameters

    # Save hyperparameters to the file
    with open(file_path, 'w') as file:
        file.write(f"{sig:.4f} ")
        file.write(f"{nos:.4f} ")
        file.write(f"{h_mu:.4f} ")
        file.write(f"{len(lsc)} ")  # Number of length-scale parameters
        file.write(" ".join(f"{val:.4f}" for val in lsc))
        file.write("\n")

def load_hyper(file_path, h, d):
    """
    Load hyperparameters from a file into a pgpr_vector.
    
    :param file_path: Path to the file containing the hyperparameters.
    :param h: A pgpr_vector to store the loaded hyperparameters.
    :param d: Dimension of the length-scale vector.
    """
    with open(file_path, 'r') as file:
        line = file.readline().strip()
        values = list(map(float, line.split()))  # Split the line into float values

    # Extract the number of length-scale parameters
    num_lsc = int(values[3])

    # Ensure the vector has enough space
    h.resize(2 + num_lsc + 1)  # 2 for noise and signal variance, +1 for mean

    # Assign values to the vector
    h[0] = values[1]  # Noise parameter
    for i in range(1, num_lsc + 1):
        h[i] = values[4 + i - 1]  # Length-scale vector
    h[num_lsc + 1] = values[0]  # Signal variance
    h[h.size() - 1] = values[2]  # Mean of the hyperparameters

# Debugging functions
def pmsg(level, current_level, message):
    """
    Print a debug message if the level is less than or equal to the current level.
    """
    if level <= current_level:
        print(message)

# Useful routines
def bubble_sort(a, ai):
    """
    Perform bubble sort on a vector and update the indices.
    """
    n = len(a)
    for i in range(n):
        for j in range(0, n - i - 1):
            if a[j] > a[j + 1]:
                a[j], a[j + 1] = a[j + 1], a[j]
                ai[j], ai[j + 1] = ai[j + 1], ai[j]

def argmaxi(v):
    """
    Return the index of the maximum value in a vector.
    """
    return np.argmax(v)

def get_rmse(v1, v2):
    """
    Compute the root mean square error between two vectors.
    """
    if v1.size() != v2.size():
        raise ValueError("Vectors must have the same size.")
    return np.sqrt(np.mean((v1 - v2) ** 2))

def get_mnlp(v1, v2, v3):
    """
    Compute the mean negative log probability.
    """
    if v1.size() != v2.size() or v1.size() != v3.size():
        raise ValueError("Vectors must have the same size.")
    return 0.5 * np.mean([(v1[i] - v2[i]) ** 2 / v3[i] + np.log(2 * np.pi * v3[i]) for i in range(v1.size())])

def A_invB_C(A, chol_b, C, D):
    """
    Compute D = A B^{-1} C where:
    A: m x n matrix (Mdoub)
    B: n x n matrix (Cholesky decomposed by pgpr_chol)
    C: n-size vector (Vdoub)
    D: m-size vector (Vdoub, output)
    """
    m = A.nrows()
    n = A.ncols()
    beta = pgpr_vector(n)
    D.assign(m, 0.0)

    beta  = chol_b.solve(C)  # Solve B * beta = C
    print("beta = ", beta)
    for k in range(m):
        for j in range(n):
            D[k] += A[k][j] * beta[j]
    return SUCC

def A_invB_C_matrix(A, chol_b, C, D):
    """
    Compute D = A B^{-1} C where:
    A: m x n matrix (Mdoub)
    B: n x n matrix (Cholesky decomposed by pgpr_chol)
    C: n x l matrix (Mdoub)
    D: m x l matrix (Mdoub, output)
    """
    m = A.nrows()
    n = A.ncols()
    l = C.ncols()
    beta = pgpr_vector(n)
    Cn = pgpr_vector(n)
    D.assign(m, l, 0.0)

    for i in range(l):
        for j in range(n):
            Cn[j] = C[j][i]
        beta = chol_b.solve(Cn)  # Solve B * beta = Cn
        for k in range(m):
            for j in range(n):
                D[k][i] += A[k][j] * beta[j]
    return SUCC

def A_invB_transC(A, chol_b, C, D):
    """
    Compute D = A B^{-1} C^T where:
    A: m x n matrix (Mdoub)
    B: n x n matrix (Cholesky decomposed by pgpr_chol)
    C: l x n matrix (Mdoub)
    D: m x l matrix (Mdoub, output)
    """
    m = A.nrows()
    n = A.ncols()
    l = C.nrows()
    beta = pgpr_vector(n)
    Cn = pgpr_vector(n)
    D.assign(m, l, 0.0)

    for i in range(l):
        for j in range(n):
            Cn[j] = C[i][j]
        beta = chol_b.solve(Cn)  # Solve B * beta = Cn
        for k in range(m):
            for j in range(n):
                D[k][i] += A[k][j] * beta[j]
    return SUCC

def A_invB_C_symmetric(A, chol_b, D):
    """
    Compute D = A B^{-1} A^T where:
    A: m x n matrix (Mdoub)
    B: n x n matrix (Cholesky decomposed by pgpr_chol)
    D: m x m matrix (Mdoub, output)
    """
    m = A.nrows()
    n = A.ncols()
    beta = pgpr_vector(n)
    Cn = pgpr_vector(n)
    v = pgpr_vector(n)
    D.assign(m, m, 0.0)

    for i in range(m):
        for j in range(n):
            Cn[j] = A[i][j]
        v = chol_b.elsolve(Cn)  # Solve L * v = Cn
        for j in range(n):
            D[i][i] += v[j] * v[j]  # Diagonal elements
        beta  = chol_b.solve(Cn)  # Solve B * beta = Cn
        for k in range(i + 1, m):
            for j in range(n):
                D[k][i] += A[k][j] * beta[j]
            D[i][k] = D[k][i]  # Symmetric assignment
    return SUCC

def trace_A_invB_C(A, chol_b, C, D):
    """
    Compute D = trace(A B^{-1} C) where:
    A: m x n matrix (Mdoub)
    B: n x n matrix (Cholesky decomposed by pgpr_chol)
    C: n x m matrix (Mdoub)
    D: m-size vector (Vdoub, output)
    """
    m = A.nrows()
    n = A.ncols()
    l = C.ncols()
    beta = pgpr_vector(n)
    Cn = pgpr_vector(n)
    D.assign(m, 0.0)

    for i in range(l):
        for j in range(n):
            Cn[j] = C[j][i]
        beta  = chol_b.solve(Cn)  # Solve B * beta = Cn
        for j in range(n):
            D[i] += A[i][j] * beta[j]
    return SUCC

def trace_A_invB_transC(A, chol_b, C, D):
    """
    Compute D = trace(A B^{-1} C^T) where:
    A: m x n matrix (Mdoub)
    B: n x n matrix (Cholesky decomposed by pgpr_chol)
    C: m x n matrix (Mdoub)
    D: m-size vector (Vdoub, output)
    """
    m = A.nrows()
    n = A.ncols()
    l = C.nrows()
    if m != l:
        raise ValueError("Matrix dimensions do not match for trace(A B^{-1} C^T)")
    beta = pgpr_vector(n)
    Cn = pgpr_vector(n)
    D.assign(m, 0.0)

    for i in range(m):
        for j in range(n):
            Cn[j] = C[i][j]
        beta = chol_b.solve(Cn)  # Solve B * beta = Cn
        for j in range(n):
            D[i] += A[i][j] * beta[j]
    return SUCC

# Timer class
class pgpr_timer:
    """
    A timer class to measure elapsed time.
    """
    def __init__(self):
        self.start_time = None
        self.end_time = None

    def start(self):
        """
        Start the timer.
        """
        self.start_time = time.time()

    def end(self):
        """
        Stop the timer and return the elapsed time in seconds.
        """
        self.end_time = time.time()
        return self.end_time - self.start_time