import numpy as np
import random
from file_paths_and_consts import *
import sys
import pandas as pd
import os
import time
# to make reproducable results
random.seed(42)
np.random.seed(42)
     
'''We consider vertex index starts from 0
first read the graph file and find out the number of vertices'''
def get_vertice_num(graph_file):
    mx = -1
    with open(graph_file) as f:
        for line in f:
            v_list = [int(i) for i in line.split()]
            assert len(v_list)==2 , "wrong file format"
            mx = max(mx, max(v_list[0],v_list[1]))
    return mx

'''Read Graph and create adjacency list'''
def read_graph(graph_file,one_indexed=1):
    n = get_vertice_num(graph_file)+(1-one_indexed)
    G = []
    for i in range(n):
        G.append([])
    with open(graph_file) as f:
        for line in f:
            edge = [int(i) for i in line.split()]
            #assert len(v_list)==2 , "wrong file format" -- already checked in previous call
            #print(edge[0],edge[1])
            G[edge[0]-one_indexed].append(edge[1]-one_indexed)
            G[edge[1]-one_indexed].append(edge[0]-one_indexed)
    return G,n

'''create thresholds for each vertices'''
def init_Tau(G,n):
    thresholds = []
    for v in range(0,n):
        deg = len(G[v])
        thresh = np.random.randint(0,deg+2,size=1)[0]
        thresholds.append(thresh)
    return thresholds

'''Read thresholds from file'''
def read_Tau(G,file):
    thresholds = []
    with open(file) as f:
        for line in f:
            thresh_list = [int(i) for i in line.split()]
            assert len(thresh_list)==2 , "wrong file format"
            thresholds.append(thresh_list[1])
    assert len(thresholds)==len(G), "number of vertices do not match in threshold list and Graph list"
    return thresholds

'''Generate num_samples x n initial states
each state corresponds to one initial state
consisting of 0/1s of the n vertices'''
def gen_init_config_dataset(G,n,num_samples,P=0.5):
    S = np.random.rand(num_samples, n)
    S[S>P] = 1
    S[S!=1] = 0
    return S


'''G is a adjacency list where G[v] has the list of v neighbor
Tau is a threshold list where Tau[v] is the threshold of vertex v
creates the next state, given the current state s'''
def gen_next_config(G,n,s,Tau):
    next_state = []
    for v in range(0,n):
        mask = np.full(n, False)
        if len(G[v])>0:
            mask[np.array(G[v])] = True
        influence = np.sum(s, where=mask)
        next_state.append((1 if influence>=Tau[v] else 0)) # or s[v]
    return np.array(next_state)

'''Given a set of initial states, creates the set of next states
eta corresponds to the noise added, if eta==0 it will just be the
next state. eta should be < 1/3, fixed throughout single experiment'''
def gen_next_config_dataset(G,n,Tau,S,eta=0):
    S_prime = []
    for s in S:
        T = gen_next_config(G,n,s,Tau)
        flips = np.random.rand(n)
        assert flips.shape[0]==T.shape[0], "next state shape and flip shape do not match"
        T_noise = np.where(flips>=eta, T, 1-T)
        S_prime.append(T_noise)
    return np.vstack(S_prime)

'''helper function for algorithm A'''
def gen_mask(G,n):
    MASKS = []
    for v in range(0,len(G)):
        mask = np.full(n, False)
        if len(G[v])>0:
            mask[np.array(G[v])] = True
        MASKS.append(mask)
    return np.vstack(MASKS)
    
def gen_next_config_single_v(MASKS,v,s,tau_v):
    influence = np.sum(s, where=MASKS[v])
    return (1 if influence>=tau_v else 0) # or s[v]

'''algorithm A1, described in paper'''
def algorithm_A1(init_state_data,next_state_data,G,n):
    Tau_list = []
    MASKS = gen_mask(G,n)
    for v in range(0,n):
        if(v%100==0):
            print('processing v=',v,'now')
        deg = len(G[v])
        min_err = init_state_data.shape[0]*2
        min_thresh = -1
        for tau in range(0,deg+2):
            err_cnt = 0
            for s,t in zip(init_state_data,next_state_data):
                y_pred = gen_next_config_single_v(MASKS,v,s,tau)
                y_true = t[v]
                if(y_pred!=y_true):
                    err_cnt = err_cnt + 1
            if err_cnt<min_err:
                min_err = err_cnt
                min_thresh = tau
        Tau_list.append(min_thresh)
    return Tau_list

'''More efficient version, so that we get info for multiple samples'''
def evaluate_tau(G,n,test_data,predicted_Tau,Actual_Tau):
    true_next_state = gen_next_config_dataset(G,n,Actual_Tau,test_data,eta=0.0)
    predicted_next_state = gen_next_config_dataset(G,n,predicted_Tau,test_data,eta=0.0)
    err = 0.0
    for y_true,y_pred in zip(true_next_state,predicted_next_state):
        if (y_true!=y_pred).any():
            err = err + 1
    return err/test_data.shape[0]

def algorithm_A1_efficient(init_state_data,next_state_data,G,n,test_data,Actual_Tau,N,D,BIAS,NOISE,eval_every=50,save_every=1000):
    live_err_counter = [[0 for j in range(len(G[i])+2)] for i in range(len(G))]
    train_data_seen = 0
    history = []
    MASKS = gen_mask(G,n)
    for init_state,next_state in zip(init_state_data,next_state_data):
        for v in range(0,n):
            deg = len(G[v])
            for tau in range(0,deg+2):
                y_true = next_state[v]
                y_pred = gen_next_config_single_v(MASKS,v,init_state,tau)
                if(y_pred!=y_true):
                    live_err_counter[v][tau] = live_err_counter[v][tau]+1
        train_data_seen = train_data_seen+1
        if(train_data_seen%eval_every==0):
            print("evaluated",train_data_seen,'training samples',flush=True)
            cur_tau_list = [tau_err_v.index(min(tau_err_v)) for tau_err_v in live_err_counter]
            #print(len(cur_tau_list),type(cur_tau_list[0]))
            #print("inferred:",cur_tau_list,"actual",Actual_Tau)
            err_p = evaluate_tau(G,n,test_data,cur_tau_list,Actual_Tau)
            history.append({'train-size':train_data_seen,'err_prob':err_p})
        if(train_data_seen%save_every==0):
            out_file_name = f'results-G-{N}-{D}-{BIAS}-{NOISE}-algorithm-A.csv'
            df = pd.DataFrame.from_dict(history)
            df['Avg_Deg'] = D
            df['Noise'] = NOISE
            df['Bias'] = BIAS
            df['N'] = N
            df.to_csv(OUTPUT_DIR+out_file_name,index=False)
    return history

#parameters
N = int(float(sys.argv[1]))
D = int(sys.argv[2])
graph_file_name = f'Gnp/G-{N}-deg-{D}.edges'
thresh_file_name = f'Gnp/G-{N}-deg-{D}.thresholds'
BIAS = 0.5#float(sys.argv[3])
NOISE = float(sys.argv[3])
train_data_size = max(5000,N*2)

## read from file
graph_file = SYNTH_GRAP_PATH+graph_file_name
threshold_file = SYNTH_GRAP_PATH+thresh_file_name
G,n = read_graph(graph_file)
Actual_Tau = read_Tau(G,threshold_file)

## generate training data
train_size = train_data_size
train_data = gen_init_config_dataset(G,n,train_size,P=BIAS)
next_state_data = gen_next_config_dataset(G,n,Actual_Tau,train_data,eta=NOISE)

## gen test data
test_size = min(1000,train_size//5)
test_data = gen_init_config_dataset(G,n,test_size,P=BIAS)
#print(test_data.shape)

#run the algorithm
start = time.time()
history = algorithm_A1_efficient(train_data,next_state_data,G,n,test_data,Actual_Tau,N,D,BIAS,NOISE)
end = time.time()

#write to output
out_file_name = f'results-G-{N}-{D}-{BIAS}-{NOISE}-algorithm-A.csv'
df = pd.DataFrame.from_dict(history)
df['Avg_Deg'] = D
df['Noise'] = NOISE
df['Bias'] = BIAS
df['N'] = N
df['time'] = round(end-start,2)
df.to_csv(OUTPUT_DIR+out_file_name,index=False)
print(N,D,BIAS,NOISE,end-start)