import numpy as np
import time

from env_gridworld import Gridworld
from recursion_gridworld import Recursion
from cost_allocation_gridworld import Cost_Allocation

def main():

    print("Beginning of Experiment (Gridworld)")

    start = time.process_time()

    # Exp1 (performance monotonicity)
    # Shapley Value (= Banzhaf Index for 2 agents)
    ba_sv_tr_1 = np.zeros(shape = (11, 11))
    ba_sv_tr_2 = np.zeros(shape = (11, 11))
    # Average Participation
    ba_ap_tr_1 = np.zeros(shape = (11, 11))
    ba_ap_tr_2 = np.zeros(shape = (11, 11))
    # Marginal Contribution
    ba_mc_tr_1 = np.zeros(shape = (11, 11))
    ba_mc_tr_2 = np.zeros(shape = (11, 11))
    # Max-Efficient Rationality
    ba_mr_tr_1 = np.zeros(shape = (11, 11))
    ba_mr_tr_2 = np.zeros(shape = (11, 11))

    print("Entering Exp1 (total time elapsed ", time.process_time() - start, "sec)")
    
    for i in range(11):
        alpha = i * .1
        for j in range(11):
            alpha_pr = j * .1

            env = Gridworld(alpha=alpha, alpha_pr=alpha_pr)
            rec_solver = Recursion(env)
            ca_solver = Cost_Allocation(env)

            J_tr = {}
            J_tr["empty_set"] = env.expected_performance(rec_solver.recursion_1_a(env.q_1_tr, env.q_2_tr))
            J_tr["\{1\}"] = env.expected_performance(rec_solver.recursion_1_c_ag1(env.q_2_tr))
            J_tr["\{2\}"] = env.expected_performance(rec_solver.recursion_1_c_ag2(env.q_1_tr))
            J_tr["\{1,2\}"] =  env.expected_performance(rec_solver.recursion_1_b())

            ba_sv_tr_1[i, j], ba_sv_tr_2[i, j] = ca_solver.SV(J_tr)
            ba_ap_tr_1[i, j], ba_ap_tr_2[i, j] = ca_solver.AP_tr(J_tr, ba_sv_tr_1[i, j], ba_sv_tr_2[i, j])
            ba_mc_tr_1[i, j], ba_mc_tr_2[i, j] = ca_solver.MC(J_tr)
            # Constraints for MER
            C_1 = J_tr["\{1\}"] - J_tr["empty_set"]
            C_2 = J_tr["\{2\}"] - J_tr["empty_set"]
            C_12 = J_tr["\{1,2\}"] - J_tr["empty_set"]
            ba_mr_tr_1[i, j], ba_mr_tr_2[i, j] = ca_solver.MR(C_1, C_2, C_12)

    np.savetxt('gridworld/data/exp1/sv_tr_1.csv', ba_sv_tr_1, delimiter=',')
    np.savetxt('gridworld/data/exp1/sv_tr_2.csv', ba_sv_tr_2, delimiter=',')
    np.savetxt('gridworld/data/exp1/ap_tr_1.csv', ba_ap_tr_1, delimiter=',')
    np.savetxt('gridworld/data/exp1/ap_tr_2.csv', ba_ap_tr_2, delimiter=',')
    np.savetxt('gridworld/data/exp1/mc_tr_1.csv', ba_mc_tr_1, delimiter=',')
    np.savetxt('gridworld/data/exp1/mc_tr_2.csv', ba_mc_tr_2, delimiter=',')
    np.savetxt('gridworld/data/exp1/mr_tr_1.csv', ba_mr_tr_1, delimiter=',')
    np.savetxt('gridworld/data/exp1/mr_tr_2.csv', ba_mr_tr_2, delimiter=',')

    print("Exiting Exp1 (total time elapsed ", time.process_time() - start, "sec)")

    # Exp3 part a (different SV approaches)
    # Seeds
    num_seeds = 10
    assert num_seeds > 1, "number of seeds should be greater than 1"
    seeds = 100 * np.arange(num_seeds)
    # alpha and alpha'
    alpha = .2
    alpha_pr = .5
    # Estimation errors
    est_err_1_lst = [.05, .1, .15, .2]
    est_err_2 = 0
    # Execution time
    time_tr = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    time_es = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    time_va = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    time_bc = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    # SV approaches
    ba_sv_tr_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_tr_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_es_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_es_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_va_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_va_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_bc_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_bc_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))

    print("Entering Exp3a (total time elapsed ", time.process_time() - start, "sec)")

    # File that contains time data
    f = open('gridworld/time/time.txt', 'w')

    for i in range(len(seeds)):
        for j in range(len(est_err_1_lst)):

            print(
                "seed ", seeds[i], "maximum estimation error ", est_err_1_lst[j]
                , "(total time elapsed ", time.process_time() - start, "sec)"
                )

            np.random.seed(seeds[i])

            env = Gridworld(alpha=alpha, 
                            alpha_pr=alpha_pr, 
                            est_err_1=est_err_1_lst[j], 
                            est_err_2=est_err_2)
            rec_solver = Recursion(env)
            ca_solver = Cost_Allocation(env)

            # True
            start_tr = time.process_time()
            J_tr = {}
            J_tr["empty_set"] = env.expected_performance(rec_solver.recursion_1_a(env.q_1_tr, env.q_2_tr))
            J_tr["\{1\}"] = env.expected_performance(rec_solver.recursion_1_c_ag1(env.q_2_tr))
            J_tr["\{2\}"] = env.expected_performance(rec_solver.recursion_1_c_ag2(env.q_1_tr))
            J_tr["\{1,2\}"] =  env.expected_performance(rec_solver.recursion_1_b())
            ba_sv_tr_1[i, j], ba_sv_tr_2[i, j] = ca_solver.SV(J_tr)

            end_tr = time.process_time()
            time_tr[i, j] = end_tr - start_tr

            # Estimated
            start_es = time.process_time()
            J_es = {}
            J_es["empty_set"] = env.expected_performance(rec_solver.recursion_1_a(env.q_1_es, env.q_2_es))
            J_es["\{1\}"] = env.expected_performance(rec_solver.recursion_1_c_ag1(env.q_2_es))
            J_es["\{2\}"] = env.expected_performance(rec_solver.recursion_1_c_ag2(env.q_1_es))
            J_es["\{1,2\}"] =  env.expected_performance(rec_solver.recursion_1_b())
            ba_sv_es_1[i, j], ba_sv_es_2[i, j] = ca_solver.SV(J_es)
            
            end_es = time.process_time()
            time_es[i,j] = end_es - start_es

            # Valid
            start_va = time.process_time()
            _, q_1, q_2 = rec_solver.recursion_2_a()
            J_va = {}
            J_va["empty_set"] = env.expected_performance(rec_solver.recursion_1_a(q_1, q_2))
            J_va["\{1\}"] = env.expected_performance(rec_solver.recursion_1_c_ag1(q_2))
            J_va["\{2\}"] = env.expected_performance(rec_solver.recursion_1_c_ag2(q_1))
            J_va["\{1,2\}"] =  env.expected_performance(rec_solver.recursion_1_b())
            ba_sv_va_1[i, j], ba_sv_va_2[i, j] = ca_solver.SV(J_va)

            end_va = time.process_time()
            time_va[i, j] = end_va - start_va

            # Blackstone consistent
            start_bc = time.process_time()
            # Agent A1
            J_0 = env.expected_performance(rec_solver.recursion_1_a(q_1, q_2))
            J_1 = max(J_0, env.expected_performance(rec_solver.recursion_3_a_ag1()))
            J_12 = env.expected_performance(rec_solver.recursion_1_b())
            J_2 = min(J_12, env.expected_performance(rec_solver.recursion_3_b_ag2()))

            J_bc1 = {}
            J_bc1["empty_set"] = J_0
            J_bc1["\{1\}"] = J_1
            J_bc1["\{2\}"] = J_2
            J_bc1["\{1,2\}"] = J_12
            ba_sv_bc_1[i, j], _ = ca_solver.SV(J_bc1)

            # Agent A2
            J_0 = env.expected_performance(rec_solver.recursion_1_a(q_1, q_2))
            J_2 = max(J_0, env.expected_performance(rec_solver.recursion_3_a_ag2()))
            J_12 = env.expected_performance(rec_solver.recursion_1_b())
            J_1 = min(J_12, env.expected_performance(rec_solver.recursion_3_b_ag1()))

            J_bc2 = {}
            J_bc2["empty_set"] = J_0
            J_bc2["\{1\}"] = J_1
            J_bc2["\{2\}"] = J_2
            J_bc2["\{1,2\}"] =  J_12
            _, ba_sv_bc_2[i, j] = ca_solver.SV(J_bc2)

            end_bc = time.process_time()
            time_bc[i,j] = end_bc - start_bc

    np.savetxt('gridworld/data/exp3a/sv_tr_1.csv', ba_sv_tr_1, delimiter=',')
    np.savetxt('gridworld/data/exp3a/sv_tr_2.csv', ba_sv_tr_2, delimiter=',')
    np.savetxt('gridworld/data/exp3a/sv_es_1.csv', ba_sv_es_1, delimiter=',')
    np.savetxt('gridworld/data/exp3a/sv_es_2.csv', ba_sv_es_2, delimiter=',')
    np.savetxt('gridworld/data/exp3a/sv_va_1.csv', ba_sv_va_1, delimiter=',')
    np.savetxt('gridworld/data/exp3a/sv_va_2.csv', ba_sv_va_2, delimiter=',')
    np.savetxt('gridworld/data/exp3a/sv_bc_1.csv', ba_sv_bc_1, delimiter=',')
    np.savetxt('gridworld/data/exp3a/sv_bc_2.csv', ba_sv_bc_2, delimiter=',')

    for j in range(len(est_err_1_lst)):
        f.write(
            'targeted value, e_max = ' + str(est_err_1_lst[j]) + ': ' 
             + 'mean=' + str(np.mean(time_tr[:, j])) + ' std=' + str(np.std(time_tr[:, j])) + '\n'
             )
        f.write(
            'point estimate, e_max = ' + str(est_err_1_lst[j]) + ': ' 
             + 'mean=' + str(np.mean(time_es[:, j])) + ' std=' + str(np.std(time_es[:, j])) + '\n'
             )
        f.write(
            'valid, e_max = ' + str(est_err_1_lst[j]) + ': ' 
             + 'mean=' + str(np.mean(time_va[:, j])) + ' std=' + str(np.std(time_va[:, j])) + '\n'
             )
        f.write(
            'consistent, e_max = ' + str(est_err_1_lst[j]) + ': ' 
             + 'mean=' + str(np.mean(time_bc[:, j])) + ' std=' + str(np.std(time_bc[:, j])) + '\n'
             )

    f.close() 
    print("Exiting Exp3a (total time elapsed ", time.process_time() - start, "sec)")


    # Exp3 part b and c (Distance and Total Blame)
    # Seeds
    num_seeds = 10
    assert num_seeds > 1, "number of seeds should be greater than 1"
    seeds = 100 * np.arange(num_seeds)
    # alpha and alpha'
    alpha = .2
    alpha_pr = .5
    # Estimation errors, always begin zero
    est_err_1_lst = [0, .05, .1, .15, .2, .25, .3, .35, .4]
    assert est_err_1_lst[0] == 0, "begin estimation errors with 0"
    est_err_2 = 0
    
    ba_sv_tr_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_tr_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_bc_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_sv_bc_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_ap_tr_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_ap_tr_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_ap_bc_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_ap_bc_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_mc_tr_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_mc_tr_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_mc_bc_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_mc_bc_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_mr_tr_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_mr_tr_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_mr_bc_1 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    ba_mr_bc_2 = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    # L1 distance
    dist_sv = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    dist_ap = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    dist_mc = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    dist_mr = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    # Total blame
    tot_ba_sv = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    tot_ba_ap = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    tot_ba_mc = np.zeros(shape = (len(seeds), len(est_err_1_lst)))
    tot_ba_mr = np.zeros(shape = (len(seeds), len(est_err_1_lst)))

    print("Entering Exp3b and Exp3c (total time elapsed ", time.process_time() - start, "sec)")

    for i in range(len(seeds)):
        for j in range(len(est_err_1_lst)):

            print(
                "seed ", seeds[i], "maximum estimation error ", est_err_1_lst[j]
                , "(total time elapsed ", time.process_time() - start, "sec)"
                )

            np.random.seed(seeds[i])

            env = Gridworld(alpha=alpha, 
                            alpha_pr=alpha_pr, 
                            est_err_1=est_err_1_lst[j], 
                            est_err_2=est_err_2)
            rec_solver = Recursion(env)
            ca_solver = Cost_Allocation(env)

            # Shapley Value (= Banzhaf Index)
            J_tr = {}
            J_tr["empty_set"] = env.expected_performance(rec_solver.recursion_1_a(env.q_1_tr, env.q_2_tr))
            J_tr["\{1\}"] = env.expected_performance(rec_solver.recursion_1_c_ag1(env.q_2_tr))
            J_tr["\{2\}"] = env.expected_performance(rec_solver.recursion_1_c_ag2(env.q_1_tr))
            J_tr["\{1,2\}"] =  env.expected_performance(rec_solver.recursion_1_b())
            ba_sv_tr_1[i, j], ba_sv_tr_2[i, j] = ca_solver.SV(J_tr)

            # Agent A1
            _, q_1, q_2 = rec_solver.recursion_2_a()
            J_0 = env.expected_performance(rec_solver.recursion_1_a(q_1, q_2))
            J_1 = max(J_0, env.expected_performance(rec_solver.recursion_3_a_ag1()))
            J_12 = env.expected_performance(rec_solver.recursion_1_b())
            J_2 = min(J_12, env.expected_performance(rec_solver.recursion_3_b_ag2()))

            J_bc1 = {}
            J_bc1["empty_set"] = J_0
            J_bc1["\{1\}"] = J_1
            J_bc1["\{2\}"] = J_2
            J_bc1["\{1,2\}"] = J_12
            ba_sv_bc_1[i, j], _ = ca_solver.SV(J_bc1)

            # Agent A2
            J_0 = env.expected_performance(rec_solver.recursion_1_a(q_1, q_2))
            J_2 = max(J_0, env.expected_performance(rec_solver.recursion_3_a_ag2()))
            J_12 = env.expected_performance(rec_solver.recursion_1_b())
            J_1 = min(J_12, env.expected_performance(rec_solver.recursion_3_b_ag1()))

            J_bc2 = {}
            J_bc2["empty_set"] = J_0
            J_bc2["\{1\}"] = J_1
            J_bc2["\{2\}"] = J_2
            J_bc2["\{1,2\}"] =  J_12
            _, ba_sv_bc_2[i, j] = ca_solver.SV(J_bc2)
            dist_sv[i, j] = abs(ba_sv_tr_1[i, j] - ba_sv_bc_1[i, j]) + abs(ba_sv_tr_2[i, j] - ba_sv_bc_2[i, j])
            tot_ba_sv[i, j] = ba_sv_bc_1[i, j] + ba_sv_bc_2[i, j]

            # Average Participation
            ba_ap_tr_1[i, j], ba_ap_tr_2[i, j] = ca_solver.AP_tr(J_tr, ba_sv_tr_1[i, j], ba_sv_tr_2[i, j])
            ba_ap_bc_1[i, j], _ = ca_solver.AP_bc(J_bc1, ba_sv_bc_1[i, j], ba_sv_bc_2[i, j])
            _, ba_ap_bc_2[i, j] = ca_solver.AP_bc(J_bc2, ba_sv_bc_1[i, j], ba_sv_bc_2[i, j])
            dist_ap[i, j] = abs(ba_ap_tr_1[i, j] - ba_ap_bc_1[i, j]) + abs(ba_ap_tr_2[i, j] - ba_ap_bc_2[i, j])
            tot_ba_ap[i, j] = ba_ap_bc_1[i, j] + ba_ap_bc_2[i, j]

            # Marginal Contribution
            ba_mc_tr_1[i, j], ba_mc_tr_2[i, j] = ca_solver.MC(J_tr)
            ba_mc_bc_1[i, j], _ = ca_solver.MC(J_bc1)
            _, ba_mc_bc_2[i, j] = ca_solver.MC(J_bc2)
            dist_mc[i, j] = abs(ba_mc_tr_1[i, j] - ba_mc_bc_1[i, j]) + abs(ba_mc_tr_2[i, j] - ba_mc_bc_2[i, j])
            tot_ba_mc[i, j] = ba_mc_bc_1[i, j] + ba_mc_bc_2[i, j]

            # Max-Efficiency Rationality
            C_1 = J_tr["\{1\}"] - J_tr["empty_set"]
            C_2 = J_tr["\{2\}"] - J_tr["empty_set"]
            C_12 = J_tr["\{1,2\}"] - J_tr["empty_set"]
            ba_mr_tr_1[i, j], ba_mr_tr_2[i, j] = ca_solver.MR(C_1, C_2, C_12)
            C_1 = J_bc1["\{1\}"] - J_bc1["empty_set"]
            C_2 = J_bc2["\{2\}"] - J_bc2["empty_set"]
            C_12 = J_bc1["\{1,2\}"] - J_bc1["empty_set"]
            ba_mr_bc_1[i, j], ba_mr_bc_2[i, j] = ca_solver.MR(C_1, C_2, C_12)
            tot_ba_mr[i, j] = ba_mr_bc_1[i, j] + ba_mr_bc_2[i, j]
            dist_mr[i, j] = tot_ba_mr[0, 0] - tot_ba_mr[i, j]

    np.savetxt('gridworld/data/exp3b/sv.csv', dist_sv, delimiter=',')
    np.savetxt('gridworld/data/exp3b/ap.csv', dist_ap, delimiter=',')
    np.savetxt('gridworld/data/exp3b/mc.csv', dist_mc, delimiter=',')
    np.savetxt('gridworld/data/exp3b/mr.csv', dist_mr, delimiter=',')

    np.savetxt('gridworld/data/exp3c/sv.csv', tot_ba_sv, delimiter=',')
    np.savetxt('gridworld/data/exp3c/ap.csv', tot_ba_ap, delimiter=',')
    np.savetxt('gridworld/data/exp3c/mc.csv', tot_ba_mc, delimiter=',')
    np.savetxt('gridworld/data/exp3c/mr.csv', tot_ba_mr, delimiter=',')

    print("Exiting Exp3b and Exp3c (total time elapsed ", time.process_time() - start, "sec)")
    print("End of experiment: total time elapsed ", time.process_time() - start)

if __name__ == '__main__':
    # Local:
    # python example.py
    main()