import numpy as np
import matplotlib.pyplot as plt


def get_policy(config, priority, alpha):
    """
    Get the policy of activating, input Normalized Config
    """
    policy = np.zeros(len(priority))
    current_frac = 0
    for state in priority:
        current_frac += config[state]
        if current_frac <= alpha:
            policy[state] = 1
        elif current_frac > alpha and current_frac <= config[state] + alpha:
            residual_frac = alpha - (current_frac - config[state])
            policy[state] = residual_frac / config[state]
        elif current_frac > config[state] + alpha:
            policy[state] = 0

    return policy


def get_global_reward(config, priority, alpha, Reward_1, Reward_0):
    """
    Get the global reward
    """
    policy = get_policy(config, priority, alpha)
    rewards = 0
    for state in priority:
        reward = Reward_1[state] * policy[state] + Reward_0[state] * (1-policy[state])
        rewards += reward * config[state]
    
    return rewards


def compute_value_function(P: np.ndarray, R: np.ndarray, g: float) -> np.ndarray:
    """
    Compute the Value Function given the Transition Matrix and Reward Matrix
    """
    return np.linalg.inv(np.eye(P.shape[0]) - g * P) @ R



def MC_once(env, priority, trajectory_length):
            curr_config = env.config.copy()
            Value_MC = get_global_reward(curr_config/env.N, priority, env.alpha, env.Reward_1, env.Reward_0) * env.N
            gamma = env.gamma
            for _ in range(trajectory_length):
                next_config = np.zeros(env.d)
                alpha = env.alpha * env.N
                current_frac = 0
                for state in priority:
                    current_frac += curr_config[state]
                    if current_frac <= alpha:
                        for _ in range(int(curr_config[state])):
                            next_config[np.random.choice(env.d, p=env.Transition_1[state])] += 1
                    elif current_frac > alpha and current_frac <= curr_config[state] + alpha:
                        residual_frac = alpha - (current_frac - curr_config[state])
                        for _ in range(int(residual_frac)):
                            next_config[np.random.choice(env.d, p=env.Transition_1[state])] += 1
                        for _ in range(int(curr_config[state] - residual_frac)):
                            next_config[np.random.choice(env.d, p=env.Transition_0[state])] += 1
                    elif current_frac > curr_config[state] + alpha:
                        for _ in range(int(curr_config[state])):
                            next_config[np.random.choice(env.d, p=env.Transition_0[state])] += 1
                curr_config = next_config
                Value_MC += gamma * get_global_reward(curr_config/env.N, priority, env.alpha, env.Reward_1, env.Reward_0) * env.N
                gamma *= env.gamma

            return Value_MC


def plot_results(mean_loss, std_loss):
    mean_loss.insert(0, 0)
    std_loss.insert(0, 0)
    plt.rcParams['font.family'] = ['serif']
    plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
    plt.rcParams.update({'font.size': 15})
    
    x = [150*i for i in range(len(mean_loss))]
    plt.figure(figsize=(8, 6))
    plt.errorbar(x, mean_loss, yerr=std_loss, fmt = 'o',capsize = 5, markeredgewidth = 1)
    plt.plot(x, mean_loss, '-', label='Local TD') 
    plt.xlabel('Number of Agents')
    x_ticks = [150*i for i in range(0, len(mean_loss), 2)]
    plt.xticks(x_ticks)

    # Calculate the constant C based on the first data point
    C = mean_loss[1] / np.sqrt(150)
    
    # Generate a smooth curve for the theoretical relationship
    x_smooth = np.linspace(0, 2000, 1000)
    y_theory_smooth = [C * np.sqrt(xi) for xi in x_smooth]
    
    # Set y-axis to display with one decimal place
    plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.FormatStrFormatter('%.1f'))
    # Plot the theoretical curve
    plt.plot(x_smooth, y_theory_smooth, '-', label='$O(\sqrt{N})$')
    
    # plt.ylabel('$\|Q^\pi-\hat{Q}\|_\mu$')
    plt.title('Error Weighted by Stationary Distribution')
    plt.legend(loc='upper left', fontsize=14)
    plt.grid(True)
    plt.savefig('results_circular/loss_vs_N.png')
    

def plot_relative_loss(mean_loss, std_loss, Q_value):
    mean_loss = [mean_loss[i] / Q_value[i] for i in range(len(mean_loss))]
    std_loss = [std_loss[i] / Q_value[i] for i in range(len(std_loss))]
    # mean_loss.insert(0, 0)
    # std_loss.insert(0, 0)
    plt.rcParams['font.family'] = ['serif']
    plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
    plt.rcParams.update({'font.size': 15})
    
    x = [150*(i+1) for i in range(len(mean_loss))]
    plt.figure(figsize=(8, 6))
    plt.errorbar(x, mean_loss, yerr=std_loss, fmt = 'o',capsize = 5, markeredgewidth = 1)
    plt.plot(x, mean_loss, '-', label='Local TD') 
    plt.xlabel('Number of Agents')
    x_ticks = [150*i for i in range(0, len(mean_loss)+1, 2)]
    plt.xticks(x_ticks)

    # Calculate the constant C based on the first data point
    C = mean_loss[0] * np.sqrt(150)
    
    # Generate a smooth curve for the theoretical relationship
    x_smooth = np.linspace(0, 2000, 1000)
    y_theory_smooth = [C / np.sqrt(xi) for xi in x_smooth]
    
    # Set y-axis to display as percentages with one decimal place
    plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(xmax=1.0, decimals=1))
    # Set y-axis ticks with intervals of 0.01 up to 0.05
    y_ticks = np.arange(0, 0.05, 0.01)
    plt.yticks(y_ticks)
    # Set y-axis limits to truncate at 0.05
    plt.ylim(0, 0.05)
    
    # Plot the theoretical curve
    plt.plot(x_smooth, y_theory_smooth, '-', label='$O(1/\sqrt{N})$')
    
    # plt.ylabel('$\|Q^\pi-\hat{Q}\|_\mu$')
    plt.title('Relative Error')
    plt.legend(loc='upper left', fontsize=14)
    plt.grid(True)
    plt.savefig('results_circular/loss_vs_N_relative.png')


def plot_results_entanglement(mean_ent, std_ent):
    mean_ent.insert(0, 0)
    mean_ent = [x/2 for x in mean_ent]
    std_ent.insert(0, 0)
    std_ent = [x/2 for x in std_ent]
    plt.rcParams['font.family'] = ['serif']
    plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
    plt.rcParams.update({'font.size': 15})
    
    x = [150*i for i in range(len(mean_ent))]
    plt.figure(figsize=(8, 6))
    plt.errorbar(x, mean_ent, yerr=std_ent, fmt = 'o',capsize = 5, markeredgewidth = 1)
    plt.plot(x, mean_ent, '-', label='Monte-Carlo') 
    plt.xlabel('Number of Agents')
    x_ticks = [150*i for i in range(0, len(mean_ent), 2)]
    plt.xticks(x_ticks)

    # Calculate the constant C based on the first data point
    C = mean_ent[1] / np.sqrt(150)
    
    # Generate a smooth curve for the theoretical relationship
    x_smooth = np.linspace(0, 2000, 1000)
    y_theory_smooth = [C * np.sqrt(xi) for xi in x_smooth]
    
    # Set y-axis to display with one decimal place
    plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.FormatStrFormatter('%.1f'))
    # Plot the theoretical curve
    plt.plot(x_smooth, y_theory_smooth, '-', label='$O(\sqrt{N})$')
    
    # plt.ylabel('$\|Q^\pi-\hat{Q}\|_\mu$')
    plt.title('Empirical Measure of Entanglement')
    plt.legend(loc='upper left', fontsize=14)
    plt.grid(True)
    plt.savefig('results_circular/ent_vs_N.png')



def plot_learning_curve(est_entangle, decomp_loss, opt_entangle):
    est_entangle = est_entangle[:270]
    decomp_loss = decomp_loss[:270]
    
    # Apply moving average to smooth decomp_loss
    window_size = 5  # You can adjust this value to control smoothness
    decomp_loss_smooth = np.convolve(decomp_loss, np.ones(window_size)/window_size, mode='valid')
    
    plt.rcParams['font.family'] = ['serif']
    plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
    plt.rcParams.update({'font.size': 15})
    
    x = [10*(i+1) for i in range(len(est_entangle))]
    x_smooth = [10*(i+1) for i in range(len(decomp_loss_smooth))]
    
    plt.figure(figsize=(8, 6))
    plt.plot(x, est_entangle, '-', label='Estimated Entanglement Measure')
    
    # Plot both original and smoothed decomp_loss
    plt.plot(x, decomp_loss, '-', alpha=0.3, label='Decomposition Error (Raw)')
    plt.plot(x_smooth, decomp_loss_smooth, '-', label='Decomposition Error (Smoothed)')
    
    plt.xlabel('Sample Trajectory Length')
    x_ticks = [i for i in range(0, 3000, 300)] 
    plt.xticks(x_ticks)
    plt.axhline(y=opt_entangle, color='r', linestyle='--', label='Entanglement Measure')
    plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.FormatStrFormatter('%.1f'))
    
    y_ticks = np.arange(0, 50, 10)
    plt.yticks(y_ticks)
    plt.ylim(0, 50)

    plt.title('Estimation Error and Decomposition Error')
    plt.legend(loc='upper right', fontsize=14)
    plt.grid(True)
    plt.savefig('results_circular/learning_curve.png')