import numpy as np
import math
import csv
import sys
import networkx as nx
import scipy
import time
import heapq
import random

def readDataFromFileIJW(filePath, addNum = 0):
    originalData = []
    with open(filePath, 'r') as csv_file:
        csv_reader = csv.reader(csv_file)
        # Passing the cav_reader object to list() to get a list of lists
        originalData = list(csv_reader)

    #print(originalData)

    edgeSet = []
    veterxNum = 0

    for lineData in originalData:
        dataString = lineData[0]
        dataList = dataString.split(' ')
        dataList[0], dataList[1] = int(dataList[0]) + addNum, int(dataList[1]) + addNum
        dataList[2] = float(dataList[2])

        edgeSet.append(dataList)
        veterxNum = max(0, dataList[0], dataList[1])


    veterxSet = [[i + 1] for i in range(veterxNum)]

    return veterxSet, edgeSet
def getDegreeVector(vertexSet, edgeSet):
  vertexNum = len(vertexSet)
  degreeList = [0 for i in range(vertexNum)]

  for edge in edgeSet:
      edgeI = edge[0]
      edgeJ = edge[1]
      edgeWeight = edge[2]

      degreeList[edgeI - 1] += edgeWeight
      degreeList[edgeJ - 1] += edgeWeight

  return degreeList

def getClusterFromDic(clustersDic):
  clusters = []
  for cluster in clustersDic:
      if cluster:
          clusters.append(cluster)
  return clusters

def getNCut(vertexSet, edgeSet, processedVertexList):
    degreeVector = getDegreeVector(vertexSet, edgeSet)

    totalNum = 0

    for subList in processedVertexList:
        subSet = set()
        notSubSet = set()

        sumDegree = 0

        for i in vertexSet:
            vertexID = i[0]
            notSubSet.add(vertexID)

        for vertexID in subList:
            subSet.add(vertexID)
            notSubSet.remove(vertexID)
            sumDegree += degreeVector[vertexID - 1]

        sumNum = 0

        for edge in edgeSet:
            edgeI = edge[0]
            edgeJ = edge[1]
            edgeWeight = edge[2]

            if (edgeI in subSet and edgeJ in notSubSet) or (edgeJ in subSet and edgeI in notSubSet):
                sumNum += edgeWeight

        sumNum /= sumDegree
        totalNum += sumNum

    totalNum *= 0.5
    return totalNum
    
    
class HeapEdge:
    def __init__(self, edge):
        self.edge = edge

    def __repr__(self):
        return f'|edge index: {self.edge[0]}, value: {self.edge[1]}|'

    def __lt__(self, other):
        return self.edge[1] < other.edge[1]

    def __eq__(self, other):
        return self.edge[1] == other.edge[1]

    def __gt__(self, other):
        return self.edge[1] > other.edge[1]


def fastHeapClustering(vertexSet, edgeSet, k, ran = False, ranNum = 1, ranMin = 0, ranMax = 1):
  if ranMin < 0 or ranMax > 1:
    print('Wrong Ran range')
    return

  vertexNum = len(vertexSet)

  degreeList = getDegreeVector(vertexSet, edgeSet)

  currentBestList = []
  currentBestEstimation = float('inf')
  finalClusters = []

  for i in range(ranNum):
    selectionList =[]

    clusterID = [i for i in range(len(vertexSet))]
    clusterDepth = [1 for i in vertexSet]
    clustersDic = [[i + 1] for i in range(len(vertexSet))]

    ranList = []
    if ran:
      ranList = [math.log(random.uniform(ranMin, ranMax)) for i in range(len(edgeSet))]

    def findRoot(index):
        nextIndex = clusterID[index]
        if nextIndex == index:
            return index
        else:
            return findRoot(nextIndex)

    volumeList = [i for i in degreeList]

    edgeValueList = [edge[2]*(1/(volumeList[edge[0] - 1]) + 1/(volumeList[edge[1] - 1])) for edge in edgeSet]

    if ran:
      edgeValueList = [ranList[i] * 1/edgeValueList[i] for i in range(len(edgeValueList))]

    edgeHeap = []

    for i in range(len(edgeValueList)):
      tmpHE = HeapEdge([i,-edgeValueList[i]])
      heapq.heappush(edgeHeap, tmpHE)


    while len(selectionList) < vertexNum - k:
      candidate = heapq.heappop(edgeHeap)
      currentTop = edgeHeap[0]

      candidateEdgeIndex = candidate.edge[0]
      candidateEdge = edgeSet[candidateEdgeIndex]
      ci = candidateEdge[0] - 1
      cj = candidateEdge[1] - 1
      cw = candidateEdge[2]

      iRoot = findRoot(ci)
      jRoot = findRoot(cj)

      if iRoot == jRoot:
        candidate.edge[1] = 0
        #TODO we can just continue?
        continue
      else:
        candidate.edge[1] = -1 * cw * (1/(volumeList[iRoot]) + 1/(volumeList[jRoot]))
        if ran:
          candidate.edge[1] = 1 / candidate.edge[1] * ranList[candidateEdgeIndex]


      if candidate < currentTop or candidate == currentTop:
        selectionList.append(candidate.edge[0])
        sumVolume = volumeList[iRoot] + volumeList[jRoot]

        if iRoot == jRoot:
            print('Something Wrong!!!!!!!!')

        if clusterDepth[iRoot] >= clusterDepth[jRoot]:
            if clusterDepth[iRoot] == clusterDepth[jRoot]:
                clusterDepth[iRoot] += 1

            clusterID[jRoot] = iRoot
            volumeList[iRoot] = sumVolume

            clustersDic[iRoot] += clustersDic[jRoot]
            clustersDic[jRoot] = None
        elif clusterDepth[iRoot] < clusterDepth[jRoot]:
            clusterID[iRoot] = jRoot

            volumeList[jRoot] = sumVolume

            clustersDic[jRoot] += clustersDic[iRoot]
            clustersDic[iRoot] = None
      else:
        heapq.heappush(edgeHeap, candidate)



    if ran:
      tmpClusters = getClusterFromDic(clustersDic)
      estimation = getNCut(vertexSet, edgeSet, tmpClusters)

      if estimation < currentBestEstimation:
        currentBestEstimation = estimation
        currentBestList = selectionList.copy()
        finalClusters = tmpClusters.copy()

    else:
      finalClusters = getClusterFromDic(clustersDic)
      break

  return finalClusters
  
def main():
    fileName = sys.argv[1]
    vertexSet, edgeSet = readDataFromFileIJW(fileName, 1)
    
    k = sys.argv[2]
    clusters = fastHeapClustering(vertexSet, edgeSet, int(k), ran=False)
    
    print('nCut :', getNCut(vertexSet, edgeSet, clusters))

if __name__=="__main__":
    main()