"""
Dataset classes for BSNP training with caching support.
"""

import torch
from torch.utils.data import Dataset
import numpy as np
from typing import Tuple, Optional, Callable, List, Dict
from pathlib import Path
import hashlib
import json
import pickle
from data.pde_solver import NonlinearPoissonSolver1D


def collate_variable_length(batch: List[Dict]) -> Dict:
    """
    Custom collate function for variable-length context and target sets.
    
    Args:
        batch: List of sample dictionaries
    
    Returns:
        Batched dictionary with padded tensors (on CPU)
    """
    # Get batch size
    batch_size = len(batch)
    
    # Find max lengths
    max_context = max(sample['x_context'].shape[0] for sample in batch)
    max_target = max(sample['x_target'].shape[0] for sample in batch)
    
    # Get dimensions
    x_dim = batch[0]['x_context'].shape[1]
    y_dim = batch[0]['y_context'].shape[1]
    
    # Initialize tensors on CPU (collate_fn shouldn't use device)
    dtype = batch[0]['x_context'].dtype
    
    x_context_padded = torch.zeros(batch_size, max_context, x_dim, dtype=dtype)
    y_context_padded = torch.zeros(batch_size, max_context, y_dim, dtype=dtype)
    x_target_padded = torch.zeros(batch_size, max_target, x_dim, dtype=dtype)
    y_target_padded = torch.zeros(batch_size, max_target, y_dim, dtype=dtype)
    
    # Masks for valid points
    context_mask = torch.zeros(batch_size, max_context, dtype=torch.bool)
    target_mask = torch.zeros(batch_size, max_target, dtype=torch.bool)
    
    # Fill tensors
    for i, sample in enumerate(batch):
        n_context = sample['x_context'].shape[0]
        n_target = sample['x_target'].shape[0]
        
        x_context_padded[i, :n_context] = sample['x_context']
        y_context_padded[i, :n_context] = sample['y_context']
        x_target_padded[i, :n_target] = sample['x_target']
        y_target_padded[i, :n_target] = sample['y_target']
        
        context_mask[i, :n_context] = True
        target_mask[i, :n_target] = True
    
    result = {
        'x_context': x_context_padded,
        'y_context': y_context_padded,
        'x_target': x_target_padded,
        'y_target': y_target_padded,
        'context_mask': context_mask,
        'target_mask': target_mask
    }
    
    # Stack other tensors (same shape across batch)
    if 'lambda_params' in batch[0]:
        result['lambda_params'] = torch.stack([sample['lambda_params'] for sample in batch])
    
    if 'solution' in batch[0]:
        result['solution'] = torch.stack([sample['solution'] for sample in batch])
    
    if 'x_grid' in batch[0]:
        result['x_grid'] = batch[0]['x_grid']  # Same for all samples
    
    return result


class NonlinearPoissonDataset(Dataset):
    """
    Dataset for nonlinear Poisson equation with physics-informed constraints.
    Supports data caching to avoid recomputing PDE solutions.
    """
    
    def __init__(
        self,
        num_samples: int,
        n_grid_points: int = 128,
        n_chebyshev: int = 5,
        n_context_range: Tuple[int, int] = (10, 50),
        n_target_range: Tuple[int, int] = (10, 50),
        noise_std: float = 0.01,
        x_range: Tuple[float, float] = (-1.0, 1.0),
        w_range: Tuple[float, float] = (0.5, 2.0),
        device: str = 'cpu',
        precompute: bool = True,
        seed: Optional[int] = None,
        cache_dir: Optional[str] = None,
        force_regenerate: bool = False
    ):
        """
        Initialize dataset with optional caching.
        
        Args:
            num_samples: Number of PDE solutions to generate
            n_grid_points: Number of grid points for PDE solver
            n_chebyshev: Number of Chebyshev coefficients
            n_context_range: Range of context points (min, max)
            n_target_range: Range of target points (min, max)
            noise_std: Standard deviation of observation noise
            x_range: Spatial domain range
            w_range: Range for nonlinearity parameter w
            device: Device parameter (kept for compatibility, data stored on CPU)
            precompute: Whether to precompute all solutions
            seed: Random seed for reproducibility
            cache_dir: Directory to cache computed solutions
            force_regenerate: Force regeneration even if cache exists
        """
        self.num_samples = num_samples
        self.n_grid_points = n_grid_points
        self.n_chebyshev = n_chebyshev
        self.n_context_range = n_context_range
        self.n_target_range = n_target_range
        self.noise_std = noise_std
        self.x_range = x_range
        self.w_range = w_range
        # Always store data on CPU to avoid multiprocessing issues
        self.device = 'cpu'
        self.precompute = precompute
        self.seed = seed
        
        # Setup cache directory
        if cache_dir is None:
            cache_dir = Path(__file__).parent / 'cache'
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # Set random seed
        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)
        
        # Generate cache filename based on parameters
        self.cache_filename = self._generate_cache_filename()
        self.cache_path = self.cache_dir / self.cache_filename
        
        # Create grid on CPU
        self.x_grid = torch.linspace(x_range[0], x_range[1], n_grid_points)
        
        # Try to load from cache or generate new data
        if not force_regenerate and self.cache_path.exists() and precompute:
            print(f"📦 Loading cached data from {self.cache_path.name}")
            self._load_from_cache()
        else:
            if force_regenerate and self.cache_path.exists():
                print(f"🔨 Force regenerate mode - ignoring existing cache")
            
            if precompute:
                self._generate_solutions()
                self._save_to_cache()
            else:
                self.solutions = None
                self.lambda_params = None
    
    def _generate_cache_filename(self) -> str:
        """Generate a unique cache filename based on dataset parameters."""
        params = {
            'num_samples': self.num_samples,
            'n_grid_points': self.n_grid_points,
            'n_chebyshev': self.n_chebyshev,
            'x_range': self.x_range,
            'w_range': self.w_range,
            'seed': self.seed,
            'precompute': self.precompute
        }
        
        # Create hash of parameters
        param_str = json.dumps(params, sort_keys=True)
        param_hash = hashlib.md5(param_str.encode()).hexdigest()[:12]
        
        # Create descriptive filename
        filename = (f"nonlinear_poisson_"
                   f"n{self.num_samples}_"
                   f"g{self.n_grid_points}_"
                   f"c{self.n_chebyshev}_"
                   f"s{self.seed}_"
                   f"{param_hash}.pkl")
        
        return filename
    
    def _save_to_cache(self):
        """Save dataset to cache file."""
        print(f"💾 Saving dataset to cache: {self.cache_path.name}")
        
        cache_data = {
            'parameters': {
                'num_samples': self.num_samples,
                'n_grid_points': self.n_grid_points,
                'n_chebyshev': self.n_chebyshev,
                'n_context_range': self.n_context_range,
                'n_target_range': self.n_target_range,
                'noise_std': self.noise_std,
                'x_range': self.x_range,
                'w_range': self.w_range,
                'seed': self.seed,
                'precompute': self.precompute
            },
            'solutions': self.solutions,
            'lambda_params': self.lambda_params,
            'x_grid': self.x_grid
        }
        
        with open(self.cache_path, 'wb') as f:
            pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        # Save metadata as JSON for human readability
        metadata_path = self.cache_path.with_suffix('.json')
        with open(metadata_path, 'w') as f:
            json.dump(cache_data['parameters'], f, indent=2)
        
        file_size_mb = self.cache_path.stat().st_size / (1024 * 1024)
        print(f"✅ Cache saved successfully ({file_size_mb:.2f} MB)")
    
    def _load_from_cache(self):
        """Load dataset from cache file."""
        try:
            with open(self.cache_path, 'rb') as f:
                cache_data = pickle.load(f)
            
            # Load data
            self.solutions = cache_data['solutions']
            self.lambda_params = cache_data['lambda_params']
            self.x_grid = cache_data['x_grid']
            
            file_size_mb = self.cache_path.stat().st_size / (1024 * 1024)
            print(f"✅ Successfully loaded {self.num_samples} samples from cache ({file_size_mb:.2f} MB)")
            
        except Exception as e:
            print(f"❌ Error loading cache: {e}")
            print(f"   Regenerating dataset...")
            self._generate_solutions()
            self._save_to_cache()
    
    def _generate_solutions(self):
        """Generate all PDE solutions."""
        print(f"🔨 Generating {self.num_samples} PDE parameter samples...")
        
        # Generate random parameters
        xi_samples = np.random.randn(self.num_samples, self.n_chebyshev)
        w_samples = np.random.uniform(
            self.w_range[0], self.w_range[1], self.num_samples
        )
        
        # Solve PDEs
        print("⚙️  Solving PDEs...")
        solver = NonlinearPoissonSolver1D(n_points=self.n_grid_points)
        solutions_np, info_list = solver.solve_batch(xi_samples, w_samples)
        
        # Check convergence
        converged = sum(1 for info in info_list if info['converged'])
        print(f"   Converged: {converged}/{self.num_samples}")
        
        # Store solutions as torch tensors on CPU
        self.solutions = torch.from_numpy(solutions_np).float()
        
        # Store lambda parameters (xi and w concatenated)
        lambda_params = np.concatenate([
            xi_samples,
            w_samples[:, np.newaxis]
        ], axis=1)
        self.lambda_params = torch.from_numpy(lambda_params).float()
        
        print(f"✅ Generated {self.num_samples} solutions")
    
    def __len__(self) -> int:
        return self.num_samples
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a single sample (returns CPU tensors)."""
        # Get solution
        if self.solutions is None:
            # Generate on the fly
            xi = np.random.randn(self.n_chebyshev)
            w = np.random.uniform(self.w_range[0], self.w_range[1])
            
            solver = NonlinearPoissonSolver1D(n_points=self.n_grid_points)
            solution_np, _ = solver.solve(xi, w)
            solution = torch.from_numpy(solution_np).float()
            
            lambda_param = torch.tensor(
                np.concatenate([xi, [w]]), dtype=torch.float32
            )
        else:
            solution = self.solutions[idx]
            lambda_param = self.lambda_params[idx]
        
        # Sample context and target points
        n_context = np.random.randint(self.n_context_range[0], self.n_context_range[1] + 1)
        n_target = np.random.randint(self.n_target_range[0], self.n_target_range[1] + 1)
        
        # Sample from grid
        all_indices = torch.randperm(self.n_grid_points)
        context_indices = all_indices[:n_context]
        target_indices = all_indices[n_context:n_context + n_target]
        
        x_context = self.x_grid[context_indices].unsqueeze(-1)
        y_context = solution[context_indices].unsqueeze(-1)
        
        # Add noise to observations
        y_context = y_context + torch.randn_like(y_context) * self.noise_std
        
        x_target = self.x_grid[target_indices].unsqueeze(-1)
        y_target = solution[target_indices].unsqueeze(-1)
        
        return {
            'x_context': x_context,
            'y_context': y_context,
            'x_target': x_target,
            'y_target': y_target,
            'lambda_params': lambda_param,
            'solution': solution,
            'x_grid': self.x_grid
        }
    
    @staticmethod
    def get_collate_fn():
        """Get custom collate function for DataLoader."""
        return collate_variable_length


# Cache management utilities
def get_cache_info(cache_dir: Optional[str] = None) -> list:
    """Get information about cached datasets."""
    if cache_dir is None:
        cache_dir = Path(__file__).parent / 'cache'
    else:
        cache_dir = Path(cache_dir)
    
    if not cache_dir.exists():
        return []
    
    cache_files = list(cache_dir.glob("*.pkl"))
    cache_info = []
    
    for cache_file in cache_files:
        metadata_file = cache_file.with_suffix('.json')
        
        info = {
            'filename': cache_file.name,
            'path': str(cache_file),
            'size_mb': cache_file.stat().st_size / (1024 * 1024),
        }
        
        if metadata_file.exists():
            with open(metadata_file, 'r') as f:
                info['parameters'] = json.load(f)
        
        cache_info.append(info)
    
    return cache_info


def clear_cache(cache_dir: Optional[str] = None, confirm: bool = True):
    """Clear all cached datasets."""
    if cache_dir is None:
        cache_dir = Path(__file__).parent / 'cache'
    else:
        cache_dir = Path(cache_dir)
    
    if not cache_dir.exists():
        print("📂 No cache directory found")
        return
    
    cache_files = list(cache_dir.glob("*.pkl")) + list(cache_dir.glob("*.json"))
    
    if not cache_files:
        print("📂 No cache files found")
        return
    
    total_size = sum(f.stat().st_size for f in cache_files) / (1024 * 1024)
    
    print(f"📦 Found {len(cache_files)} cache files ({total_size:.2f} MB)")
    
    if confirm:
        response = input("🗑️  Delete all cache files? [y/N]: ")
        if response.lower() != 'y':
            print("❌ Cancelled")
            return
    
    for cache_file in cache_files:
        cache_file.unlink()
    
    print(f"✅ Deleted {len(cache_files)} cache files ({total_size:.2f} MB)")


class SyntheticFunctionDataset(Dataset):
    """Simple synthetic function dataset for testing."""
    
    def __init__(
        self,
        num_samples: int,
        func: Callable,
        n_context_range: Tuple[int, int] = (10, 50),
        n_target_range: Tuple[int, int] = (10, 50),
        x_range: Tuple[float, float] = (-1.0, 1.0),
        noise_std: float = 0.01,
        device: str = 'cpu'
    ):
        self.num_samples = num_samples
        self.func = func
        self.n_context_range = n_context_range
        self.n_target_range = n_target_range
        self.x_range = x_range
        self.noise_std = noise_std
        self.device = 'cpu'
    
    def __len__(self) -> int:
        return self.num_samples
    
    def __getitem__(self, idx: int) -> dict:
        n_context = np.random.randint(self.n_context_range[0], self.n_context_range[1] + 1)
        n_target = np.random.randint(self.n_target_range[0], self.n_target_range[1] + 1)
        
        x_context = torch.rand(n_context, 1) * (self.x_range[1] - self.x_range[0]) + self.x_range[0]
        x_target = torch.rand(n_target, 1) * (self.x_range[1] - self.x_range[0]) + self.x_range[0]
        
        y_context = self.func(x_context)
        y_target = self.func(x_target)
        
        if self.noise_std > 0:
            y_context = y_context + torch.randn_like(y_context) * self.noise_std
            y_target = y_target + torch.randn_like(y_target) * self.noise_std
        
        return {
            'x_context': x_context,
            'y_context': y_context,
            'x_target': x_target,
            'y_target': y_target
        }
    
    @staticmethod
    def get_collate_fn():
        return collate_variable_length