import numpy as np
import matplotlib.pyplot as plt
from ODA import Decentralized
import sys
import os
from collections import defaultdict
import random


sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from Ours import run_multiple_experiments

np.random.seed(42)
random.seed(42)

num_players = 10
num_arms = 10
deltas = [0.3, 0.4, 0.5]  
start_base = num_arms
horizon = 60000  
trials = 1 

player_ranking = []
for p_idx in range(num_players):
    player_ranking.append(list(range(num_arms)))

print("Player ranking:")
for p_idx, prefs in enumerate(player_ranking):
    print(f"Player {p_idx}: {prefs}")

arm_ranking = []

for a_idx in range(3):
    arm_ranking.append( {(0,1,2,3,4,5,6,7,8,9):(0,1,2), (0,1,2):(0,1,2) , (0,):(0,), (1,):(1,),(2,):(2,), (0,1):(0,1), (0,2):(0,2), (1,2):(1,2),
    (1,2,3,4,5,6,7,8,9):(1,2),(0,2,3,4,5,6,7,8,9):(0,2),(0,1,3,4,5,6,7,8,9):(0,1), 
    (2,3,4,5,6,7,8,9):(2,),(1,3,4,5,6,7,8,9):(1,), (0,3,4,5,6,7,8,9):(0,),
    (3,4,5,6,7,8,9):(),
    (0,1,2,3):(0,1,2) , (0,3):(0,), (1,3):(1,),(2,3):(2,), (0,1,3):(0,1), (0,2,3):(0,2), (1,2,3):(1,2),
    (0,1,2,4):(0,1,2) , (0,4):(0,), (1,4):(1,),(2,4):(2,), (0,1,4):(0,1), (0,2,4):(0,2), (1,2,4):(1,2),
    (0,1,2,5):(0,1,2) , (0,5):(0,), (1,5):(1,),(2,5):(2,), (0,1,5):(0,1), (0,2,5):(0,2), (1,2,5):(1,2),
    (0,1,2,6):(0,1,2) , (0,6):(0,), (1,6):(1,),(2,6):(2,), (0,1,6):(0,1), (0,2,6):(0,2), (1,2,6):(1,2),
    (0,1,2,7):(0,1,2) , (0,7):(0,), (1,7):(1,),(2,7):(2,), (0,1,7):(0,1), (0,2,7):(0,2), (1,2,7):(1,2),
    (0,1,2,8):(0,1,2) , (0,8):(0,), (1,8):(1,),(2,8):(2,), (0,1,8):(0,1), (0,2,8):(0,2), (1,2,8):(1,2),
    (0,1,2,9):(0,1,2) , (0,9):(0,), (1,9):(1,),(2,9):(2,), (0,1,9):(0,1), (0,2,9):(0,2), (1,2,9):(1,2),
    (3,):(), (4,):(), (5,):(), (6,):(), (7,):(), (8,):(), (9,):()
    } )

for a_idx in range(3):
    a_idx = 3+a_idx
    arm_ranking.append( {(0,1,2,3,4,5,6,7,8,9):(3,4,5), (3,4,5):(3,4,5) , (3,):(3,), (4,):(4,),(5,):(5,), (3,4):(3,4), (4,5):(4,5),(3,5):(3,5), 
    (0,1,2,4,5,6,7,8,9):(4,5),(0,1,2,3,5,6,7,8,9):(3,5),(0,1,2,3,4,6,7,8,9):(3,4), 
    (0,1,2,5,6,7,8,9):(5,),(0,1,2,4,6,7,8,9):(4,), (0,1,2,3,6,7,8,9):(3,),
    (0,1,2,6,7,8,9):(),
    (3,4,5,6):(3,4,5) , (3,6):(3,), (4,6):(4,),(5,6):(5,), (3,4,6):(3,4), (4,5,6):(4,5),(3,5,6):(3,5), 
    (3,4,5,7):(3,4,5) , (3,7):(3,), (4,7):(4,),(5,7):(5,), (3,4,7):(3,4), (4,5,7):(4,5),(3,5,7):(3,5), 
    (3,4,5,8):(3,4,5) , (3,8):(3,), (4,8):(4,),(5,8):(5,), (3,4,8):(3,4), (4,5,8):(4,5),(3,5,8):(3,5), 
    (3,4,5,9):(3,4,5) , (3,9):(3,), (4,9):(4,),(5,9):(5,), (3,4,9):(3,4), (4,5,9):(4,5),(3,5,9):(3,5), 
    (6,):(), (7,):(), (8,):(), (9,):()
    } )

for a_idx in range(2):
    a_idx = 6+a_idx
    arm_ranking.append( {(0,1,2,3,4,5,6,7,8,9):(6,7), (6,7):(6,7) , (6,):(6,),(7,):(7,),
    (0,1,2,3,4,5,6,8,9):(6,),(0,1,2,3,4,5,7,8,9):(7,),(0,1,2,3,4,5,8,9):(), 
    (6,7,8):(6,7) , (6,8):(6,),(7,8):(7,),
    (6,7,9):(6,7) , (6,9):(6,),(7,9):(7,),
    (8,):(), (9,):()
    } )

for a_idx in range(2):
    a_idx = 8+a_idx
    arm_ranking.append( {(0,1,2,3,4,5,6,7,8,9):(8,9), (8,9):(8,9) , (8,):(8,), (9,):(9,),
    (0,1,2,3,4,5,6,7,9):(9,),(0,1,2,3,4,5,6,7,8):(8,),(0,1,2,3,4,5,6,7):(), 
    } )

results = {
    'decen_regrets': [],
    'decen_unstable': [],
    'test_regrets': [],
    'test_unstable': []
}

for delta in deltas:
    print(f"\n===== run delta = {delta}  =====")
    
    start = delta * start_base
    players_mean_value = [np.zeros(num_arms) for j in range(num_players)]
    for j in range(num_players):
        for i in range(num_arms):
            players_mean_value[j][i] = start - delta * i
    
    player_mean = [np.zeros([num_arms]) for j in range(num_players)]
    for p_idx in range(num_players):
        for arm in range(num_arms):
            player_mean[p_idx][arm] = players_mean_value[p_idx][player_ranking[p_idx].index(arm)]
    
    arm_capacity = [1, 2, 3]
    print("run ODA method...")
    try:
        decen = Decentralized(
            horizon=horizon, 
            trial=trials, 
            num_player=num_players, 
            num_arm=num_arms, 
            player_ranking=player_ranking, 
            arm_ranking=arm_ranking, 
            player_mean=player_mean, 
            arm_capacity=arm_capacity
        )
        
        mean_regret, mean_unstable = decen.decen_elimination_substitue(
            N=num_players, 
            C=1, 
            delta=delta
        )
        
        results['decen_regrets'].append(mean_regret)
        results['decen_unstable'].append(mean_unstable)
    except Exception as e:
        print(f"run Decentralized goes wrong: {e}")
        import traceback
        traceback.print_exc()
        results['decen_regrets'].append(np.zeros(horizon))
        results['decen_unstable'].append(np.zeros(horizon))
    
    print("run our method...")
    try:
        avg_regret, avg_unstable, _, _ = run_multiple_experiments(
            delta=delta, 
            num_runs=1,
            horizon=horizon
        )
        
        results['test_regrets'].append(avg_regret)
        results['test_unstable'].append(avg_unstable)
    except Exception as e:
        print(f"run ours goes wrong: {e}")
        import traceback
        traceback.print_exc()
        results['test_regrets'].append(np.zeros(horizon))
        results['test_unstable'].append(np.zeros(horizon))


np.savez(
    'ODA_and_ours.npz',
    decen_regrets=np.array(results['decen_regrets'], dtype=object),
    decen_unstable=np.array(results['decen_unstable'], dtype=object),
    test_regrets=np.array(results['test_regrets'], dtype=object),
    test_unstable=np.array(results['test_unstable'], dtype=object),
    deltas=np.array(deltas)
)


colors = ['r', 'g', 'b', 'purple']
line_styles = ['-', '--', ':', '-.']

for i, delta in enumerate(deltas):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    fig.suptitle("Different algorithms comparison, substitutee preferences, optimal regret, N=10, K=10, Delta={delta}", 
                fontsize=14, y=0.96
                )
    
    if len(results['decen_regrets']) > i and len(results['test_regrets']) > i:
        decen_regret = results['decen_regrets'][i]
        test_regret = results['test_regrets'][i]
        decen_unstable = results['decen_unstable'][i]
        test_unstable = results['test_unstable'][i]
        
        min_length_regret = min(len(decen_regret), len(test_regret))
        min_length_unstable = min(len(decen_unstable), len(test_unstable))
        
        if min_length_regret > 0:
            x_values = range(min_length_regret)
            ax1.plot(x_values, decen_regret[:min_length_regret], color='blue', linestyle='-', marker= '+',markevery= 6000,
                     label='ODA')
            ax1.plot(x_values, test_regret[:min_length_regret], color='red', linestyle='-', marker= '+',markevery= 6000,
                     label='Ours')
            # ax1.set_title(f'Maximum Cumulative Stable Regret (delta={delta})')
            ax1.set_xlabel('Round t')
            ax1.set_ylabel('Maximum Cumulative Stable Regret')
            ax1.legend()
            ax1.grid(True)
        
        if min_length_unstable > 0:
            x_values = range(min_length_unstable)
            ax2.plot(x_values, decen_unstable[:min_length_unstable], color='blue', linestyle='-', marker= '+',markevery= 6000,
                     label='ODA')
            ax2.plot(x_values, test_unstable[:min_length_unstable], color='red', linestyle='-', marker= '+',markevery= 6000,
                     label='Ours')
            # ax2.set_title(f'Cumulative Market Unstability (delta={delta})')
            ax2.set_xlabel('Round t')
            ax2.set_ylabel('Cumulative Market Unstability')
            ax2.legend()
            ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'ODA_and_ours_{delta}.png')
    plt.show()

fig = plt.figure(figsize=(15, 10))
fig.suptitle("Different algorithms comparison, substitutee preferences, optimal regret, N=10, K=10", 
                fontsize=14, y=0.98)

plt.subplot(2, 1, 1)
for i, delta in enumerate(deltas):
    if i < len(results['decen_regrets']) and i < len(results['test_regrets']):
        decen_regret = results['decen_regrets'][i]
        test_regret = results['test_regrets'][i]
        
        min_length = min(len(decen_regret), len(test_regret))
        if min_length > 0:
            x_values = range(min_length)
            plt.plot(x_values, decen_regret[:min_length], color=colors[i % len(colors)], linestyle='-', 
                     label=f'ODA (delta={delta})')
            plt.plot(x_values, test_regret[:min_length], color=colors[i % len(colors)], linestyle='--', 
                     label=f'Ours (delta={delta})')

# plt.title('Maximum Cumulative Stable Regret')
plt.xlabel('Round t')
plt.ylabel('Maximum Cumulative Stable Regret')
plt.legend()
plt.grid(True)

plt.subplot(2, 1, 2)
for i, delta in enumerate(deltas):
    if i < len(results['decen_unstable']) and i < len(results['test_unstable']):
        decen_unstable = results['decen_unstable'][i]
        test_unstable = results['test_unstable'][i]
        
        min_length = min(len(decen_unstable), len(test_unstable))
        if min_length > 0:
            x_values = range(min_length)
            plt.plot(x_values, decen_unstable[:min_length], color=colors[i % len(colors)], linestyle='-', 
                     label=f'ODA (delta={delta})')
            plt.plot(x_values, test_unstable[:min_length], color=colors[i % len(colors)], linestyle='--', 
                     label=f'Ours (delta={delta})')

# plt.title('Cumulative Market Unstability')
plt.xlabel('Round t')
plt.ylabel('Cumulative Market Unstability')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('ODA_and_ours.png')
plt.show()
