"""
This script implements Granger Causality on the river runoff data and reports the results
"""

import networkx as nx
import numpy as np
from numpy.random import default_rng
rng = default_rng(seed=111)
from timeawarepc.tpc import cfc_tpc, cfc_pc
from timeawarepc.gc import cfc_gc
from timeawarepc.simulate_data import *
from timeawarepc.find_cfc import *
import time
import nitime.analysis as nta
import nitime.timeseries as ts


def get_confusion_matrix(true_graph,estimated_graph):
    num_nodes = true_graph.shape[0]
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 1:
                    TP += 1
                if true_graph[i,j] == 1 and estimated_graph[i,j] == 0:
                    FN += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 1:
                    FP += 1
                if true_graph[i,j] == 0 and estimated_graph[i,j] == 0:
                    TN += 1

    return TP, FP, TN, FN


data = np.loadtxt(r'.\river-runoff\river-runoff_N-12_T-4600\river-runoff_N-12_T-4600_0001.txt', delimiter=' ')


true_ess_graph_t = np.array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

true_ess_graph = true_ess_graph_t.T
CS_prev = 0
alpha_best = 0
alpha_arr = np.array([0.112])#np.arange(0.001,0.5,0.001)#
for alpha_num in alpha_arr:
    start_time = time.time()
    adjmat = cfc_gc(data, maxdelay=1, alpha=alpha_num)
    end_time = time.time()
    elapsed_time = end_time - start_time

    # np.save("adjmat_granger_causality.npy",adjmat)
    # print(adjmat)

    [TP, FP, TN, FN] = get_confusion_matrix(true_ess_graph, adjmat[0])
    TPR = TP/(TP+FN)
    FPR = FP/(TN+FP)
    CS = TPR - FPR

    print("--------------------------------")
    print(f"alpha = {alpha_num}")
    print(f"TPR = {TPR}")
    print(f"FPR = {FPR}")
    print(f"CS = {CS}")
    print(f"Total algorithm run time = {elapsed_time}")

    if CS > CS_prev:
        alpha_best = alpha_num
        CS_prev = CS

print(f"Best alpha = {alpha_best}")