
import os
import glob
import numpy as np
import scipy.io as sio

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

from . import register_dataset
from sklearn.model_selection import train_test_split
import pandas as pd  
from typing import Tuple
import pickle
import json
from typing import Iterable, Optional, Sequence
from torch.utils.data import TensorDataset
from typing import Any, Dict, Optional

class mmwave:
    input_shape = (None, None, None) # Channel State Information (mmwave) size, (C, H, W), can reshape to (CH, W) for RNN input, CH is the sequential dimension
    num_classes = None

    def __init__(self, data_dir, coding_schema, time_step):
        self.data_dir = data_dir
        self.T = time_step
        self.encode = coding_schema

    def download_data(self):
        raise NotImplementedError
    
    def get_dataset(self, train=True):
        raise NotImplementedError




@register_dataset('aophand_aop')
class iAOPHandAoP(mmwave):
    
    def __init__(self, data_dir, coding_schema, time_step, pkl_name=None):


        super().__init__(data_dir, coding_schema, time_step)

        self.input_shape = (1, 64, 64)
        self.num_classes = None


        if pkl_name is not None:

            self.pkl_name = pkl_name
        else:

            if os.path.isfile(self.data_dir):

                self.pkl_name = os.path.basename(self.data_dir)
                self.data_dir = os.path.dirname(self.data_dir)
            else:

                pkl_files = [f for f in os.listdir(self.data_dir)
                             if f.endswith(".pkl")]

                if len(pkl_files) == 1:

                    self.pkl_name = pkl_files[0]
                else:

                    self.pkl_name = "AOPHand_aop_pre.pkl"

        print(f"[AOPHandAoP] Using pkl_name = {self.pkl_name}")


        self.coord_mean = None
        self.coord_std = None

    # ---------------------------------------------------------- #
    #                        download_data                       #
    # ---------------------------------------------------------- #
    def download_data(self):
       
        pkl_path = os.path.join(self.data_dir, self.pkl_name)

        if not os.path.exists(pkl_path):
            raise FileNotFoundError(
                f"[AOPHandAoP] preprocessed file not found: {pkl_path}\n"
                f"Checked directory: {self.data_dir}\n"
                f"You may specify it using: -data_dir=/path/to/xxx.pkl "
                f"or --dataset aophand_aop --data_dir=/home/Firewall/data/mm_data_unified/AOPHand_aop"
            )

        print(f"[AOPHandAoP] Loading preprocessed data from {pkl_path} ...")
        with open(pkl_path, "rb") as f:
            obj = pickle.load(f)

        if "data" not in obj or "labels" not in obj:
            raise KeyError(
                f"[AOPHandAoP] Expected keys 'data' and 'labels' in {pkl_path}, "
                f"got keys: {list(obj.keys())}"
            )

        data = obj["data"]      # (N, F, P, D) 
        labels = obj["labels"]  # (N,)
        meta = obj.get("meta", [{}] * len(data))


        coord_mean = obj.get("coord_mean", None)
        coord_std = obj.get("coord_std", None)
        if coord_mean is not None and coord_std is not None:
            coord_mean = np.asarray(coord_mean, dtype=np.float32)
            coord_std = np.asarray(coord_std, dtype=np.float32)
            coord_std[coord_std < 1e-6] = 1.0
            self.coord_mean = coord_mean   # (D,)
            self.coord_std = coord_std     # (D,)
            print(f"[AOPHandAoP] Loaded coord_mean/std from pkl, shape={coord_mean.shape}")
        else:
            self.coord_mean = None
            self.coord_std = None
            print("[AOPHandAoP] No coord_mean/std in pkl, skip normalization in Dataset.")

        data = data.astype(np.float32)
        labels = labels.astype(np.int64)

        if data.ndim != 4:
            raise ValueError(
                f"[AOPHandAoP] 'data' should have shape (N,F,P,D), got {data.shape}"
            )

        N, F, P, D = data.shape
        print(f"[AOPHandAoP] Raw shape: N={N}, F={F}, P={P}, D={D}")


        uniq = np.unique(labels)
        old2new = {int(old): int(i) for i, old in enumerate(uniq)}
        new_labels = np.array([old2new[int(l)] for l in labels], dtype=np.int64)

        self.num_classes = len(uniq)
        print(
            f"[AOPHandAoP] Unique labels: {len(uniq)}, "
            f"old range [{uniq.min()},{uniq.max()}] -> new range [0,{self.num_classes-1}]"
        )

        df = pd.DataFrame({
            "data": list(data),          
            "label": new_labels,
            "meta": list(meta),
        })


        idx_perm = np.random.permutation(N)
        n_train = int(0.8 * N)
        train_idx = idx_perm[:n_train]
        test_idx = idx_perm[n_train:]

        self.df_train = df.iloc[train_idx].reset_index(drop=True)
        self.df_test = df.iloc[test_idx].reset_index(drop=True)

        print(
            f"[AOPHandAoP] Loaded N={N} samples, "
            f"num_classes={self.num_classes}, "
            f"train={len(self.df_train)}, test={len(self.df_test)}, "
            f"input_shape={self.input_shape}, time_step={self.T}"
        )

    # ---------------------------------------------------------- #
    #                         get_dataset                        #
    # ---------------------------------------------------------- #
    def get_dataset(self, train: bool = True):


        coord_mean = getattr(self, "coord_mean", None)
        coord_std = getattr(self, "coord_std", None)

        class AOPHandAoPSeqDataset(Dataset):
            def __init__(self, df, target_T, input_shape,
                         coord_mean=None, coord_std=None):
                self.df = df
                self.target_T = target_T
                self.C, self.H, self.W = input_shape

                self.coord_mean = coord_mean  # (D,) or None
                self.coord_std = coord_std    # (D,) or None

                sample_data = df.iloc[0]["data"]  # (F,P,D)
                if sample_data.ndim != 3:
                    raise ValueError(
                        f"[AOPHandAoPSeqDataset] each 'data' should have shape (F,P,D), "
                        f"got {sample_data.shape}"
                    )
                _, P, D = sample_data.shape
                self.P = P
                self.D = D

                self.vec_dim = self.P * self.D        # e.g., 64*4 = 256
                self.img_dim = self.H * self.W        # 64*64 = 4096

            def __len__(self):
                return len(self.df)

            def __getitem__(self, index):
                row = self.df.iloc[index]

                data_np = row["data"].astype(np.float32)  # (F,P,D)
                label = int(row["label"])                 # 0..C-1


                if self.coord_mean is not None and self.coord_std is not None:
                    data_np = (data_np - self.coord_mean.reshape(1, 1, -1)) / \
                              self.coord_std.reshape(1, 1, -1)

                F = data_np.shape[0]


                if F >= self.target_T:
                    idx = np.linspace(0, F - 1, self.target_T).astype(int)
                    x_sampled = data_np[idx]   # (T,P,D)
                else:
                    need = self.target_T - F
                    if F > 0:
                        pad_idx = np.random.choice(F, need, replace=True)
                        pad = data_np[pad_idx]  # (need,P,D)
                        x_sampled = np.concatenate([data_np, pad], axis=0)
                    else:
                        x_sampled = np.zeros(
                            (self.target_T, self.P, self.D),
                            dtype=np.float32
                        )

                T = self.target_T
                x_flat = x_sampled.reshape(T, self.vec_dim)  # (T,vec_dim)

                pad_dim = self.img_dim - self.vec_dim       # 4096 - 256 = 3840
                if pad_dim > 0:
                    pad = np.zeros((T, pad_dim), dtype=np.float32)
                    x_flat = np.concatenate([x_flat, pad], axis=1)
                elif pad_dim < 0:
                    x_flat = x_flat[:, :self.img_dim]

                x_img = x_flat.reshape(T, 1, self.H, self.W)  # (T,1,64,64)
                x_tensor = torch.from_numpy(x_img)

                return x_tensor, label

        target_df = self.df_train if train else self.df_test
        return AOPHandAoPSeqDataset(
            target_df,
            self.T,
            self.input_shape,
            coord_mean=coord_mean,
            coord_std=coord_std,
        )

