import numpy as np # numpy lib
import torch # Pytorch
import torch.nn as nn # Pytorch neural network
from torch.nn import TransformerEncoder, TransformerEncoderLayer # Pytorch transformer lib
import torch.optim as optim # Pytorch optimizer lib

X= np.loadtxt('X40std9.txt', delimiter=',') # Pre-info of fidelity matrix, can change to others
R= np.loadtxt('R40std9.txt', delimiter=',') # Pre-info of success rate matrix , can change to others
R[R == 0] = 1 # Success rate matrix moditication (avoid diagonal line 0 for error in 1./R)
X = X.astype(np.float32) # X to float 32
R = R.astype(np.float32)  # R to float 32
E0 = 1 - X * np.exp(-1./ R)  # Greedy error to pytorch tensor
Et = torch.from_numpy(E0).float() # Err to pytorch tensor
Xt = torch.from_numpy(X).float() # X to pytorch tensor
Rt = torch.from_numpy(R).float() # R to pytorch tensor
MR = torch.from_numpy(np.exp(-1. / R)) # exp(-1./R) to pytorch tensor

Nqubit = 40 # Qubit number
Npath = 20 # Entanglement worker number that can work in parallels
rat = 2e5 # Entanglement trial rate

G0 = torch.zeros((Nqubit, Nqubit), dtype=torch.int32) # Connected qubit graph adjacent matrix
Gc = torch.zeros((Nqubit, Nqubit), dtype=torch.int32) # Entanglement trial graph adjacent matrix

P1 = torch.arange(1, 41).repeat(40, 1)/40 # Position embedding for qubit 1
column_values = torch.arange(1, 41) / 40 # Transposition of position embedding 1
P2 = column_values.view(-1, 1).repeat(1, 40)  # Position embedding for qubit 2

MX_tensor = Xt.view(1, 1600, 1) # 40*40 tensor reshape to 1*1600*1 tensor
MR_tensor = MR.view(1, 1600, 1) # 40*40 tensor reshape to 1*1600*1 tensor
MG0_tensor = G0.view(1, 1600, 1) # 40*40 tensor reshape to 1*1600*1 tensor
MGc_tensor = G0.view(1, 1600, 1) # 40*40 tensor reshape to 1*1600*1 tensor
ME_tensor = Et.view(1, 1600, 1) # 40*40 tensor reshape to 1*1600*1 tensor
P1_tensor = P1.view(1, 1600, 1) # 40*40 tensor reshape to 1*1600*1 tensor
P2_tensor = P2.view(1, 1600, 1) # 40*40 tensor reshape to 1*1600*1 tensor

#transformer input tensor dimension [input dimension, sequence length, batch size]
input_data = torch.cat((MX_tensor, MR_tensor, MG0_tensor, MGc_tensor, ME_tensor, P1_tensor, P2_tensor), dim=2)  # 7 1*1600*1 tensor merged to 7*1600*1 as input

class UnionFind: # Data structure to store the connected qubit graph efficiently
    def __init__(self, size): # qubit graph initialization
        self.parent = torch.arange(size, dtype=torch.long) # All the qubit node has a parent node of itself
        self.size = torch.ones(size, dtype=torch.long) # All the cluster size to 1

    def find(self, x): # find the parent root node of a specific qubit
        if self.parent[x] != x: # only the parent root node fullfills the self.parent[x] = x
            self.parent[x] = self.find(self.parent[x].item()) # iterative find the parent
        return self.parent[x] # return the root node of a specific qubit

    def union(self, x, y): # union two qubit cluster set if they are connected
        rootX = self.find(x) # parent root node of the first qubit
        rootY = self.find(y) # parent root node of the second qubit
        if rootX != rootY: # when the root node are different, the smaller set attached to the larger set
            if self.size[rootX] < self.size[rootY]:
                self.parent[rootX] = rootY
                self.size[rootY] += self.size[rootX] # The larger set has a larger size
            else:
                self.parent[rootY] = rootX
                self.size[rootX] += self.size[rootY] # The larger set has a larger size

    def findLargestSetSize(self): # Find the largest set size in the data structure
        return torch.max(self.size)

    def isConnected(self, x, y): # Check whether x and y are in the same set (by checking whether their root parent nodes are the same)
        return self.find(x) == self.find(y)

class CustomTransformerModel(nn.Module): # Transformer architecture
    def __init__(self, input_features=7, embed_dim=32, output_features=1, seq_len=1600, nhead=1, num_encoder_layers=1): # Transformer parameters
        super(CustomTransformerModel, self).__init__() # Transformer initialization
        self.seq_len = seq_len # Transformer sequence length
        self.linear_transform = nn.Linear(input_features, embed_dim) # Embedding layer (input_features to embed_dim NN)
        encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=64) # Transformer encoding layer
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers) # Transformer multilayer stacking
        self.output_layer = nn.Linear(embed_dim, output_features) #output layer from embed_dim to output_features

    def forward(self, x): # Forward propagating
        x = self.linear_transform(x) # Embedding layer
        x = x.permute(1, 0, 2) # Change to [sequence length, input dimension, batch size]
        x = self.transformer_encoder(x) # Transformer layer
        x = x.permute(1, 0, 2) # Change back to [input dimension, sequence length, batch size]
        x = self.output_layer(x) # Output layer
        return x # Get the neural network output result

model = CustomTransformerModel() # Define the model
optimizer = optim.Adam(model.parameters(), lr=1e-2) # Set the optimizer type and learning rate

def findQVmax(delta): # Environment simulation
    G0 = torch.zeros((Nqubit, Nqubit), dtype=torch.int32) # Initialized connected qubit graph adjacent matrix
    Gc = torch.zeros((Nqubit, Nqubit), dtype=torch.int32) # Initialized entanglement trial graph adjacent matrix
    delete = torch.tensor([], dtype=torch.int32) # Entanglement worker list
    aval = torch.tensor([], dtype=torch.int32) # Available qubit list 
    r_ent = torch.zeros(Npath, dtype=Rt.dtype) # Entanglement worker success rate initialization
    maxSize = 0 # Maximum cluster point set number
    count = 0 # Environment simulation step
    flag = False # State change flag
    flag2 = False # Available for next scheduling flag
    uf = UnionFind(Nqubit) # Initialize unionfind structure for connected cluster storage
    TResult = torch.empty(0, 4, dtype=torch.int32) # Result storage table, update once qubit 1 and qubit 2 have sucessful entanglement [simulation frame, qubit 1, qubit 2, maximum cluster set]
    with torch.no_grad(): # Don't compute the gradient graph (save the memory)
        output = model(input_data) # output of action matrix from input embedding
    min_val = output.min() # min output
    max_val = output.max() # max output
    norm_output = (output - min_val)/(max_val - min_val) # normalized the output from 0 to 1
    Ent0 = norm_output.view(40,40)*0.1 # restrict the max amplitude to 0.1
    Ent = Ent0+delta*0.1 # Random perturbation to the output
    Ent[range(len(Ent)), range(len(Ent))] = 0 # Reset the diagonal items to all 0
    MA = Ent + Et # Action matrix: greedy result + NN output
    new_err = MA.clone() # Copy output action matrix
    while len(delete)<Nqubit: # Initialize the entanglement worker to work on all the qubit connections
        minVals, _ = torch.min(new_err, dim=0) # Find the mim err
        colInd = torch.argmin(minVals).unsqueeze(0) # Get the column index
        rowInd = torch.argmin(new_err[:, colInd]).unsqueeze(0) # Get the row index
        delete = torch.cat((delete, rowInd, colInd), dim=0) # Add the index to the entanglement worker
        Gc[rowInd, colInd] = 1 # The entanglement trial graph adjacent matrix changes
        Gc[colInd, rowInd] = 1 # The entanglement trial graph adjacent matrix changes
        new_err[delete, :] = 1 # The cost matrix changes
        new_err[:, delete] = 1 # The cost matrix changes

    for i in range(Npath): # Assign the success rate of each entanglement worker
        r_ent[i] = Rt[delete[2 * i], delete[2 * i + 1]]

    while maxSize<30: # Stop condition of cluster building, log2QV<30
        while not flag2: # If we are not able to do the next scheduling, keep doing entanglement until we have enough free qubit available for schedule
            while not flag: # if state is not changing, keep the loop until anything happens
                count += 1 # simulation count increases
                for i in range(Npath): # Check all the entanglement worker
                    if torch.rand(1) < r_ent[i] / rat: # Monte Carlo simulation to see whether the corresponding entanglement is success
                        r_ent[i].zero_() # If success, the corresponding entanglement rate become zero
                        aval = torch.cat((aval, delete[2*i:2*i+2])) # The corresponding two qubit free as available list
                        G0[delete[2 * i], delete[2 * i + 1]] = 1 # Connected qubit graph adjacent matrix update
                        G0[delete[2 * i + 1], delete[2 * i]] = 1 # Connected qubit graph adjacent matrix update
                        uf.union(delete[2 * i], delete[2 * i + 1]) # Unionfind data structure update
                        maxSize = uf.findLargestSetSize().item() # Find max cluster size
                        a0 = torch.tensor([count], dtype=torch.int32) # Result data preparation
                        a1 = delete[2 * i].unsqueeze(0) # Result data preparation
                        a2 = delete[2 * i + 1].unsqueeze(0) # Result data preparation
                        a3 = torch.tensor([maxSize], dtype=torch.int32) # Result data preparation
                        iter_data = torch.cat((a0, a1, a2, a3)).unsqueeze(0) # Merge result data into a table
                        TResult = torch.cat((TResult, iter_data), dim=0) # Attach to the big result table for future QV calculation
                        delete[2 * i + 1] = -1 # Set the weight of entanglement worker -1
                        delete[2 * i] = -1 # Set the weight of entanglement worker -1
                        flag = True # State has been changed
                if not flag:
                    break

            indices = aval.int() # check the available list now
            sub_G0 = G0[indices][:, indices] # build a subgraph with all the available qubit
            eye = torch.eye(len(aval), dtype=torch.int32) # Make the diagonal matrix
            sub_G0 += eye # Make the diagonal all 1 in the subgraph
            isFullyConnected = torch.all(sub_G0 == 1).item() # Check whether the whole matrix is 1 (no free path for scheduling)
            flag = False # State flag reset to False to wait for next state change
            if not isFullyConnected: # We have enough free qubit for the next schedule
                flag2 = True

        MG0_tensor = G0.view(1, 1600, 1) # New connected qubit graph adjacent matrix
        MGc_tensor = G0.view(1, 1600, 1) # New entanglement trial graph adjacent matrix
        input_data2 = torch.cat((MX_tensor, MR_tensor, MG0_tensor, MGc_tensor, ME_tensor, P1_tensor, P2_tensor), dim=2) # New input features
        with torch.no_grad(): # Don't compute the gradient graph (save the memory)
            output = model(input_data2) # output of action matrix from input embedding
        min_val = output.min() # min output
        max_val = output.max() # max output
        norm_output = (output - min_val)/(max_val - min_val)  # normalized the output from 0 to 1
        Ent = norm_output.view(40,40)*0.1+delta*0.1 # Random perturbation to the output
        Ent[range(len(Ent)), range(len(Ent))] = 0 # Reset the diagonal items to all 0
        MA = Ent + Et # Action matrix: greedy result + NN output
        rows, cols = torch.meshgrid(aval, aval, indexing='ij') # We only search for the available scheduling among available qubit list
        err = MA.clone()  # Copy output action matrix
        new_err[rows, cols] = err[rows, cols]  # Copy output action matrix
        while not torch.all(new_err == 1): # We stop the searching when all the available qubit is assigned for entanglement
            minVals, _ = torch.min(new_err, dim=0) # Find the mim err
            colInd = torch.argmin(minVals) # Get the column index
            rowInd = torch.argmin(new_err[:, colInd]) # Get the row index
            if uf.isConnected(colInd, rowInd): # Check if that two available qubit is already in a connected cluster using UnionFind data structure
                new_err[rowInd, colInd] = 1 # If so, change the err matrix so in the next search it will not consider this case
                new_err[colInd, rowInd] = 1 # If so, change the err matrix so in the next search it will not consider this case
            else:
                index = torch.where(delete == -1)[0] # If not, find the idle entanglement worker by delete == -1
                index = torch.min(index).item() # Find the first idle worker
                delete[index] = rowInd # Assign new task to it
                delete[index + 1] = colInd # Assign new task to it
                Gc[delete[index], delete[index+1]] = 1 # The entanglement trial graph adjacent matrix changes
                Gc[delete[index+1], delete[index]] = 1 # The entanglement trial graph adjacent matrix changes
                remove_indices = {rowInd.item(), colInd.item()} # To change the aval qubit list
                mask = ~torch.isin(aval, torch.tensor(list(remove_indices), dtype=aval.dtype)) # Maks function for pytorch
                aval = aval[mask] # Remove the assigned qubit to the entanglement worker to update the qubit available list
                new_err[rowInd, :] = 1 # Err matrix update so in the next search it will not consider this case
                new_err[:, rowInd] = 1 # Err matrix update so in the next search it will not consider this case
                new_err[colInd, :] = 1 # Err matrix update so in the next search it will not consider this case
                new_err[:, colInd] = 1 # Err matrix update so in the next search it will not consider this case

        flag2 = False # After assigning all the entanglement worker, there is no availabl qubit for next scheduling
        flag = False # Wait for the next stage change
        for i in range(Npath): # Check all the entanglement workers task update their success rate
            if delete[2*i] > -1: # For all the worker that has work
                r_ent[i] = Rt[delete[2 * i], delete[2 * i + 1]] # Update the success rate
                if err[delete[2 * i], delete[2 * i + 1]]>0.02: # Have a err threshold for the entanglement working on, if the error is too large
                    r_ent[i] = 0 # Delete this task
                    aval = torch.cat((aval, delete[2*i:2*i+2])) # Free the corresponding qubit to the available list
                    delete[2 * i + 1] = -1 # Make the entanglement worker idle
                    delete[2 * i] = -1  # Make the entanglement worker idle
    e0 = torch.zeros(count, dtype=torch.float) + 0.00001 # error envolvement list
    Nmax = torch.zeros(count, dtype=torch.int32)  # max cluster size envolvement list

    for i in range(count): # search all the entanglement trial events
        for j in range(len(TResult)): # go through the result storage table
            if i > TResult[j][0]: # If the counting time is after the success of that entanglement build
                e0[i] = e0[i] + 1 - Xt[TResult[j][1], TResult[j][2]] * torch.exp(-(i - TResult[j][0]) / 5e4) # Add the error into the system

    for j in range(len(TResult) - 2): # go through the result storage table
        Nmax[TResult[j][0]:TResult[j+1][0]] = TResult[j][3] # update teh max cluster size envolvement list

    maxQV = torch.max(torch.minimum(Nmax, 1. / e0)) # calculate the maxlog2QV based on the equation
    return -maxQV # return as reward

# def funcdif(): # Fitting NN initialization
#     output = model(input_data)
#     MA = output.view(40, 40) # resize the output to a 40*40 tensor matrix
#     criterion = nn.MSELoss() # loss function as the error between output and target
#     return criterion(MA, Et) # Match the output matrix to the greedy search matrix

avgloss = 0
num_epochs = 1000 # Number of epochs

# for epoch in range(num_epochs): # fitting training process
#     loss = funcdif() # loss function for fitting
#     optimizer.zero_grad() # reset gradient
#     loss.backward() # NN back propagation
#     optimizer.step() # NN optimization update


for epoch in range(num_epochs): # Real training for optimizing the MA output
    delta = torch.rand(40,40)*0.1 # Random perturbation
    loss1 = findQVmax(delta*0) # Averaging the environment simulation to get the policy gradient
    loss2 = findQVmax(delta*0) # Averaging the environment simulation to get the policy gradient
    loss3 = findQVmax(delta*0) # Averaging the environment simulation to get the policy gradient
    loss4 = findQVmax(delta*0) # Averaging the environment simulation to get the policy gradient
    loss5 = findQVmax(delta*0) # Averaging the environment simulation to get the policy gradient
    avg1 = (loss1+loss2+loss3+loss4+loss5)/5 # Mean quantum volume expectation from current policy
    print(loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item())
    loss6 = findQVmax(delta) # Averaging the environment simulation to get the policy gradient
    loss7 = findQVmax(delta) # Averaging the environment simulation to get the policy gradient
    loss8 = findQVmax(delta) # Averaging the environment simulation to get the policy gradient
    loss9 = findQVmax(delta) # Averaging the environment simulation to get the policy gradient
    loss10 = findQVmax(delta) # Averaging the environment simulation to get the policy gradient
    avg2 = (loss6+loss7+loss8+loss9+loss10)/5 # Mean quantum volume expectation from current policy add perturbation

    output = model(input_data) # Gradient graph calculation
    if avg1.item()>avg2.item(): # To see whether the perturbation is good or not
        target = output + delta.view(1, 1600, 1) # if it is good, we prefer adding this perturbation
    else:
        target = output - delta.view(1, 1600, 1) # if not, our target is in another direction

    criterion = nn.MSELoss() # cross-entropy loss define
    loss = criterion(output,target) # loss as the different between output and target
    optimizer.zero_grad()  # reset gradient
    loss.backward() # NN back propagation
    optimizer.step()  # NN optimization update