import numpy as np
import pandas as pd
import os
from datetime import datetime

from collections import Counter

datapath = './../../Data'

def LoadInstance(name,N,i):
    path = datapath + "/%s/" % name
    file = "%s_%d" % (name,i+1)
    pathfile = path + file
    
    if N<=2000:

        file = open ( pathfile , 'r')
        idx = np.array([[float(num) for num in line.split(' ')] for line in file ])

        w = np.zeros((N,N))
        for i in range(len(idx[:,0])):
            w[int(idx[i,0])-1,int(idx[i,1])-1] = idx[i,2]

        w = w + w.T
    
    else:

        w = np.zeros([N, N])

        wp = pd.read_csv(pathfile+".csv.gz").to_numpy()

        w[0:len(wp), 0:len(wp)] = wp

        w + w.transpose()
    
    return w


def LoadOptimal(name,N):
    path = datapath + "/%s/" % name
    if N<2000:
        file = "%s_REF" % (name)
    else:
        file = "%s_SOL" % (name)
    pathfile = path + file
    file = open ( pathfile , 'r')
    H0 = np.array([[float(num) for num in line.split(' ')] for line in file ])
    return H0

def load_wishart(N,alphatxt,i):

    #alphatxt = '0.80'
    name = 'WISHART_%d_%s' % (N,alphatxt)
    w    = LoadInstance(name,N,i)
    H0   = LoadOptimal(name,N)[i]
    eps0=np.mean(np.abs(w))
    #eps0=np.mean(np.abs(w))/np.sqrt(N)

    return w,eps0,H0
    
from itertools import product
    
# calculate partition function
def partitionf(w,N,beta,vtype):
    n   = pow(2,N)

    Clist = []
    Hlist = []
    Plist = []

    if vtype==0:
        var = [0,1]
    else:
        var = [-1,1]
    
    for x in product(var, repeat=N):
        x = np.array(x)
        H = -0.5*np.sum(x*np.dot(w,x))
        P = np.exp(-beta*H)

        Clist.append(x)
        Hlist.append(H)
        Plist.append(P)

    Ho = np.unique(Hlist)
    Plist = np.array(Plist)
    Hlist = np.array(Hlist)


    Z = np.sum(Plist)

    Zlist = []
    for H in Ho:
        idx = np.where(Hlist == H)[0]
        Zi = np.sum(Plist[idx])
        Zlist.append(Zi/Z)

    Pth = np.array(Zlist)
    
    return Pth,Ho

def load_from_file(folder_name, N, alphatxt, T):
    # Construct the file name for data
    data_file_name = f"wishart_{N}_{alphatxt}_{T}.txt"
    data_file_path = os.path.join(folder_name, data_file_name)

    # Read the data from the file
    with open(data_file_path, 'r') as file:
        f_eval = float(file.readline().strip())
        param_out = list(map(float, file.readline().strip().split()))

    print(f"Data loaded from file: {data_file_path}")
    return f_eval, param_out


def create_timestamped_folder(solvertype):
    # Get the current date and time
    current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    # Construct the folder name
    folder_name = f"{current_time}_{solvertype}"

    # Check if the folder already exists
    if not os.path.exists(folder_name):
        # Create the folder if it doesn't exist
        os.makedirs(folder_name)
        print(f"Folder created: {folder_name}")
    else:
        print(f"Folder already exists: {folder_name}")

    return folder_name


def save_to_file(folder_name, file_name, f_eval, evalist, param_out):
    # Construct the file name
    file_path = os.path.join(folder_name, file_name)

    # Write the data to the file
    with open(file_path, 'w') as file:
        file.write(f"{f_eval}\n")
        file.write(" ".join(map(str, np.exp(param_out))) + "\n")
        file.write(" ".join(map(str, evalist)))
        

    print(f"Data saved to file: {file_path}")


def read_file(folder_name, file_name):
    file_path = os.path.join(folder_name, file_name)
        
    p0=[]
    params=[]
    pvec=[]
    
    if os.path.exists(file_path):

        with open(file_path, 'r') as file:
            p0 = float(file.readline().strip())
            params = list(map(float, file.readline().strip().split()))
            pvec = list(map(float, file.readline().strip().split()))

    return p0, params, pvec

def count_vector_occurrences(L, L0):

    # Create a Counter for the vectors in L
    counter = Counter(tuple(vector) for vector in L)
    
    # Count the occurrences of each vector in L0
    occurrences = [counter[tuple(vector)] for vector in L0]
    
    return occurrences