import torchquantum.functional as tqf
import torchquantum as tq
from torchquantum.macro import C_DTYPE, F_DTYPE

import torch
import torch.nn as nn
import torch.nn.functional as F
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')


class U3(tq.QuantumModule):
    """Class for U3 gate."""
    def __init__(self):
        super(U3, self).__init__()
        self.num_params = 3
        self.params = self.build_params()
        self.n_wires = 1

    def build_params(self):
        parameters = nn.Parameter(torch.empty([1, self.num_params], dtype=F_DTYPE))
        parameters.requires_grad = True
        torch.nn.init.uniform_(parameters, -torch.pi, torch.pi)
        return parameters

    def forward(self, q_device: tq.QuantumDevice, wires=None):
        self.wires = wires
        matrix = tqf.u3_matrix(self.params)
        self.func(q_device, self.wires, matrix)
    
    def func(self, q_device: tq.QuantumDevice, wires, matrix):
        state = q_device.states
        q_device.states = tqf.apply_unitary_bmm(state, matrix, wires)

class CU3(tq.QuantumModule):
    """Class for CU3 gate."""
    def __init__(self):
        super(CU3, self).__init__()
        self.num_params = 3
        self.params = self.build_params()
        self.n_wires = 2

    def build_params(self):
        parameters = nn.Parameter(torch.empty([1, self.num_params], dtype=F_DTYPE))
        parameters.requires_grad = True
        torch.nn.init.uniform_(parameters, -torch.pi, torch.pi)
        return parameters

    def forward(self, q_device: tq.QuantumDevice, wires=None):
        self.wires = wires
        matrix = tqf.cu3_matrix(self.params)
        self.func(q_device, self.wires, matrix)
    
    def func(self, q_device: tq.QuantumDevice, wires, matrix):
        state = q_device.states
        q_device.states = tqf.apply_unitary_bmm(state, matrix, wires)

class PauliX(tq.QuantumModule):
    """Class for PauliX gate."""
    def __init__(self):
        super(PauliX, self).__init__()
        self.num_params = 0
        self.n_wires = 1
        self.matrix = tqf.mat_dict['paulix']

    def forward(self, q_device: tq.QuantumDevice, wires=None):
        self.wires = wires
        self.func(q_device, self.wires, self.matrix)
    
    def func(self, q_device: tq.QuantumDevice, wires, matrix):
        state = q_device.states
        q_device.states = tqf.apply_unitary_bmm(state, matrix, wires)

class PauliY(tq.QuantumModule):
    """Class for PauliY gate."""
    def __init__(self):
        super(PauliY, self).__init__()
        self.num_params = 0
        self.n_wires = 1
        self.matrix = tqf.mat_dict['pauliy']

    def forward(self, q_device: tq.QuantumDevice, wires=None):
        self.wires = wires
        self.func(q_device, self.wires, self.matrix)
    
    def func(self, q_device: tq.QuantumDevice, wires, matrix):
        state = q_device.states
        q_device.states = tqf.apply_unitary_bmm(state, matrix, wires)

class PauliZ(tq.QuantumModule):
    """Class for PauliZ gate."""
    def __init__(self):
        super(PauliZ, self).__init__()
        self.num_params = 0
        self.n_wires = 1
        self.matrix = tqf.mat_dict['pauliz']

    def forward(self, q_device: tq.QuantumDevice, wires=None):
        self.wires = wires
        self.func(q_device, self.wires, self.matrix)
    
    def func(self, q_device: tq.QuantumDevice, wires, matrix):
        state = q_device.states
        q_device.states = tqf.apply_unitary_bmm(state, matrix, wires)