import math, numpy as np
import matplotlib.pyplot as plt
import os
import json
from datetime import datetime

class Circle:
    def __init__(self, cx, cy, r):
        self.cx, self.cy, self.r = cx, cy, r

class UnicycleEnv:
    def __init__(self, dt=0.1, vmax=1.0, wmax=1.0, max_steps=200):
        self.dt=dt; self.vmax=vmax; self.wmax=wmax; self.max_steps=max_steps
        self.start=np.array([0.,0.,0.],dtype=np.float32)
        self.goal =np.array([3.,2.],dtype=np.float32)
        self.obstacles=[Circle(2,1,0.3)]
        self.robot_r=0.15; self.goal_tol=0.2
        
        # 轨迹记录
        self.trajectory = []
        self.actions = []
        self.rewards = []
        self.episode_info = {}
        
        self.reset()

    def reset(self):
        self.x=self.start.copy(); self.t=0
        
        # 清空轨迹记录
        self.trajectory = [self.x.copy()]  # 记录起始位置
        self.actions = []
        self.rewards = []
        self.episode_info = {
            'start_time': datetime.now().isoformat(),
            'start_pos': self.start.copy(),
            'goal_pos': self.goal.copy(),
            'obstacles': [(o.cx, o.cy, o.r) for o in self.obstacles],
            'robot_radius': self.robot_r,
            'dt': self.dt
        }
        
        return self._get_obs()

    def _get_obs(self):
        obs=[self.x[0],self.x[1],self.x[2],self.goal[0],self.goal[1]]
        for o in self.obstacles: obs.extend([o.cx,o.cy,o.r])
        return np.array(obs,dtype=np.float32)

    def _collision(self):
        for o in self.obstacles:
            if np.linalg.norm(self.x[:2]-np.array([o.cx,o.cy])) <= (o.r+self.robot_r):
                return True
        return False

    def step(self,a):
        self.t+=1
        v=(np.tanh(a[0])+1)/2*self.vmax
        w=np.tanh(a[1])*self.wmax
        p_prev=self.x[:2].copy(); d_prev=np.linalg.norm(p_prev-self.goal)

        self.x[0]+=self.dt*v*math.cos(self.x[2])
        self.x[1]+=self.dt*v*math.sin(self.x[2])
        self.x[2]+=self.dt*w
        d_now=np.linalg.norm(self.x[:2]-self.goal)

        # Reward shaping
        progress=(d_prev-d_now)
        if progress>0:
            r=+2.0*progress
        else:
            r=3.0*progress
        for o in self.obstacles:
            gap=(o.r+self.robot_r+0.2)-np.linalg.norm(self.x[:2]-np.array([o.cx,o.cy]))
            if gap>0: r-=3.0*gap
        #r-=0.01*(v**2+w**2)

        #dist2goal=-np.linalg.norm(self.x[:2]-self.goal)
        #r+=dist2goal

        done=False; info={}
        if d_now<self.goal_tol: r+=20; done=True; info["success"]=True
        if self._collision(): r-=20; done=True; info["collision"]=True
        if self.t>=self.max_steps: done=True; info["timeout"]=True

        # 记录轨迹数据
        self.trajectory.append(self.x.copy())
        self.actions.append(a.copy() if hasattr(a, 'copy') else np.array(a))
        self.rewards.append(float(r))
        
        # 如果episode结束，更新episode信息
        if done:
            self.episode_info.update({
                'end_time': datetime.now().isoformat(),
                'total_steps': self.t,
                'final_distance': float(d_now),
                'total_reward': sum(self.rewards),
                'success': info.get("success", False),
                'collision': info.get("collision", False),
                'timeout': info.get("timeout", False)
            })

        return self._get_obs(), float(r), done, info

    def render(self, traj=None, pred=None, title="Unicycle Navigation"):
        """渲染环境，风格与main_mpc.py一致"""
        fig, ax = plt.subplots(figsize=(6, 6))
        
        # 轨迹
        if traj is not None:
            traj = np.array(traj)
            ax.plot(traj[:, 0], traj[:, 1], '-', lw=2, label='trajectory')
            # 起点
            ax.plot(traj[0, 0], traj[0, 1], 'go', ms=8, label='start')
        else:
            # 如果没有轨迹，显示当前位置
            ax.plot(self.x[0], self.x[1], 'bo', ms=8, label='current')
            # 起点
            ax.plot(self.start[0], self.start[1], 'go', ms=8, label='start')
        
        # 目标点
        ax.plot(self.goal[0], self.goal[1], 'r*', ms=12, label='goal')
        
        # 预测轨迹（如果有）
        if pred is not None and hasattr(pred, 'size') and pred.size > 0:
            ax.plot(pred[:, 0], pred[:, 1], '--', lw=1.5, label='predicted')
        
        # 障碍物和安全边界
        for i, obs in enumerate(self.obstacles):
            # 障碍物本体
            circle = plt.Circle((obs.cx, obs.cy), obs.r, fill=True, alpha=0.25, color='tab:orange')
            ax.add_artist(circle)
            # 安全边界（障碍物半径 + 机器人半径）
            safety = plt.Circle((obs.cx, obs.cy), obs.r + self.robot_r, fill=False, ls='--', color='tab:red')
            ax.add_artist(safety)
            # 标签
            ax.text(obs.cx, obs.cy, f"O{i+1}", ha='center', va='center')
        
        # 设置图形属性
        ax.set_aspect('equal', 'box')
        ax.grid(True, ls=':')
        ax.set_xlabel('x [m]')
        ax.set_ylabel('y [m]')
        ax.set_title(title)
        ax.legend(loc='best')
        plt.tight_layout()
        plt.show()
    
    def save_trajectory(self, filename=None, directory="trajectories"):
        """
        保存单条轨迹到文件
        
        Args:
            filename: 文件名，如果为None则自动生成
            directory: 保存目录
        
        Returns:
            str: 保存的文件路径
        """
        # 创建保存目录
        os.makedirs(directory, exist_ok=True)
        
        # 生成文件名
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            success_str = "success" if self.episode_info.get('success', False) else "fail"
            filename = f"traj_{timestamp}_{success_str}.npz"
        
        filepath = os.path.join(directory, filename)
        
        # 准备保存数据
        trajectory_array = np.array(self.trajectory)
        actions_array = np.array(self.actions)
        rewards_array = np.array(self.rewards)
        
        # 保存为npz格式
        np.savez(
            filepath,
            trajectory=trajectory_array,
            actions=actions_array,
            rewards=rewards_array,
            episode_info=self.episode_info
        )
        
        print(f"轨迹已保存到: {filepath}")
        print(f"轨迹长度: {len(self.trajectory)} 步")
        print(f"总奖励: {sum(self.rewards):.2f}")
        print(f"成功到达: {'是' if self.episode_info.get('success', False) else '否'}")
        
        return filepath
    
    def save_trajectory_json(self, filename=None, directory="trajectories"):
        """
        保存轨迹为JSON格式（便于阅读和调试）
        
        Args:
            filename: 文件名，如果为None则自动生成
            directory: 保存目录
            
        Returns:
            str: 保存的文件路径
        """
        # 创建保存目录
        os.makedirs(directory, exist_ok=True)
        
        # 生成文件名
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            success_str = "success" if self.episode_info.get('success', False) else "fail"
            filename = f"traj_{timestamp}_{success_str}.json"
        
        filepath = os.path.join(directory, filename)
        
        # 准备保存数据
        data = {
            'trajectory': [pos.tolist() for pos in self.trajectory],
            'actions': [act.tolist() if hasattr(act, 'tolist') else list(act) for act in self.actions],
            'rewards': self.rewards,
            'episode_info': {
                key: value.tolist() if hasattr(value, 'tolist') else value
                for key, value in self.episode_info.items()
            }
        }
        
        # 保存为JSON格式
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
        
        print(f"轨迹JSON已保存到: {filepath}")
        
        return filepath
    
    def get_trajectory_stats(self):
        """获取轨迹统计信息"""
        if not self.trajectory:
            return None
            
        stats = {
            'trajectory_length': len(self.trajectory),
            'total_reward': sum(self.rewards) if self.rewards else 0,
            'average_reward': np.mean(self.rewards) if self.rewards else 0,
            'final_distance_to_goal': np.linalg.norm(self.trajectory[-1][:2] - self.goal) if self.trajectory else float('inf'),
            'path_length': 0.0,
            'success': self.episode_info.get('success', False),
            'collision': self.episode_info.get('collision', False),
            'timeout': self.episode_info.get('timeout', False)
        }
        
        # 计算路径长度
        if len(self.trajectory) > 1:
            for i in range(1, len(self.trajectory)):
                dist = np.linalg.norm(self.trajectory[i][:2] - self.trajectory[i-1][:2])
                stats['path_length'] += dist
        
        return stats