import torch

class LemkeTableauTorch:
    def __init__(self, M, q, maxIter=100, device=None):
        if device is None:
            device = M.device

        self.device = device
        M = M.to(device)
        q = q.to(device)

        n = q.shape[0]
        I_n = torch.eye(n, device=device)
        ones_n_1 = torch.ones((n, 1), device=device)
        T_left = torch.cat([I_n, -M], dim=1)
        T_right = torch.cat([-ones_n_1, q.reshape(n, 1)], dim=1)
        self.T = torch.cat([T_left, T_right], dim=1)  # (n, 2n+2)

        self.n = n
        self.wPos = torch.arange(n, device=device)
        self.zPos = torch.arange(n, 2*n, device=device)

        self.W = 0
        self.Z = 1
        self.Y = 2
        self.Q = 3

        TbInd = torch.vstack([
            self.W * torch.ones(n, dtype=torch.long, device=device),
            torch.arange(n, dtype=torch.long, device=device)
        ])
        TnbInd = torch.vstack([
            self.Z * torch.ones(n, dtype=torch.long, device=device),
            torch.arange(n, dtype=torch.long, device=device)
        ])
        DriveInd = torch.tensor([[self.Y], [0]], dtype=torch.long, device=device)
        QInd = torch.tensor([[self.Q], [0]], dtype=torch.long, device=device)

        # (2 x (2n + 2))
        self.Tind = torch.hstack([TbInd, TnbInd, DriveInd, QInd])

        self.maxIter = maxIter

    def lemkeAlgorithm(self):
        initVal = self.initialize()
        # if min(q) >= 0 during initialization, solution is found immediately
        if not initVal:
            return torch.zeros(self.n, device=self.device), 0, 'Solution Found'

        for _ in range(self.maxIter):
            stepVal = self.step()
            if self.Tind[0, -2].item() == self.Y:
                z = self.extractSolution()
                return z, 0, 'Solution Found'
            elif not stepVal:
                return None, 1, 'Secondary ray found'

        # exceeded maximum number of iterations
        return None, 2, 'Max Iterations Exceeded'

    def initialize(self):
        q = self.T[:, -1]
        minQ = torch.min(q)
        if minQ < 0:
            ind = torch.argmin(q)
            self.clearDriverColumn(ind)
            self.pivot(ind)
            return True
        else:
            return False

    def step(self):
        q = self.T[:, -1]
        a = self.T[:, -2]
        minRatio = float('inf')
        chosen_index = None

        for i in range(self.n):
            if a[i] > 0:
                newRatio = q[i].item() / a[i].item()
                if newRatio < minRatio:
                    minRatio = newRatio
                    chosen_index = i

        if chosen_index is not None and minRatio < float('inf'):
            self.clearDriverColumn(chosen_index)
            self.pivot(chosen_index)
            return True
        else:
            return False

    def extractSolution(self):
        z = torch.zeros(self.n, device=self.device)
        q = self.T[:, -1]
        for i in range(self.n):
            if self.Tind[0, i] == self.Z:
                z_index = self.Tind[1, i].item()
                z[z_index] = q[i]
        return z

    def partnerPos(self, pos):
        v = self.Tind[0, pos].item()
        ind = self.Tind[1, pos].item()
        if v == self.W:
            return self.zPos[ind].item()
        elif v == self.Z:
            return self.wPos[ind].item()
        else:
            return None

    def pivot(self, pos):
        ppos = self.partnerPos(pos)
        if ppos is not None:
            self.swapColumns(pos, ppos)
            self.swapColumns(pos, -2)
            return True
        else:
            self.swapColumns(pos, -2)
            return False

    def swapMatColumns(self, mat, i, j):
        tmp = torch.clone(mat[:, i])
        mat[:, i] = mat[:, j]
        mat[:, j] = tmp
        return mat

    def swapPos(self, v, ind, newPos):
        if v == self.W:
            self.wPos[ind] = newPos % (2 * self.n + 2)
        elif v == self.Z:
            self.zPos[ind] = newPos % (2 * self.n + 2)

    def swapColumns(self, i, j):
        iInd = self.Tind[:, i]
        jInd = self.Tind[:, j]

        v_i, idx_i = iInd[0].item(), iInd[1].item()
        v_j, idx_j = jInd[0].item(), jInd[1].item()

        self.swapPos(v_i, idx_i, j)
        self.swapPos(v_j, idx_j, i)

        self.Tind = self.swapMatColumns(self.Tind, i, j)
        self.T = self.swapMatColumns(self.T, i, j)

    def clearDriverColumn(self, ind):
        a = self.T[ind, -2].item()
        self.T[ind] = self.T[ind] / a  # normalization
        for i in range(self.n):
            if i != ind:
                b = self.T[i, -2].item()
                self.T[i] = self.T[i] - b * self.T[ind]

    def ind2str(self, indvec):
        v, pos = indvec[0].item(), indvec[1].item()
        if v == self.W:
            return f'w{pos}'
        elif v == self.Z:
            return f'z{pos}'
        elif v == self.Y:
            return 'y'
        elif v == self.Q:
            return 'q'
        else:
            return '??'

    def indexStringArray(self):
        columns = []
        for i in range(self.Tind.shape[1]):
            columns.append(self.ind2str(self.Tind[:, i]))
        return columns

    def indexedTableau(self):
        col_labels = self.indexStringArray()
        return col_labels, self.T

    def __repr__(self):
        col_labels, mat = self.indexedTableau()
        return f"LemkeTableauTorch(\nColumns: {col_labels}\nMatrix:\n{mat}\n)"

    def __str__(self):
        return self.__repr__()


def lemkelcp_torch(M, q, maxIter=100, device=None):
    """
    solve LCP(M, q) applying Lemke algorithm 
    
    M (torch.Tensor): (n x n)
    q (torch.Tensor): (n,)

    return:
        (z, exit_code, exit_string)
        z: 1-dim torch.Tensor. if there is no sol, None
        exit_code:
            0 -> find solution 
            1 -> no solution
            2 -> exceed max iteration
        exit_string: result string
    """
    tableau = LemkeTableauTorch(M, q, maxIter=maxIter, device=device)
    return tableau.lemkeAlgorithm()