import numpy as np
import random
import itertools
from collections import defaultdict
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d


def update_regret_and_record(iteration_reward, best_single_round_reward, cumulative_regret, regret_list):
    regret = best_single_round_reward - iteration_reward
    cumulative_regret += regret
    regret_list.append(cumulative_regret)
    return cumulative_regret, regret_list

def apply_smoothing(regret_list, sigma=3):
    return gaussian_filter1d(regret_list, sigma=sigma)

def create_substitutable_preferences(N, K):
    arm_preferences = {}
    
    for a in range(K):
        all_subsets = []
        for size in range(1, N + 1):
            for subset in itertools.combinations(range(N), size):
                all_subsets.append(subset)
        
        num_preferences = np.random.randint(N, len(all_subsets)//2)
        initial_preferences = random.sample(all_subsets, num_preferences)
        
        preference_dict = {}
        for subset in all_subsets:
            subset = tuple(sorted(subset))
            subset_set = set(subset)
            
            possible_choices = set()
            for pref in initial_preferences:
                pref_set = set(pref)
                if subset_set.issubset(pref_set):
                    possible_choices.update(subset_set.intersection(pref_set))

            if len(possible_choices) == 0:
                preference_dict[subset] = ()
            else:
                preference_dict[subset] = tuple(sorted(possible_choices))
        
        arm_preferences[a] = preference_dict
    
    return arm_preferences


def generate_reward(mean, scale=1.0):
    return np.random.normal(mean, scale)

def check_market_stability(matched, player_preferences, arm_preferences):
    N = len(player_preferences)
    K = len(arm_preferences)
    
    arm_matched = defaultdict(set)
    for p, a in enumerate(matched):
        if a != -1:
            arm_matched[a].add(p)
    
    for p in range(N):
        current_arm = matched[p]
        
        if current_arm == -1:
            preferred_arms = player_preferences[p]
        else:
            try:
                current_rank = player_preferences[p].index(current_arm)
                preferred_arms = player_preferences[p][:current_rank]
            except ValueError:
                preferred_arms = player_preferences[p]
        
        for arm in preferred_arms:
            potential_players = arm_matched[arm].copy()
            potential_players.add(p)
            potential_tuple = tuple(sorted(potential_players))
            
            if potential_tuple in arm_preferences[arm]:
                chosen_set = set(arm_preferences[arm][potential_tuple])
                
                if p in chosen_set:
                    rejected = arm_matched[arm] - chosen_set
                    if rejected or current_arm == -1:
                        return 1  
    
    return 0

def gale_shapley_optimal(player_preferences, arm_preferences, player_mean):
    N = len(player_preferences)
    K = len(arm_preferences)
    
    player_next_proposal = [0] * N 
    arm_current_players = [set() for _ in range(K)] 
    player_matched_arm = [-1] * N  
    
    max_iterations = N * K * 10  
    
    for iteration in range(max_iterations):
        proposing_players = []
        for p in range(N):
            if (player_matched_arm[p] == -1 and 
                player_next_proposal[p] < len(player_preferences[p])):
                proposing_players.append(p)
        
        if not proposing_players:
            break  
        
        proposals = defaultdict(set)
        for p in proposing_players:
            arm = player_preferences[p][player_next_proposal[p]]
            player_next_proposal[p] += 1
            proposals[arm].add(p)
        
        for arm, new_proposers in proposals.items():
            all_candidates = arm_current_players[arm] | new_proposers
            candidate_tuple = tuple(sorted(all_candidates))
            
            if candidate_tuple in arm_preferences[arm]:
                chosen_players = set(arm_preferences[arm][candidate_tuple])
            else:
                chosen_players = set()
                for subset in arm_preferences[arm]:
                    if set(subset).issubset(all_candidates):
                        potential_chosen = set(arm_preferences[arm][subset])
                        if len(potential_chosen) > len(chosen_players):
                            chosen_players = potential_chosen
            
            rejected_players = arm_current_players[arm] - chosen_players
            for p in rejected_players:
                player_matched_arm[p] = -1
            
            accepted_new = new_proposers & chosen_players
            for p in accepted_new:
                player_matched_arm[p] = arm
            
            arm_current_players[arm] = chosen_players
    
    optimal_rewards = np.zeros(N)
    for p in range(N):
        if player_matched_arm[p] != -1:
            optimal_rewards[p] = player_mean[p, player_matched_arm[p]]
    
    return player_matched_arm, optimal_rewards

def initialize_matching(N, K, arm_preferences, player_preferences, player_mean, optimal_rewards, noise_scale=1.0):
    players = [{"index": None, "active": True} for _ in range(N)]
    regret_list, cumulative_regret = [], 0.0
    unstable_list, cumulative_unstable = [], 0.0
    global_index = 0
    iteration = 0

    current_matching = [-1] * N

    while any(p["active"] for p in players) and iteration < 1000:  
        active_players = [i for i, p in enumerate(players) if p["active"]]

        best_reward = sum(optimal_rewards[p] for p in active_players)
        iteration_reward = 0.0
        arm_proposals = defaultdict(list)


        for p in active_players:
            arm = np.random.randint(K)
            arm_proposals[arm].append(p)

            current_matching[p] = arm
            reward = generate_reward(player_mean[p, arm], noise_scale)
            iteration_reward += reward

        matched = set()
        for arm, props in arm_proposals.items():
            prop_tuple = tuple(sorted(props))
            if prop_tuple in arm_preferences[arm]:
                chosen = arm_preferences[arm][prop_tuple]
                matched.update(chosen)


        for p in matched:
            if players[p]["active"]:
                players[p]["index"] = global_index
                players[p]["active"] = False
                global_index += 1

        unstable_degree = check_market_stability(current_matching, player_preferences, arm_preferences)
        cumulative_unstable += unstable_degree
        unstable_list.append(cumulative_unstable)


        regret = best_reward - iteration_reward
        cumulative_regret += regret
        regret_list.append(cumulative_regret)
        
        iteration += 1

    return players, regret_list, unstable_list, iteration



def precheck_single_proposal(players, arm_preferences, player_preferences, player_mean, optimal_rewards, prior_regret, prior_unstable, noise_scale=1.0):
    N, K = len(players), len(arm_preferences)
    index_order = sorted([(i, p["index"]) for i, p in enumerate(players)], key=lambda x: x[1])
    feasible_arms = [[] for _ in range(N)]
    
    regret_list = [prior_regret]
    unstable_list = [prior_unstable]
    cumulative_regret = prior_regret
    cumulative_unstable = prior_unstable
    
    current_matching = [-1] * N
    iteration = 0

    for p, idx in index_order:
        if idx is None: continue
        best_reward = optimal_rewards[p]
        
        for a in range(K):
            single_player_tuple = (p,)
            if single_player_tuple in arm_preferences[a] and p in arm_preferences[a][single_player_tuple]:
                feasible_arms[p].append(a)
                current_matching[p] = a
                reward = generate_reward(player_mean[p, a], noise_scale)
                
                unstable_degree = check_market_stability(current_matching, player_preferences, arm_preferences)
                cumulative_unstable += unstable_degree
                unstable_list.append(cumulative_unstable)
 
                regret = best_reward - reward
                cumulative_regret += regret
                regret_list.append(cumulative_regret)
                iteration += 1

    return feasible_arms, regret_list, unstable_list, iteration



def preference_learning_ucb(players, feasible_arms, player_mean, player_preferences, arm_preferences, optimal_matching, optimal_rewards, prior_regret, prior_unstable, delta=0.3, horizon=20000, min_explore=10, noise_scale=1.0):
    N, K = player_mean.shape
    mu = np.zeros((N, K))
    T = np.zeros((N, K))
    
    regret_list = [prior_regret]
    unstable_list = [prior_unstable]
    cumulative_regret = prior_regret
    cumulative_unstable = prior_unstable
    
    current_matching = [-1] * N
    converged = [False]*N
    convergence_time = [-1]*N

    for t in range(horizon):
        active_players = [p for p in range(N) if players[p]["index"] is not None and not converged[p]]
        if not active_players:
     
            for remaining in range(horizon - t):
                regret_list.append(cumulative_regret)
                unstable_list.append(cumulative_unstable)
            break


        best_reward = sum(optimal_rewards[p] for p in active_players)
        iteration_reward = 0

        for p in active_players:
            arms = feasible_arms[p]
            if not arms: continue


            if any(T[p, a] < min_explore for a in arms):
                chosen = next(a for a in arms if T[p, a] < min_explore)
            else:
                ucb_scores = []
                for a in arms:
                    n = max(T[p, a], 1)
                    ucb = mu[p, a] + np.sqrt(2*(1/delta)*np.log(t+1)/n)
                    ucb_scores.append((a, ucb))
                chosen = max(ucb_scores, key=lambda x: x[1])[0]


            current_matching[p] = chosen

            reward = generate_reward(player_mean[p, chosen], noise_scale)
            iteration_reward += reward
            T[p, chosen] += 1
            mu[p, chosen] += (reward - mu[p, chosen]) / T[p, chosen]

            if t > 500 and min(T[p, a] for a in arms) >= min_explore:
                est_best = max(arms, key=lambda a: mu[p, a])
                if est_best == optimal_matching[p]:
                    converged[p] = True
                    convergence_time[p] = t

        unstable_degree = check_market_stability(current_matching, player_preferences, arm_preferences)
        
        cumulative_unstable += unstable_degree
        unstable_list.append(cumulative_unstable)
        
      
        regret = best_reward - iteration_reward
        cumulative_regret += regret
        regret_list.append(cumulative_regret)



    converged_times = [t for t in convergence_time if t != -1]

    return mu, regret_list, unstable_list, converged


def final_gale_shapley(N, mu, arm_preferences, player_preferences, player_mean, optimal_matching, optimal_rewards, players, prior_regret, prior_unstable, delta=0.3, noise_scale=1.0):
    K = len(arm_preferences)
    
    player_rank = [list(np.argsort(-mu[p])) for p in range(N)]
    
    propose_order = [0]*N
    matched = [-1]*N
    arm_matched = [[] for _ in range(K)]
    
    regret_list = [prior_regret]
    unstable_list = [prior_unstable]
    cumulative_regret = prior_regret
    cumulative_unstable = prior_unstable

    for it in range(100):  
        unmatched = [p for p in range(N) if matched[p]==-1 and propose_order[p]!=-1]
        if not unmatched: break
        
        best_reward = sum(optimal_rewards[p] for p in unmatched)
        iteration_reward = 0.0

        proposals = defaultdict(list)
        for p in unmatched:
            if propose_order[p] >= len(player_rank[p]):
                propose_order[p] = -1
                continue
            a = player_rank[p][propose_order[p]]
            propose_order[p] += 1
            proposals[a].append(p)
        
        for a, props in proposals.items():
            all_players = set(arm_matched[a] + props)
            prop_tuple = tuple(sorted(all_players))

            selected = []
            if prop_tuple in arm_preferences[a]:
                selected = arm_preferences[a][prop_tuple]
            else:
                for subset in arm_preferences[a]:
                    if set(subset).issubset(all_players):
                        selected.extend([p for p in arm_preferences[a][subset] if p in all_players])

            for p in arm_matched[a]:
                if p not in selected:
                    matched[p] = -1
            
            for p in props:
                if p in selected:
                    matched[p] = a
                    iteration_reward += generate_reward(player_mean[p, a], noise_scale)
            
            arm_matched[a] = [p for p in all_players if p in selected]
        
        unstable_degree = check_market_stability(matched, player_preferences, arm_preferences)
        
        cumulative_unstable += unstable_degree
        unstable_list.append(cumulative_unstable)
        
        regret = best_reward - iteration_reward
        cumulative_regret += regret
        regret_list.append(cumulative_regret)

    match_diff = sum(1 for p in range(N) if matched[p] != optimal_matching[p])
    print(f"Number of differences from the best match: {match_diff}/{N}")
    
    final_rewards = sum(player_mean[p, matched[p]] for p in range(N) if matched[p] != -1)
    optimal_total = sum(optimal_rewards)
    print(f"Final total revenue: {final_rewards:.4f} vs Optimal total return: {optimal_total:.4f}")
    
    final_unstable = check_market_stability(matched, player_preferences, arm_preferences)
    print(f"The instability of the final match: {final_unstable}")
    
    return matched, regret_list, unstable_list


def run_experiment(delta=0.3, seed=42, horizon=20000):
    np.random.seed(seed)
    random.seed(seed)

    N, K = 10, 10
    noise_scale = 1.0  # 1-subgaussian
    min_explore = 10
    
    player_mean = np.zeros((N, K))
    for p in range(N):
        for a in range(K):
            player_mean[p, a] = 0.1 + 0.8 * np.random.random()
    
    player_preferences = []
    for p in range(N):
        pref = list(np.argsort(-player_mean[p]))
        player_preferences.append(pref)
    

    arm_preferences = create_substitutable_preferences(N, K)
    optimal_matching, optimal_rewards = gale_shapley_optimal(player_preferences, arm_preferences, player_mean)
    print(f"optimal matching: {optimal_matching}")
    print(f"total revenue: {optimal_rewards}")
    
    players, reg1, unstable1, _ = initialize_matching(N, K, arm_preferences, player_preferences, player_mean, optimal_rewards, noise_scale)
    r1 = reg1[-1]
    u1 = unstable1[-1]

    feasible, reg1_5, unstable1_5, _ = precheck_single_proposal(players, arm_preferences, player_preferences, player_mean, optimal_rewards, r1, u1, noise_scale)
    r1_5 = reg1_5[-1]
    u1_5 = unstable1_5[-1]
    

    mu, reg2, unstable2, converged = preference_learning_ucb(players, feasible, player_mean, player_preferences, arm_preferences, optimal_matching, optimal_rewards, r1_5, u1_5, delta, horizon, min_explore, noise_scale)
    r2 = reg2[-1]
    u2 = unstable2[-1]
    

    matched, reg3, unstable3 = final_gale_shapley(N, mu, arm_preferences, player_preferences, player_mean, optimal_matching, optimal_rewards, players, r2, u2, delta, noise_scale)
    

    phase_lengths = [len(reg1), len(reg1_5) - 1, len(reg2) - 1, len(reg3) - 1]
    full_regret = reg1 + reg1_5[1:] + reg2[1:] + reg3[1:]
    full_unstable = unstable1 + unstable1_5[1:] + unstable2[1:] + unstable3[1:]
    
    return full_regret, full_unstable, matched, optimal_matching, optimal_rewards, phase_lengths


def run_multiple_experiments(delta, num_runs=5, horizon=20000):
    all_regrets = []
    all_unstable = []
    all_match_accuracy = []
    
    for run in range(num_runs):
        print(f"run delta={delta} 's {run+1}/{num_runs} experiment")
        regret, unstable, final_match, optimal_match, optimal_rewards, _ = run_experiment(delta=delta, seed=42+run, horizon=horizon)
        
        all_regrets.append(regret)
        all_unstable.append(unstable)
        
        match_accuracy = sum(1 for p in range(len(final_match)) if final_match[p] == optimal_match[p]) / len(final_match)
        all_match_accuracy.append(match_accuracy)

    min_length = min(len(r) for r in all_regrets)
    avg_regret = [sum(r[i] for r in all_regrets) / num_runs for i in range(min_length)]
    
    min_length = min(len(u) for u in all_unstable)
    avg_unstable = [sum(u[i] for u in all_unstable) / num_runs for i in range(min_length)]
    
    avg_match_accuracy = sum(all_match_accuracy) / num_runs
    
    return avg_regret, avg_unstable, avg_match_accuracy, min_length



def plot_result(num_runs=3, horizon=20000):
    deltas = [0.2, 0.3, 0.4, 0.5]
    styles = [':', '--', '-.', '-']
    colors = ['red', 'blue', 'green', 'purple']
    

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    
    all_regrets = []
    all_unstables = []
    all_match_accuracies = []
    min_lengths = []
    
    for i, d in enumerate(deltas):
        print(f"\n\n==== run delta={d} experiment（{num_runs}average）====")
        avg_regret, avg_unstable, avg_accuracy, min_length = run_multiple_experiments(delta=d, num_runs=num_runs, horizon=horizon)
        all_regrets.append(avg_regret)
        all_unstables.append(avg_unstable)
        all_match_accuracies.append(avg_accuracy)
        min_lengths.append(min_length)
    
    for i, d in enumerate(deltas):
        sample_step = max(1, min_lengths[i]//500)
        x_values = range(0, min_lengths[i], sample_step)
        sampled_regret = [all_regrets[i][j] for j in x_values]
        ax1.plot(x_values, sampled_regret, styles[i], color=colors[i], 
                 label=f'Δ = {d}')
    
    ax1.set_ylabel("Maximum Cumulative Stable Regret")
    ax1.set_title(f"Average over {num_runs} runs: Different preference gaps, substitutable preferences, N = 10, K = 10")
    ax1.grid(True)
    ax1.legend()
    
    for i, d in enumerate(deltas):
        sample_step = max(1, min_lengths[i]//500)
        x_values = range(0, min_lengths[i], sample_step)
        sampled_unstable = [all_unstables[i][j] for j in x_values]
        ax2.plot(x_values, sampled_unstable, styles[i], color=colors[i], 
                 label=f'Δ = {d}')
    
    ax2.set_xlabel("Round t")
    ax2.set_ylabel("Cumulative Market Unstability")
    ax2.set_title(f"Average over {num_runs} runs: Different preference gaps, substitutable preferences, N = 10, K = 10")
    ax2.grid(True)
    ax2.legend()
    
    plt.tight_layout()
    plt.savefig('regret_and_unstability_analysis_avg.png')
    plt.show()
    
    for i, d in enumerate(deltas):
        print(f"Delta = {d}: Accuracy = {all_match_accuracies[i]:.2f}")

if __name__ == "__main__":
    plot_result(num_runs=1, horizon=50000)