"""
Training Procedure for Cat-Mouse Chase Environment with Termination Probability Commitment

Cat has 8 direction options (including diagonals).
Mouse moves every step, turns with probability p_turn (only away from cat).

Instead of fixed duration, agent commits to a TERMINATION PROBABILITY at option start.
Each step, option terminates with this probability (geometric distribution).

Key hypothesis: Higher p_turn → higher optimal termination probability
- p_turn=0: Mouse predictable → low term prob (stay in option longer)
- p_turn=1: Mouse unpredictable → high term prob (re-aim frequently)
"""

import numpy as np
import torch
from options.Option import Option
from options.IntraOptionPolicies.ConstantIntraOptionPolicy import ConstantIntraOptionPolicy
from options.Buffer.Buffer import Buffer
from MetaPolicies.NNMetaPolicyChase import (
    NNMetaPolicyChase, TERM_PROBS, POSSIBLE_ACTIONS
)
from env.open_grid.open_grid_chase import OpenGridChase
class DummyTermination:
    """Placeholder termination (not used in duration mode)."""
    def get_termination_probability(self, state, target):
        return 0.0

class TrainingProcedureChase:
    def __init__(
        self,
        p_turn: float = 0.0,
        switch_reward: float = -0.1,
        step_reward: float = -0.1,
        terminal_reward: float = 10.0,
        gamma: float = 0.99,
        epsilon: float = 0.1,
        num_episodes: int = 1000,
        max_episode_steps: int = 50,
        hidden_size: int = 64,
        lr_policy: float = 0.001,
        verbose: bool = True,
        print_every: int = 100,
    ):
        self.p_turn = p_turn
        self.switch_reward = switch_reward
        self.step_reward = step_reward
        self.terminal_reward = terminal_reward
        self.gamma = gamma
        self.epsilon = epsilon
        self.num_episodes = num_episodes
        self.max_episode_steps = max_episode_steps
        self.hidden_size = hidden_size
        self.lr_policy = lr_policy
        self.verbose = verbose
        self.print_every = print_every

        self.env = None
        self.buffer = None
        self.options = None
        self.meta_policy = None
        
        self.episode_returns = []
        self.avg_duration_history = []
        
    def reset(self):
        """Initialize all components."""
        self.buffer = Buffer()
        self.env = OpenGridChase(
            p_turn=self.p_turn,
            gamma=self.gamma,
            switch_reward=self.switch_reward,
            step_reward=self.step_reward,
            terminal_reward=self.terminal_reward
        )
        
        # Create 8 direction options for cat
        self.options = [
            Option(i, ConstantIntraOptionPolicy(action), DummyTermination())
            for i, action in enumerate(POSSIBLE_ACTIONS)
        ]
        
        self.meta_policy = NNMetaPolicyChase(
            self.options, hidden_size=self.hidden_size, learning_rate=self.lr_policy
        )
        
        self.episode_returns = []
        self.avg_term_prob_history = []

    def train(self):
        """Run training loop."""
        self.reset()
        
        for episode in range(self.num_episodes):
            self.buffer.clear()
            cat_state, mouse_state = self.env._reset()
            mouse_dir = self.env.mouse_direction
            done = False
            episode_steps = 0
            episode_term_probs = []
            
            while not done and episode_steps < self.max_episode_steps:
                if np.random.random() < self.epsilon:
                    option = self.options[np.random.randint(len(self.options))]
                    term_prob_idx = np.random.randint(len(TERM_PROBS))
                    term_prob = TERM_PROBS[term_prob_idx]
                else:
                    option, term_prob, term_prob_idx = self.meta_policy.choose_option_and_term_prob(
                        cat_state, mouse_state, mouse_dir
                    )
                
                episode_term_probs.append(term_prob)
                
                remaining = self.max_episode_steps - episode_steps
                
                steps_done, done = self.env.execute_option_with_term_prob(
                    option, term_prob, term_prob_idx, self.buffer, max_steps=remaining
                )
                episode_steps += steps_done
                
                cat_state = self.env.state
                mouse_state = self.env.mouse_state
                mouse_dir = self.env.mouse_direction
            
            episode_return = sum(exp.reward for exp in self.buffer.buffer)
            self.episode_returns.append(episode_return)
            
            truncated = not done and episode_steps >= self.max_episode_steps
            last_exp = self.buffer.buffer[-1] if self.buffer.buffer else None
            
            self.meta_policy.update(
                self.buffer, gamma=self.gamma, epsilon=self.epsilon,
                truncated=truncated,
                last_cat_state=last_exp.next_state if last_exp else None,
                last_mouse_state=last_exp.next_target_state if last_exp else None,
                last_mouse_dir=self.env.mouse_direction
            )
            
            avg_term_prob = np.mean(episode_term_probs) if episode_term_probs else 0
            self.avg_term_prob_history.append(avg_term_prob)
            
            if self.verbose and (episode + 1) % self.print_every == 0:
                recent_return = np.mean(self.episode_returns[-100:])
                recent_term_prob = np.mean(self.avg_term_prob_history[-100:])
                global_term_prob = self.meta_policy.get_average_term_prob()
                
                print(
                    f"[p_turn={self.p_turn}, sw={self.switch_reward}] "
                    f"Episode {episode+1}/{self.num_episodes}: "
                    f"Return={episode_return:.1f}, Avg100={recent_return:.1f}, "
                    f"TermProb={recent_term_prob:.2f}, GlobalTermProb={global_term_prob:.2f}"
                )
        
        return self.get_results()
    
    def get_results(self):
        """Return results dictionary."""
        term_prob_dist = self.meta_policy.get_term_prob_distribution()
        global_avg_term_prob = self.meta_policy.get_average_term_prob()
        
        print(f"\n*** TERM PROB DISTRIBUTION (p_turn={self.p_turn}, switch={self.switch_reward}) ***")
        for p, frac in sorted(term_prob_dist.items()):
            print(f"    p={p:.1f}: {frac*100:.1f}%")
        print(f"\n*** GLOBAL AVERAGE TERM PROB: {global_avg_term_prob:.3f} ***")
        print(f"    (Higher p_turn → higher term prob expected)\n")
        
        return {
            "p_turn": self.p_turn,
            "switch_reward": self.switch_reward,
            "episode_returns": np.array(self.episode_returns),
            "avg_term_prob_history": np.array(self.avg_term_prob_history),
            "final_avg_term_prob": global_avg_term_prob,
            "term_prob_distribution": term_prob_dist,
        }

