import numpy as np
import random
from file_paths_and_consts import *
import sys
import pandas as pd
import os
import time
# to make reproducable results
random.seed(42)
np.random.seed(42)

def append_record(record):
    with open('results-v2.json', 'a') as f:
        json.dump(record, f)
        f.write(os.linesep)

'''We consider vertex index starts from 0 
first read the graph file and find out the number of vertices'''
def get_vertice_num(graph_file):
    mx = -1
    with open(graph_file) as f:
        for line in f:
            v_list = [int(i) for i in line.split()]
            assert len(v_list)==2 , "wrong file format"
            mx = max(mx, max(v_list[0],v_list[1]))
    return mx

'''Read Graph and create adjacency list'''
def read_graph(graph_file):
    n = get_vertice_num(graph_file)
    G = []
    for i in range(n):
        G.append([])
    with open(graph_file) as f:
        for line in f:
            edge = [int(i) for i in line.split()]
            #assert len(v_list)==2 , "wrong file format" -- already checked in previous call
            #print(edge[0],edge[1])
            G[edge[0]-1].append(edge[1]-1)
            G[edge[1]-1].append(edge[0]-1)
    return G,n

'''create thresholds for each vertices'''
def init_Tau(G,n):
    thresholds = []
    for v in range(0,n):
        deg = len(G[v])
        thresh = np.random.randint(0,deg+2,size=1)[0]
        thresholds.append(thresh)
    return thresholds

'''Read thresholds from file'''
def read_Tau(G,file):
    thresholds = []
    with open(file) as f:
        for line in f:
            thresh_list = [int(i) for i in line.split()]
            assert len(thresh_list)==2 , "wrong file format"
            thresholds.append(thresh_list[1])
    assert len(thresholds)==len(G), "number of vertices do not match in threshold list and Graph list"
    return thresholds

'''Generate num_samples x n initial states
each state corresponds to one initial state
consisting of 0/1s of the n vertices'''
def gen_init_config_dataset(G,n,num_samples,P=0.5):
    S = np.random.rand(num_samples, n)
    S[S>P] = 1
    S[S!=1] = 0
    return S


'''G is a adjacency list where G[v] has the list of v neighbor
Tau is a threshold list where Tau[v] is the threshold of vertex v
creates the next state, given the current state s'''
def gen_next_config(G,n,s,Tau):
    next_state = []
    for v in range(0,n):
        mask = np.full(n, False)
        if len(G[v])>0:
            mask[np.array(G[v])] = True
        influence = np.sum(s, where=mask)
        next_state.append((1 if influence>=Tau[v] else 0)) # or s[v]
    return np.array(next_state)

'''Given a set of initial states, creates the set of next states
eta corresponds to the noise added, if eta==0 it will just be the
next state. eta should be < 1/3, fixed throughout single experiment'''
def gen_next_config_dataset(G,n,Tau,S,eta=0):
    S_prime = []
    for s in S:
        T = gen_next_config(G,n,s,Tau)
        flips = np.random.rand(n)
        assert flips.shape[0]==T.shape[0], "next state shape and flip shape do not match"
        T_noise = np.where(flips>=eta, T, 1-T)
        S_prime.append(T_noise)
    return np.vstack(S_prime)

'''helper function for algorithm A'''
def gen_mask(G,n):
    MASKS = []
    for v in range(0,len(G)):
        mask = np.full(n, False)
        if len(G[v])>0:
            mask[np.array(G[v])] = True
        MASKS.append(mask)
    return np.vstack(MASKS)
    
def gen_next_config_single_v(MASKS,v,s,tau_v):
    influence = np.sum(s, where=MASKS[v])
    return (1 if influence>=tau_v else 0) # or s[v]

'''More efficient version, so that we get info for multiple samples'''
def evaluate_tau(G,n,test_data,predicted_Tau,Actual_Tau):
    true_next_state = gen_next_config_dataset(G,n,Actual_Tau,test_data,eta=0.0)
    predicted_next_state = gen_next_config_dataset(G,n,predicted_Tau,test_data,eta=0.0)
    err = 0.0
    for y_true,y_pred in zip(true_next_state,predicted_next_state):
        if (y_true!=y_pred).any():
            err = err + 1
    return err/test_data.shape[0]

'''score function for algorithm 2'''
def score(state,v,MASKS):
    return np.sum(state, where=MASKS[v])

'''threshold learning for algorithm 2'''
def learn_thresh_old(G,n,live_lambda_counter,live_zero_counter,Q, MAX_DEGREE,eta):
    cur_tau_list = []
    tolerance = 5
    for v in range(0,n):
        thresh = 0
        deg = len(G[v])
        for tau in range(0,deg+2):
            if live_lambda_counter[v][tau]>=max(1,tolerance): # Make sure to change back!!!!
                if live_zero_counter[v][tau] > (live_lambda_counter[v][tau] - live_zero_counter[v][tau]):
                    thresh = tau+1
                elif live_zero_counter[v][tau] < (live_lambda_counter[v][tau] - live_zero_counter[v][tau]):
                    break
                else:
                    k = random.randint(0,1)
                    if k==0:
                        thresh = tau+1
                    else:
                        break
        cur_tau_list.append(thresh)
    return cur_tau_list

'''threshold learning for algorithm 2
tolerance on majority vs minority instead
of total score'''
def learn_thresh(G,n,live_lambda_counter,live_zero_counter,Q, MAX_DEGREE,eta):
    cur_tau_list = []
    tolerance = 1
    #THRESH_DIFF = 1.0-2.0*eta
    THRESH_DIFF = 0.15
    THRSH_DIFF = (0.1 if THRESH_DIFF<0 else (1 if THRESH_DIFF>1 else THRESH_DIFF))
    for v in range(0,n):
        thresh = 0
        deg = len(G[v])
        for tau in range(0,deg+2):
            tot_score = live_lambda_counter[v][tau]
            a_score = live_zero_counter[v][tau] 
            b_score = live_lambda_counter[v][tau] - live_zero_counter[v][tau]
            if tot_score==0:
                continue
            diff = abs(a_score-b_score)/tot_score
            if diff>=THRESH_DIFF and tot_score>tolerance:
                if a_score > b_score:
                    thresh = tau+1
                else:
                    break
        cur_tau_list.append(thresh)
    return cur_tau_list


def algorithm_B_efficient(init_state_data,next_state_data,G,n,test_data,Actual_Tau,N,D,BIAS,NOISE,eval_every=50,save_every=1000,check_vertex=-1):
    live_lambda_counter = [[0 for j in range(len(G[i])+2)] for i in range(len(G))]
    live_zero_counter = [[0 for j in range(len(G[i])+2)] for i in range(len(G))]
    train_data_seen = 0
    history = []
    MAX_DEGREE = max([len(G[i]) for i in range(0,len(G))])
    MASKS = gen_mask(G,n)
    for init_state,next_state in zip(init_state_data,next_state_data):
        #print('t_0 config: ',init_state,'t_1 config:',next_state,'index of sample:',train_data_seen+1)
        for v in range(0,n):
            s = int(np.sum(init_state, where=MASKS[v]))
            #if v==check_vertex:
            #    print(s)
            #print(s)
            #print('by looking at C we hit score',s,'for vertex',v)
            live_lambda_counter[v][s] = live_lambda_counter[v][s]+1
            if next_state[v]==0:
                live_zero_counter[v][s] = live_zero_counter[v][s]+1
                # print("{} : ({}, {})".format(v, s, live_zero_counter[v][s]))
        train_data_seen = train_data_seen+1
        if(train_data_seen%eval_every==0):
            print("evaluated",train_data_seen,'training samples',flush=True)
            cur_tau_list = learn_thresh(G,n,live_lambda_counter,live_zero_counter,train_data_seen,MAX_DEGREE,NOISE)
            #print("learned",cur_tau_list)
            #print("actual",Actual_Tau)
            err_p = evaluate_tau(G,n,test_data,cur_tau_list,Actual_Tau)
            history.append({'train-size':train_data_seen,'err_prob':err_p})
        if(train_data_seen%save_every==0):
            out_file_name = f'results-G-{N}-{D}-{BIAS}-{NOISE}-algorithm-B.csv'
            df = pd.DataFrame.from_dict(history)
            df['Avg_Deg'] = D
            df['Noise'] = NOISE
            df['Bias'] = BIAS
            df['N'] = N
            df.to_csv(OUTPUT_DIR+out_file_name,index=False)
    return history



#parameters
N = int(float(sys.argv[1]))
D = int(sys.argv[2])
graph_file_name = f'Gnp/G-{N}-deg-{D}.edges'
thresh_file_name = f'Gnp/G-{N}-deg-{D}.thresholds'
BIAS = 0.5#float(sys.argv[3])
NOISE = float(sys.argv[3])
#train_data_size = 20000
train_data_size = max(5000,N*2)

## read from file
graph_file = SYNTH_GRAP_PATH+graph_file_name
threshold_file = SYNTH_GRAP_PATH+thresh_file_name
G,n = read_graph(graph_file)
Actual_Tau = read_Tau(G,threshold_file)

## generate training data
train_size = train_data_size
train_data = gen_init_config_dataset(G,n,train_size,P=BIAS)
next_state_data = gen_next_config_dataset(G,n,Actual_Tau,train_data,eta=NOISE)

## gen test data
test_size = min(1000,train_size//5)
test_data = gen_init_config_dataset(G,n,test_size,P=BIAS)
#print(test_data.shape)

#run the algorithm
start = time.time()
history = algorithm_B_efficient(train_data,next_state_data,G,n,test_data,Actual_Tau,N,D,BIAS,NOISE,eval_every=50,check_vertex=-1)
end = time.time()

#write to output
out_file_name = f'results-G-{N}-{D}-{BIAS}-{NOISE}-algorithm-B.csv'
df = pd.DataFrame.from_dict(history)
df['Avg_Deg'] = D
df['Noise'] = NOISE
df['Bias'] = BIAS
df['N'] = N
df['time'] = round(end-start,2)
df.to_csv(OUTPUT_DIR+out_file_name,index=False)
print(N,D,BIAS,NOISE,end-start)
