#!/usr/bin/env python3
"""
Bilevel Control System Training - Using SAC Algorithm
Upper level: SAC policy generates target trajectory points hat_p
Lower level: QP solver tracks trajectory points
"""

import numpy as np
import torch
import matplotlib.pyplot as plt
import math
import casadi as ca
from datetime import datetime
import os
from sac import SAC, Replay

class Circle:
    def __init__(self, cx, cy, r):
        self.cx, self.cy, self.r = cx, cy, r


class BilevelEnv:
    """
    Bilevel control environment:
    - Upper level RL: generates next target trajectory point hat_p
    - Lower level QP: single-step MPC solves control inputs to track trajectory points
    """
    def __init__(self, dt=0.1, vmax=1.0, wmax=1.0, max_steps=200, max_traj_dist=0.1):
        self.dt = dt
        self.vmax = vmax
        self.wmax = wmax
        self.max_steps = max_steps
        self.max_traj_dist = max_traj_dist  # maximum distance between trajectory points
        
        # Environment setup
        self.start = np.array([0., 0., 0.], dtype=np.float32)
        self.goal = np.array([3., 2.], dtype=np.float32)
        self.obstacles = [Circle(2, 1, 0.3)]
        self.robot_r = 0.15
        self.goal_tol = 0.2
        
        # QP solver parameters
        self.cp = 1.0    # position tracking weight
        self.cu = 0.01   # control input weight
        self.d_safe = 0.1  # safety distance
        
        # Trajectory recording
        self.trajectory = []
        self.hat_p_history = []  # target trajectory points history
        self.actions = []
        self.rewards = []
        self.episode_info = {}
        
        self.reset()
    
    def reset(self):
        """Reset environment"""
        self.x = self.start.copy()
        self.t = 0
        
        # Clear trajectory records
        self.trajectory = [self.x.copy()]
        self.hat_p_history = []
        self.actions = []
        self.rewards = []
        self.episode_info = {
            'start_time': datetime.now().isoformat(),
            'start_pos': self.start.copy(),
            'goal_pos': self.goal.copy(),
            'obstacles': [(o.cx, o.cy, o.r) for o in self.obstacles],
            'robot_radius': self.robot_r,
            'dt': self.dt
        }
        
        return self._get_obs()
    
    def _get_obs(self):
        """Get observation vector: [x, y, θ, goal_x, goal_y, obs1_x, obs1_y, obs1_r, ...]"""
        obs = [self.x[0], self.x[1], self.x[2], self.goal[0], self.goal[1]]
        for obstacle in self.obstacles:
            obs.extend([obstacle.cx, obstacle.cy, obstacle.r])
        return np.array(obs, dtype=np.float32)
    
    def _collision_check(self, pos):
        """Check if position collides with obstacles"""
        for obs in self.obstacles:
            dist = np.linalg.norm(pos - np.array([obs.cx, obs.cy]))
            if dist <= (obs.r + self.robot_r):
                return True
        return False
    
    def _solve_qp(self, current_state, hat_p):
        """
        Solve single-step MPC control input using QP
        
        Args:
            current_state: current state [x, y, θ]
            hat_p: target trajectory point [x_target, y_target]
            
        Returns:
            u: control input [v, ω]
        """
        # Create optimization variables
        opti = ca.Opti()
        u = opti.variable(2)  # [v, ω]
        t1 = opti.variable(1)
        # t2 = opti.variable(1)
        
        # Current state
        x, y, theta = current_state
        alpha = np.load('alpha.npy')
        beta = np.load('beta.npy')

        # Predict next step state
        x_next = x + self.dt * u[0] * ca.cos(theta)
        y_next = y + self.dt * u[0] * ca.sin(theta)
        theta_next = theta + self.dt * u[1]
        
        p_next = ca.vertcat(x_next, y_next)
        hat_p_ca = ca.vertcat(hat_p[0], hat_p[1])
        
        # Calculate desired heading towards target
        direction_to_target = hat_p_ca - ca.vertcat(x, y)
        desired_theta = ca.atan2(direction_to_target[1], direction_to_target[0])
        
        # Objective function: track trajectory point + heading alignment + control regularization
        position_cost = self.cp * ca.sumsqr(p_next - hat_p_ca)
        heading_cost = 0.5 * (theta_next - desired_theta)**2  # penalize heading error
        control_cost = self.cu * ca.sumsqr(u)
        penalty_cost = 0
        
        # cost = position_cost + control_cost + heading_cost + t1 + t2
        cost = position_cost + control_cost + heading_cost + t1
        # cost = position_cost + control_cost + t
        opti.minimize(cost)
        
        # Control constraints
        opti.subject_to(u[0] >= 0.0)          # v >= 0
        opti.subject_to(u[0] <= self.vmax)    # v <= vmax
        opti.subject_to(u[1] >= -self.wmax)   # ω >= -wmax
        opti.subject_to(u[1] <= self.wmax)    # ω <= wmax
        
        # Obstacle constraints
        for obs in self.obstacles:
            obs_pos = ca.vertcat(obs.cx, obs.cy)
            slack = opti.variable()        # 松弛变量 >= 0
            opti.subject_to(slack >= 0)
            dist_constraint = ca.sumsqr(p_next - obs_pos) + slack >= (obs.r + self.robot_r + self.d_safe)**2
            opti.subject_to(dist_constraint)
            cost += 1000 * slack   # 大惩罚，尽量不违反约束

        opti.subject_to(t1 >= 0)
        opti.subject_to(t1 >= alpha[0,0]*u[0] + alpha[0,1]*u[1] - beta[0])

        # opti.subject_to(t1 >= 0)
        # opti.subject_to(t1 >= alpha[0,0]*u[0] + alpha[0,1]*u[1] - beta[0])
        # opti.subject_to(t2 >= 0)
        # opti.subject_to(t2 >= alpha[1,0]*u[0] + alpha[1,1]*u[1] - beta[1])
        
        # Solver settings
        opts = {
            'ipopt.print_level': 0,
            'print_time': 0,
            'ipopt.max_iter': 100,
            'ipopt.tol': 1e-4
        }
        opti.solver('ipopt', opts)
        
        # Initial guess - better initial guess for angular velocity
        # Calculate rough initial angular velocity needed
        direction = hat_p - np.array([x, y])
        if np.linalg.norm(direction) > 1e-6:
            target_heading = np.arctan2(direction[1], direction[0])
            angle_diff = target_heading - theta
            # Normalize angle difference to [-pi, pi]
            while angle_diff > np.pi:
                angle_diff -= 2 * np.pi
            while angle_diff < -np.pi:
                angle_diff += 2 * np.pi
            initial_omega = np.clip(angle_diff / self.dt, -self.wmax, self.wmax)
        else:
            initial_omega = 0.0
        
        opti.set_initial(u, [0.5, initial_omega])
        
        # Solve
        sol = opti.solve()
        u_opt = sol.value(u)
        return np.array([u_opt[0], u_opt[1]])
    
    
    def step(self, hat_p_action):
        """
        Execute one step
        
        Args:
            hat_p_action: target trajectory point offset output by upper-level RL [dx, dy]
            
        Returns:
            next_state, reward, done, info
        """
        self.t += 1
        
        # Convert action to target trajectory point
        current_pos = self.x[:2]
        
        # Limit trajectory point distance
        hat_p_offset = np.array(hat_p_action[:2])
        offset_norm = np.linalg.norm(hat_p_offset)
        if offset_norm > self.max_traj_dist:
            hat_p_offset = hat_p_offset / offset_norm * self.max_traj_dist
        
        hat_p = current_pos + hat_p_offset
        
        # Solve control input using QP
        u = self._solve_qp(self.x, hat_p)
        
        # Execute dynamics
        prev_pos = self.x[:2].copy()
        prev_dist = np.linalg.norm(prev_pos - self.goal)
        
        self.x[0] += self.dt * u[0] * math.cos(self.x[2])
        self.x[1] += self.dt * u[0] * math.sin(self.x[2])
        self.x[2] += self.dt * u[1]
        
        current_dist = np.linalg.norm(hat_p - self.goal)
        
        # Compute reward
        reward = self._compute_reward(prev_pos, self.x[:2], hat_p, u, prev_dist, current_dist)
        
        # Check termination conditions
        done = False
        info = {}
        
        if current_dist < self.goal_tol:
            reward += 20.0
            done = True
            info["success"] = True
        
        if self._collision_check(self.x[:2]):
            reward -= 20.0
            done = True
            info["collision"] = True
        
        if self.t >= self.max_steps:
            done = True
            info["timeout"] = True
        
        # Record trajectory
        self.trajectory.append(self.x.copy())
        self.hat_p_history.append(hat_p.copy())
        self.actions.append(u.copy())
        self.rewards.append(float(reward))
        
        # Update episode information
        if done:
            self.episode_info.update({
                'end_time': datetime.now().isoformat(),
                'total_steps': self.t,
                'final_distance': float(current_dist),
                'total_reward': sum(self.rewards),
                'success': info.get("success", False),
                'collision': info.get("collision", False),
                'timeout': info.get("timeout", False)
            })
        
        return self._get_obs(), float(reward), done, info
    
    def _compute_reward(self, prev_pos, current_pos, hat_p, u, prev_dist, current_dist):
        """Compute reward function"""
        # 1. Goal approach reward
        progress = prev_dist - current_dist
        if progress > 0:
            goal_reward = 2.0 * progress
        else:
            goal_reward = 3.0 * progress  # penalty for moving away from goal
        
        # 2. Trajectory tracking reward
        #tracking_error = np.linalg.norm(current_pos - hat_p)
        #tracking_reward = -1.0 * tracking_error
        
        # 3. Obstacle penalty
        obstacle_penalty = 0.0
        for obs in self.obstacles:
            obs_pos = np.array([obs.cx, obs.cy])
            dist_to_obs = np.linalg.norm(current_pos - obs_pos)
            safety_dist = obs.r + self.robot_r + 0.2
            if dist_to_obs < safety_dist:
                obstacle_penalty -= 3.0 * (safety_dist - dist_to_obs)
        
        # 4. Control input penalty
        #control_penalty = -0.01 * (u[0]**2 + u[1]**2)
        
        total_reward = goal_reward + obstacle_penalty
        return total_reward


def train(epochs=1000):
    """Train bilevel control system - consistent style with train_e2e"""
    env = BilevelEnv(max_traj_dist=0.1)
    s_dim = len(env._get_obs())
    a_dim = 2
    agent = SAC(s_dim, a_dim)
    buf = Replay(s_dim, a_dim)
    
    for ep in range(epochs):
        s = env.reset()
        ep_ret = 0
        traj = [env.x.copy()]
        done = False
        info = {}
        
        while not done:
            a = agent.act(s)
            s2, r, done, info = env.step(a)
            ep_ret += r
            buf.store(s, a, r, s2, float(done))
            s = s2
            traj.append(env.x.copy())
            if buf.size > 400:
                agent.update(buf)
        
        print(f"Epoch {ep}, Return {ep_ret:.2f}, Done info {info}")
        
        # Periodically save model
        if (ep + 1) % 100 == 0:
            torch.save(agent.actor.state_dict(), f"trained_model/actor_bilevel_ep{ep+1}.pt")
    
    torch.save(agent.actor.state_dict(), "trained_model/actor_bilevel.pt")
    return agent




if __name__ == "__main__":
    train()
