from stable_baselines3 import DQN 
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, StopTrainingOnRewardThreshold
import torch
import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)

import argparse
from agent import *
from envs import *
from warnings import filterwarnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import imageio
from typing import List, Dict
from gymnasium import spaces
import torch
import os
import imageio
import logging
import math
from stable_baselines3.common.monitor import Monitor
from dotenv import load_dotenv
load_dotenv()
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv

def test_model(model, env, base_dir="results", num_episodes=10, save_video=True, fps=30, grid_cols=None):
    """
    Test model with video recording and position tracking, saving a grid video
    
    Args:
        model: Trained model
        env: Environment instance
        base_dir: Base directory for saving results
        num_episodes: Number of episodes to run
        save_video: Whether to save videos
        fps: Frames per second for video saving
        grid_cols: Number of columns in the video grid (default: automatically calculated)
    """

    # Create directories
    results_dir = os.path.join(base_dir, "test_results")
    video_dir = os.path.join(results_dir, "videos")
    data_dir = os.path.join(results_dir, "data")
    print(f"logging to: {results_dir}") 
    for directory in [results_dir, video_dir, data_dir]:
        os.makedirs(directory, exist_ok=True)
    
    # Storage for plotting
    all_positions = []
    episode_summaries = []
    all_episode_frames = []  # Store frames from all episodes for grid video
    all_successes = [] 
    for episode in range(num_episodes):
        obs = env.reset()[0]
        done = False
        total_reward = 0
        steps = 0
        frames = []
        episode_data = []
        
        # Track positions for this episode
        x_positions = []
        y_positions = []
        
        while not done :
            # Get action from model
            action, _states = model.predict(obs, deterministic=True)
            action = action.item() if hasattr(action, 'item') else action
            
            # Record frame if video saving is enabled
            if save_video: 
                try:
                    frame = env.render()
                    frames.append(frame)
                except Exception as e:
                    print(f"Failed to render frame: {e}")
            
            # Execute action
            obs, reward, terminated, truncated, info = env.step(action)
            # print(reward)
            done = terminated or truncated
            # Collect step data
            step_data = {
                'episode': episode,
                'step': steps,
                'action': action,
                'reward': reward,
                'ant_x_position': info.get('ant_x_position', 0),
                'ant_y_position': info.get('ant_y_position', 0),
                'distance_from_origin': info.get('distance_from_origin', 0),
                'front_left_leg_height': info.get('front_left_leg_height', 0),
                'front_right_leg_height': info.get('front_right_leg_height', 0),
                'back_left_leg_height': info.get('back_left_leg_height', 0),
                'back_right_leg_height': info.get('back_right_leg_height', 0)
            }
            
            episode_data.append(step_data)
            x_positions.append(info.get('ant_x_position', 0))
            y_positions.append(info.get('ant_y_position', 0))
            
            total_reward += reward
            steps += 1
        
        # Save episode data
        episode_df = pd.DataFrame(episode_data)
        episode_df.to_csv(os.path.join(data_dir, f"episode_{episode}.csv"), index=False)
        
        # Store frames for the grid video
        if save_video and frames:
            all_episode_frames.append(frames)
        
        # Store positions for overall plot
        all_positions.append({
            'episode': episode,
            'x_positions': x_positions,
            'y_positions': y_positions,
        })
        
        # Store episode summary
        episode_summaries.append({
            'episode': episode,
            'total_reward': total_reward,
            'steps': steps,
            'final_x': x_positions[-1] if x_positions else 0,
            'final_y': y_positions[-1] if y_positions else 0,
            'max_distance': max([((x**2 + y**2)**0.5) for x, y in zip(x_positions, y_positions)])
        })
        
        print(f"Episode {episode + 1}: Reward = {total_reward}, Steps = {steps}")
        all_successes.append(info.get('success', False))
        print(all_successes) 
    print(np.mean(all_successes), np.std(all_successes)) 
    # Save episode summaries
    summary_df = pd.DataFrame(episode_summaries)
    summary_df.to_csv(os.path.join(results_dir, "episode_summaries.csv"), index=False)
    
    # Save the grid video
    if save_video and all_episode_frames:
        grid_video_path = os.path.join(video_dir, "grid_video.mp4")
        save_video_grid(
            frames_list=all_episode_frames,
            output_path=grid_video_path,
            fps=fps,
            grid_cols=grid_cols
        )
        print(f"Grid video saved to {grid_video_path}")
    
    return summary_df


def train(args, build_agent, total_timesteps, eval_freq, tensorboard_path, model_path):
    # Setup environment
    env = setup_hrl_environment(args)
    eval_env = setup_hrl_environment(args)
    # if env has actions then copy them to eval_env
    if hasattr(env, 'actions'):
        eval_env.actions = env.actions
        print(f"copying actions from env to eval_env") 
    eval_env = Monitor(eval_env)
 
    env_fn = lambda: env
    vec_env = DummyVecEnv([env_fn])
    vec_env = VecNormalize(vec_env, norm_obs=False, norm_reward=args.hrl_norm_rewards)
    
    eval_env_fn = lambda: eval_env
    eval_vec_env = DummyVecEnv([eval_env_fn])
    eval_vec_env = VecNormalize(eval_vec_env, norm_obs=False, norm_reward=args.hrl_norm_rewards)
    
    # Train MODEL
    print("Starting training...")
    model, callbacks = build_agent(vec_env, eval_vec_env, eval_freq, tensorboard_path=tensorboard_path)
    print(f"total_timesteps: {total_timesteps}") 
    # Learn with callbacks
    model.learn(total_timesteps=total_timesteps, callback=callbacks) 
    
    # Save the trained model
    model.save(model_path)
    
    vec_env.save(os.path.join(os.path.dirname(model_path), "vec_normalize_stats.pkl"))