import numpy as np
import os
import json
from datetime import datetime
import matplotlib.pyplot as plt


class DQNTrainer:
    def __init__(self, agent, env, evaluator, save_dir="results"):
        self.agent = agent
        self.env = env
        self.evaluator = evaluator
        self.save_dir = save_dir

    def train(self, episodes=1000, max_steps=200, save_interval=100, eval_seeds=[42, 520, 1231, 114, 886]):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.save_dir = os.path.join(self.save_dir, f"dqn_lunarlander_{timestamp}")
        os.makedirs(self.save_dir, exist_ok=True)

        rewards_history = []
        episode_lengths = []
        moving_avg_rewards = []
        evaluation_history = []

        for episode in range(1, episodes + 1):
            state, _ = self.env.reset()
            total_reward = 0
            done = False
            steps = 0

            while not done and steps < max_steps:
                action = self.agent.act(state)
                next_state, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated

                self.agent.memory.push(state, action, reward, next_state if not done else None, done)
                self.agent.learn()

                state = next_state
                total_reward += reward
                steps += 1

            rewards_history.append(total_reward)
            episode_lengths.append(steps)

            # Calculate moving average
            if episode >= 100:
                avg_reward = np.mean(rewards_history[-100:])
                moving_avg_rewards.append(avg_reward)
            else:
                moving_avg_rewards.append(np.mean(rewards_history))

            print(
                f"Episode {episode}/{episodes}, Reward: {total_reward:.2f}, Steps: {steps}, Epsilon: {self.agent.epsilon:.2f}")

            # Periodic evaluation
            if episode % save_interval == 0 or episode == episodes:
                eval_results = self.evaluator.evaluate(self.agent, eval_seeds, max_steps)
                evaluation_history.append({
                    'episode': episode,
                    'metrics': eval_results
                })

                print(f"\nEvaluation at Episode {episode}:")
                print(f"  Mean Reward: {eval_results['mean_reward']:.2f}")
                print(f"  Mean Fuel: {eval_results['mean_fuel']:.2f}")
                print(f"  Success Rate: {eval_results['success_rate']:.2f}")
                print(f"  NWS: {eval_results['nws']:.2f}")

                # Save evaluation image
                img_data = self.evaluator.save_eval_image(eval_results['worst_case_image'],
                                                          self.save_dir,
                                                          f"eval_ep{episode}_worst.png")

                # Save model
                model_path = os.path.join(self.save_dir, f"dqn_model_ep{episode}.pth")
                self.agent.save(model_path)
                print(f"Model saved to {model_path}\n")

        # Save final model and training data
        self._save_training_results(rewards_history, episode_lengths,
                                    moving_avg_rewards, evaluation_history, eval_seeds)

        return self.save_dir

    def _save_training_results(self, rewards_history, episode_lengths,
                               moving_avg_rewards, evaluation_history, eval_seeds):
        final_model_path = os.path.join(self.save_dir, "dqn_model_final.pth")
        self.agent.save(final_model_path)

        training_data = {
            'rewards': rewards_history,
            'episode_lengths': episode_lengths,
            'moving_avg_rewards': moving_avg_rewards,
            'evaluation_history': evaluation_history,
            'env_seeds': eval_seeds
        }
        with open(os.path.join(self.save_dir, 'training_history.json'), 'w') as f:
            json.dump(training_data, f)

        # Plot training results
        self._plot_training_results(rewards_history, moving_avg_rewards, evaluation_history)
        self._plot_evaluation_metrics(evaluation_history)
        self._plot_best_nws_progress(evaluation_history)

    def _plot_training_results(self, rewards_history, moving_avg_rewards, evaluation_history):
        plt.figure(figsize=(12, 6))
        plt.plot(rewards_history, alpha=0.6, label='Episode Reward')
        plt.plot(moving_avg_rewards, 'r-', linewidth=2, label='Moving Avg (100 episodes)')

        # Mark evaluation points
        eval_episodes = [e['episode'] for e in evaluation_history]
        eval_scores = [e['metrics']['nws'] for e in evaluation_history]
        plt.scatter(eval_episodes, [moving_avg_rewards[e - 1] for e in eval_episodes],
                    c='green', marker='o', label='Evaluation Points')

        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.title('DQN Training Performance with Periodic Evaluation')
        plt.legend()
        plt.grid()
        plt.savefig(os.path.join(self.save_dir, 'training_plot.png'))
        plt.close()

    def _plot_evaluation_metrics(self, evaluation_history):
        plt.figure(figsize=(12, 8))

        plt.subplot(2, 2, 1)
        plt.plot([e['episode'] for e in evaluation_history],
                 [e['metrics']['mean_reward'] for e in evaluation_history], 'o-')
        plt.xlabel('Episode')
        plt.ylabel('Mean Reward')
        plt.title('Evaluation Mean Reward')
        plt.grid()

        plt.subplot(2, 2, 2)
        plt.plot([e['episode'] for e in evaluation_history],
                 [e['metrics']['mean_fuel'] for e in evaluation_history], 'o-')
        plt.xlabel('Episode')
        plt.ylabel('Mean Fuel')
        plt.title('Evaluation Mean Fuel')
        plt.grid()

        plt.subplot(2, 2, 3)
        plt.plot([e['episode'] for e in evaluation_history],
                 [e['metrics']['success_rate'] for e in evaluation_history], 'o-')
        plt.xlabel('Episode')
        plt.ylabel('Success Rate')
        plt.title('Evaluation Success Rate')
        plt.grid()

        plt.subplot(2, 2, 4)
        plt.plot([e['episode'] for e in evaluation_history],
                 [e['metrics']['nws'] for e in evaluation_history], 'o-')
        plt.xlabel('Episode')
        plt.ylabel('NWS')
        plt.title('Normalized Weighted Score')
        plt.grid()

        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, 'evaluation_metrics.png'))
        plt.close()

    def _plot_best_nws_progress(self, evaluation_history):
        best_nws = -float('inf')
        best_nws_history = []
        episodes = []

        for eval_point in evaluation_history:
            current_nws = eval_point['metrics']['nws']
            if current_nws > best_nws:
                best_nws = current_nws
                best_nws_history.append(best_nws)
                episodes.append(eval_point['episode'])
            else:
                best_nws_history.append(best_nws)
                episodes.append(eval_point['episode'])

        plt.figure(figsize=(10, 6))
        plt.plot(episodes, best_nws_history, 'g-o', linewidth=2, markersize=8)

        # Annotate the best point
        max_idx = np.argmax(best_nws_history)
        plt.annotate(f'Best NWS: {best_nws_history[max_idx]:.3f}',
                     xy=(episodes[max_idx], best_nws_history[max_idx]),
                     xytext=(10, 10), textcoords='offset points',
                     bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                     arrowprops=dict(arrowstyle='->'))

        plt.xlabel('Episode')
        plt.ylabel('Best NWS')
        plt.title('Progression of Best Normalized Weighted Score')
        plt.grid(True)

        # Save the plot
        plot_path = os.path.join(self.save_dir, 'best_nws_progression.png')
        plt.savefig(plot_path)
        plt.close()