import sys,os
import torch
import torchquantum as tq
import torch.nn.functional as F

from new_gates import U3, CU3

class Encoder(tq.QuantumModule):
    def __init__(self, n_wires, pixels, gene):
        super(Encoder, self).__init__()
        self.n_wires = n_wires
        self.pixels = pixels
        self.gene = gene
        assert len(pixels) == 2 and pixels[0] == n_wires
        self.ind = []
        tmp = 0
        for i in range(self.n_wires):
            for j in range(self.pixels[1]):
                self.ind.append(tmp)
            tmp += 1      
        self.init_encoder()

    def forward(self, q_device, x):
        '''
        :param x: [bsz, num_pixel]
        '''
        for i in range(self.pixels[0]*self.pixels[1]):
            op = self.queue[i]
            op(q_device, self.ind[i], x[:,i])

    def init_encoder(self):
        self.queue = tq.QuantumModuleList()
        for g in self.gene:
            if g == 0:
                self.queue.append(tq.RX())
            elif g == 1:
                self.queue.append(tq.RY())
            else:
                self.queue.append(tq.RZ())
        
class SuperLayer(tq.QuantumModule):
    def __init__(self, n_wires, n_layers, gene):
        super(SuperLayer, self).__init__()
        self.n_wires = n_wires
        self.n_layers = n_layers
        assert self.n_layers == len(gene)
        self.gene = gene
        self.queue = tq.QuantumModuleList()
        self.ind = []
        for i in range(self.n_layers):
            for j in range(self.n_wires):
                if self.gene[i][j] == 0:
                    self.ind.append([j])
            for j in range(self.n_wires):
                if self.gene[i][j+self.n_wires] == 0:
                    self.ind.append([j, (j+1)%self.n_wires])
            self.queue += self.init_qnn(i)
    
    def forward(self, q_device, fault_dict=dict()):
        for i in range(len(self.queue)):
            op = self.queue[i]
            op(q_device, self.ind[i])
            if i in fault_dict.keys():
                fop, wire = fault_dict[i]
                fop(q_device, wire)
    
    def init_qnn(self, i):
        queue = tq.QuantumModuleList()
        for g in self.gene[i][:self.n_wires]:
            if g == 0:
                queue.append(U3())
        for g in self.gene[i][self.n_wires:]:
            if g == 0:
                queue.append(CU3())
        return queue
    
class ErrorSuperNet(tq.QuantumModule):
    def __init__(self, n_wires, n_layers, pixels, gene, qnn=None):
        super(ErrorSuperNet, self).__init__()

        self.n_wires = n_wires
        self.n_layers = n_layers
        if not isinstance(pixels, torch.Tensor):
            pixels = torch.tensor(pixels)
        self.pixels = pixels
        self.gene = gene
        self.encoder = Encoder(n_wires=self.n_wires, pixels=self.pixels, gene=self.gene.encoder)

        if qnn:
            self.qnn = qnn
        else:
            self.qnn = SuperLayer(n_wires=self.n_wires, n_layers=self.n_layers, gene=self.gene.qnn)
        
        self.measure = tq.MeasureAll(tq.PauliZ)
    
    def forward(self, x:torch.Tensor, fault_dict=dict()):
        qdev = tq.QuantumDevice(
            n_wires=self.n_wires, bsz=x.shape[0], device=x.device
        )
        bsz = x.shape[0]
        x = F.avg_pool2d(x, 6).view(bsz, 16)

        self.encoder(qdev, x)

        self.qnn(qdev, fault_dict)

        x = self.measure(qdev)

        x = x.reshape(bsz, 4)
        # x = x.sum(-1).squeeze()
        # x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        return x