import torch
import numpy as np
from train_bilevel import BilevelEnv
from sac import Actor   # Use the same Actor definition as in training
import sys
import os
import matplotlib.pyplot as plt
# Add path for e2e environment
sys.path.append(os.path.join(os.path.dirname(__file__), 'train_e2e'))
sys.path.append(os.path.join(os.path.dirname(__file__), 'train_bilevel_no_penalty'))
from env import UnicycleEnv
# Import MPC components
from main_mpc import MPCController, UnicycleDynamics, CircleObstacle
import math
from train_bilevel_no_penalty import BilevelEnv_no_penalty
from time import time


def test_bilevel():
    # Initialize environment
    env = BilevelEnv()
    s_dim = len(env._get_obs())
    a_dim = 2

    # Load trained policy
    actor = Actor(s_dim, a_dim)
    actor.load_state_dict(torch.load("trained_model/actor_bilevel_ep300.pt"))
    actor.eval()

    # Test one trajectory - run until timeout
    s = env.reset()
    traj = [env.x[:2].copy()]  # Only store [x, y] position
    hat_p_traj = []
    actions = []  # Store actions for energy consumption calculation
    done, info = False, {}
    ep_ret = 0
    start_time = 0  # Starting time
    start_time = time()
    while not done:
        with torch.no_grad():
            a, logp = actor(torch.FloatTensor(s).unsqueeze(0),deterministic=False)
            a = a[0].cpu().numpy()
        actions.append(a.copy())  # Store action for energy calculation
        s, r, done, info = env.step(a)
        ep_ret += r
        traj.append(env.x[:2].copy())  # Only store [x, y] position
        if env.hat_p_history:
            hat_p_traj.append(env.hat_p_history[-1].copy())
        
        if done:
            # Print intermediate results but keep running
            if info.get("success", False):
                print(f"Goal reached at step {env.t}")
            elif info.get("collision", False):
                print(f"Collision at step {env.t}")
    end_time = time()
    computation_time = (end_time - start_time) / len(actions)
    # Calculate evaluation metrics
    # 1. Travel Time - time steps * dt
    travel_time = env.t * env.dt
    
    # 2. Path Length - sum of distances between consecutive points
    traj = np.array(traj)
    path_length = 0
    for i in range(1, len(traj)):
        path_length += np.linalg.norm(traj[i] - traj[i-1])
    
    # 3. Energy Consumption - sum of squared control inputs
    actions = np.array(actions)
    energy_consumption = 0
    for action in actions:
        # Energy = sum of squared velocities (linear and angular)
        energy_consumption += action[0]**2 + action[1]**2
    energy_consumption *= env.dt  # Multiply by time step

    print("Episode return:", ep_ret)
    print("Episode info:", info)
    print("\n=== Evaluation Metrics ===")
    print(f"Travel Time: {travel_time:.2f} seconds")
    print(f"Path Length: {path_length:.2f} meters")
    print(f"Energy Consumption: {energy_consumption:.2f} J")
    
    traj = np.array(traj)
    hat_p_traj = np.array(hat_p_traj) if hat_p_traj else None
    
    plt.figure(figsize=(8, 6))
    plt.plot(traj[:, 0], traj[:, 1], 'b-', lw=2, label='Actual trajectory')
    plt.plot(traj[0, 0], traj[0, 1], 'go', ms=8, label='Start')
    plt.plot(env.goal[0], env.goal[1], 'r*', ms=12, label='Goal')
    
    if hat_p_traj is not None:
        plt.plot(hat_p_traj[:, 0], hat_p_traj[:, 1], 'r--', alpha=0.7, label='Target trajectory points')
        plt.scatter(hat_p_traj[:, 0], hat_p_traj[:, 1], c='red', s=15, alpha=0.6)
    
    # Obstacles
    for i, obs in enumerate(env.obstacles):
        circle = plt.Circle((obs.cx, obs.cy), obs.r, fill=True, alpha=0.25, color='orange')
        plt.gca().add_artist(circle)
        safety = plt.Circle((obs.cx, obs.cy), obs.r + env.robot_r, fill=False, ls='--', color='red')
        plt.gca().add_artist(safety)
        plt.text(obs.cx, obs.cy, f"O{i+1}", ha='center', va='center')
    
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    plt.xlabel('x [m]')
    plt.ylabel('y [m]')
    
    # Add evaluation metrics to the plot title
    title = f'Bilevel Control Test Results\n'
    title += f'Time: {travel_time:.2f}s | Path: {path_length:.2f}m | Energy: {energy_consumption:.2f}J | Computation Time: {computation_time:.4f}s'
    plt.title(title, fontsize=12)
    
    plt.legend()
    
    plt.tight_layout()
    plt.show()

    return travel_time, path_length, energy_consumption, computation_time


def test_bilevel_no_penalty():
    # Initialize environment
    env = BilevelEnv_no_penalty()
    s_dim = len(env._get_obs())
    a_dim = 2

    # Load trained policy (no penalty version)
    actor = Actor(s_dim, a_dim)
    actor.load_state_dict(torch.load("trained_model/actor_bilevel_no_penalty_ep500.pt"))
    actor.eval()

    # Test one trajectory - run until timeout
    s = env.reset()
    traj = [env.x[:2].copy()]  # Only store [x, y] position
    hat_p_traj = []
    actions = []  # Store actions for energy consumption calculation
    done, info = False, {}
    ep_ret = 0
    start_time = 0  # Starting time
    start_time = time()

    while not done:
        with torch.no_grad():
            a, logp = actor(torch.FloatTensor(s).unsqueeze(0),deterministic=False)
            a = a[0].cpu().numpy()
        actions.append(a.copy())  # Store action for energy calculation
        s, r, done, info = env.step(a)
        ep_ret += r
        traj.append(env.x[:2].copy())  # Only store [x, y] position
        if env.hat_p_history:
            hat_p_traj.append(env.hat_p_history[-1].copy())
        
        if done:
            # Print intermediate results but keep running
            if info.get("success", False):
                print(f"Goal reached at step {env.t}")
            elif info.get("collision", False):
                print(f"Collision at step {env.t}")
    end_time = time()
    computation_time = (end_time - start_time) / len(actions)
    # Calculate evaluation metrics
    # 1. Travel Time - time steps * dt
    travel_time = env.t * env.dt
    
    # 2. Path Length - sum of distances between consecutive points
    traj = np.array(traj)
    path_length = 0
    for i in range(1, len(traj)):
        path_length += np.linalg.norm(traj[i] - traj[i-1])
    
    # 3. Energy Consumption - sum of squared control inputs
    actions = np.array(actions)
    energy_consumption = 0
    for action in actions:
        # Energy = sum of squared velocities (linear and angular)
        energy_consumption += action[0]**2 + action[1]**2
    energy_consumption *= env.dt  # Multiply by time step

    print("Episode return:", ep_ret)
    print("Episode info:", info)
    print("\n=== Evaluation Metrics ===")
    print(f"Travel Time: {travel_time:.2f} seconds")
    print(f"Path Length: {path_length:.2f} meters")
    print(f"Energy Consumption: {energy_consumption:.2f} J")
    traj = np.array(traj)
    
    plt.figure(figsize=(8, 6))
    plt.plot(traj[:, 0], traj[:, 1], 'b-', lw=2, label='Actual trajectory')
    plt.plot(traj[0, 0], traj[0, 1], 'go', ms=8, label='Start')
    plt.plot(env.goal[0], env.goal[1], 'r*', ms=12, label='Goal')
    
    if hat_p_traj:
        hat_p_traj = np.array(hat_p_traj)
        plt.plot(hat_p_traj[:, 0], hat_p_traj[:, 1], 'r--', alpha=0.7, label='Target trajectory points')
        plt.scatter(hat_p_traj[:, 0], hat_p_traj[:, 1], c='red', s=15, alpha=0.6)
    
    # Obstacles
    for i, obs in enumerate(env.obstacles):
        circle = plt.Circle((obs.cx, obs.cy), obs.r, fill=True, alpha=0.25, color='orange')
        plt.gca().add_artist(circle)
        safety = plt.Circle((obs.cx, obs.cy), obs.r + env.robot_r, fill=False, ls='--', color='red')
        plt.gca().add_artist(safety)
        plt.text(obs.cx, obs.cy, f"O{i+1}", ha='center', va='center')
    
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    plt.xlabel('x [m]')
    plt.ylabel('y [m]')
    
    # Add evaluation metrics to the plot title
    title = f'Bilevel Control (No Penalty) Test Results\n'
    title += f'Time: {travel_time:.2f}s | Path: {path_length:.2f}m | Energy: {energy_consumption:.2f}J | Computation Time: {computation_time:.4f}s'
    plt.title(title, fontsize=12)
    
    plt.legend()

    plt.tight_layout()
    plt.show()

    return travel_time, path_length, energy_consumption, computation_time


def test_e2e():
    # Initialize environment
    env = UnicycleEnv()
    s_dim = len(env._get_obs())
    a_dim = 2

    # Load trained policy
    actor = Actor(s_dim, a_dim)
    actor.load_state_dict(torch.load("trained_model/actor_sac_e2e.pt"))
    actor.eval()

    # Test one trajectory - run until timeout
    s = env.reset()
    traj = [env.x[:2].copy()]  # Only store [x, y] position
    actions = []  # Store actions for energy consumption calculation
    done, info = False, {}
    ep_ret = 0
    start_time = time()
    while not done:
        with torch.no_grad():
            a, logp = actor(torch.FloatTensor(s).unsqueeze(0),deterministic=False)
            a = a[0].cpu().numpy()
        actions.append(a.copy())  # Store action for energy calculation
        s, r, done, info = env.step(a)
        ep_ret += r
        traj.append(env.x[:2].copy())  # Only store [x, y] position
        
        if done:
            # Print intermediate results but keep running
            if info.get("success", False):
                print(f"Goal reached at step {env.t}")
            elif info.get("collision", False):
                print(f"Collision at step {env.t}")
    end_time = time()
    computation_time = end_time - start_time
    # Calculate evaluation metrics
    # 1. Travel Time - time steps * dt
    travel_time = env.t * env.dt
    
    # 2. Path Length - sum of distances between consecutive points
    traj = np.array(traj)
    path_length = 0
    for i in range(1, len(traj)):
        path_length += np.linalg.norm(traj[i] - traj[i-1])
    
    # 3. Energy Consumption - sum of squared control inputs
    actions = np.array(actions)
    energy_consumption = 0
    for action in actions:
        # Energy = sum of squared velocities (linear and angular)
        energy_consumption += action[0]**2 + action[1]**2
    energy_consumption *= env.dt  # Multiply by time step

    print("Episode return:", ep_ret)
    print("Episode info:", info)
    print("\n=== Evaluation Metrics ===")
    print(f"Travel Time: {travel_time:.2f} seconds")
    print(f"Path Length: {path_length:.2f} meters")
    print(f"Energy Consumption: {energy_consumption:.2f} J")
    traj = np.array(traj)
    
    plt.figure(figsize=(8, 6))
    plt.plot(traj[:, 0], traj[:, 1], 'b-', lw=2, label='Actual trajectory')
    plt.plot(traj[0, 0], traj[0, 1], 'go', ms=8, label='Start')
    plt.plot(env.goal[0], env.goal[1], 'r*', ms=12, label='Goal')
    
    # Obstacles
    for i, obs in enumerate(env.obstacles):
        circle = plt.Circle((obs.cx, obs.cy), obs.r, fill=True, alpha=0.25, color='orange')
        plt.gca().add_artist(circle)
        safety = plt.Circle((obs.cx, obs.cy), obs.r + env.robot_r, fill=False, ls='--', color='red')
        plt.gca().add_artist(safety)
        plt.text(obs.cx, obs.cy, f"O{i+1}", ha='center', va='center')
    
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    plt.xlabel('x [m]')
    plt.ylabel('y [m]')
    
    # Add evaluation metrics to the plot title
    title = f'End-to-End Control Test Results\n'
    title += f'Time: {travel_time:.2f}s | Path: {path_length:.2f}m | Energy: {energy_consumption:.2f}J | Computation Time: {computation_time:.4f}s'
    plt.title(title, fontsize=12)
    
    plt.legend()
    plt.tight_layout()
    plt.show()

    return travel_time, path_length, energy_consumption, computation_time


def test_mpc():
    # MPC parameters
    dt = 0.1
    horizon = 35
    vmax = 1.0
    wmax = 1.2
    robot_radius = 0.15
    goal_tolerance = 0.2

    # Start & goal
    x0 = np.array([0.0, 0.0, 0.0])
    goal = np.array([3.0, 2.0])

    # Obstacles (same as other tests for consistency)
    obstacles = [CircleObstacle(2, 1, 0.3)]

    # Initialize MPC controller
    dyn = UnicycleDynamics(dt)
    mpc = MPCController(
        dyn,
        horizon=horizon,
        v_bounds=(0.0, vmax),
        w_bounds=(-wmax, wmax),
        robot_radius=robot_radius,
        obstacles=obstacles,
        slack_weights=(100.0, 500.0),
    )

    # Simulation parameters
    T_max = 60.0
    steps = int(T_max / dt)
    traj = [x0[:2].copy()]  # Only store [x, y] position for consistency
    actions = []  # Store actions for energy consumption calculation

    x = x0.copy()
    success = False
    collision = False
    start_time = time()
    for t in range(steps):
        # Check goal reached
        if np.linalg.norm(x[:2] - goal) < goal_tolerance:
            print(f"Goal reached at step {t}")
            success = True
            break
            
        # Check collision (simple check)
        for obs in obstacles:
            if np.linalg.norm(x[:2] - np.array([obs.cx, obs.cy])) < (obs.r + robot_radius - 0.01):
                print(f"Collision at step {t}")
                collision = True
                break
                
        if collision:
            break

        # Solve MPC
        ok, u_mpc, pred = mpc.solve(x, goal)
        if not ok or np.any(np.isnan(u_mpc)):
            # Fallback control
            dx = goal[0] - x[0]
            dy = goal[1] - x[1]
            dist = np.sqrt(dx**2 + dy**2)
            
            if dist > 0.01:
                desired_th = math.atan2(dy, dx)
                e_th = desired_th - x[2]
                # Normalize angle
                while e_th > math.pi:
                    e_th -= 2 * math.pi
                while e_th < -math.pi:
                    e_th += 2 * math.pi
                    
                v = min(vmax, dist)
                w = np.clip(2.0 * e_th, -wmax, wmax)
                u = np.array([v, w])
            else:
                u = np.array([0.0, 0.0])
        else:
            u = u_mpc

        actions.append(u.copy())  # Store action for energy calculation

        # Apply control and step dynamics
        x_new = np.array([
            x[0] + dt * u[0] * math.cos(x[2]),
            x[1] + dt * u[0] * math.sin(x[2]),
            x[2] + dt * u[1],
        ])
        x = x_new
        traj.append(x[:2].copy())  # Only store [x, y] position

    end_time = time()
    computation_time = (end_time - start_time) / len(actions)
    # Calculate evaluation metrics
    # 1. Travel Time - time steps * dt
    travel_time = len(actions) * dt
    
    # 2. Path Length - sum of distances between consecutive points
    traj = np.array(traj)
    path_length = 0
    for i in range(1, len(traj)):
        path_length += np.linalg.norm(traj[i] - traj[i-1])
    
    # 3. Energy Consumption - sum of squared control inputs
    actions = np.array(actions)
    energy_consumption = 0
    for action in actions:
        # Energy = sum of squared velocities (linear and angular)
        energy_consumption += action[0]**2 + action[1]**2
    energy_consumption *= dt  # Multiply by time step

    # Create info dictionary for consistency
    info = {
        "success": success,
        "collision": collision,
        "steps": len(actions)
    }

    print("Episode info:", info)
    print("\n=== Evaluation Metrics ===")
    print(f"Travel Time: {travel_time:.2f} seconds")
    print(f"Path Length: {path_length:.2f} meters")
    print(f"Energy Consumption: {energy_consumption:.2f} J")

    # Visualize trajectory
    traj = np.array(traj)
    
    plt.figure(figsize=(8, 6))
    plt.plot(traj[:, 0], traj[:, 1], 'b-', lw=2, label='Actual trajectory')
    plt.plot(traj[0, 0], traj[0, 1], 'go', ms=8, label='Start')
    plt.plot(goal[0], goal[1], 'r*', ms=12, label='Goal')
    
    # Obstacles
    for i, obs in enumerate(obstacles):
        circle = plt.Circle((obs.cx, obs.cy), obs.r, fill=True, alpha=0.25, color='orange')
        plt.gca().add_artist(circle)
        safety = plt.Circle((obs.cx, obs.cy), obs.r + robot_radius, fill=False, ls='--', color='red')
        plt.gca().add_artist(safety)
        plt.text(obs.cx, obs.cy, f"O{i+1}", ha='center', va='center')
    
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    plt.xlabel('x [m]')
    plt.ylabel('y [m]')
    
    # Add evaluation metrics to the plot title
    title = f'MPC Control Test Results\n'
    title += f'Time: {travel_time:.2f}s | Path: {path_length:.2f}m | Energy: {energy_consumption:.2f}J | Computation Time: {computation_time:.4f}s'
    plt.title(title, fontsize=12)
    
    plt.legend()
    plt.tight_layout()
    plt.show()

    return travel_time, path_length, energy_consumption, computation_time


if __name__ == "__main__":
    metrics = []
    for i in range(5):
        metric = []
        metric.append(test_bilevel())
        metric.append(test_bilevel_no_penalty())
        metric.append(test_e2e())
        metric.append(test_mpc())
        metrics.append(metric)

    mean = np.mean(metrics, axis=0)
    std = np.std(metrics, axis=0)
    
    # Create comparison table
    print("\n" + "=" * 90)
    print("PERFORMANCE COMPARISON TABLE")
    print("=" * 90)
    
    # Table header
    header = f"{'Method':<25} {'Travel Time (s)':<15} {'Path Length (m)':<17} {'Energy (J)':<12} {'Computation Time (s)':<12}"
    print(header)
    print("-" * 90)
    method_names = ["Bilevel Control", "Bilevel (No Penalty)", "End-to-End Control", "MPC Control"]
    # Table rows
    for i, (method, mean, std) in enumerate(zip(method_names, mean, std)):
        travel_time, path_length, energy, computation_time = mean
        travel_time_std, path_length_std, energy_std, computation_time_std = std
        row = f"{method:<25} {travel_time:<15.2f} ± {travel_time_std:<15.2f} {path_length:<17.2f} ± {path_length_std:<17.2f} {energy:<12.2f} ± {energy_std:<12.2f} {computation_time:<12.2f} ± {computation_time_std:<12.4f}"
        print(row)
    

    

    