# -*- coding: utf-8 -*-

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from scipy.stats import beta
import copy

# initialize the beliefs about the strategy H
# the default initial probability distributions are Beta distributions:
# Beta(alpha_row, beta_row) for the first population, and 
# Beta(alpha_col, beta_col) for the second population.
alpha_row = 140*2
beta_row = 60*2
alpha_col = alpha_row
beta_col = beta_row

# parameters for simulations
num_agents = 1000
time_steps = 200
num_sim = 100
sigma = 10 # the initial sum of weights, which is lambda in the main paper
beta = 10
A = -3 
B = 2 
bimodal = False


def cal_pol(belief):
    policy = 1 / (1 + np.exp(A * beta * belief + B * beta))
    return policy

def beta_Stats(alpha, beta):
    mean = alpha / (alpha + beta)
    var = alpha * beta / ((alpha + beta)**2 * (alpha + beta + 1))
    return [mean, var]

def agent_based_simulations():
    beliefs_row = np.zeros((num_agents, time_steps))
    beliefs_col = np.zeros((num_agents, time_steps))
    mean_row = np.zeros((1, time_steps))
    mean_col = np.zeros((1, time_steps))
    beliefs_row[:, 0] = init_beliefs_row
    beliefs_col[:, 0] = init_beliefs_col
    
    for t in range(0, time_steps-1):
        print(t)
        policy_row = cal_pol(beliefs_row[:, t])
        policy_col = cal_pol(beliefs_col[:, t])
        mean_pol_row = np.mean(policy_row)
        mean_pol_col = np.mean(policy_col)
        beliefs_row[:, t + 1] = (beliefs_row[:, t] * (sigma + t) + mean_pol_col) / (sigma + t + 1)
        beliefs_col[:, t + 1] = (beliefs_col[:, t] * (sigma + t) + mean_pol_row) / (sigma + t + 1)
        mean_row[0, t] = mean_pol_row
        mean_col[0, t] = mean_pol_col
        
    mean_belief_row = np.mean(beliefs_row, axis = 0)
    mean_belief_col = np.mean(beliefs_col, axis = 0)
    print("agent-based simulations end")
    
    return mean_row, mean_col, mean_belief_row, mean_belief_col 


#initital settings
if bimodal == False:
    init_beliefs_row = np.random.beta(alpha_row, beta_row, num_agents)
    init_beliefs_col = np.random.beta(alpha_col, beta_col, num_agents)
    [init_mean_bel_row, init_var_bel_row] = beta_Stats(alpha_row, beta_row) 
    [init_mean_bel_col, init_var_bel_col] = beta_Stats(alpha_col, beta_col) 


#run agent-based simulations
sim_mean_belief_row = np.zeros((num_sim, time_steps))
sim_mean_policy_row = np.zeros((num_sim, time_steps))
sim_mean_belief_col = np.zeros((num_sim, time_steps))
sim_mean_policy_col = np.zeros((num_sim, time_steps))
for i in range(0, num_sim):
    print(i)
    sim_results = agent_based_simulations()    
    sim_mean_policy_row[i, :] = sim_results[0]
    sim_mean_policy_col[i, :] = sim_results[1]
    sim_mean_belief_row[i, :] = sim_results[2]
    sim_mean_belief_col[i, :] = sim_results[3]

#plot simulation results
plt.figure()
t_space=np.linspace(0, time_steps, time_steps)
#plt.plot(t_space, sim_results[0], linewidth = 1, label = 'Mean Prob. Playing H (Population 1)')
#plt.plot(t_space, sim_results[1], linewidth = 1, label = 'Mean Pob. Playing H (Population 2)')
plt.plot(t_space, sim_results[2], linewidth = 1, label = 'Mean Belief about Playing H (Population 1)')
plt.plot(t_space, sim_results[3], linewidth = 1, label = 'Mean Belief about Playing H (Population 2)')
plt.title('Simulation Results')
plt.xlabel('Time t')
plt.legend()
plt.show()



# save simulation results
file = open("sim_mean_belief_row.txt", "w")
for i in range(0, num_sim):
    for belief in sim_mean_belief_row[i]:    
        file.write(str(belief) + ', ')
    file.write('\n')
file.write('\n')
file.close()

file = open("sim_mean_policy_row.txt", "w")
for i in range(0, num_sim):
    for policy in sim_mean_policy_row[i]:    
        file.write(str(policy) + ', ')
    file.write('\n')
file.write('\n')
file.close()

file = open("sim_mean_belief_col.txt", "w")
for i in range(0, num_sim):
    for belief in sim_mean_belief_col[i]:    
        file.write(str(belief) + ', ')
    file.write('\n')
file.write('\n')
file.close()

file = open("sim_mean_policy_col.txt", "w")
for i in range(0, num_sim):
    for policy in sim_mean_policy_col[i]:    
        file.write(str(policy) + ', ')
    file.write('\n')
file.write('\n')
file.close()




