import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import config as cfg
import glob
from sklearn.preprocessing import StandardScaler
import pickle
import scipy.io as sio

class PennActionDataset(Dataset):
    def __init__(self, data_root, input_frames=10, output_frames=10, train=True, auto_save_scaler=True):
        self.data_root = data_root
        self.input_frames = input_frames
        self.output_frames = output_frames
        self.sequence_length = input_frames + output_frames
        self.input_length = input_frames
        self.output_length = output_frames
        self.action_files = self._get_action_files(train)
        self.sequences = self._load_sequences()
        self.scaler = StandardScaler()
        self._normalize_data()
        if train and auto_save_scaler:
            scaler_path = os.path.join(cfg.RESULTS_DIR, 'scaler.pkl')
            self.save_scaler(scaler_path)
            print(f"Scaler saved to {scaler_path}")

    def _get_action_files(self, train):
        all_files = sorted(glob.glob(os.path.join(self.data_root, '*.mat')))
        total_files = len(all_files)
        train_split = int(total_files * 0.8)
        if train:
            return all_files[:train_split]
        else:
            return all_files[train_split:]

    def _load_sequences(self):
        all_sequences = []
        success_count = 0
        error_count = 0
        for file_path in self.action_files:
            try:
                mat_data = sio.loadmat(file_path)
                x_coords = mat_data['x']
                y_coords = mat_data['y']
                pose_data = np.concatenate([x_coords, y_coords], axis=1)
                if pose_data.shape[0] < self.sequence_length:
                    print(f"Skipping {file_path}: Too few frames ({pose_data.shape[0]} < {self.sequence_length})")
                    error_count += 1
                    continue
                for i in range(0, pose_data.shape[0] - self.sequence_length + 1, 5):
                    sequence = pose_data[i:i+self.sequence_length]
                    all_sequences.append(sequence)
                success_count += 1
            except Exception as e:
                error_count += 1
                print(f"Error loading {file_path}: {e}")
                continue
        print(f"Successfully loaded {success_count} files, errors in {error_count} files.")
        if not all_sequences:
            raise ValueError("No valid sequences could be loaded. Check your data format and paths.")
        return all_sequences

    def _normalize_data(self):
        if not self.sequences:
            raise ValueError("No valid sequences were loaded. Check your data files format.")
        all_data = np.vstack(self.sequences)
        print("Data statistics before normalization:")
        print(f"  Shape: {all_data.shape}")
        print(f"  Mean: {np.mean(all_data, axis=0)[:10]}...")
        print(f"  Std: {np.std(all_data, axis=0)[:10]}...")
        self.scaler.fit(all_data)
        for i in range(len(self.sequences)):
            self.sequences[i] = self.scaler.transform(self.sequences[i])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        sequence = self.sequences[index]
        input_seq = sequence[:self.input_length]
        target_seq = sequence[self.input_length:]
        input_seq = input_seq.reshape(-1)
        target_seq = target_seq.reshape(-1)
        input_tensor = torch.FloatTensor(input_seq)
        target_tensor = torch.FloatTensor(target_seq)
        return input_tensor, target_tensor

    def save_scaler(self, path):
        try:
            os.makedirs(os.path.dirname(path), exist_ok=True)
            with open(path, 'wb') as f:
                pickle.dump(self.scaler, f)
            print(f"Scaler saved to {path}")
            print(f"File exists: {os.path.exists(path)}")
        except Exception as e:
            print(f"Error saving scaler: {e}")
            import traceback
            traceback.print_exc()

    def load_scaler(self, path):
        with open(path, 'rb') as f:
            self.scaler = pickle.load(f)

def get_dataloaders(data_root, batch_size=32, input_frames=10, output_frames=10):
    train_dataset = PennActionDataset(
        data_root=data_root,
        input_frames=input_frames,
        output_frames=output_frames,
        train=True
    )
    test_dataset = PennActionDataset(
        data_root=data_root,
        input_frames=input_frames,
        output_frames=output_frames,
        train=False
    )
    os.makedirs(cfg.RESULTS_DIR, exist_ok=True)
    train_dataset.save_scaler(os.path.join(cfg.RESULTS_DIR, 'scaler.pkl'))
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=False
    )
    return train_loader, test_loader
