import os
import matplotlib.pyplot as plt
import numpy as np
import pickle
import glob
from torch.utils.data import Dataset
import torch.nn.functional as F
from utils.normalization import normalize_min_max, unnormalize_min_max, normalize_sqrt, unnormalize_sqrt
import torch
from utils.utils import rotate_trajs_x_direction

def seq_collate_rat(batch):
    (past_traj, fut_traj, past_traj_orig, fut_traj_orig, traj_vel, hist_feats, hist_cond_cue, fut_cond_cue) = zip(*batch)
    pre_motion_3D = torch.stack(past_traj,dim=0)
    fut_motion_3D = torch.stack(fut_traj,dim=0)
    pre_motion_3D_orig = torch.stack(past_traj_orig, dim=0)
    fut_motion_3D_orig = torch.stack(fut_traj_orig, dim=0)
    fut_traj_vel = torch.stack(traj_vel, dim=0)
    hist_cond_cue = torch.stack(hist_cond_cue, dim=0)
    fut_cond_cue = torch.stack(fut_cond_cue, dim=0)
    hist_feats = torch.stack(hist_feats, dim=0)

    batch_size, vertical_size = pre_motion_3D.shape[0], pre_motion_3D.shape[1] ### bt
    traj_mask = torch.zeros(batch_size * vertical_size, batch_size * vertical_size)
    for i in range(batch_size):
        traj_mask[i*vertical_size:(i+1)*vertical_size, i*vertical_size:(i+1)*vertical_size] = 1.
    data = {
        'batch_size': torch.tensor(batch_size),
        'past_traj': pre_motion_3D,
        'fut_traj': fut_motion_3D,
        'past_traj_original_scale': pre_motion_3D_orig,
        'fut_traj_original_scale': fut_motion_3D_orig,
        'traj_mask': traj_mask,
        'fut_traj_vel': fut_traj_vel,
        'hist_feats': hist_feats,
        'hist_cond_cue': hist_cond_cue,
        "fut_cond_cue": fut_cond_cue,
    }

    return data 

def seq_collate_imle_train(batch):
    (past_traj, fut_traj, past_traj_orig, fut_traj_orig, traj_vel, y_t, y_pred_data) = zip(*batch)

    pre_motion_3D = torch.stack(past_traj,dim=0)
    fut_motion_3D = torch.stack(fut_traj,dim=0)
    pre_motion_3D_orig = torch.stack(past_traj_orig, dim=0)
    fut_motion_3D_orig = torch.stack(fut_traj_orig, dim=0)
    fut_traj_vel = torch.stack(traj_vel, dim=0)
    y_t = torch.stack(y_t, dim=0)
    y_pred_data = torch.stack(y_pred_data,dim=0)

    batch_size, vertical_size = pre_motion_3D.shape[0], pre_motion_3D.shape[1] ### bt
    traj_mask = torch.zeros(batch_size * vertical_size, batch_size * vertical_size)
    for i in range(batch_size):
        traj_mask[i*vertical_size:(i+1)*vertical_size, i*vertical_size:(i+1)*vertical_size] = 1.
    data = {
        'batch_size': torch.tensor(batch_size),
        'past_traj': pre_motion_3D,
        'fut_traj': fut_motion_3D,
        'past_traj_original_scale': pre_motion_3D_orig,
        'fut_traj_original_scale': fut_motion_3D_orig,
        'fut_traj_vel': fut_traj_vel,
        'traj_mask': traj_mask,
        'y_t': y_t,
        'y_pred_data': y_pred_data
    }

    return data


class BabelDatasetMinMax(Dataset):
    def __init__(self,
                 obs_len=12, pred_len=18, training=True,
                 num_scenes=None, test_scenes=None,
                 overfit=False, imle=False, cfg=None,
                 data_dir='data/babel', data_file='rat_train.npy',
                 data_norm='min_max'):
        super().__init__()
        self.obs_len  = obs_len
        self.pred_len = pred_len
        self.seq_len  = obs_len + pred_len
        self.imle     = imle
        self.data_norm = data_norm

        self.dt = 1/18
        self.head_idx = 5
        self.neck_idx = 6

        if not overfit:
            if training:
                data_root = os.path.join(data_dir, 'babel_train.npy')
                cmd_root = os.path.join(data_dir, 'babel_train_cmd.npy')
            else:
                data_root = os.path.join(data_dir, 'babel_test.npy')
                cmd_root = os.path.join(data_dir, 'babel_test_cmd.npy')
        else:
            data_root = os.path.join(data_dir, 'babel_test.npy')
            cmd_root = os.path.join(data_dir, 'babel_test_cmd.npy')

        self.trajs_raw = np.load(data_root) 
        self.cmd_raw = np.load(cmd_root)

        if training:
            self.trajs = self.trajs_raw[:num_scenes]
            self.cmd = self.cmd_raw[:num_scenes]
        else:
            self.trajs = self.trajs_raw[:test_scenes]
            self.cmd = self.cmd_raw[:test_scenes]

        ### Overfit test
        if overfit:
            self.trajs = self.trajs_raw[:num_scenes]
            self.cmd = self.cmd_raw[:num_scenes]


        self.data_len = len(self.trajs)
        self.traj_abs = torch.from_numpy(self.trajs).float()  # [N, T, V, 2]

        self.traj_abs = self.traj_abs.permute(0, 2, 1, 3).contiguous()
        self.actor_num = self.traj_abs.shape[1]  # =8

        pre = self.traj_abs[:, :, :self.obs_len, :]     # [N, V, T_h, 2]
        fut = self.traj_abs[:, :, self.obs_len:, :]     # [N, V, T_p, 2]
        init = pre[:, :, -1:].clone()                   # [N, V, 1, 2]

        fut_traj = (fut - init).contiguous()
        past_abs = pre.contiguous()
        past_rel = (pre - init).contiguous()
        past_vel = torch.cat([past_rel[:, :, 1:] - past_rel[:, :, :-1],
                              torch.zeros_like(past_rel[:, :, -1:])], dim=2)
        past_traj = torch.cat([past_abs, past_rel, past_vel], dim=-1)  # [N,V,T_h, 2+2+2=6]
        self.fut_traj_vel = torch.cat([fut_traj[:, :, 1:] - fut_traj[:, :, :-1],
                                       torch.zeros_like(fut_traj[:, :, -1:])], dim=2)


        if training:
            stats = {}

            abs_xyz = past_abs.reshape(-1, 3)  # [N*T*V, 2]
            stats['abs_mean'] = abs_xyz.mean(0)
            stats['abs_std'] = abs_xyz.std(0) + 1e-6
            stats['abs_min'], stats['abs_max'] = self.robust_minmax(abs_xyz)


            rel_xyz = past_rel.reshape(-1, 3)
            stats['rel_mean'] = rel_xyz.mean(0)
            stats['rel_std'] = rel_xyz.std(0) + 1e-6
            stats['rel_min'], stats['rel_max'] = self.robust_minmax(rel_xyz)


            vel_xyz = past_vel.reshape(-1, 3)
            stats['vel_mean'] = vel_xyz.mean(0)
            stats['vel_std'] = vel_xyz.std(0) + 1e-6
            stats['vel_min'], stats['vel_max'] = self.robust_minmax(vel_xyz)


            fut_xyz = fut_traj.reshape(-1, 3)
            stats['fut_mean'] = fut_xyz.mean(0)
            stats['fut_std'] = fut_xyz.std(0) + 1e-6
            stats['fut_min'], stats['fut_max'] = self.robust_minmax(fut_xyz)

            cfg.stats = stats
            cfg.fut_traj_max  = None
            cfg.fut_traj_min  = None
            cfg.past_traj_max = None
            cfg.past_traj_min = None
        self.past_traj_original_scale = past_traj.clone()
        self.fut_traj_original_scale  = fut_traj.clone()


        if self.data_norm == 'min_max':
            abs_n = self.z(past_abs, torch.tensor(cfg.stats['abs_mean']), torch.tensor(cfg.stats['abs_std']))
            rel_n = self.z(past_rel, torch.tensor(cfg.stats['rel_mean']), torch.tensor(cfg.stats['rel_std']))
            vel_n = self.z(past_vel, torch.tensor(cfg.stats['vel_mean']), torch.tensor(cfg.stats['vel_std']))

            self.past_traj = torch.cat([abs_n, rel_n, vel_n], dim=-1)  # [N,V,T_h, 2+2+2=6]
            self.fut_traj  = self.z(fut_traj, torch.tensor(cfg.stats['fut_mean']), torch.tensor(cfg.stats['fut_std']))
        else:
            self.past_traj = past_traj
            self.fut_traj  = fut_traj

        self.data_len = self.traj_abs.shape[0]
        print(f"Babel Dataset: size {self.data_len} | mode={'train' if training else 'test'}")

        if imle:
            pass

    @torch.no_grad()
    def compute_hist_feats(self,
                           past_abs_xy: torch.Tensor,  
                           dt: float = 1 / 10,  
                           head_idx: int = 0,
                           neck_idx: int = 1,
                           ) -> torch.Tensor:
        device = past_abs_xy.device
        V, T, _ = past_abs_xy.shape

        vel = _time_derivative(past_abs_xy, dt)  # [V, T, 2]
        acc_vec = _time_derivative(vel, dt)  # [V, T, 2]
        speed = vel.norm(dim=-1, keepdim=True)  # [V, T, 1]
        acc = acc_vec.norm(dim=-1, keepdim=True)  # [V, T, 1]

        head = past_abs_xy[head_idx]  # [T, 2]
        neck = past_abs_xy[neck_idx]  # [T, 2]
        hn = head - neck  # [T, 2]
        yaw = torch.atan2(hn[..., 1], hn[..., 0])  # [T]
        yaw = (yaw + torch.pi) % (2 * torch.pi) - torch.pi
        d_yaw = _time_derivative(yaw, dt)  # [T]
        d_yaw = (d_yaw + torch.pi) % (2 * torch.pi) - torch.pi

        v = _time_derivative(head, dt)  # [T, 2]
        a = _time_derivative(v, dt)  # [T, 2]
        cross = (v[..., 0] * a[..., 1] - v[..., 1] * a[..., 0]).abs()  # |x'y'' - y'x''|
        denom = (v.pow(2).sum(dim=-1).clamp_min(1e-6)).pow(1.5)  # (x'^2 + y'^2)^(3/2)
        kappa = cross / denom  # [T]

        yaw_b = yaw.view(1, T, 1).repeat(V, 1, 1)
        dyaw_b = d_yaw.view(1, T, 1).repeat(V, 1, 1)
        kappa_b = kappa.view(1, T, 1).repeat(V, 1, 1)

        hist_feats = torch.cat([vel, speed, acc, yaw_b, dyaw_b, kappa_b], dim=-1)
        return hist_feats

    @torch.no_grad()
    def compute_zc(
            self,
            hist_feats: torch.Tensor,  # [V, T_h, C_h]
            cue_feats: torch.Tensor,  # [T_h, C_c]
            stats: dict = None, 
    ) -> torch.Tensor:
        """
        Dz = 4: [speed, d_yaw, kappa, signed_strength]
        return: z_c [V, T_h, 4]
        """
        V, T, C = hist_feats.shape
        speed = hist_feats[..., 2]  # [V,T]
        d_yaw = hist_feats[..., 5]  # [V,T]
        kappa = hist_feats[..., 6]  # [V,T]

        signed_strength = torch.zeros(T, device=hist_feats.device)
        if cue_feats is not None and cue_feats.shape[-1] >= 2:
            signed_strength = cue_feats[..., -2] 
        signed_strength = signed_strength.view(1, T).expand(V, T)

        z = torch.stack([speed, d_yaw, kappa, signed_strength], dim=-1)  # [V,T,4]

        if stats is not None:
            mean = torch.as_tensor(stats['mean'], device=z.device).view(1, 1, 4)
            std = torch.as_tensor(stats['std'], device=z.device).view(1, 1, 4)
            z = (z - mean) / (std + 1e-6)

        return z

    def __len__(self):
        return self.data_len

    def __getitem__(self, index):
        past = self.past_traj_original_scale[index][..., :2]  # [V,T_h,2] 
        fut = self.fut_traj_original_scale[index]  # [V,T_p,2]

        hist_feats = self.compute_hist_feats(past, dt=self.dt, head_idx=self.head_idx, neck_idx=self.neck_idx)  # [V,T_h,C_h]
        cue_feats = torch.from_numpy(self.cmd[index, :, :])
        # z_d_logits = self.compute_zd_logits(hist_feats, cue_feats, s_thr=self.s_thr, r_thr=self.r_thr, tau=self.tau)  # [V,T_h,4]
        # z_c = self.compute_zc(hist_feats, cue_feats, stats=None)  # [V,T_h,4]

        return [
            self.past_traj[index],                 # [V,T_h,6]
            self.fut_traj[index],                  # [V,T_p,2]
            self.past_traj_original_scale[index],  # [V,T_h,6]
            self.fut_traj_original_scale[index],   # [V,T_p,2]
            self.fut_traj_vel[index],              # [V,T_p,2]

            hist_feats,
            cue_feats[:30],  
            cue_feats[30:], 
        ]

    def robust_minmax(self, x, q=(1, 99)):
        lo = np.percentile(x, q[0], axis=0)
        hi = np.percentile(x, q[1], axis=0)
        return lo, hi

    def z(self, x, mean, std):
        return (x - mean) / std

    def iz(self, zx, mean, std):
        return zx * std + mean

def _central_diff(x, dt, pad_mode='replicate'):
    """
    x: [..., T, C] ; central difference in time
    return: same shape
    """
    xp = F.pad(x[..., 2:, :], (0,0,1,0), mode=pad_mode)   # t+1
    xm = F.pad(x[..., :-2, :], (0,0,0,1), mode=pad_mode)  # t-1
    return (xp - xm) / (2 * dt)

def _time_derivative(x: torch.Tensor, dt: float, time_dim: int = -2) -> torch.Tensor:
    orig = x
    dtype = x.dtype
    squeeze_mode = None

    if x.ndim == 1:
        # (T,) -> (1, T, 1)
        x = x.view(1, x.shape[0], 1)
        squeeze_mode = "T"
    elif x.ndim == 2:
        # (T, C) -> (1, T, C) （时间维=0）
        x = x.unsqueeze(0)
        squeeze_mode = "TC"
    else:
        x = x.movedim(time_dim, -2)

    T = x.size(-2)
    dx = torch.empty_like(x, dtype=dtype)

    if T == 1:
        dx.zero_()
    elif T == 2:
        d = (x[..., 1, :] - x[..., 0, :]) / dt
        dx[..., 0, :] = d
        dx[..., 1, :] = d
    else:
        dx[..., 0, :]  = (x[..., 1, :]   - x[..., 0, :])   / dt
        dx[..., 1:-1, :] = (x[..., 2:, :] - x[..., :-2, :]) / (2.0 * dt)
        dx[..., -1, :] = (x[..., -1, :]  - x[..., -2, :])  / dt

    if orig.ndim >= 3:
        dx = dx.movedim(-2, time_dim)  
    if squeeze_mode == "T":
        dx = dx.view(-1)              
    elif squeeze_mode == "TC":
        dx = dx.squeeze(0)             

    return dx

def _safe_div(a, b, eps=1e-6):
    return a / (b.abs() + eps)

def _wrap_pi(a):
    # wrap angle to [-pi, pi]
    return (a + torch.pi) % (2 * torch.pi) - torch.pi

import copy
from data.dataset_registry import register_dataset
from torch.utils.data import DataLoader

@register_dataset("babel_dataset")
def build_babel_dataloader(cfg, args):
	"""
	Build the data loader for the Rat dataset.
	"""
	train_dset = BabelDatasetMinMax(
		data_dir=cfg.data_dir,
		obs_len=cfg.past_frames,
		pred_len=cfg.future_frames,
		training=True,
		num_scenes=cfg.n_train,
		overfit=cfg.overfit,
		cfg=cfg,
		# rotate=cfg.rotate,
		data_norm=cfg.data_norm)

	train_loader = DataLoader(
		train_dset,
		batch_size=cfg.train_batch_size,
		shuffle=True,
		num_workers=4,
		collate_fn=seq_collate_rat,
		pin_memory=True)

	if cfg.overfit:
		test_dset = copy.deepcopy(train_dset)
	else:
		test_dset = BabelDatasetMinMax(
		data_dir=cfg.data_dir,
		obs_len=cfg.past_frames,
		pred_len=cfg.future_frames,
		training=False,
		overfit=cfg.overfit,
		test_scenes=cfg.n_test,
		cfg=cfg,
		# rotate=cfg.rotate,
		data_norm=cfg.data_norm)
		
	test_loader = DataLoader(
		test_dset,
		batch_size=cfg.test_batch_size, 
		shuffle=False,
		num_workers=4,
		collate_fn=seq_collate_rat,
		pin_memory=True)
	
	return train_loader, test_loader
