from normal_form_game_agents import *
from normal_form_games import *
import matplotlib.pyplot as plt
import numpy as np


def initialize_game_setting(current_eta, game, manipulated_utilities, agent_type='MW'):
    A1_manip, A2, B1, B2, C1, C2_manip, D1, D2 = manipulated_utilities
    manipulated_game_0 = create_game(A1_manip, A2, B1, B2, C1, C2_manip, D1, D2)
    manipulated_game_1 = manipulated_game_0
    if agent_type == 'MW':
        agents_manipulated = [MultiplicativeWeightsAgent(0, deepcopy(manipulated_game_0), eta=current_eta),
                              MultiplicativeWeightsAgent(1, deepcopy(manipulated_game_1), eta=current_eta)]
        agents_truthful = [MultiplicativeWeightsAgent(0, deepcopy(game), eta=current_eta),
                           MultiplicativeWeightsAgent(1, deepcopy(game), eta=current_eta)]

    elif agent_type == 'FTPL':
        agents_manipulated = [FTPLAgent(0, deepcopy(manipulated_game_0), eta=current_eta),
                              FTPLAgent(1, deepcopy(manipulated_game_1), eta=current_eta)]
        agents_truthful = [FTPLAgent(0, deepcopy(game), eta=current_eta),
                           FTPLAgent(1, deepcopy(game), eta=current_eta)]

    return game, agents_manipulated, agents_truthful

def run_simulation(eta, game, manipulated_utilities, sample_size=1, agent_type='MW', total_time=50000, start_count_t=1):
    actual_user_payoffs_by_c = []
    manipulated_baseline_payoffs_by_c = []
    print('c =', manipulated_utilities[0], '\td =', manipulated_utilities[5], '\teta = ', eta, )

    actual_user_payoffs_c_d = []
    payoffs_manipulated_c_d = []
    for j in range(sample_size):
        game, agents_manipulated, agents_truthful = initialize_game_setting(eta, game, manipulated_utilities, agent_type=agent_type)
        print('iteration:\t', j+1, '/', sample_size)
        game_final_manipulated, agents_final_manipulated = run_T_rounds(deepcopy(game), deepcopy(agents_manipulated), total_time)
        payoffs_manipulated_c_d.append(np.mean(game_final_manipulated.get_payoff_history()[start_count_t:], axis=0))

    actual_user_payoffs_by_c.append(actual_user_payoffs_c_d)
    manipulated_baseline_payoffs_by_c.append(payoffs_manipulated_c_d)
    return actual_user_payoffs_by_c, manipulated_baseline_payoffs_by_c, game_final_manipulated, agents_final_manipulated

def plot_NE_utilities_old(true_utilities, manipulated_utilities, color=[0,0.8,0], marker='o',label='True NE'):
    A1, A2, B1, B2, C1, C2, D1, D2 = true_utilities
    A1_manip, A2_manip, B1_manip, B2_manip, C1_manip, C2_manip, D1_manip, D2_manip = manipulated_utilities
    if label == 'Users NE':
        # plot the utilities in the true NE
        p_NE = 1 / (1 + (A2 - B2) / (D2 - C2))
        q_NE = 1 / (1 + (A1 - C1) / (D1 - B1))
        u1_NE = p_NE * (q_NE * A1 + (1 - q_NE) * B1) + (1 - p_NE) * (q_NE * C1 + (1 - q_NE) * D1)
        u2_NE = q_NE * (p_NE * A2 + (1 - p_NE) * C2) + (1 - q_NE) * (p_NE * B2 + (1 - p_NE) * D2)
        plt.plot(u1_NE, u2_NE, marker, color=color, label=label)
        marker = 'x'

    # plot the true utilities in the NE of the manipulated game
    q_manipulated = 1 / (1 + (A1_manip - C1) / (D1 - B1))
    p_manipulated = 1 / (1 + (A2 - B2) / (D2 - C2_manip))
    u1_manipulated_NE = p_manipulated * (q_manipulated * A1 + (1 - q_manipulated) * B1) + \
                        (1 - p_manipulated) * (q_manipulated * C1 + (1 - q_manipulated) * D1)
    u2_manipulated_NE = q_manipulated * (p_manipulated * A2 + (1 - p_manipulated) * C2) + \
                        (1 - q_manipulated) * (p_manipulated * B2 + (1 - p_manipulated) * D2)
    plt.plot(u1_manipulated_NE, u2_manipulated_NE, marker, color=color, label=label, alpha=0.7)

def plot_NE_utilities(utilities, utilities_manipulated, color=[0,0.8,0], marker='o',label='True NE', alpha=1):
    A1, A2, B1, B2, C1, C2, D1, D2 = utilities
    A1_manip, A2_manip, B1_manip, B2_manip, C1_manip, C2_manip, D1_manip, D2_manip = utilities_manipulated

    # plot the true utilities in the NE of the manipulated game
    q_manipulated = 1 / (1 + (A1_manip - C1) / (D1 - B1))
    p_manipulated = 1 / (1 + (A2 - B2) / (D2 - C2_manip))
    u1_manipulated_NE = p_manipulated * (q_manipulated * A1 + (1 - q_manipulated) * B1) + \
                        (1 - p_manipulated) * (q_manipulated * C1 + (1 - q_manipulated) * D1)
    u2_manipulated_NE = q_manipulated * (p_manipulated * A2 + (1 - p_manipulated) * C2) + \
                        (1 - q_manipulated) * (p_manipulated * B2 + (1 - p_manipulated) * D2)
    if label == None or label == '' or label == 'none':
        plt.plot(u1_manipulated_NE, u2_manipulated_NE, marker, color=color, alpha=0.8)
    else:
        plt.plot(u1_manipulated_NE, u2_manipulated_NE, marker, color=color, label=label, alpha=0.8)

def plot_empirical_payoffs(payoff_pair_list, label, color, marker='o'):
    # plot the empirical utilities with and without manipulation
    u1, u2 = payoff_pair_list
    plt.plot(u1, u2, color=color, label=label, marker=marker, linestyle='', fillstyle='none')

def create_2x2_game_matrix(A1,A2,B1,B2,C1,C2,D1,D2):
    # 2x2 game matrix of the form:
        #####################
        # A1,A2   #   B1,B2 #
        #####################
        # C1,C2   #   D1,D2 #
        #####################
    game_matrix = [[[A1,A2], [B1,B2]], [[C1,C2], [D1,D2]]]
    return game_matrix

def create_game(a,b,c,d,e,f,g,h):
    game_matrix = create_2x2_game_matrix(a,b,c,d,e,f,g,h)
    return NormalFormGame(game_matrix)
