# This project uses PyTorch for deep learning functionalities.
# PyTorch is an open source machine learning library developed by Facebook's AI Research lab.
# More information about PyTorch can be found at https://pytorch.org
## License
# This project is licensed under the Creative Commons Attribution 4.0 International License.

import numpy as np # numpy lib
import torch # Pytorch

X= np.loadtxt('X40std9.txt', delimiter=',') # Pre-info of fidelity matrix, can change to others
R= np.loadtxt('R40.txt', delimiter=',') # Pre-info of success rate matrix , can change to others
R[R == 0] = 1 # Success rate matrix moditication (avoid diagonal line 0 for error in 1./R)
X = X.astype(np.float32) # X to float 32
R = R.astype(np.float32)  # R to float 32
E0 = np.random.rand(40, 40) * 0.1 # Random scheduling
Et = torch.from_numpy(E0).float() # Err to pytorch tensor
Xt = torch.from_numpy(X).float() # X to pytorch tensor
Rt = torch.from_numpy(R).float() # R to pytorch tensor

Nqubit = 40 # Qubit number
Npath = 20 # Entanglement worker number that can work in parallels
rat = 2e5 # Entanglement trial rate

class UnionFind: # Data structure to store the connected qubit graph efficiently
    def __init__(self, size): # qubit graph initialization
        self.parent = torch.arange(size, dtype=torch.long) # All the qubit node has a parent node of itself
        self.size = torch.ones(size, dtype=torch.long) # All the cluster size to 1

    def find(self, x): # find the parent root node of a specific qubit
        if self.parent[x] != x: # only the parent root node fullfills the self.parent[x] = x
            self.parent[x] = self.find(self.parent[x].item()) # iterative find the parent
        return self.parent[x] # return the root node of a specific qubit

    def union(self, x, y): # union two qubit cluster set if they are connected
        rootX = self.find(x) # parent root node of the first qubit
        rootY = self.find(y) # parent root node of the second qubit
        if rootX != rootY: # when the root node are different, the smaller set attached to the larger set
            if self.size[rootX] < self.size[rootY]:
                self.parent[rootX] = rootY
                self.size[rootY] += self.size[rootX] # The larger set has a larger size
            else:
                self.parent[rootY] = rootX
                self.size[rootX] += self.size[rootY] # The larger set has a larger size

    def findLargestSetSize(self): # Find the largest set size in the data structure
        return torch.max(self.size)

    def isConnected(self, x, y): # Check whether x and y are in the same set (by checking whether their root parent nodes are the same)
        return self.find(x) == self.find(y)

def findQVmax(): # Environment simulation
    G0 = torch.zeros((Nqubit, Nqubit), dtype=torch.int32) # Initialized connected qubit graph adjacent matrix
    Gc = torch.zeros((Nqubit, Nqubit), dtype=torch.int32) # Initialized entanglement trial graph adjacent matrix
    delete = torch.tensor([], dtype=torch.int32) # Entanglement worker list
    aval = torch.tensor([], dtype=torch.int32) # Available qubit list 
    r_ent = torch.zeros(Npath, dtype=Rt.dtype) # Entanglement worker success rate initialization
    maxSize = 0 # Maximum cluster point set number
    count = 0 # Environment simulation step
    flag = False # State change flag
    flag2 = False # Available for next scheduling flag
    uf = UnionFind(Nqubit) # Initialize unionfind structure for connected cluster storage
    TResult = torch.empty(0, 4, dtype=torch.int32) # Result storage table, update once qubit 1 and qubit 2 have sucessful entanglement [simulation frame, qubit 1, qubit 2, maximum cluster set]
    MA =Et # Action matrix: greedy result
    new_err = MA.clone() # Copy output action matrix
    while len(delete)<Nqubit: # Initialize the entanglement worker to work on all the qubit connections
        minVals, _ = torch.min(new_err, dim=0) # Find the mim err
        colInd = torch.argmin(minVals).unsqueeze(0) # Get the column index
        rowInd = torch.argmin(new_err[:, colInd]).unsqueeze(0) # Get the row index
        delete = torch.cat((delete, rowInd, colInd), dim=0) # Add the index to the entanglement worker
        Gc[rowInd, colInd] = 1 # The entanglement trial graph adjacent matrix changes
        Gc[colInd, rowInd] = 1 # The entanglement trial graph adjacent matrix changes
        new_err[delete, :] = 1 # The cost matrix changes
        new_err[:, delete] = 1 # The cost matrix changes

    for i in range(Npath): # Assign the success rate of each entanglement worker
        r_ent[i] = Rt[delete[2 * i], delete[2 * i + 1]]

    while maxSize<30: # Stop condition of cluster building, log2QV<30
        while not flag2: # If we are not able to do the next scheduling, keep doing entanglement until we have enough free qubit available for schedule
            while not flag: # if state is not changing, keep the loop until anything happens
                count += 1 # simulation count increases
                for i in range(Npath): # Check all the entanglement worker
                    if torch.rand(1) < r_ent[i] / rat: # Monte Carlo simulation to see whether the corresponding entanglement is success
                        r_ent[i].zero_() # If success, the corresponding entanglement rate become zero
                        aval = torch.cat((aval, delete[2*i:2*i+2])) # The corresponding two qubit free as available list
                        G0[delete[2 * i], delete[2 * i + 1]] = 1 # Connected qubit graph adjacent matrix update
                        G0[delete[2 * i + 1], delete[2 * i]] = 1 # Connected qubit graph adjacent matrix update
                        uf.union(delete[2 * i], delete[2 * i + 1]) # Unionfind data structure update
                        maxSize = uf.findLargestSetSize().item() # Find max cluster size
                        a0 = torch.tensor([count], dtype=torch.int32) # Result data preparation
                        a1 = delete[2 * i].unsqueeze(0) # Result data preparation
                        a2 = delete[2 * i + 1].unsqueeze(0) # Result data preparation
                        a3 = torch.tensor([maxSize], dtype=torch.int32) # Result data preparation
                        iter_data = torch.cat((a0, a1, a2, a3)).unsqueeze(0) # Merge result data into a table
                        TResult = torch.cat((TResult, iter_data), dim=0) # Attach to the big result table for future QV calculation
                        delete[2 * i + 1] = -1 # Set the weight of entanglement worker -1
                        delete[2 * i] = -1 # Set the weight of entanglement worker -1
                        flag = True # State has been changed
                if not flag:
                    break

            indices = aval.int() # check the available list now
            sub_G0 = G0[indices][:, indices] # build a subgraph with all the available qubit
            eye = torch.eye(len(aval), dtype=torch.int32) # Make the diagonal matrix
            sub_G0 += eye # Make the diagonal all 1 in the subgraph
            isFullyConnected = torch.all(sub_G0 == 1).item() # Check whether the whole matrix is 1 (no free path for scheduling)
            flag = False # State flag reset to False to wait for next state change
            if not isFullyConnected: # We have enough free qubit for the next schedule
                flag2 = True

        MA = Et # Action matrix: greedy result + NN output
        rows, cols = torch.meshgrid(aval, aval, indexing='ij') # We only search for the available scheduling among available qubit list
        err = MA.clone()  # Copy output action matrix
        new_err[rows, cols] = err[rows, cols]  # Copy output action matrix
        while not torch.all(new_err == 1): # We stop the searching when all the available qubit is assigned for entanglement
            minVals, _ = torch.min(new_err, dim=0) # Find the mim err
            colInd = torch.argmin(minVals) # Get the column index
            rowInd = torch.argmin(new_err[:, colInd]) # Get the row index
            if uf.isConnected(colInd, rowInd): # Check if that two available qubit is already in a connected cluster using UnionFind data structure
                new_err[rowInd, colInd] = 1 # If so, change the err matrix so in the next search it will not consider this case
                new_err[colInd, rowInd] = 1 # If so, change the err matrix so in the next search it will not consider this case
            else:
                index = torch.where(delete == -1)[0] # If not, find the idle entanglement worker by delete == -1
                index = torch.min(index).item() # Find the first idle worker
                delete[index] = rowInd # Assign new task to it
                delete[index + 1] = colInd # Assign new task to it
                Gc[delete[index], delete[index+1]] = 1 # The entanglement trial graph adjacent matrix changes
                Gc[delete[index+1], delete[index]] = 1 # The entanglement trial graph adjacent matrix changes
                remove_indices = {rowInd.item(), colInd.item()} # To change the aval qubit list
                mask = ~torch.isin(aval, torch.tensor(list(remove_indices), dtype=aval.dtype)) # Maks function for pytorch
                aval = aval[mask] # Remove the assigned qubit to the entanglement worker to update the qubit available list
                new_err[rowInd, :] = 1 # Err matrix update so in the next search it will not consider this case
                new_err[:, rowInd] = 1 # Err matrix update so in the next search it will not consider this case
                new_err[colInd, :] = 1 # Err matrix update so in the next search it will not consider this case
                new_err[:, colInd] = 1 # Err matrix update so in the next search it will not consider this case

        flag2 = False # After assigning all the entanglement worker, there is no availabl qubit for next scheduling
        flag = False # Wait for the next stage change
        for i in range(Npath): # Check all the entanglement workers task update their success rate
            if delete[2*i] > -1: # For all the worker that has work
                r_ent[i] = Rt[delete[2 * i], delete[2 * i + 1]] # Update the success rate
                if err[delete[2 * i], delete[2 * i + 1]]>0.02: # Have a err threshold (Ath) for the entanglement working on, if the error is too large
                    r_ent[i] = 0 # Delete this task
                    aval = torch.cat((aval, delete[2*i:2*i+2])) # Free the corresponding qubit to the available list
                    delete[2 * i + 1] = -1 # Make the entanglement worker idle
                    delete[2 * i] = -1  # Make the entanglement worker idle
    e0 = torch.zeros(count, dtype=torch.float) + 0.00001 # error envolvement list
    Nmax = torch.zeros(count, dtype=torch.int32)  # max cluster size envolvement list

    for i in range(count): # search all the entanglement trial events
        for j in range(len(TResult)): # go through the result storage table
            if i > TResult[j][0]: # If the counting time is after the success of that entanglement build
                e0[i] = e0[i] + 1 - Xt[TResult[j][1], TResult[j][2]] * torch.exp(-(i - TResult[j][0]) / 5e4) # Add the error into the system

    for j in range(len(TResult) - 2): # go through the result storage table
        Nmax[TResult[j][0]:TResult[j+1][0]] = TResult[j][3] # update teh max cluster size envolvement list

    maxQV = torch.max(torch.minimum(Nmax, 1. / e0)) # calculate the maxlog2QV based on the equation
    return -maxQV # return as reward

## Print the scheduling result histogram
for i in range(100): # Real training for optimizing the MA output
    QV = findQVmax().item() # print the greedy scheduling result
    print(QV)
