import numpy as np
import os
import argparse
from pathlib import Path
import random
import shutil
import sys

def get_data_dir():
    """Get the data directory from environment variables."""
    data_dir = os.environ.get('DPPO_DATA_DIR', None)
    if data_dir is None:
        print("Warning: DPPO_DATA_DIR not found in environment variables.")
        
        # Try alternative environment variables that might exist
        alt_vars = ['DATA_DIR', 'DATASET_DIR', 'ROBOMIMIC_DATA_DIR']
        for var in alt_vars:
            if var in os.environ:
                data_dir = os.environ[var]
                print(f"Using {var} as data directory: {data_dir}")
                break
    
    if data_dir is None:
        print("No data directory environment variable found. Please specify manually.")
    
    return data_dir

def analyze_npz_dataset(npz_path):
    """Analyze an NPZ dataset and return its contents information."""
    print(f"\nAnalyzing dataset: {npz_path}")
    
    if not os.path.exists(npz_path):
        print(f"Error: Dataset not found at {npz_path}")
        return None, None, None
    
    # Load the NPZ file
    with np.load(npz_path, allow_pickle=True) as data:
        # Get all keys
        keys = list(data.keys())
        
        print(f"\nDataset information:")
        print(f"Number of keys in dataset: {len(keys)}")
        
        # Detect format type
        format_type = "unknown"
        num_episodes = None
        total_timesteps = None
        
        # Check for robomimic format (episode_ends)
        if 'episode_ends' in data:
            format_type = "robomimic"
            episode_ends = data['episode_ends'][()]
            num_episodes = len(episode_ends)
            total_timesteps = episode_ends[-1] if len(episode_ends) > 0 else 0
            
            print(f"Format: RoboMimic (episode_ends)")
            print(f"Total episodes: {num_episodes}")
            print(f"Total timesteps: {total_timesteps}")
            
            # Calculate episode lengths
            episode_starts = np.concatenate([[0], episode_ends[:-1]])
            episode_lengths = episode_ends - episode_starts
            
            print(f"Average episode length: {np.mean(episode_lengths):.2f}")
            print(f"Min episode length: {np.min(episode_lengths)}")
            print(f"Max episode length: {np.max(episode_lengths)}")
            
        # Check for stitched format (traj_lengths)
        elif 'traj_lengths' in data:
            format_type = "stitched"
            traj_lengths = data['traj_lengths'][()]
            num_episodes = len(traj_lengths)
            total_timesteps = np.sum(traj_lengths)
            
            print(f"Format: Stitched (traj_lengths)")
            print(f"Total episodes: {num_episodes}")
            print(f"Total timesteps: {total_timesteps}")
            
            print(f"Average episode length: {np.mean(traj_lengths):.2f}")
            print(f"Min episode length: {np.min(traj_lengths)}")
            print(f"Max episode length: {np.max(traj_lengths)}")
        
        else:
            print("Warning: Unknown format - neither 'episode_ends' nor 'traj_lengths' found")
        
        print(f"\nKeys in the dataset:")
        for key in keys:
            try:
                shape = data[key].shape
                dtype = data[key].dtype
                print(f"  - {key}: shape={shape}, dtype={dtype}")
            except:
                print(f"  - {key}: (complex structure)")
    
    return keys, num_episodes, format_type

def convert_episode_format(data, from_format, to_format):
    """Convert between episode_ends and traj_lengths formats."""
    if from_format == to_format:
        return data
    
    if from_format == "robomimic" and to_format == "stitched":
        # Convert episode_ends to traj_lengths
        episode_ends = data['episode_ends'][()]
        episode_starts = np.concatenate([[0], episode_ends[:-1]])
        traj_lengths = episode_ends - episode_starts
        
        # Remove episode_ends and add traj_lengths
        new_data = {k: v for k, v in data.items() if k != 'episode_ends'}
        new_data['traj_lengths'] = traj_lengths
        
    elif from_format == "stitched" and to_format == "robomimic":
        # Convert traj_lengths to episode_ends
        traj_lengths = data['traj_lengths'][()]
        episode_ends = np.cumsum(traj_lengths)
        
        # Remove traj_lengths and add episode_ends
        new_data = {k: v for k, v in data.items() if k != 'traj_lengths'}
        new_data['episode_ends'] = episode_ends
    
    else:
        new_data = data
    
    return new_data

def extract_random_episodes(npz_path, num_episodes, output_path, output_format="auto"):
    """Extract a random subset of episodes from the dataset."""
    print(f"\nExtracting {num_episodes} random episodes...")
    
    # Load the dataset
    with np.load(npz_path, allow_pickle=True) as data:
        # Detect format
        if 'episode_ends' in data:
            format_type = "robomimic"
            episode_ends = data['episode_ends'][()]
            num_total_episodes = len(episode_ends)
            episode_starts = np.concatenate([[0], episode_ends[:-1]])
            episode_lengths = episode_ends - episode_starts
        elif 'traj_lengths' in data:
            format_type = "stitched"
            traj_lengths = data['traj_lengths'][()]
            num_total_episodes = len(traj_lengths)
            episode_starts = np.concatenate([[0], np.cumsum(traj_lengths[:-1])])
            episode_ends = np.cumsum(traj_lengths)
            episode_lengths = traj_lengths
        else:
            print("Error: Unknown dataset format")
            return False
        
        if num_episodes > num_total_episodes:
            print(f"Warning: Requested {num_episodes} episodes but only {num_total_episodes} available.")
            num_episodes = num_total_episodes
        
        # Select random episode indices
        selected_episode_indices = sorted(random.sample(range(num_total_episodes), num_episodes))
        print(f"Selected episode indices: {selected_episode_indices[:10]}..." if len(selected_episode_indices) > 10 else f"Selected episode indices: {selected_episode_indices}")
        
        # Calculate data indices for selected episodes
        data_indices = []
        new_episode_lengths = []
        
        for ep_idx in selected_episode_indices:
            start_idx = episode_starts[ep_idx]
            end_idx = episode_ends[ep_idx]
            
            # Add indices for this episode
            data_indices.extend(range(start_idx, end_idx))
            new_episode_lengths.append(episode_lengths[ep_idx])
        
        # Create new dataset with selected episodes
        new_data = {}
        
        for key in data.keys():
            if key in ['episode_ends', 'traj_lengths']:
                continue  # Handle these separately
            else:
                # Extract data for selected indices
                original_data = data[key]
                if hasattr(original_data, 'shape') and len(original_data.shape) > 0:
                    if len(original_data) == episode_ends[-1]:  # This is per-timestep data
                        new_data[key] = original_data[data_indices]
                    else:
                        # This might be per-episode data
                        if len(original_data) == num_total_episodes:
                            new_data[key] = original_data[selected_episode_indices]
                        else:
                            # Keep as is
                            new_data[key] = original_data
                else:
                    # Scalar or other data
                    new_data[key] = original_data
        
        # Determine output format
        if output_format == "auto":
            output_format = format_type
        
        # Add appropriate episode information
        if output_format == "robomimic":
            new_episode_ends = np.cumsum(new_episode_lengths)
            new_data['episode_ends'] = new_episode_ends
        else:  # stitched
            new_data['traj_lengths'] = np.array(new_episode_lengths)
        
        # Save the new dataset
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        np.savez_compressed(output_path, **new_data)
        print(f"Saved extracted dataset to: {output_path}")
        print(f"Output format: {output_format}")
        
        return True

def copy_normalization_file(source_dir, dest_dir):
    """Copy normalization.npz file from source to destination directory."""
    norm_file = "normalization.npz"
    source_path = os.path.join(source_dir, norm_file)
    dest_path = os.path.join(dest_dir, norm_file)
    
    if os.path.exists(source_path):
        shutil.copy2(source_path, dest_path)
        print(f"Copied {norm_file} to destination folder")
    else:
        print(f"Warning: {norm_file} not found in source directory")

def interactive_extract():
    """Interactive mode for extracting episodes."""
    # Get data directory
    data_dir = get_data_dir()
    
    if data_dir is None:
        data_dir = input("Please enter the data directory path: ").strip()
    
    # List available environments/datasets
    robomimic_dir = os.path.join(data_dir, "robomimic")
    
    if not os.path.exists(robomimic_dir):
        print(f"Error: robomimic directory not found at {robomimic_dir}")
        return
    
    print("\nAvailable environments:")
    envs = []
    for env in os.listdir(robomimic_dir):
        env_path = os.path.join(robomimic_dir, env)
        if os.path.isdir(env_path):
            train_npz = os.path.join(env_path, "train.npz")
            if os.path.exists(train_npz):
                envs.append(env)
                print(f"  - {env}")
    
    if not envs:
        print("No valid environments found with train.npz files")
        return
    
    # Select environment
    env = input("\nEnter environment name: ").strip()
    if env not in envs:
        print(f"Error: Environment '{env}' not found")
        return
    
    # Analyze the dataset
    dataset_path = os.path.join(robomimic_dir, env, "train.npz")
    keys, num_episodes, format_type = analyze_npz_dataset(dataset_path)
    
    if num_episodes is None:
        print("Cannot proceed without episode information")
        return
    
    # Get number of episodes to extract
    print(f"\nTotal episodes available: {num_episodes}")
    while True:
        try:
            num_extract = int(input("How many episodes to extract? "))
            if 0 < num_extract <= num_episodes:
                break
            else:
                print(f"Please enter a number between 1 and {num_episodes}")
        except ValueError:
            print("Please enter a valid number")
    
    # Ask for output format
    print(f"\nCurrent format: {format_type}")
    print("Output format options:")
    print("  1. Same as input (auto)")
    print("  2. RoboMimic format (episode_ends)")
    print("  3. Stitched format (traj_lengths)")
    
    while True:
        format_choice = input("Choose output format (1/2/3): ").strip()
        if format_choice == "1":
            output_format = "auto"
            break
        elif format_choice == "2":
            output_format = "robomimic"
            break
        elif format_choice == "3":
            output_format = "stitched"
            break
        else:
            print("Please enter 1, 2, or 3")
    
    # Create output directory and file
    output_dir = os.path.join(robomimic_dir, f"{env}_{num_extract}ep")
    output_path = os.path.join(output_dir, "train.npz")
    
    print(f"\nOutput will be saved to: {output_dir}")
    
    # Extract episodes
    success = extract_random_episodes(dataset_path, num_extract, output_path, output_format)
    
    if success:
        # Copy normalization file
        source_dir = os.path.join(robomimic_dir, env)
        copy_normalization_file(source_dir, output_dir)
        print("\nExtraction complete!")

def main():
    parser = argparse.ArgumentParser(description='Extract random episodes from NPZ dataset')
    parser.add_argument('--interactive', '-i', action='store_true', 
                      help='Run in interactive mode')
    parser.add_argument('--env', type=str, help='Environment name (e.g., lift, square)')
    parser.add_argument('--num_episodes', type=int, help='Number of episodes to extract')
    parser.add_argument('--data_dir', type=str, help='Data directory (default: from env variable)')
    parser.add_argument('--output_format', type=str, choices=['auto', 'robomimic', 'stitched'], 
                      default='auto', help='Output format type')
    
    args = parser.parse_args()
    
    if args.interactive or (args.env is None and args.num_episodes is None):
        interactive_extract()
    else:
        # Command line mode
        if args.env is None or args.num_episodes is None:
            print("Error: --env and --num_episodes are required in non-interactive mode")
            return
        
        # Get data directory
        data_dir = args.data_dir or get_data_dir()
        if data_dir is None:
            print("Error: No data directory specified")
            return
        
        # Construct paths
        dataset_path = os.path.join(data_dir, "robomimic", args.env, "train.npz")
        output_dir = os.path.join(data_dir, "robomimic", f"{args.env}_{args.num_episodes}ep")
        output_path = os.path.join(output_dir, "train.npz")
        
        # Analyze dataset
        keys, num_episodes, format_type = analyze_npz_dataset(dataset_path)
        if num_episodes is None:
            return
        
        # Extract episodes
        success = extract_random_episodes(dataset_path, args.num_episodes, output_path, args.output_format)
        
        if success:
            # Copy normalization file
            source_dir = os.path.join(data_dir, "robomimic", args.env)
            copy_normalization_file(source_dir, output_dir)
            print("\nExtraction complete!")

if __name__ == "__main__":
    main()
