import os
import matplotlib.pyplot as plt
import numpy as np
import pickle
import glob
from torch.utils.data import Dataset
import torch.nn.functional as F
from utils.normalization import normalize_min_max, unnormalize_min_max, normalize_sqrt, unnormalize_sqrt
import torch
from utils.utils import rotate_trajs_x_direction

def seq_collate_rat(batch):
    (past_traj, fut_traj, past_traj_orig, fut_traj_orig, traj_vel, hist_feats, hist_cond_cue, fut_cond_cue) = zip(*batch)
    
    pre_motion_3D = torch.stack(past_traj,dim=0)
    fut_motion_3D = torch.stack(fut_traj,dim=0)
    pre_motion_3D_orig = torch.stack(past_traj_orig, dim=0)
    fut_motion_3D_orig = torch.stack(fut_traj_orig, dim=0)
    fut_traj_vel = torch.stack(traj_vel, dim=0)
    hist_cond_cue = torch.stack(hist_cond_cue, dim=0)
    fut_cond_cue = torch.stack(fut_cond_cue, dim=0)
    hist_feats = torch.stack(hist_feats, dim=0)

    batch_size, vertical_size = pre_motion_3D.shape[0], pre_motion_3D.shape[1] ### bt
    traj_mask = torch.zeros(batch_size * vertical_size, batch_size * vertical_size)
    for i in range(batch_size):
        traj_mask[i*vertical_size:(i+1)*vertical_size, i*vertical_size:(i+1)*vertical_size] = 1.
    data = {
        'batch_size': torch.tensor(batch_size),
        'past_traj': pre_motion_3D,
        'fut_traj': fut_motion_3D,
        'past_traj_original_scale': pre_motion_3D_orig,
        'fut_traj_original_scale': fut_motion_3D_orig,
        'traj_mask': traj_mask,
        'fut_traj_vel': fut_traj_vel,
        'hist_feats': hist_feats,
        'hist_cond_cue': hist_cond_cue,
        "fut_cond_cue": fut_cond_cue,
    }

    return data 

def seq_collate_imle_train_rat(batch):
    (past_traj, fut_traj, past_traj_orig, fut_traj_orig, traj_vel, hist_feats, hist_cond_cue, fut_cond_cue, y_t, y_pred_data) = zip(*batch)

    pre_motion_3D = torch.stack(past_traj,dim=0)
    fut_motion_3D = torch.stack(fut_traj,dim=0)
    pre_motion_3D_orig = torch.stack(past_traj_orig, dim=0)
    fut_motion_3D_orig = torch.stack(fut_traj_orig, dim=0)
    fut_traj_vel = torch.stack(traj_vel, dim=0)
    hist_cond_cue = torch.stack(hist_cond_cue, dim=0)
    fut_cond_cue = torch.stack(fut_cond_cue, dim=0)
    hist_feats = torch.stack(hist_feats, dim=0)

    y_t = torch.stack(y_t, dim=0)
    y_pred_data = torch.stack(y_pred_data,dim=0)

    batch_size, vertical_size = pre_motion_3D.shape[0], pre_motion_3D.shape[1] ### bt
    traj_mask = torch.zeros(batch_size * vertical_size, batch_size * vertical_size)
    for i in range(batch_size):
        traj_mask[i*vertical_size:(i+1)*vertical_size, i*vertical_size:(i+1)*vertical_size] = 1.
    data = {
        'batch_size': torch.tensor(batch_size),
        'past_traj': pre_motion_3D,
        'fut_traj': fut_motion_3D,
        'past_traj_original_scale': pre_motion_3D_orig,
        'fut_traj_original_scale': fut_motion_3D_orig,
        'fut_traj_vel': fut_traj_vel,
        'traj_mask': traj_mask,
        'hist_feats': hist_feats,
        'hist_cond_cue': hist_cond_cue,
        "fut_cond_cue": fut_cond_cue,
        'y_t': y_t,
        'y_pred_data': y_pred_data
    }

    return data


class RatDatasetMinMax(Dataset):
    def __init__(self,
                 obs_len=12, pred_len=18, training=True,
                 num_scenes=None, test_scenes=None,
                 overfit=False, imle=False, cfg=None, rotate=False,
                 data_dir='data/rat', data_file='hist10pred20/rat_train.npy',
                 data_norm='min_max'):
        super().__init__()
        self.obs_len  = obs_len
        self.pred_len = pred_len
        self.seq_len  = obs_len + pred_len
        self.imle     = imle
        self.data_norm = data_norm

        self.dt = 1/18
        self.head_idx = 5
        self.neck_idx = 6

        if not overfit:
            if training:
                data_root = os.path.join(data_dir, 'rat_pose_train.npy')
                cmd_root = os.path.join(data_dir, 'rat_stim_train.npy')
            else:
                data_root = os.path.join(data_dir, 'rat_pose_test.npy')
                cmd_root = os.path.join(data_dir, 'rat_stim_test.npy')
        else:
            data_root = os.path.join(data_dir, 'rat_pose_train.npy')
            cmd_root = os.path.join(data_dir, 'rat_stim_train.npy')

        self.trajs_raw = np.load(data_root) #(N,15,11,2)
        self.cmd_raw = np.load(cmd_root)

        if training:
            self.trajs = self.trajs_raw[:num_scenes]
            self.cmd = self.cmd_raw[:num_scenes]
        else:
            self.trajs = self.trajs_raw[:test_scenes]
            self.cmd = self.cmd_raw[:test_scenes]

        if overfit:
            self.trajs = self.trajs_raw[:num_scenes]
            self.cmd = self.cmd_raw[:num_scenes]

        self.data_len = len(self.trajs)
        self.traj_abs = torch.from_numpy(self.trajs).float()  # [N, T, V, 2]

        self.traj_abs = self.traj_abs.permute(0, 2, 1, 3).contiguous()
        self.actor_num = self.traj_abs.shape[1]  # =8

        pre = self.traj_abs[:, :, :self.obs_len, :]     # [N, V, T_h, 2]
        fut = self.traj_abs[:, :, self.obs_len:, :]     # [N, V, T_p, 2]
        init = pre[:, :, -1:].clone()                   # [N, V, 1, 2]

        fut_traj = (fut - init).contiguous()
        past_abs = pre.contiguous()
        past_rel = (pre - init).contiguous()
        past_vel = torch.cat([past_rel[:, :, 1:] - past_rel[:, :, :-1],
                              torch.zeros_like(past_rel[:, :, -1:])], dim=2)
        past_traj = torch.cat([past_abs, past_rel, past_vel], dim=-1)  # [N,V,T_h, 2+2+2=6]
        self.fut_traj_vel = torch.cat([fut_traj[:, :, 1:] - fut_traj[:, :, :-1],
                                       torch.zeros_like(fut_traj[:, :, -1:])], dim=2)
        
        if rotate:
            past_rel, fut_traj, past_vel = rotate_trajs_x_direction(past_rel, fut_traj, past_vel)


        if training:
            stats = {}

            abs_xy = past_abs.reshape(-1, 2)  # [N*T*V, 2]
            stats['abs_mean'] = abs_xy.mean(0)
            stats['abs_std'] = abs_xy.std(0) + 1e-6
            stats['abs_min'], stats['abs_max'] = self.robust_minmax(abs_xy)

            rel_xy = past_rel.reshape(-1, 2)
            stats['rel_mean'] = rel_xy.mean(0)
            stats['rel_std'] = rel_xy.std(0) + 1e-6
            stats['rel_min'], stats['rel_max'] = self.robust_minmax(rel_xy)

            vel_xy = past_vel.reshape(-1, 2)
            stats['vel_mean'] = vel_xy.mean(0)
            stats['vel_std'] = vel_xy.std(0) + 1e-6
            stats['vel_min'], stats['vel_max'] = self.robust_minmax(vel_xy)

            fut_xy = fut_traj.reshape(-1, 2)
            stats['fut_mean'] = fut_xy.mean(0)
            stats['fut_std'] = fut_xy.std(0) + 1e-6
            stats['fut_min'], stats['fut_max'] = self.robust_minmax(fut_xy)

            cfg.stats = stats
            cfg.fut_traj_max  = None
            cfg.fut_traj_min  = None
            cfg.past_traj_max = None
            cfg.past_traj_min = None
        self.past_traj_original_scale = past_traj.clone()
        self.fut_traj_original_scale  = fut_traj.clone()

        if self.data_norm == 'min_max':
            abs_n = self.z(past_abs, torch.tensor(cfg.stats['abs_mean']), torch.tensor(cfg.stats['abs_std']))
            rel_n = self.z(past_rel, torch.tensor(cfg.stats['rel_mean']), torch.tensor(cfg.stats['rel_std']))
            vel_n = self.z(past_vel, torch.tensor(cfg.stats['vel_mean']), torch.tensor(cfg.stats['vel_std']))

            self.past_traj = torch.cat([abs_n, rel_n, vel_n], dim=-1)  # [N,V,T_h, 2+2+2=6]
            self.fut_traj  = self.z(fut_traj, torch.tensor(cfg.stats['fut_mean']), torch.tensor(cfg.stats['fut_std']))
        else:
            self.past_traj = past_traj
            self.fut_traj  = fut_traj

        self.data_len = self.traj_abs.shape[0]
        print(f"RatDataset: size {self.data_len} | mode={'train' if training else 'test'}")

        """load distillation target"""
        if imle:
            os.makedirs(os.path.join(data_dir, 'imle'), exist_ok=True)
            pkl_ls = sorted(glob.glob(os.path.join(data_dir, 'imle/*train*.pkl')))
            keys_ls = ['past_traj', 'fut_traj', 'past_traj_original_scale', 'fut_traj_original_scale', 'fut_traj_vel', 'y_t', 'y_pred_data']
            imle_data_dict = {}
            total_scenes_loaded_ = 0
            for i_pkl, cur_pkl in enumerate(pkl_ls):
                data = pickle.load(open(cur_pkl, 'rb'))

                if i_pkl == 0:
                    self.imle_meta_data = data['meta_data']
                
                for key in keys_ls:
                    if key not in imle_data_dict:
                        imle_data_dict[key] = []
                    if key == 'y_t':
                        imle_data_dict[key].append(data[key][:, -1])
                    else:
                        imle_data_dict[key].append(data[key])

                total_scenes_loaded_ += data['past_traj'].shape[0]

                if total_scenes_loaded_ >= len(self.trajs):
                    break

            # concat the data
            for key in keys_ls:
                imle_data_dict[key] = torch.from_numpy(np.concatenate(imle_data_dict[key], axis=0))[:len(self.trajs)]

            self.imle_data_dict = imle_data_dict


    @torch.no_grad()
    def compute_hist_feats(self,
                           past_abs_xy: torch.Tensor,  
                           dt: float = 1 / 10,  
                           head_idx: int = 0,
                           neck_idx: int = 1,
                           ) -> torch.Tensor:
        """
        返回 hist_feats: [V, T_h, C_h]
        C_h 顺序:
          [ vel_x, vel_y, speed, acc, yaw, d_yaw, kappa ]
        均保证时间维长度 == T_h
        """
        device = past_abs_xy.device
        V, T, _ = past_abs_xy.shape

        vel = _time_derivative(past_abs_xy, dt)  # [V, T, 2]
        acc_vec = _time_derivative(vel, dt)  # [V, T, 2]
        speed = vel.norm(dim=-1, keepdim=True)  # [V, T, 1]
        acc = acc_vec.norm(dim=-1, keepdim=True)  # [V, T, 1]

        head = past_abs_xy[head_idx]  # [T, 2]
        neck = past_abs_xy[neck_idx]  # [T, 2]
        hn = head - neck  # [T, 2]
        yaw = torch.atan2(hn[..., 1], hn[..., 0])  # [T]

        yaw = (yaw + torch.pi) % (2 * torch.pi) - torch.pi
        d_yaw = _time_derivative(yaw, dt)  # [T]
        d_yaw = (d_yaw + torch.pi) % (2 * torch.pi) - torch.pi


        v = _time_derivative(head, dt)  # [T, 2]
        a = _time_derivative(v, dt)  # [T, 2]
        cross = (v[..., 0] * a[..., 1] - v[..., 1] * a[..., 0]).abs()  # |x'y'' - y'x''|
        denom = (v.pow(2).sum(dim=-1).clamp_min(1e-6)).pow(1.5)  # (x'^2 + y'^2)^(3/2)
        kappa = cross / denom  # [T]

        yaw_b = yaw.view(1, T, 1).repeat(V, 1, 1)
        dyaw_b = d_yaw.view(1, T, 1).repeat(V, 1, 1)
        kappa_b = kappa.view(1, T, 1).repeat(V, 1, 1)

        hist_feats = torch.cat([vel, speed, acc, yaw_b, dyaw_b, kappa_b], dim=-1)
        return hist_feats

    @torch.no_grad()
    def compute_cue_feats(
            self,
            instr_id: torch.Tensor,  # [T_h], int64, 0=none, 1=fwd, 2=left, 3=right
            instr_strength: torch.Tensor,  
            add_time_since: bool = True,
            use_strength_for_event: bool = True,  
    ) -> torch.Tensor:
        assert instr_id.ndim == 1 and instr_strength.ndim == 1
        T = instr_id.shape[0]
        device = instr_id.device

        onehot = F.one_hot(instr_id.long(), num_classes=4).float()  # [T,4]

        has_cmd = (instr_id > 0)
        strength = torch.where(has_cmd, instr_strength, torch.zeros_like(instr_strength))
        strength = strength.view(T, 1)  # [T,1]

        sign = torch.zeros_like(instr_strength)
        sign = torch.where(instr_id == 2, -1.0, sign)  # left  -> -1
        sign = torch.where(instr_id == 3, 1.0, sign)  # right -> +1
        signed_strength = (sign * strength.view(-1)).view(T, 1)  # [T,1]

        feats = [onehot, strength, signed_strength]

        if add_time_since:
            if use_strength_for_event:
                event_mask = has_cmd & (instr_strength > 0)
            else:
                event_mask = has_cmd

            idx = torch.arange(T, device=device)
            last_idx = torch.where(event_mask, idx, torch.full_like(idx, -1))

            cum_last, _ = torch.cummax(last_idx, dim=0)  # [-1, ..., t_last, ...]
            time_since = (idx - cum_last).clamp(min=0).to(torch.float32).view(T, 1)
            feats.append(time_since)

        cue_feats = torch.cat(feats, dim=-1)  # [T, 4(+1+1+1)=7]
        return cue_feats

    @torch.no_grad()
    def compute_zc(
            self,
            hist_feats: torch.Tensor,  # [V, T_h, C_h]
            cue_feats: torch.Tensor,  # [T_h, C_c]
            stats: dict = None,  
    ) -> torch.Tensor:
        """
        Dz = 4: [speed, d_yaw, kappa, signed_strength]
        return: z_c [V, T_h, 4]
        """
        V, T, C = hist_feats.shape
        speed = hist_feats[..., 2]  # [V,T]
        d_yaw = hist_feats[..., 5]  # [V,T]
        kappa = hist_feats[..., 6]  # [V,T]

        signed_strength = torch.zeros(T, device=hist_feats.device)
        if cue_feats is not None and cue_feats.shape[-1] >= 2:
            signed_strength = cue_feats[..., -2]  
        signed_strength = signed_strength.view(1, T).expand(V, T)

        z = torch.stack([speed, d_yaw, kappa, signed_strength], dim=-1)  # [V,T,4]

        if stats is not None:
            mean = torch.as_tensor(stats['mean'], device=z.device).view(1, 1, 4)
            std = torch.as_tensor(stats['std'], device=z.device).view(1, 1, 4)
            z = (z - mean) / (std + 1e-6)

        return z

    def __len__(self):
        return self.data_len

    def __getitem__(self, index):
        past = self.past_traj_original_scale[index][..., :2]  # [V,T_h,2] 
        fut = self.fut_traj_original_scale[index]  # [V,T_p,2]

        instr_id = torch.from_numpy(self.cmd[index, :, 0])  # [T_h + T_f] 0/1/2
        instr_strength = torch.from_numpy(self.cmd[index, :, 1])  # [T_h + T_f] float

        hist_feats = self.compute_hist_feats(past, dt=self.dt, head_idx=self.head_idx, neck_idx=self.neck_idx)  # [V,T_h,C_h]
        cue_feats = self.compute_cue_feats(instr_id, instr_strength)  # [T_h,C_c]

        if self.imle:
            out = [
                    self.imle_data_dict['past_traj'][index], 
                    self.imle_data_dict['fut_traj'][index],
                    self.imle_data_dict['past_traj_original_scale'][index],
                    self.imle_data_dict['fut_traj_original_scale'][index],
                    self.imle_data_dict['fut_traj_vel'][index],
                    hist_feats,
                    cue_feats[:30], 
                    cue_feats[30:],  
                    self.imle_data_dict['y_t'][index],
                    self.imle_data_dict['y_pred_data'][index]
                ]
        else:
            out = [
                self.past_traj[index],                 # [V,T_h,6]
                self.fut_traj[index],                  # [V,T_p,2]
                self.past_traj_original_scale[index],  # [V,T_h,6]
                self.fut_traj_original_scale[index],   # [V,T_p,2]
                self.fut_traj_vel[index],              # [V,T_p,2]

                hist_feats,
                cue_feats[:30],  
                cue_feats[30:],  
            ]
        return out

    def robust_minmax(self, x, q=(1, 99)):
        lo = np.percentile(x, q[0], axis=0)
        hi = np.percentile(x, q[1], axis=0)
        return lo, hi

    def z(self, x, mean, std):
        print("Normalize shape = {}".format(x.shape))
        return (x - mean) / std

    def iz(self, zx, mean, std):
        return zx * std + mean

def _central_diff(x, dt, pad_mode='replicate'):
    """
    x: [..., T, C] ; central difference in time
    return: same shape
    """
    xp = F.pad(x[..., 2:, :], (0,0,1,0), mode=pad_mode)   # t+1
    xm = F.pad(x[..., :-2, :], (0,0,0,1), mode=pad_mode)  # t-1
    return (xp - xm) / (2 * dt)

def _time_derivative(x: torch.Tensor, dt: float, time_dim: int = -2) -> torch.Tensor:
    orig = x
    dtype = x.dtype
    squeeze_mode = None

    if x.ndim == 1:
        # (T,) -> (1, T, 1)
        x = x.view(1, x.shape[0], 1)
        squeeze_mode = "T"
    elif x.ndim == 2:
        # (T, C) -> (1, T, C) 
        x = x.unsqueeze(0)
        squeeze_mode = "TC"
    else:
        x = x.movedim(time_dim, -2)

    T = x.size(-2)
    dx = torch.empty_like(x, dtype=dtype)

    if T == 1:
        dx.zero_()
    elif T == 2:
        d = (x[..., 1, :] - x[..., 0, :]) / dt
        dx[..., 0, :] = d
        dx[..., 1, :] = d
    else:
        dx[..., 0, :]  = (x[..., 1, :]   - x[..., 0, :])   / dt
        dx[..., 1:-1, :] = (x[..., 2:, :] - x[..., :-2, :]) / (2.0 * dt)
        dx[..., -1, :] = (x[..., -1, :]  - x[..., -2, :])  / dt

    if orig.ndim >= 3:
        dx = dx.movedim(-2, time_dim)  
    if squeeze_mode == "T":
        dx = dx.view(-1)               # (T,)
    elif squeeze_mode == "TC":
        dx = dx.squeeze(0)             # (T, C)

    return dx

def _safe_div(a, b, eps=1e-6):
    return a / (b.abs() + eps)

def _wrap_pi(a):
    # wrap angle to [-pi, pi]
    return (a + torch.pi) % (2 * torch.pi) - torch.pi

import copy
from data.dataset_registry import register_dataset
from torch.utils.data import DataLoader

@register_dataset("rat_dataset")
def build_rat_dataloader(cfg, args):
    """
    Rat dataset dataloader builder (MinMax normalized).
    """

    train_dset = RatDatasetMinMax(
        data_dir=cfg.data_dir,
        obs_len=cfg.past_frames,
        pred_len=cfg.future_frames,
        training=True,
        num_scenes=cfg.n_train,
        overfit=cfg.overfit,
        cfg=cfg,
        data_norm=cfg.data_norm
    )

    train_loader = DataLoader(
        train_dset,
        batch_size=cfg.train_batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        collate_fn=seq_collate_rat,
        pin_memory=True
    )

    if cfg.overfit:
        test_dset = copy.deepcopy(train_dset)
    else:
        test_dset = RatDatasetMinMax(
            data_dir=cfg.data_dir,
            obs_len=cfg.past_frames,
            pred_len=cfg.future_frames,
            training=False,
            test_scenes=cfg.n_test,
            overfit=cfg.overfit,
            cfg=cfg,
            data_norm=cfg.data_norm
        )

    test_loader = DataLoader(
        test_dset,
        batch_size=cfg.test_batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        collate_fn=seq_collate_rat,
        pin_memory=True
    )

    return train_loader, test_loader
