import torch
import math
import time
import numpy as np
import pennylane as qml

torch_pi = torch.Tensor([math.pi])

def rotation_mat(a):
    """
    Input:  a scaler a
    Output: a matrix. The signal rotation matrix W(a).
    (requires -1 <= 'a' <= 1)
    """
    diag = a
    off_diag = (1 - a**2) ** (1 / 2) * 1j
    W = [[diag, off_diag], [off_diag, diag]]

    return W

def QSP_circ(phi, x):
    """
    Input:  a vector phi and a scaler x
    Output: a pennylane circuit. This circuit realizes the quantum signal process. 
    The components (real part of (0, 0) element) in the matrix representation of the final unitary can approximate any Plyy(x), where x in R^1.
    """
    qml.Hadamard(wires=0)  # set initial state |+>
    W = rotation_mat(x)
    for angle in phi[:-1]:
        qml.RZ(angle, wires=0)
        qml.QubitUnitary(W, wires=0)
    qml.RZ(phi[-1], wires=0)  # final rotation
    qml.Hadamard(wires=0)  # change of basis |+> , |->
    return

def step_function(x, K):
    """
    Define a step function on [0, 1] that divides the interval into K equal parts and assigns the output value as k/K when x falls within the k-th segment.

    Args:
        x (numpy.ndarray): Input values.
        K (int): Number of segments.

    Returns:
        numpy.ndarray: Output values of the step function.
    """

    if K <= 0:
        raise ValueError("Number of segments (K) must be a positive integer")

    segments = torch.linspace(0, 1, K + 1)  # Divide the interval into K equal parts
    output = torch.zeros_like(x)

    for k in range(K):
        mask = (x >= segments[k]) & (x <= segments[k + 1])
        output[mask] = k / K

    return output
    
def random_step_function(x, K):
    """
    Define a step function on [0, 1] that divides the interval into K equal parts and assigns the output value as k/K when x falls within the k-th segment.

    Args:
        x (numpy.ndarray): Input values.
        K (int): Number of segments.

    Returns:
        numpy.ndarray: Output values of the step function.
    """

    if K <= 0:
        raise ValueError("Number of segments (K) must be a positive integer")

    segments = torch.linspace(0, 1, K + 1)  # Divide the interval into K equal parts
    output = torch.zeros_like(x)

    for k in range(K):
        mask = (x >= segments[k]) & (x <= segments[k + 1])
        output[mask] = k / K + torch.rand(1) / (3*K)

    return output


class Discretization(torch.nn.Module):
    "Here we discrete the input X in R^d, we minus X by another Poly func to approximate the point that we make Taylor expansion."
    def __init__(self, K, eps, random_seed=None):
        super().__init__()
        self.K       = K
        self.depth   = math.ceil(1 / eps * math.log2(K / eps))
        if random_seed is None:
            self.phi = torch_pi * torch.rand(self.depth, requires_grad=True)
        else:
            gen = torch.Generator()
            gen.manual_seed(random_seed)
            self.phi = torch_pi * torch.rand(self.depth, requires_grad=True, generator=gen)

        self.phi = torch.nn.Parameter(self.phi)
        self.num_phi = self.depth

    def forward(self, X):
        '''
        input:  a matrix x
        output: a new matrix with grad
        '''
        X_new               = X.clone().detach().requires_grad_(True)
        eta_matrix          = torch.zeros_like(X_new)
        generate_qsp_mat    = qml.matrix(QSP_circ)
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                u_qsp       = generate_qsp_mat(self.phi, X[i, j])
                truncated_index = torch.round(u_qsp[0, 0].real * self.K)
                if truncated_index < 0:
                    truncated_index = 0
                elif truncated_index > self.K-1:
                    truncated_index = self.K-1
                eta_matrix[i, j] = truncated_index
                with torch.no_grad():
                    X_new[i, j] -= truncated_index / self.K
        return X_new, eta_matrix
    
    def train(self, x):
        '''
        x is a vector
        '''
        generate_qsp_mat = qml.matrix(QSP_circ)
        func_outputs = torch.zeros_like(x)
        for i in range(x.shape[0]):
            u_qsp = generate_qsp_mat(self.phi, x[i])
            func_outputs[i] = u_qsp[0, 0].real
        return func_outputs
    

class QSP_Func_Poly(torch.nn.Module):
    "Here we implement a polynomial function on X in R^d, and this equals the sum of (s+1)*d^s monomials with diff params."
    def __init__(self, s, depth_constant, d, K, random_seed=None):
        """
        Input: s: continuous coefficient. d: input dim.
        """
        super().__init__()
        self.K = K
        self.num_phi = (s+1)*d**s * (s + 1) * d
        self.range   = int((s+1)*d**s)
        if random_seed is None:
            self.phi = torch_pi * torch.rand((s+1)*d**s, depth_constant * (s + 1), d, requires_grad=True)
            self.eta = 1 * torch_pi * torch.rand(K, d, requires_grad=True)
        else:
            gen      = torch.Generator()
            gen.manual_seed(random_seed)
            self.phi = torch_pi * (torch.rand((s+1)*d**s, 1 * (s + 1), d, requires_grad=True, generator=gen) + 1)
            self.eta = 1 * (torch.rand(K**d, self.range, requires_grad=True, generator=gen) + 1)
        self.phi     = torch.nn.Parameter(self.phi)
        self.eta     = torch.nn.Parameter(self.eta)

    def forward(self, X, eta_matrix):
        y_pred = torch.zeros(X.shape[0])
        generate_qsp_mat = qml.matrix(QSP_circ)

        for i in range(X.shape[0]):
            y_pred_k          = torch.zeros(1)
            taylor_index      = 0
            for l in range(X.shape[1]):
                if eta_matrix[i, l] < 0:
                    taylor_index += 0 * self.K**l
                else:
                    taylor_index += eta_matrix[i, l] * self.K**l
            taylor_index_1 = taylor_index.long()
            for k in range(self.range):
                taylor_coeff  = self.eta[taylor_index_1, k]
                outcomes      = taylor_coeff
                for j in range(X.shape[1]):
                    u_qsp     = generate_qsp_mat(self.phi[k, :, j], X[i, j].detach()) 
                    P_a       = u_qsp[0, 0].real
                    outcomes  = torch.mul(outcomes, P_a)
                y_pred_k     += outcomes
            y_pred[i]         = 2 * torch.abs(y_pred_k)  / self.range
        return y_pred
    

class QNN(torch.nn.Module):
    "Here we discrete the input X in R^d, we minus X by another Poly func to approximate the point that we make Taylor expansion."
    def __init__(self, s, depth_constant, K, eps, d, random_seed_1=None, random_seed_2=None):
        super(QNN, self).__init__()
        self.disc_model  = Discretization(K, eps, random_seed_1)
        self.poly_model  = QSP_Func_Poly(s, depth_constant, d, K, random_seed_2)
        self.K = K

    
    def forward(self, X):
        X_new, eta_matrix = self.disc_model.forward(X)
        X_new = X
        output           = self.poly_model.forward(X_new, eta_matrix)
        return output

    def load_discretization_model(self, model_path):
        self.disc_model.load_state_dict(torch.load(model_path))

    def check(self, X):
        X_new, eta_matrix = self.disc_model.forward(X)
        return  eta_matrix

class QNN_with_perfect_discretization(torch.nn.Module):
    "Here we discrete the input X in R^d, we minus X by another Poly func to approximate the point that we make Taylor expansion."
    def __init__(self, s, depth_constant, K, eps, d, random_seed_1=None, random_seed_2=None):
        super(QNN_with_perfect_discretization, self).__init__()
        self.K = K
        self.poly_model  = QSP_Func_Poly(s, depth_constant, d, K, random_seed_2)
    
    def forward(self, X):
        eta_matrix = step_function(X, self.K)
        X_new = X 
        output = self.poly_model.forward(X_new, eta_matrix)
        return output
        






