import numpy as np
import time

from env_graph import Graph
from recursion_graph import Recursion
from cost_allocation_graph import Cost_Allocation

def main():

    print("Beginning of Experiment (Graph)")

    start = time.process_time()

    # Exp2 (coordination)
    # Number of agents
    n = 4
    # Thresholds
    cases = [1, 7, 9, 10]

    # Shapley Value
    ba_sv = np.zeros(shape=(len(cases), n))
    # Average Participation
    ba_ap = np.zeros(shape=(len(cases), n))
    # Banzhaf Index
    ba_bi = np.zeros(shape=(len(cases), n))
    # Marginal Contribution
    ba_mc = np.zeros(shape=(len(cases), n))
    # Max-Efficient Rationality
    ba_mr = np.zeros(shape=(len(cases), n))

    print("Entering Exp2 (total time elapsed ", time.process_time() - start, "sec)")
    
    # Set of agents
    N = [i for i in range(n)]
    
    for c in range(len(cases)):
        env = Graph(n_agents=n, version=cases[c])
        rec_solver = Recursion(env)
        ca_solver = Cost_Allocation(env)
        S = ca_solver.powerset(N)
        J_tr = {}
        for W in S:
            W_pr = [i for i in N if i not in W]
            J_tr[str(W)] = env.expected_performance(rec_solver.recursion_1_c(W_pr, env.jq_tr))

        ba_sv[c] = ca_solver.SV(J_tr, J_tr)
        ba_ap[c] = ca_solver.AP(J_tr, J_tr, ba_sv[c])
        ba_bi[c] = ca_solver.BI(J_tr, J_tr)
        ba_mc[c] = ca_solver.MC(J_tr, J_tr)
        ba_mr[c] = ca_solver.MR(J_tr, J_tr)

    np.savetxt('graph/data/exp2/sv.csv', ba_sv, delimiter=',')
    np.savetxt('graph/data/exp2/ap.csv', ba_ap, delimiter=',')
    np.savetxt('graph/data/exp2/bi.csv', ba_bi, delimiter=',')
    np.savetxt('graph/data/exp2/mc.csv', ba_mc, delimiter=',')
    np.savetxt('graph/data/exp2/mr.csv', ba_mr, delimiter=',')

    print("Exiting Exp2 (total time elapsed ", time.process_time() - start, "sec)")

    # Exp3 part a (different SV approaches)
    # Number of agents
    n = 4
    # Seeds
    num_seeds = 10
    assert num_seeds > 1, "number of seeds should be greater than 1"
    seeds = 100 * np.arange(num_seeds)
    # Estimation errors
    est_err_lst = [.01, .05, .1]
    # SV approaches
    ba_sv_tr = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    ba_sv_es = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    ba_sv_va = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    ba_sv_bc = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    # Execution time
    time_tr = np.zeros(shape = (len(seeds), len(est_err_lst)))
    time_es = np.zeros(shape = (len(seeds), len(est_err_lst)))
    time_va = np.zeros(shape = (len(seeds), len(est_err_lst)))
    time_bc = np.zeros(shape = (len(seeds), len(est_err_lst)))

    print("Entering Exp3a (total time elapsed ", time.process_time() - start, "sec)")

    # File that contains time data
    f = open('graph/time/time.txt', 'w')

    for i in range(len(seeds)): 
        for j in range(len(est_err_lst)):

            print(
                "seed ", seeds[i], "maximum estimation error ", est_err_lst[j]
                , "(total time elapsed ", time.process_time() - start, "sec)"
                )

            np.random.seed(seeds[i])

            env = Graph(n_agents=n, est_err=est_err_lst[j])
            rec_solver = Recursion(env)
            ca_solver = Cost_Allocation(env)

            N = [i for i in range(env.n_agents)]

            # True
            start_tr = time.process_time()
            S = ca_solver.powerset(N)
            J_tr = {}
            for W in S:
                W_pr = [i for i in N if i not in W]
                J_tr[str(W)] = env.expected_performance(rec_solver.recursion_1_c(W_pr, env.jq_tr))
            ba_sv_tr[i, j] = ca_solver.SV(J_tr, J_tr)

            end_tr = time.process_time()
            time_tr[i, j] = end_tr - start_tr

            # Estimated
            start_es = time.process_time()
            S = ca_solver.powerset(N)
            J_es = {}
            for W in S:
                W_pr = [i for i in N if i not in W]
                J_es[str(W)] = env.expected_performance(rec_solver.recursion_1_c(W_pr, env.jq_es))
            ba_sv_es[i, j] = ca_solver.SV(J_es, J_es)

            end_es = time.process_time()
            time_es[i,j] = end_es - start_es

            # Valid
            start_va = time.process_time()
            S = ca_solver.powerset(N)
            _, jq_va = rec_solver.recursion_2_a()
            J_va = {}
            for W in S:
                W_pr = [i for i in N if i not in W]
                J_va[str(W)] = env.expected_performance(rec_solver.recursion_1_c(W_pr, jq_va))
            ba_sv_va[i, j] = ca_solver.SV(J_va, J_va)

            end_va = time.process_time()
            time_va[i, j] = end_va - start_va

            # Backstone consistent
            start_bc = time.process_time()
            S = ca_solver.powerset(N)
            J_L = {}
            J_U = {}
            for W in S:
                W_pr = [i for i in N if i not in W]
                J_L[str(W)] = env.expected_performance(rec_solver.recursion_3_a(W_pr))
                J_U[str(W)] = env.expected_performance(rec_solver.recursion_3_b(W_pr))
            ba_sv_bc[i, j] = ca_solver.SV(J_L, J_U)

            end_bc = time.process_time()
            time_bc[i,j] = end_bc - start_bc

    # Agent 0
    np.savetxt('graph/data/exp3a/sv_tr.csv', ba_sv_tr[:, :, 0], delimiter=',')
    np.savetxt('graph/data/exp3a/sv_es.csv', ba_sv_es[:, :, 0], delimiter=',')
    np.savetxt('graph/data/exp3a/sv_va.csv', ba_sv_va[:, :, 0], delimiter=',')
    np.savetxt('graph/data/exp3a/sv_bc.csv', ba_sv_bc[:, :, 0], delimiter=',')

    for j in range(len(est_err_lst)):
        f.write(
            'targeted value, e_max = ' + str(est_err_lst[j]) + ': ' 
             + 'mean=' + str(np.mean(time_tr[:, j])) + ' std=' + str(np.std(time_tr[:, j])) + '\n'
             )
        f.write(
            'point estimate, e_max = ' + str(est_err_lst[j]) + ': ' 
             + 'mean=' + str(np.mean(time_es[:, j])) + ' std=' + str(np.std(time_es[:, j])) + '\n'
             )
        f.write(
            'valid, e_max = ' + str(est_err_lst[j]) + ': ' 
             + 'mean=' + str(np.mean(time_va[:, j])) + ' std=' + str(np.std(time_va[:, j])) + '\n'
             )
        f.write(
            'consistent, e_max = ' + str(est_err_lst[j]) + ': ' 
             + 'mean=' + str(np.mean(time_bc[:, j])) + ' std=' + str(np.std(time_bc[:, j])) + '\n'
             )

    f.close() 
    print("Exiting Exp3a (total time elapsed ", time.process_time() - start, "sec)")

    # Exp3 part b and c (Distance and Total Blame)
    # Number of agents
    n = 4
    # Seeds
    num_seeds = 10
    assert num_seeds > 1, "number of seeds should be greater than 1"
    seeds = 100 * np.arange(num_seeds)
    # Estimation errors, always begin with zero
    est_err_lst = [0, .025, .05, .075, .1, .125, .15, .175, .2]
    assert est_err_lst[0] == 0, "begin estimation errors with 0"
    ba_sv_bc = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    ba_ap_bc = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    ba_bi_bc = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    ba_mc_bc = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    ba_mr_bc = np.zeros(shape=(len(seeds), len(est_err_lst), n))
    # L1 distance
    dist_sv = np.zeros(shape = (len(seeds), len(est_err_lst)))
    dist_ap = np.zeros(shape = (len(seeds), len(est_err_lst)))
    dist_bi = np.zeros(shape = (len(seeds), len(est_err_lst)))
    dist_mc = np.zeros(shape = (len(seeds), len(est_err_lst)))
    dist_mr = np.zeros(shape = (len(seeds), len(est_err_lst)))
    # Total blame
    tot_ba_sv = np.zeros(shape = (len(seeds), len(est_err_lst)))
    tot_ba_ap = np.zeros(shape = (len(seeds), len(est_err_lst)))
    tot_ba_bi = np.zeros(shape = (len(seeds), len(est_err_lst)))
    tot_ba_mc = np.zeros(shape = (len(seeds), len(est_err_lst)))
    tot_ba_mr = np.zeros(shape = (len(seeds), len(est_err_lst)))

    print("Entering Exp3b and Exp3c (total time elapsed ", time.process_time() - start, "sec)")

    for i in range(len(seeds)): 
        for j in range(len(est_err_lst)):

            print(
                "seed ", seeds[i], "maximum estimation error ", est_err_lst[j]
                , "(total time elapsed ", time.process_time() - start, "sec)"
                )

            np.random.seed(seeds[i])

            env = Graph(n_agents=n, est_err=est_err_lst[j])
            rec_solver = Recursion(env)
            ca_solver = Cost_Allocation(env)

            N = [i for i in range(env.n_agents)]

            S = ca_solver.powerset(N)
            J_L = {}
            J_U = {}
            for W in S:
                W_pr = [i for i in N if i not in W]
                J_L[str(W)] = env.expected_performance(rec_solver.recursion_3_a(W_pr))
                J_U[str(W)] = env.expected_performance(rec_solver.recursion_3_b(W_pr))
            ba_sv_bc[i, j] = ca_solver.SV(J_L, J_U)
            ba_ap_bc[i, j] = ca_solver.AP(J_L, J_U, ba_sv_bc[i, j])
            ba_bi_bc[i, j] = ca_solver.BI(J_L, J_U)
            ba_mc_bc[i, j] = ca_solver.MC(J_L, J_U)
            ba_mr_bc[i, j] = ca_solver.MR(J_L, J_U)

            tot_ba_sv[i, j] = np.sum([ba_sv_bc[i, j, a] for a in range(n)])
            tot_ba_ap[i, j] = np.sum([ba_ap_bc[i, j, a] for a in range(n)])
            tot_ba_bi[i, j] = np.sum([ba_bi_bc[i, j, a] for a in range(n)])
            tot_ba_mc[i, j] = np.sum([ba_mc_bc[i, j, a] for a in range(n)])
            tot_ba_mr[i, j] = np.sum([ba_mr_bc[i, j, a] for a in range(n)])

            dist_sv[i, j] = np.sum([abs(ba_sv_bc[0, 0, a] - ba_sv_bc[i, j, a]) 
                                    for a in range(n)])
            dist_ap[i, j] = np.sum([abs(ba_ap_bc[0, 0, a] - ba_ap_bc[i, j, a]) 
                                    for a in range(n)])
            dist_bi[i, j] = np.sum([abs(ba_bi_bc[0, 0, a] - ba_bi_bc[i, j, a]) 
                                    for a in range(n)])
            dist_mc[i, j] = np.sum([abs(ba_mc_bc[0, 0, a] - ba_mc_bc[i, j, a]) 
                                    for a in range(n)])
            dist_mr[i, j] = tot_ba_mr[0, 0] - tot_ba_mr[i, j]

    np.savetxt('graph/data/exp3b/sv.csv', dist_sv, delimiter=',')
    np.savetxt('graph/data/exp3b/ap.csv', dist_ap, delimiter=',')
    np.savetxt('graph/data/exp3b/bi.csv', dist_bi, delimiter=',')
    np.savetxt('graph/data/exp3b/mc.csv', dist_mc, delimiter=',')
    np.savetxt('graph/data/exp3b/mr.csv', dist_mr, delimiter=',')

    np.savetxt('graph/data/exp3c/sv.csv', tot_ba_sv, delimiter=',')
    np.savetxt('graph/data/exp3c/ap.csv', tot_ba_ap, delimiter=',')
    np.savetxt('graph/data/exp3c/bi.csv', tot_ba_bi, delimiter=',')
    np.savetxt('graph/data/exp3c/mc.csv', tot_ba_mc, delimiter=',')
    np.savetxt('graph/data/exp3c/mr.csv', tot_ba_mr, delimiter=',')

    print("Exiting Exp3b and Exp3c (total time elapsed ", time.process_time() - start, "sec)")
    print("End of experiment: total time elapsed ", time.process_time() - start)


if __name__ == '__main__':
    # Local:
    # python example.py
    main()