from torch import  nn
import torch
import pennylane as qml
# from qiskit.providers.fake_provider import GenericBackendV2
# from qiskit_aer.noise import NoiseModel
from qiskit_ibm_runtime import QiskitRuntimeService
from qiskit_ibm_runtime import QiskitRuntimeService
class QAN(nn.Module):
    def __init__(self, input_dim=1, output_dim=1, hidden_dim=10, num_layers=3,
                 use_remote=False, ibm_channel="ibm_quantum", ibm_instance=None,
                 ibm_backend_name=None, ibm_min_qubits=None, shots=2048,
                 optimization_level=None, resilience_level=None):
        super(QAN, self).__init__()
        n_qubits = hidden_dim
        if use_remote:
            try:
                # Prefer to use the user-specified instance; otherwise, use the default saved account
                # service = (QiskitRuntimeService(channel=ibm_channel, instance=ibm_instance)
                #            if ibm_instance else QiskitRuntimeService())
                token = "xxxx"
                # QiskitRuntimeService.save_account(channel="ibm_quantum", token=token)
                service = QiskitRuntimeService(instance="", channel="ibm_cloud",token=token)
            
                print(service)
                min_qubits = ibm_min_qubits or n_qubits
                backend = service.least_busy(operational=True, simulator=False, min_num_qubits=min_qubits)

                device_kwargs = dict(shots=shots)
                if optimization_level is not None:
                    device_kwargs["optimization_level"] = optimization_level
                if resilience_level is not None:
                    device_kwargs["resilience_level"] = resilience_level

                dev = qml.device("qiskit.remote", wires=n_qubits, backend=backend, **device_kwargs)
            except Exception as e:
                print(f"[QAN] IBM remote init failed, fallback to default.qubit. Reason: {e}")
                dev = qml.device("default.qubit", wires=n_qubits)

        else:
            dev = qml.device("default.qubit", wires=n_qubits)

        hidden_dim = n_qubits  # keep original variable usage

        @qml.qnode(dev)
        def qnode(inputs, weights1, weights2):
            # qml.AmplitudeEmbedding(features=inputs, wires=range(hidden_dim), normalize=True)
            # qml.AngleEmbedding(inputs, wires=range(n_qubits))
            for i in range(num_layers):
                qml.BasicEntanglerLayers(weights1[i], wires=range(n_qubits))
                for j in range(n_qubits):
                    qml.RZ(inputs[:,j], wires=j)
            qml.BasicEntanglerLayers(weights2, wires=range(n_qubits))
            # for i in range(n_qubits):
            #     qml.CNOT(wires=[i, (i + 1) % n_qubits])  # Create entanglement between adjacent qubits
            
            # for i in range(n_qubits):
            #     qml.CNOT(wires=[i, (i + 1) % n_qubits])  # Create entanglement between adjacent qubits
           
            # for i in range(n_qubits):
            #     qml.CNOT(wires=[i, (i + 1) % n_qubits])  # Create entanglement between adjacent qubits
            
            # for i in range(n_qubits):
            #     qml.RX(rx_angles[i], wires=i)
            # return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
            res = []
            for i in range(n_qubits):
                res.append(qml.expval(qml.PauliZ(i)))
            return res

        self.clayer_1 = nn.Linear(input_dim, hidden_dim)  
        self.clayer_2 = torch.nn.Linear(hidden_dim, output_dim) 
       
        weight_shape = {
            "weights1": (num_layers, 1, n_qubits),  # Weights for BasicEntanglerLayers
            "weights2": (1,n_qubits)         # RX gate parameters for each qubit
        }
        # weight_shape = {
        #     "weights1": (num_layers, 2, n_qubits, 3),
        #     "weights2": (2, n_qubits, 3)
        # }
        self.qnode = qml.QNode(qnode, dev, diff_method="backprop", interface="torch")
        # self.qnode = qml.QNode(qnode, dev)
        # self.qnode = qnode
        self.qlayer = qml.qnn.TorchLayer(self.qnode, weight_shape)   
    
    def forward(self, src):
        x1 = self.clayer_1(src)             # [B, T, hidden_dim]
        # B, T, D = x1.shape
        # x1_flat = x1.view(B * T, D)         # [B*T, hidden_dim] -> [batch, n_qubits]
        x2 = self.qlayer(x1)           # [B*T, n_qubits] (or your measurement count)
        # x2 = x2.view(B, T, -1)              # Restore [B, T, D]
        output = self.clayer_2(x2)          # [B, T, output_dim]
        return output

class QAN1(nn.Module):
    def __init__(self, input_dim=1, output_dim=1, hidden_dim=8, num_layers=3):
        super(QAN1, self).__init__()
        n_qubits = hidden_dim
        dev = qml.device("default.qubit", wires=n_qubits)

        @qml.qnode(dev)
        def qnode(inputs, weights1,rx_angles):
            qml.AngleEmbedding(inputs, wires=range(n_qubits))
            qml.BasicEntanglerLayers(weights1, wires=range(n_qubits))
           
            # for i in range(n_qubits):
            #     qml.CNOT(wires=[i, (i + 1) % n_qubits])  # Create entanglement between adjacent qubits
            
            # for i in range(n_qubits):
            #     qml.CNOT(wires=[i, (i + 1) % n_qubits])  # Create entanglement between adjacent qubits
           
            # for i in range(n_qubits):
            #     qml.CNOT(wires=[i, (i + 1) % n_qubits])  # Create entanglement between adjacent qubits
            
            for i in range(n_qubits):
                qml.RX(rx_angles[i], wires=i)
            return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
        self.clayer_1 = nn.Linear(input_dim, hidden_dim)  
        self.clayer_2 = torch.nn.Linear(hidden_dim, output_dim) 
       
        weight_shapes = {
            "weights1": (num_layers,n_qubits),  # Weights for BasicEntanglerLayers
            "rx_angles": (n_qubits,)         # RX gate parameters for each qubit
        }
        # backend = GenericBackendV2(num_qubits=2, seed=42)
        # qk_noise_model = NoiseModel.from_backend(backend)
        # from qiskit_aer.noise import depolarizing_error

        # # Create noise model for RX gate, e.g., depolarizing noise with 0.01 error probability
        # rx_noise = depolarizing_error(0.01, 1)
        # # Add RX gate noise to the noise model
        # for i in range(2):
        #     qk_noise_model.add_quantum_error(rx_noise, ["rx"], [i])
        # # print(qk_noise_model)
        # pl_noise_model = qml.from_qiskit_noise(qk_noise_model)
        # # print(pl_noise_model)
        # noise_qnode = qml.add_noise(qnode,pl_noise_model)
        self.qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)   

    def forward(self, src):
        x1 = self.clayer_1(src)             # [B, T, hidden_dim]
       
        x2 = self.qlayer(x1)           
       
        output = self.clayer_2(x2)          #
        return output

if __name__ == "__main__":
    import torch

    # Model parameters
    input_dim = 96
    output_dim = 24
    hidden_dim = 8
    num_layers = 3

    # Create model
    model = QAN1(input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim, num_layers=num_layers)

    # Input data [batch_size, input_dim]
    # Note: Adjust forward function reshape if you expect [B, T, input_dim]
    batch_size = 3
    x = torch.randn(batch_size, 7, input_dim)  # Input shape [4, 1]

    # Forward pass
    y = model(x)

    # Output
    print("Input shape:", x.shape)
    print("Output shape:", y.shape)
    # print("Output data:", y)

    model = QAN(input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim, num_layers=num_layers)

    # Input data [batch_size, input_dim]
    # Note: Adjust forward function reshape if you expect [B, T, input_dim]
    batch_size = 3
    x = torch.randn(batch_size, 7, input_dim)  # Input shape [4, 1]

    # Forward pass
    y = model(x)

    # Output
    print("Input shape:", x.shape)
    print("Output shape:", y.shape)