from pathlib import Path
from os.path import dirname, join, abspath
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional

# Add parent directory to path to find teacher_student package
sys.path.insert(0, dirname(dirname(abspath(__file__))))

from teacher_student.dataloader import TaskDataModule, get_dataset
from teacher_student.task_utils import update_tasks_mnist

ROOT_DIR = dirname(dirname(Path(__file__).resolve()))


def analyze_batch_statistics(batch_data: torch.Tensor) -> Dict:
    """
    Analyze statistical properties of a batch of image data.
    
    Args:
        batch_data: Tensor of shape (N, C, H, W) or (N, H, W)
        
    Returns:
        Dictionary containing statistical information
    """
    # Flatten all dimensions except batch
    if len(batch_data.shape) == 4:  # (N, C, H, W)
        flat_data = batch_data.view(batch_data.shape[0], -1)
    elif len(batch_data.shape) == 3:  # (N, H, W)
        flat_data = batch_data.view(batch_data.shape[0], -1)
    else:
        flat_data = batch_data
    
    # Convert to numpy for easier analysis
    data_np = flat_data.numpy()
    
    stats = {
        'min_value': float(np.min(data_np)),
        'max_value': float(np.max(data_np)),
        'mean_value': float(np.mean(data_np)),
        'std_value': float(np.std(data_np)),
        'median_value': float(np.median(data_np)),
        'percentile_1': float(np.percentile(data_np, 1)),
        'percentile_99': float(np.percentile(data_np, 99)),
        'has_negative': bool(np.any(data_np < 0)),
        'negative_count': int(np.sum(data_np < 0)),
        'total_pixels': int(data_np.size),
        'negative_percentage': float(np.sum(data_np < 0) / data_np.size * 100),
        'zero_count': int(np.sum(data_np == 0)),
        'positive_count': int(np.sum(data_np > 0))
    }
    
    return stats


def collect_sample_data(dataloader, sample_size: int = 1000) -> torch.Tensor:
    """
    Collect a sample of data from a dataloader.
    
    Args:
        dataloader: PyTorch DataLoader
        sample_size: Maximum number of samples to collect
        
    Returns:
        Tensor containing sampled image data
    """
    collected_data = []
    collected_samples = 0
    
    for batch in dataloader:
        images, _ = batch
        batch_size = images.shape[0]
        
        if collected_samples + batch_size <= sample_size:
            collected_data.append(images)
            collected_samples += batch_size
        else:
            # Take only what we need to reach sample_size
            remaining = sample_size - collected_samples
            collected_data.append(images[:remaining])
            collected_samples += remaining
            break
    
    if collected_data:
        return torch.cat(collected_data, dim=0)
    else:
        return torch.empty(0)


def visualize_value_distribution(data: torch.Tensor, title: str = "Value Distribution", 
                               save_path: Optional[str] = None) -> None:
    """
    Create histogram visualization of data value distribution.
    
    Args:
        data: Tensor containing the data to visualize
        title: Title for the plot
        save_path: Optional path to save the plot
    """
    # Flatten the data
    flat_data = data.flatten().numpy()
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Full histogram
    axes[0].hist(flat_data, bins=100, alpha=0.7, edgecolor='black')
    axes[0].set_xlabel('Pixel Value')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title(f'{title} - Full Range')
    axes[0].grid(True, alpha=0.3)
    axes[0].axvline(x=0, color='red', linestyle='--', alpha=0.8, label='Zero line')
    axes[0].legend()
    
    # Zoomed histogram around zero
    zero_range_data = flat_data[(flat_data >= -0.5) & (flat_data <= 0.5)]
    if len(zero_range_data) > 0:
        axes[1].hist(zero_range_data, bins=50, alpha=0.7, edgecolor='black')
        axes[1].set_xlabel('Pixel Value')
        axes[1].set_ylabel('Frequency')
        axes[1].set_title(f'{title} - Around Zero')
        axes[1].grid(True, alpha=0.3)
        axes[1].axvline(x=0, color='red', linestyle='--', alpha=0.8, label='Zero line')
        axes[1].legend()
    else:
        axes[1].text(0.5, 0.5, 'No data in [-0.5, 0.5] range', 
                    transform=axes[1].transAxes, ha='center', va='center')
        axes[1].set_title(f'{title} - Around Zero (No Data)')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to: {save_path}")
    
    plt.show()


def check_dataset_values(data_module: TaskDataModule, sample_size: int = 1000, 
                        verbose: bool = True, visualize: bool = False,
                        save_plots: bool = False, save_dir: Optional[str] = None) -> Dict:
    """
    Comprehensive analysis of dataset values, specifically checking for negative values.
    
    Args:
        data_module: TaskDataModule instance (must be setup already)
        sample_size: Number of samples to analyze from each split
        verbose: Whether to print detailed results
        visualize: Whether to create histogram plots
        save_plots: Whether to save plots to disk
        save_dir: Directory to save plots (if save_plots=True)
        
    Returns:
        Dictionary containing comprehensive analysis results
    """
    
    if verbose:
        print("="*60)
        print("DATASET VALUE ANALYSIS")
        print("="*60)
        print(f"Analyzing up to {sample_size} samples from each data split...")
        print()
    
    results = {
        'train': {},
        'val': {},
        'test': {},
        'overall': {}
    }
    
    # Analyze each data split
    splits = {
        'train': data_module.train_dataloader(),
        'test': data_module.test_dataloader()
    }
    
    all_data = []
    
    for split_name, dataloader in splits.items():
        if verbose:
            print(f"Analyzing {split_name} split...")
        
        # Collect sample data
        sample_data = collect_sample_data(dataloader, sample_size)
        
        if sample_data.numel() == 0:
            if verbose:
                print(f"  No data found in {split_name} split!")
            continue
            
        all_data.append(sample_data)
        
        # Analyze statistics
        stats = analyze_batch_statistics(sample_data)
        results[split_name] = stats
        
        if verbose:
            print(f"  Samples analyzed: {sample_data.shape[0]}")
            print(f"  Image shape: {sample_data.shape[1:]}")
            print(f"  Value range: [{stats['min_value']:.4f}, {stats['max_value']:.4f}]")
            print(f"  Mean ± Std: {stats['mean_value']:.4f} ± {stats['std_value']:.4f}")
            print(f"  Has negative values: {stats['has_negative']}")
            if stats['has_negative']:
                print(f"  Negative pixels: {stats['negative_count']:,} / {stats['total_pixels']:,} ({stats['negative_percentage']:.2f}%)")
            print(f"  Zero pixels: {stats['zero_count']:,} ({stats['zero_count']/stats['total_pixels']*100:.2f}%)")
            print(f"  Positive pixels: {stats['positive_count']:,} ({stats['positive_count']/stats['total_pixels']*100:.2f}%)")
            print()
        
        # Create visualization if requested
        if visualize:
            plot_title = f"{split_name.capitalize()} Split"
            save_path = None
            if save_plots and save_dir:
                save_path = join(save_dir, f"value_distribution_{split_name}.png")
            visualize_value_distribution(sample_data, plot_title, save_path)
    
    # Overall analysis
    if all_data:
        overall_data = torch.cat(all_data, dim=0)
        overall_stats = analyze_batch_statistics(overall_data)
        results['overall'] = overall_stats
        
        if verbose:
            print("OVERALL DATASET SUMMARY:")
            print("-" * 40)
            print(f"Total samples analyzed: {overall_data.shape[0]}")
            print(f"Value range: [{overall_stats['min_value']:.4f}, {overall_stats['max_value']:.4f}]")
            print(f"Mean ± Std: {overall_stats['mean_value']:.4f} ± {overall_stats['std_value']:.4f}")
            print(f"Median: {overall_stats['median_value']:.4f}")
            print(f"1st-99th percentile: [{overall_stats['percentile_1']:.4f}, {overall_stats['percentile_99']:.4f}]")
            print()
            print("NEGATIVE VALUE ANALYSIS:")
            print(f"  Contains negative values: {overall_stats['has_negative']}")
            if overall_stats['has_negative']:
                print(f"  Negative pixels: {overall_stats['negative_count']:,} / {overall_stats['total_pixels']:,}")
                print(f"  Percentage negative: {overall_stats['negative_percentage']:.2f}%")
            print()
            
            # Interpretation
            print("INTERPRETATION:")
            if overall_stats['has_negative']:
                print("  ✓ Dataset CONTAINS negative values")
                print("  ✓ This is expected with normalization: transforms.Normalize((0.5,), (0.5,))")
                print("  ✓ Original [0,1] range is transformed to [-1,1] range")
            else:
                print("  ✗ Dataset does NOT contain negative values")
                print("  ✗ This might indicate missing normalization or different preprocessing")
            print("="*60)
    
    return results


def compare_normalized_vs_unnormalized(data_dir: str, selected_labels: List[int] = None,
                                     sample_size: int = 500) -> None:
    """
    Compare dataset values with and without normalization to demonstrate the effect.
    
    Args:
        data_dir: Path to data directory
        selected_labels: List of labels to include (None for all)
        sample_size: Number of samples to analyze
    """
    print("="*80)
    print("COMPARISON: NORMALIZED vs UNNORMALIZED DATA")
    print("="*80)
    
    # Create data modules with and without normalization
    from torchvision import transforms
    
    # Without normalization (only ToTensor)
    class TaskDataModuleNoNorm(TaskDataModule):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.transform = transforms.Compose([transforms.ToTensor()])
    
    if selected_labels is None:
        selected_labels = [0, 1, 2, 3, 4]
    
    # Unnormalized data
    dm_unnorm = TaskDataModuleNoNorm(
        dataset_class=get_dataset("MNIST"),
        batch_size=64,
        selected_labels=selected_labels,
        data_dir=data_dir,
        seed=42,
        num_workers=1
    )
    dm_unnorm.setup()
    
    # Normalized data (default)
    dm_norm = TaskDataModule(
        dataset_class=get_dataset("MNIST"),
        batch_size=64,
        selected_labels=selected_labels,
        data_dir=data_dir,
        seed=42,
        num_workers=1
    )
    dm_norm.setup()
    
    print("1. UNNORMALIZED DATA (ToTensor only):")
    print("-" * 50)
    results_unnorm = check_dataset_values(dm_unnorm, sample_size=sample_size, verbose=True, visualize=False)
    
    print("\n2. NORMALIZED DATA (with Normalize((0.5,), (0.5,))):")
    print("-" * 50)
    results_norm = check_dataset_values(dm_norm, sample_size=sample_size, verbose=True, visualize=False)
    
    print("\n3. COMPARISON SUMMARY:")
    print("-" * 50)
    unnorm_stats = results_unnorm['overall']
    norm_stats = results_norm['overall']
    
    print(f"Unnormalized range: [{unnorm_stats['min_value']:.4f}, {unnorm_stats['max_value']:.4f}]")
    print(f"Normalized range:   [{norm_stats['min_value']:.4f}, {norm_stats['max_value']:.4f}]")
    print(f"Unnormalized mean:  {unnorm_stats['mean_value']:.4f}")
    print(f"Normalized mean:    {norm_stats['mean_value']:.4f}")
    print(f"Unnormalized has negatives: {unnorm_stats['has_negative']}")
    print(f"Normalized has negatives:   {norm_stats['has_negative']}")
    print("="*80)


def main():
    """
    Example usage and testing of the dataset checking functions.
    """
    # Set up paths
    data_dir = Path(join(ROOT_DIR, "data"))
    
    print("Testing dataset_checker.py functionality...")
    print()
    
    # Test 1: Basic MNIST analysis
    print("TEST 1: Basic MNIST Dataset Analysis")
    print("-" * 50)
    
    task_labels = [0, 1, 2, 3, 4]
    dm = TaskDataModule(
        dataset_class=get_dataset("MNIST"),
        batch_size=64,
        selected_labels=task_labels,
        data_dir=str(data_dir),
        seed=42,
        num_workers=1
    )
    dm.setup()
    
    results = check_dataset_values(dm, sample_size=1000, verbose=True, visualize=False)
    
    # Test 2: Compare normalized vs unnormalized
    print("\nTEST 2: Normalized vs Unnormalized Comparison")
    print("-" * 50)
    compare_normalized_vs_unnormalized(str(data_dir), task_labels, sample_size=500)
    
    # Test 3: Different task configurations
    print("\nTEST 3: Different Task Configurations")
    print("-" * 50)
    
    for control_tasks in ["half", "equal"]:
        print(f"\nAnalyzing '{control_tasks}' task configuration:")
        task1_labels, task2_labels = update_tasks_mnist(control_tasks)
        
        dm_task = TaskDataModule(
            dataset_class=get_dataset("MNIST"),
            batch_size=64,
            selected_labels=task1_labels,
            data_dir=str(data_dir),
            seed=42,
            num_workers=1
        )
        dm_task.setup()
        
        results = check_dataset_values(dm_task, sample_size=500, verbose=False)
        overall = results['overall']
        print(f"  Labels: {task1_labels}")
        print(f"  Range: [{overall['min_value']:.4f}, {overall['max_value']:.4f}]")
        print(f"  Has negatives: {overall['has_negative']} ({overall['negative_percentage']:.1f}%)")
    
    print("\nAll tests completed successfully!")


if __name__ == "__main__":
    main()
