import numpy as np
import random
from file_paths_and_consts import *
import sys
import pandas as pd
import json
import os
import time
# to make reproducable results
random.seed(56)
np.random.seed(56)

def append_record(record):
    with open('results-v2.json', 'a') as f:
        json.dump(record, f)
        f.write(os.linesep)

'''We consider vertex index starts from 0 
first read the graph file and find out the number of vertices'''
def get_vertice_num(graph_file):
    mx = -1
    with open(graph_file) as f:
        for line in f:
            v_list = [int(i) for i in line.split()]
            assert len(v_list)==2 , "wrong file format"
            mx = max(mx, max(v_list[0],v_list[1]))
    return mx

'''Read Graph and create adjacency list'''
def read_graph(graph_file):
    n = get_vertice_num(graph_file)
    G = []
    for i in range(n):
        G.append([])
    with open(graph_file) as f:
        for line in f:
            edge = [int(i) for i in line.split()]
            #assert len(v_list)==2 , "wrong file format" -- already checked in previous call
            #print(edge[0],edge[1])
            if edge[0]==edge[1]:
                print('self loop')
            G[edge[0]-1].append(edge[1]-1)
            G[edge[1]-1].append(edge[0]-1)
    return G,n

'''create thresholds for each vertices'''
def init_Tau(G,n):
    thresholds = []
    for v in range(0,n):
        deg = len(G[v])
        thresh = np.random.randint(0,deg+2,size=1)[0]
        thresholds.append(thresh)
    return thresholds

'''Read thresholds from file'''
def read_Tau(G,file):
    thresholds = []
    with open(file) as f:
        for line in f:
            thresh_list = [int(i) for i in line.split()]
            assert len(thresh_list)==2 , "wrong file format"
            thresholds.append(thresh_list[1])
    assert len(thresholds)==len(G), "number of vertices do not match in threshold list and Graph list"
    return thresholds

'''Generate num_samples x n initial states
each state corresponds to one initial state
consisting of 0/1s of the n vertices'''
def gen_init_config_dataset(G,n,num_samples,P=0.5):
    S = np.random.rand(num_samples, n)
    S[S>P] = 1
    S[S!=1] = 0
    return S

'''G is a adjacency list where G[v] has the list of v neighbor
Tau is a threshold list where Tau[v] is the threshold of vertex v
creates the next state, given the current state s'''
def gen_next_config(G,n,s,Tau):
    next_state = []
    for v in range(0,n):
        mask = np.full(n, False)
        if len(G[v])>0:
            mask[np.array(G[v])] = True
        influence = np.sum(s, where=mask)
        next_state.append((1 if influence>=Tau[v] else 0)) # or s[v]
    return np.array(next_state)

'''Given a set of initial states, creates the set of next states
eta corresponds to the noise added, if eta==0 it will just be the
next state. eta should be < 1/3, fixed throughout single experiment'''
def gen_next_config_dataset(G,n,Tau,S,eta=0):
    S_prime = []
    for s in S:
        T = gen_next_config(G,n,s,Tau)
        flips = np.random.rand(n)
        assert flips.shape[0]==T.shape[0], "next state shape and flip shape do not match"
        T_noise = np.where(flips>=eta, T, 1-T)
        S_prime.append(T_noise)
    return np.vstack(S_prime)

'''helper function for algorithm A'''
def gen_mask(G,n):
    MASKS = []
    for v in range(0,len(G)):
        mask = np.full(n, False)
        if len(G[v])>0:
            mask[np.array(G[v])] = True
        MASKS.append(mask)
    return np.vstack(MASKS)
    
def gen_next_config_single_v(MASKS,v,s,tau_v):
    influence = np.sum(s, where=MASKS[v])
    return (1 if influence>=tau_v else 0) # or s[v]

'''More efficient version, so that we get info for multiple samples'''
def evaluate_tau(G,n,test_data,predicted_Tau,Actual_Tau):
    true_next_state = gen_next_config_dataset(G,n,Actual_Tau,test_data,eta=0.0)
    predicted_next_state = gen_next_config_dataset(G,n,predicted_Tau,test_data,eta=0.0)
    err = 0.0
    for y_true,y_pred in zip(true_next_state,predicted_next_state):
        if (y_true!=y_pred).any():
            err = err + 1
    # S = np.random.rand(1, n)
    # S[S>0.5] = 1
    # S[S!=1] = 0
    # print(S)
    return err/test_data.shape[0]
    

'''score function for algorithm 2'''
def score(state,v,MASKS):
    return np.sum(state, where=MASKS[v])

'''threshold learning for algorithm 3'''
def learn_thresh(G,n,live_lambda_counter,live_zero_counter,Q, MAX_DEGREE,eta):
    cur_tau_list = []
    tolerance = max(1,(0.1*Q)/(2*n)) #epsilon = 0.1, taken from paper
    THRESH_DIFF = 0.15
    THRSH_DIFF = (0.1 if THRESH_DIFF<0 else (1 if THRESH_DIFF>1 else THRESH_DIFF))
    for v in range(0,n):
        
        deg = len(G[v])
        thresh_range_tot_score = [[0 for j in range(deg+2)] for i in range(deg+2)]
        thresh_range_a_score = [[0 for j in range(deg+2)] for i in range(deg+2)]
        thresh_range_b_score = [[0 for j in range(deg+2)] for i in range(deg+2)]
        
        for s1 in range(0,deg+2):
            for s2 in range(s1,deg+2):
                for s in range(s1,s2+1):
                    thresh_range_tot_score[s1][s2] = thresh_range_tot_score[s1][s2]+live_lambda_counter[v][s]
                    thresh_range_a_score[s1][s2] = thresh_range_a_score[s1][s2]+live_zero_counter[v][s]
                    thresh_range_b_score[s1][s2] = thresh_range_b_score[s1][s2]+(live_lambda_counter[v][s]-live_zero_counter[v][s])
        
        thresh = 0
        for s1 in range(0,deg+2):
            for s2 in range(s1,deg+2):
                #if thresh_range_tot_score[s1][s2]>=tolerance:
                if thresh_range_tot_score[s1][s2]==0:
                    continue
                if (abs(thresh_range_a_score[s1][s2]-thresh_range_b_score[s1][s2])/thresh_range_tot_score[s1][s2])>=THRESH_DIFF and (thresh_range_tot_score[s1][s2]>1):
                    if thresh_range_a_score[s1][s2]>thresh_range_b_score[s1][s2]:
                        thresh = max(thresh,s1+1) ## recheck this line with Zirou or Ravi
                    # elif thresh_range_a_score[s1][s2]==thresh_range_b_score[s1][s2]:
                    #     k = random.randint(0,1)
                    #     if k==0:
                    #         thresh = max(thresh,s1+1) ## recheck this line with Zirou or Ravi
        
        cur_tau_list.append(thresh)
        
    return cur_tau_list

def algorithm_C(init_state_data,next_state_data,G,n,test_data,Actual_Tau,N,D,BIAS,NOISE,eval_every=50,save_every=1000,check_vertex=-1):
    l_score = [[0 for j in range(len(G[i])+2)] for i in range(len(G))]
    a_score = [[0 for j in range(len(G[i])+2)] for i in range(len(G))]
    b_score = [[0 for j in range(len(G[i])+2)] for i in range(len(G))]
    train_data_seen = 0
    history = []
    MAX_DEGREE = max([len(G[i]) for i in range(0,len(G))])
    MASKS = gen_mask(G,n)
    
    for init_state,next_state in zip(init_state_data,next_state_data):
        for v in range(0,n):
            s = int(np.sum(init_state, where=MASKS[v]))
            l_score[v][s] = l_score[v][s]+1
            if next_state[v]==0:
                a_score[v][s] = a_score[v][s]+1
            else:
                b_score[v][s] = b_score[v][s]+1
        train_data_seen = train_data_seen+1
        if(train_data_seen%eval_every==0):
            print("evaluated",train_data_seen,'training samples',flush=True)
            cur_tau_list = learn_thresh(G,n,l_score,a_score,train_data_seen,MAX_DEGREE,NOISE)
            err_p = evaluate_tau(G,n,test_data,cur_tau_list,Actual_Tau)
            history.append({'train-size':train_data_seen,'err_prob':err_p})
        if(train_data_seen%save_every==0):
            out_file_name = f'results-G-{N}-{D}-{BIAS}-{NOISE}-algorithm-C.csv'
            df = pd.DataFrame.from_dict(history)
            df['Avg_Deg'] = D
            df['Noise'] = NOISE
            df['Bias'] = BIAS
            df['N'] = N
            df.to_csv(OUTPUT_DIR+out_file_name,index=False)
    return history,learn_thresh(G,n,l_score,a_score,train_data_seen,MAX_DEGREE,NOISE)

#parameters
N = int(sys.argv[1])
D = int(sys.argv[2])
graph_file_name = f'Gnp/G-{N}-deg-{D}.edges'
thresh_file_name = f'Gnp/G-{N}-deg-{D}.thresholds'
BIAS = 0.5#float(sys.argv[3])
NOISE = float(sys.argv[3])
train_data_size = max(5000,N*2)
#train_data_size = 20000

## read from file
graph_file = SYNTH_GRAP_PATH+graph_file_name
threshold_file = SYNTH_GRAP_PATH+thresh_file_name
G,n = read_graph(graph_file)
Actual_Tau = read_Tau(G,threshold_file)

## generate training data
train_size = train_data_size
train_data = gen_init_config_dataset(G,n,train_size,P=BIAS)
next_state_data = gen_next_config_dataset(G,n,Actual_Tau,train_data,eta=NOISE)

## gen test data
test_size = min(1000,train_size//5)
test_data = gen_init_config_dataset(G,n,test_size,P=BIAS)
#print(test_data.shape)

## learning
start = time.time()
history,cur_tau_list = algorithm_C(train_data,next_state_data,G,n,test_data,Actual_Tau,N,D,BIAS,NOISE,eval_every=50,check_vertex=-1)
end = time.time()

#write to output
out_file_name = f'results-G-{N}-{D}-{BIAS}-{NOISE}-algorithm-C.csv'
df = pd.DataFrame.from_dict(history)
df['Avg_Deg'] = D
df['Noise'] = NOISE
df['Bias'] = BIAS
df['N'] = N
df['time'] = round(end-start,2)
df.to_csv(OUTPUT_DIR+out_file_name,index=False)
print(N,D,BIAS,NOISE,end-start)
