import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.ticker import MaxNLocator

def plot_training_history(agent):
    """Plot training loss history"""
    plt.figure(figsize=(15, 5))
    
    if agent.algorithm == "SAC":
        #Plot Critic Losses
        plt.subplot(1, 3, 1)
        plt.plot(agent.train_info['critic_losses'])
        plt.title('Critic Losses')
        plt.xlabel('Update times')
        plt.ylabel('loss')
        plt.grid(True)
        
        #Plot Actor Losses
        plt.subplot(1, 3, 2)
        plt.plot(agent.train_info['actor_losses'])
        plt.title('Actor Loss')
        plt.xlabel('Update times')
        plt.ylabel('loss')
        plt.grid(True)
        
        #Draw alpha variation
        plt.subplot(1, 3, 3)
        plt.plot(agent.train_info['alphas'])
        plt.title('Alpha Change')
        plt.xlabel('Update times')
        plt.ylabel('Alpha value')
        plt.grid(True)
    
    elif agent.algorithm == "TD3":
        #Plot Critic Losses
        plt.subplot(1, 2, 1)
        plt.plot(agent.train_info['critic_losses'])
        plt.title('Critic Losses')
        plt.xlabel('Update times')
        plt.ylabel('loss')
        plt.grid(True)
        
        #Plot Actor Losses
        plt.subplot(1, 2, 2)
        plt.plot(agent.train_info['actor_losses'])
        plt.title('Actor Loss')
        plt.xlabel('Update times')
        plt.ylabel('loss')
        plt.grid(True)
    
    elif agent.algorithm == "PPO":
        #Plot Total Losses
        plt.subplot(2, 2, 1)
        plt.plot(agent.train_info['total_losses'])
        plt.title('Total losses')
        plt.xlabel('Update times')
        plt.ylabel('loss')
        plt.grid(True)
        
        #Plot Value Loss
        plt.subplot(2, 2, 2)
        plt.plot(agent.train_info['value_losses'])
        plt.title('Value Loss')
        plt.xlabel('Update times')
        plt.ylabel('loss')
        plt.grid(True)
        
        #Draw Motion Losses
        plt.subplot(2, 2, 3)
        plt.plot(agent.train_info['action_losses'])
        plt.title('Loss of action')
        plt.xlabel('Update times')
        plt.ylabel('loss')
        plt.grid(True)
        
        #Plot Entropy Loss
        plt.subplot(2, 2, 4)
        plt.plot(agent.train_info['entropy_losses'])
        plt.title('Entropy loss')
        plt.xlabel('Update times')
        plt.ylabel('loss')
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{agent.algorithm}_training_history.png")
    plt.show()


def plot_evaluation_results(eval_results, algorithm_name):
    """Plotting Assessment Results"""
    iterations = [res[0] for res in eval_results]
    rmse_values = [res[1]['avg_rmse'] for res in eval_results]
    success_rates = [res[1]['success_rate'] for res in eval_results]
    
    plt.figure(figsize=(15, 5))
    
    #Draw RMSE Trends
    plt.subplot(1, 2, 1)
    plt.plot(iterations, rmse_values, marker='o', linestyle='-')
    plt.title(f'{algorithm_name} - RMSE Trends')
    plt.xlabel('Number of training iterations')
    plt.ylabel('Average RMSE')
    plt.grid(True)
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    
    #Plotting Success Rate Trends
    plt.subplot(1, 2, 2)
    plt.plot(iterations, success_rates, marker='o', linestyle='-', color='green')
    plt.title(f'{algorithm_name} - Success Rate Trends')
    plt.xlabel('Number of training iterations')
    plt.ylabel('Success Rate')
    plt.grid(True)
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    
    plt.tight_layout()
    plt.savefig(f"{algorithm_name}_eval_results.png")
    plt.show()


def compare_algorithms(results_dict):
    """
    Compare the performance of different algorithms
    
    Args:
        results_dict: dictionary containing the final evaluation results of each algorithm
    """
    algorithms = list(results_dict.keys())
    metrics = ['avg_rmse', 'success_rate', 'avg_treatment_similarity', 'avg_steps_used']
    metric_names = ['Average RMSE', 'Success Rate', 'Treatment Similarity', 'Average moves']
    
    #Create wpDataTable
    plt.figure(figsize=(12, 10))
    
    for i, metric in enumerate(metrics):
        plt.subplot(2, 2, i+1)
        
        values = [results_dict[alg][metric] for alg in algorithms]
        if metric == 'success_rate':
            values = [v * 100 for v in values]  #Convert to Percentage
        
        bars = plt.bar(algorithms, values, color=['skyblue', 'lightgreen', 'salmon'])
        
        #Add a Numeric Label
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                     f'{height:.3f}' if metric != 'success_rate' else f'{height:.1f}%',
                     ha='center', va='bottom', fontsize=9)
        
        plt.title(metric_names[i])
        plt.xlabel('Algorithm')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        if metric == 'success_rate':
            plt.ylabel('Probability (%)')
        else:
            plt.ylabel(metric_names[i])
    
    plt.tight_layout()
    plt.savefig("algorithms_comparison.png")
    plt.show()


def plot_treatment_plan(agent, dataset_collection, history_dict, goal, title=None):
    """
    Visualize Generated Treatment Plan
    
    Args:
        agent: trained model
        dataset_collection: Dataset collection
        history_dict: History Status Dictionary
        goal: Goal status
        title: Chart Title
    """
    #Generate treatment plan
    treatments, outputs, steps_used = agent.generate_treatment_plan(
        history_dict, 
        goal, 
        dataset_collection, 
        future_length=agent.future_length,
        early_stop=True
    )
    
    #Preparing Drawing Data
    goal_value = goal.cpu().numpy() if hasattr(goal, 'cpu') else goal
    
    #Get historical output
    if 'outputs' in history_dict:
        history_outputs = history_dict['outputs'][0, :, 0]
    elif 'prev_outputs' in history_dict:
        history_outputs = history_dict['prev_outputs'][0, :, 0]
    else:
        history_outputs = []
    
    #Get historical healing
    history_treatments = history_dict['current_treatments'][0]
    
    #time series
    history_time = np.arange(len(history_outputs))
    future_time = np.arange(len(history_outputs), len(history_outputs) + len(outputs))
    
    plt.figure(figsize=(15, 8))
    
    #Draw state changes
    plt.subplot(2, 1, 1)
    plt.plot(history_time, history_outputs, 'o-', label='Historical Output')
    plt.plot(future_time, [o[0] for o in outputs], 'o-', label='Forecast Output')
    plt.axhline(y=goal_value, color='r', linestyle='--', label='Target')
    plt.xlabel('Time step:')
    plt.ylabel('State value')
    plt.title('State change trajectory' if title is None else title)
    plt.legend()
    plt.grid(True)
    
    #Draw treatment plan
    plt.subplot(2, 1, 2)
    
    #Historical Healing
    for i in range(history_treatments.shape[1]):
        plt.plot(history_time, history_treatments[:, i], 'o-', label=f'Historical Healing {i +1}')
    
    #Predicted treatment
    for i in range(treatments.shape[1]):
        plt.plot(future_time, treatments[:, i], 'o-', label=f'Predictive treatment {i +1}')
    
    plt.xlabel('Time step:')
    plt.ylabel('Treatment intensity')
    plt.title('<g id="Bold">Treatment Planning:</g>')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{agent.algorithm}_treatment_plan.png")
    plt.show()
    
    #Back to Evaluation Metrics
    mse = ((outputs[-1] - goal_value) ** 2).mean()
    rmse = np.sqrt(mse)
    
    print(f"Goal: {goal_value}")
    print(f"Final forecast: {outputs [-1]}")
    print(f"MSE: {mse:.6f}")
    print(f"RMSE: {rmse:.6f}")
    print(f"Steps: {steps_used} of {agent.future_length}")
    
    return treatments, outputs, mse, steps_used