
import torch
import torch.nn as nn
import pennylane as qml
import math

use_identity_block = True
use_gaussian_init = False

ANSATZ_LIST = ["default", "hardware_efficient"]

class QuantumLayer(nn.Module):

    SIMULATOR_default = 'default.qubit'
    SIMULATOR_gpu = 'lightning.gpu'
    SIMULATORS = [SIMULATOR_default, SIMULATOR_gpu]

    def __init__(self, 
                 n_qubits, 
                 rep, 
                 pad_with,
                 simulator = SIMULATOR_default,
                 ansatz = 'default',
                 device = 'cpu'):
        
        super().__init__()
        if simulator not in QuantumLayer.SIMULATORS:
            raise ValueError(f"Simulator {simulator} not supported. Supported simulators are: {QuantumLayer.SIMULATORS}")

        self.n_qubits = n_qubits
        self.sim_dev = qml.device(simulator, wires=n_qubits)
        self.ansatz = ansatz
        self.show_plot = True
        self.pad_with = pad_with
        self.device = device

        if self.ansatz == 'default':
            if use_identity_block: # --- use Identity Block Initialization --- #
                
                # weights not involving conditional operations #
                self.weights = torch.rand(n_qubits, rep, device = self.device) * 2 * math.pi
                self.weights = nn.Parameter(torch.cat((self.weights, self.weights), dim = 0))

                # weights involving conditional operations #
                self.cond = torch.rand(n_qubits, n_qubits, rep, device = self.device) * 2 * math.pi
                self.cond = nn.Parameter(torch.cat((self.cond, self.cond), dim = 0))

            elif use_gaussian_init: # --- use Gaussian Initialization --- #
                # not tested yet

                gaussian_std = 1 #train_cfg['gaussian_std']

                # weights not involving conditional operations #
                self.weights = nn.Parameter(
                    torch.normal(torch.zeros(n_qubits, rep), torch.full((n_qubits, rep), gaussian_std)).to(self.device))

                # weights involving conditional operations #
                self.cond = nn.Parameter(
                    torch.normal(torch.zeros(n_qubits, n_qubits, rep), torch.full((n_qubits, n_qubits, rep), gaussian_std)).to(self.device))
        if self.ansatz == "hardware_efficient":

            self.weights = torch.rand(n_qubits, rep, device = self.device) * 2 * math.pi
            self.weights = nn.Parameter(self.weights)
            self.cond = None



    def amplitude_embedding(self, x):
      
        qml.AmplitudeEmbedding(x, wires = range(self.n_qubits), pad_with = self.pad_with, normalize = True)

    def Y_rotations(self, params, var_qubits):
 
        for i in range(len(params)):
            qml.RY(params[i], wires = var_qubits[i])
    
    def linear_CNOT_entangle(self, n_qubits):
        for i in range(0, n_qubits - 1):
            qml.CNOT(wires = [i, i + 1])


    def conditional_full_entangle(self, weights, entangle_qbs):

        n = len(entangle_qbs)
        for i in range(0, n):
            for j in range(n):
                if j != i:
                    qml.CRY(phi = weights[i, j], wires = [i, j])

    def QNode(self, inputs, weights, cond):

        @qml.qnode(self.sim_dev, interface = 'torch')
        def qnode(inputs, weights, cond):

            self.amplitude_embedding(inputs)

            if self.ansatz == 'default':
                for i in range(weights.size()[-1]):

                    if use_identity_block: #default

                        self.conditional_full_entangle(cond[:self.n_qubits,:,i], list(range(self.n_qubits)))
                        self.Y_rotations(weights[:self.n_qubits,i], list(range(self.n_qubits)))

                        qml.adjoint(self.Y_rotations)(weights[self.n_qubits:,i], list(range(self.n_qubits)))
                        qml.adjoint(self.conditional_full_entangle)(cond[self.n_qubits:,:,i], list(range(self.n_qubits)))

                    elif use_gaussian_init:

                        self.conditional_full_entangle(cond[:self.n_qubits,:,i], list(range(self.n_qubits)))
                        self.Y_rotations(weights[:self.n_qubits,i], list(range(self.n_qubits)))

            if self.ansatz == 'hardware_efficient':
                for i in range(weights.size()[-1]):

                    self.Y_rotations(weights[:,i], list(range(self.n_qubits)))
                    self.linear_CNOT_entangle(self.n_qubits)

                    #qml.adjoint(self.Y_rotations)(weights[self.n_qubits:,i], list(range(self.n_qubits)))
                    #qml.adjoint(self.linear_CNOT_entangle)(self.n_qubits)


            return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
        
        
        return qnode(inputs, weights, cond)

    def forward(self, x):

        q_out = torch.cat(self.QNode(x, self.weights, self.cond)).reshape(self.n_qubits, -1).T.float()
        return q_out
  



