import numpy as np
import casadi as ca
import math
import matplotlib.pyplot as plt
from dataclasses import dataclass

# -----------------------------
# 1) Dynamics
# -----------------------------
class UnicycleDynamics:
    def __init__(self, dt: float):
        self.dt = float(dt)

    def step(self, x, u):
        """
        x = [x, y, theta], u = [v, omega]
        discrete-time model:
            x_{t+1} = x_t + Ts * v * cos(theta)
            y_{t+1} = y_t + Ts * v * sin(theta)
            th_{t+1} = th_t + Ts * omega
        """
        x_next = ca.vertcat(
            x[0] + self.dt * u[0] * ca.cos(x[2]),
            x[1] + self.dt * u[0] * ca.sin(x[2]),
            x[2] + self.dt * u[1],
        )
        return x_next

# -----------------------------
# 2) Obstacles (circles)
# -----------------------------
@dataclass
class CircleObstacle:
    cx: float
    cy: float
    r: float  # radius (excluding robot radius)

# -----------------------------
# 3) MPC Controller
# -----------------------------
class MPCController:
    def __init__(
        self,
        dyn: UnicycleDynamics,
        horizon: int = 25,
        v_bounds=(0.0, 1.2),
        w_bounds=(-1.2, 1.2),
        robot_radius: float = 0.15,
        obstacles: list = None,
        weights=None,
        slack_weights=(100.0, 500.0),  # (L1, L2) for slacks - reduce weights to avoid numerical issues
    ):
        self.dyn = dyn
        self.N = int(horizon)
        self.vmin, self.vmax = v_bounds
        self.wmin, self.wmax = w_bounds
        self.robot_radius = float(robot_radius)
        self.obstacles = obstacles or []

        # Cost weights
        if weights is None:
            weights = {
                "Q_pos": 8.0,        # stage pos error
                "Q_th": 0.05,        # stage heading error
                "R_v": 0.05,         # control effort v
                "R_w": 0.02,         # control effort w
                "Rd_v": 0.10,        # delta-u (smooth) v
                "Rd_w": 0.05,        # delta-u (smooth) w
                "Qf_pos": 120.0,     # terminal pos error
                "Qf_th": 0.2,        # terminal heading error
            }
        self.W = weights
        self.slack_L1, self.slack_L2 = slack_weights

        # Build the casadi Opti model once; later only update parameters and initial values
        self._build_opti()

        # warm-start storage
        self.last_U = None
        self.last_X = None

    @staticmethod
    def _angle_wrap(expr):
        # wrap to [-pi, pi] using atan2(sin, cos) — works in CasADi symbolic graph
        return ca.atan2(ca.sin(expr), ca.cos(expr))

    def _build_opti(self):
        N = self.N
        n_state = 3
        n_ctrl = 2
        n_obs = len(self.obstacles)

        opti = ca.Opti()
        X = opti.variable(n_state, N + 1)   # states
        U = opti.variable(n_ctrl, N)        # controls
        S = opti.variable(max(1, n_obs), N + 1)  # slacks per obstacle per stage (if no obs, 1xN+1 dummy)

        # parameters
        x0 = opti.parameter(n_state)        # current state
        goal = opti.parameter(2)            # goal (xg, yg)

        # initial condition
        opti.subject_to(X[:, 0] == x0)

        # bounds on inputs
        opti.subject_to(opti.bounded(self.vmin, U[0, :], self.vmax))
        opti.subject_to(opti.bounded(self.wmin, U[1, :], self.wmax))

        # slacks nonnegative (vectorized)
        opti.subject_to(ca.vec(S) >= 0)

        # dynamics constraints
        for k in range(N):
            x_next = self.dyn.step(X[:, k], U[:, k])
            opti.subject_to(X[:, k + 1] == x_next)

        # obstacle constraints (softened by slack S)
        # d^2 - R^2 + s >= 0  (=> d^2 >= R^2 - s)
        if n_obs > 0:
            for k in range(N + 1):
                for i, obs in enumerate(self.obstacles):
                    dx = X[0, k] - obs.cx
                    dy = X[1, k] - obs.cy
                    d2 = dx * dx + dy * dy
                    R = obs.r + self.robot_radius
                    opti.subject_to(d2 - (R * R) + S[i, k] >= 0)

        # objective - simplified version, temporarily remove angle cost
        J = 0
        for k in range(N):
            # position error
            ex = X[0, k] - goal[0]
            ey = X[1, k] - goal[1]
            epos2 = ex * ex + ey * ey

            # simplified stage cost - only includes position and control cost
            J += self.W["Q_pos"] * epos2
            J += self.W["R_v"] * (U[0, k] ** 2) + self.W["R_w"] * (U[1, k] ** 2)

            # control smoothness
            if k > 0:
                J += self.W["Rd_v"] * ((U[0, k] - U[0, k - 1]) ** 2)
                J += self.W["Rd_w"] * ((U[1, k] - U[1, k - 1]) ** 2)

            # slacks penalty at stage k
            if len(self.obstacles) > 0:
                J += self.slack_L1 * ca.sum1(S[:, k]) + self.slack_L2 * ca.sum1(S[:, k] ** 2)

        # simplified terminal cost - only includes position
        eN_x = X[0, N] - goal[0]
        eN_y = X[1, N] - goal[1]
        eN_pos2 = eN_x * eN_x + eN_y * eN_y
        J += self.W["Qf_pos"] * eN_pos2
        
        if len(self.obstacles) > 0:
            J += self.slack_L1 * ca.sum1(S[:, N]) + self.slack_L2 * ca.sum1(S[:, N] ** 2)

        opti.minimize(J)

        # IPOPT options - maximize numerical stability
        p_opts = {"expand": True}
        s_opts = {
            "max_iter": 100,
            "print_level": 0,  # turn off detailed output
            "tol": 1e-2,  # further relax tolerance
            "acceptable_tol": 5e-2,
            "mu_init": 1e-2,  # larger initial barrier parameter
            "acceptable_iter": 3,  # faster acceptance of solutions
            "nlp_scaling_method": "none",  # turn off scaling to avoid numerical issues
            "linear_solver": "mumps",
            "bound_relax_factor": 1e-6,  # boundary relaxation
            "honor_original_bounds": "no",  # allow slight violation of bounds
            "check_derivatives_for_naninf": "yes",  # check for NaN/Inf in derivatives
            "derivative_test": "none",  # skip derivative test
            "fast_step_computation": "yes",  # fast step computation
        }
        opti.solver("ipopt", p_opts, s_opts)

        # save members
        self.opti = opti
        self.X = X
        self.U = U
        self.S = S
        self.p_x0 = x0
        self.p_goal = goal

    def _set_initial_guess(self, x0_val, goal_val):
        # Straight-line coarse guess for X and gentle controls
        N = self.N
        if self.last_X is not None and self.last_U is not None:
            # warm start by shifting previous plan
            X_guess = np.hstack([self.last_X[:, 1:], self.last_X[:, -1:]])
            U_guess = np.hstack([self.last_U[:, 1:], self.last_U[:, -1:]])
        else:
            X_guess = np.zeros((3, N + 1))
            U_guess = np.zeros((2, N))
            X_guess[:, 0] = x0_val
            # linear interpolation to goal
            for k in range(1, N + 1):
                alpha = k / N
                X_guess[0, k] = (1 - alpha) * x0_val[0] + alpha * goal_val[0]
                X_guess[1, k] = (1 - alpha) * x0_val[1] + alpha * goal_val[1]
                # simplified angle initial guess, avoid atan2 issues
                dx = goal_val[0] - x0_val[0]
                dy = goal_val[1] - x0_val[1]
                if abs(dx) > 1e-6 or abs(dy) > 1e-6:
                    X_guess[2, k] = math.atan2(dy, dx)
                else:
                    X_guess[2, k] = x0_val[2]  # keep current angle
            
            # more conservative control initial guess
            U_guess[0, :] = np.clip(0.5 * np.ones(N), self.vmin, self.vmax)  # reduce initial velocity
            U_guess[1, :] = 0.0

        self.opti.set_initial(self.X, X_guess)
        self.opti.set_initial(self.U, U_guess)
        # set safer initial values for slack variables
        if self.S.shape[0] > 0:
            # calculate initial slack variable values based on distance between initial state and obstacles
            S_init = np.ones(self.S.shape) * 0.1  # default value
            if len(self.obstacles) > 0:
                for k in range(self.N + 1):
                    for i, obs in enumerate(self.obstacles):
                        dx = X_guess[0, k] - obs.cx
                        dy = X_guess[1, k] - obs.cy
                        d2 = dx * dx + dy * dy
                        R = obs.r + self.robot_radius
                        violation = max(0, R*R - d2)  # if constraint is violated
                        S_init[i, k] = max(0.01, violation + 0.1)  # ensure positive value
            self.opti.set_initial(self.S, S_init)

    def solve(self, x0_val: np.ndarray, goal_val: np.ndarray):
        """
        Solve one MPC step. Returns u0, pred_traj (N+1 x 3), and a flag.
        """
        # numerical safety check
        if np.any(np.isnan(x0_val)) or np.any(np.isinf(x0_val)):
            print("[MPC] Invalid x0_val contains NaN or Inf")
            return False, None, None
        if np.any(np.isnan(goal_val)) or np.any(np.isinf(goal_val)):
            print("[MPC] Invalid goal_val contains NaN or Inf")
            return False, None, None
            
        self.opti.set_value(self.p_x0, x0_val)
        self.opti.set_value(self.p_goal, goal_val)
        self._set_initial_guess(x0_val, goal_val)

        try:
            sol = self.opti.solve()
            
            # check solve status
            if sol.stats()['return_status'] not in ['Solve_Succeeded', 'Solved_To_Acceptable_Level']:
                print(f"[MPC] Solver warning: {sol.stats()['return_status']}")
                # still try to use the solution, as it might be acceptable
            
            X_sol = sol.value(self.X)  # (3, N+1)
            U_sol = sol.value(self.U)  # (2, N)
            
            # check if solution contains NaN or infinity
            if np.any(np.isnan(X_sol)) or np.any(np.isinf(X_sol)) or \
               np.any(np.isnan(U_sol)) or np.any(np.isinf(U_sol)):
                print("[MPC] Solution contains NaN or Inf values")
                return False, None, None

            self.last_X = X_sol.copy()
            self.last_U = U_sol.copy()

            u0 = np.array([U_sol[0, 0], U_sol[1, 0]])
            pred = X_sol.T  # (N+1, 3)
            return True, u0, pred
        except Exception as e:
            # If IPOPT fails, return failure (upper level can use fallback strategy)
            print(f"[MPC] Solver failure: {str(e)}")
            # try to use debug information
            try:
                X_debug = self.opti.debug.value(self.X)
                U_debug = self.opti.debug.value(self.U)
                print(f"[MPC] Debug - X shape: {X_debug.shape}, U shape: {U_debug.shape}")
                print(f"[MPC] Debug - X has NaN: {np.any(np.isnan(X_debug))}, U has NaN: {np.any(np.isnan(U_debug))}")
            except:
                pass
            return False, None, None

# -----------------------------
# 4) Simple fallback controller (go-to-goal)
# -----------------------------
def fallback_go_to_goal(x, goal, v_max, w_max):
    k_v = 0.8
    k_w = 1.8
    dx, dy = goal[0] - x[0], goal[1] - x[1]
    th_ref = math.atan2(dy, dx)
    e_th = math.atan2(math.sin(th_ref - x[2]), math.cos(th_ref - x[2]))
    dist = math.hypot(dx, dy)

    v = np.clip(k_v * dist, 0.0, v_max)
    w = np.clip(k_w * e_th, -w_max, w_max)
    return np.array([v, w])

# -----------------------------
# 5) Simulation & plotting
# -----------------------------
def run_simulation():
    # --- parameters ---
    dt = 0.1
    horizon = 25
    vmax = 1.0
    wmax = 1.2
    robot_radius = 0.15

    # start & goal
    x0 = np.array([0.0, 0.0, 0.0])
    goal = np.array([3.0, 2.0])

    # obstacles (cx, cy, r) — small obstacles can be adjusted
    obstacles = [
        CircleObstacle(2, 1, 0.3),
    ]

    dyn = UnicycleDynamics(dt)
    mpc = MPCController(
        dyn,
        horizon=horizon,
        v_bounds=(0.0, vmax),
        w_bounds=(-wmax, wmax),
        robot_radius=robot_radius,
        obstacles=obstacles,
        slack_weights=(100.0, 500.0),
    )

    # --- closed-loop run ---
    T_max = 60.0
    steps = int(T_max / dt)
    traj = [x0.copy()]
    u_hist = []
    pred_traj_hist = None

    x = x0.copy()
    for t in range(steps):
        # stop condition
        if np.linalg.norm(x[:2] - goal) < 0.12:
            print(f"Reached goal at step {t}.")
            break

        ok, u_mpc, pred = mpc.solve(x, goal)
        if not ok or np.any(np.isnan(u_mpc)):
            # fallback once
            u = fallback_go_to_goal(x, goal, vmax, wmax)
        else:
            u = u_mpc
            pred_traj_hist = pred  # keep last preview for plotting

        # apply first control and step
        x = np.array(ca.vertcat(
            x[0] + dt * u[0] * math.cos(x[2]),
            x[1] + dt * u[0] * math.sin(x[2]),
            x[2] + dt * u[1],
        )).astype(float).flatten()

        traj.append(x.copy())
        u_hist.append(u.copy())

    traj = np.array(traj)
    u_hist = np.array(u_hist)
    np.save("traj.npy", traj)
    np.save('u_hist.npy', u_hist)

    plot_world(traj, goal, obstacles, robot_radius, pred_traj_hist)
    return traj, u_hist

def plot_world(traj, goal, obstacles, robot_r, pred=None):
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.plot(traj[:, 0], traj[:, 1], '-', lw=2, label='trajectory')
    ax.plot(traj[0, 0], traj[0, 1], 'go', ms=8, label='start')
    ax.plot(goal[0], goal[1], 'r*', ms=12, label='goal')

    # predicted horizon
    if pred is not None and hasattr(pred, 'size') and pred.size > 0:
        ax.plot(pred[:, 0], pred[:, 1], '--', lw=1.5, label='predicted (last step)')

    # obstacles with safety radius
    for i, obs in enumerate(obstacles):
        circle = plt.Circle((obs.cx, obs.cy), obs.r, fill=True, alpha=0.25, color='tab:orange')
        safety = plt.Circle((obs.cx, obs.cy), obs.r + robot_r, fill=False, ls='--', color='tab:red')
        ax.add_artist(circle)
        ax.add_artist(safety)
        ax.text(obs.cx, obs.cy, f"O{i+1}", ha='center', va='center')

    ax.set_aspect('equal', 'box')
    ax.grid(True, ls=':')
    ax.set_xlabel('x [m]')
    ax.set_ylabel('y [m]')
    ax.set_title('MPC for Unicycle')
    ax.legend(loc='best')
    plt.tight_layout()
    plt.show()

# -----------------------------
# main
# -----------------------------
if __name__ == "__main__":
    run_simulation()