"""Data loading and preprocessing module"""
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import config as cfg
import glob
from sklearn.preprocessing import StandardScaler
import pickle

class H36MDataset(Dataset):
    """Human3.6M dataset loader"""
    
    def __init__(self, data_root, input_frames=10, output_frames=10, train=True, subjects=None, auto_save_scaler=True):
        """
        Initialize the dataset
        
        Args:
            data_root: dataset root directory
            input_frames: input sequence length
            output_frames: output sequence length
            train: whether it is a training set
            subjects: list of subjects to use, None means using all subjects
        """
        self.data_root = data_root
        self.input_frames = input_frames
        self.output_frames = output_frames
        self.sequence_length = input_frames + output_frames
        
        self.input_length = input_frames
        self.output_length = output_frames

        if subjects is None:
            if train:
                self.subjects = ["S11", "S9", "S6", "S7", "S8"]
            else:
                self.subjects = ["S5", "S1"]
        else:
            self.subjects = subjects
            
        self.action_files = self._get_action_files()
        self.sequences = self._load_sequences()
        
        self.scaler = StandardScaler()
        self._normalize_data()
        
        if train and auto_save_scaler:
            scaler_path = os.path.join(cfg.RESULTS_DIR, 'scaler.pkl')
            self.save_scaler(scaler_path)
            print(f"Scaler saved to {scaler_path}")

    def _get_action_files(self):
        """Get paths of all action files"""
        action_files = []
        for subject in self.subjects:
            subject_path = os.path.join(self.data_root, subject)
            if not os.path.exists(subject_path):
                continue
            files = glob.glob(os.path.join(subject_path, "*.txt"))
            action_files.extend(files)
        return action_files
    
    def _load_sequences(self):
        """Load all sequence data"""
        import pandas as pd
        all_sequences = []
        success_count = 0
        error_count = 0
        
        for file_path in self.action_files:
            try:
                data_df = pd.read_csv(file_path, header=None, sep=',', na_values=['-', 'nan', ''], dtype=str)
                data_df = data_df.apply(pd.to_numeric, errors='coerce')
                
                if data_df.shape[1] != 99:
                    print(f"Warning: {file_path} has {data_df.shape[1]} columns, expected 99. Skipping.")
                    error_count += 1
                    continue
                    
                if data_df.isnull().any().any():
                    data_df = data_df.ffill()
                    if data_df.isnull().any().any():
                        data_df = data_df.bfill()
                    if data_df.isnull().any().any():
                        data_df = data_df.fillna(0.0)
                
                data = data_df.values
                
                if data.shape[0] < self.sequence_length:
                    print(f"Skipping {file_path}: Too few frames ({data.shape[0]} < {self.sequence_length})")
                    error_count += 1
                    continue
                
                for i in range(0, data.shape[0] - self.sequence_length + 1, 5):
                    sequence = data[i:i+self.sequence_length]
                    all_sequences.append(sequence)
                
                success_count += 1
                    
            except Exception as e:
                error_count += 1
                print(f"Error loading {file_path}: {e}")
                continue
        
        print(f"Successfully loaded {success_count} files, errors in {error_count} files.")
        
        if not all_sequences:
            raise ValueError("No valid sequences could be loaded. Check your data format and paths.")
                
        return all_sequences
    
    def _normalize_data(self):
        if not self.sequences:
            raise ValueError("No valid sequences were loaded. Check your data files format.")
        
        all_data = np.vstack([seq.reshape(-1, seq.shape[-1]) for seq in self.sequences])
        
        print("Data statistics before normalization:")
        print(f"  Shape: {all_data.shape}")
        print(f"  Mean: {np.mean(all_data, axis=0)[:10]}...")
        print(f"  Std: {np.std(all_data, axis=0)[:10]}...")
        
        self.scaler.fit(all_data)
        
        for i in range(len(self.sequences)):
            seq_shape = self.sequences[i].shape
            flattened = self.sequences[i].reshape(-1, seq_shape[-1])
            normalized = self.scaler.transform(flattened)
            self.sequences[i] = normalized.reshape(seq_shape)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, index):
        """Get a sample"""
        sequence = self.sequences[index]
        input_seq = sequence[:self.input_length]
        target_seq = sequence[self.input_length:]
        
        input_seq = input_seq.reshape(-1)
        target_seq = target_seq.reshape(-1)
        
        input_tensor = torch.FloatTensor(input_seq)
        target_tensor = torch.FloatTensor(target_seq)
    
        return input_tensor, target_tensor
    
    def save_scaler(self, path):
        """Save scaler to file"""
        try:
            os.makedirs(os.path.dirname(path), exist_ok=True)
            with open(path, 'wb') as f:
                pickle.dump(self.scaler, f)
            print(f"Scaler successfully saved to {path}")
            print(f"File exists: {os.path.exists(path)}")
        except Exception as e:
            print(f"Error saving scaler: {e}")
            import traceback
            traceback.print_exc()
    
    def load_scaler(self, path):
        """Load scaler from file"""
        with open(path, 'rb') as f:
            self.scaler = pickle.load(f)

def get_dataloaders(data_root, batch_size=32, input_frames=10, output_frames=10):
    """Create training and testing dataloaders"""
    train_dataset = H36MDataset(
        data_root=data_root,
        input_frames=input_frames,
        output_frames=output_frames,
        train=True
    )
    
    test_dataset = H36MDataset(
        data_root=data_root,
        input_frames=input_frames,
        output_frames=output_frames,
        train=False
    )
    
    os.makedirs(cfg.RESULTS_DIR, exist_ok=True)
    train_dataset.save_scaler(os.path.join(cfg.RESULTS_DIR, 'scaler.pkl'))
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=False
    )
    
    return train_loader, test_loader
