import numpy as np
import h5py
import torch
import torch.utils.data as data
import logging
import yaml
import os
from pathlib import Path
from typing import Optional

# Setup Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def _env_or_default(env_key: str, default: Optional[str]) -> Optional[str]:
    """Helper to read a path-like value from environment variables."""
    val = os.environ.get(env_key)
    return val if val not in (None, "") else default

def get_training_config(config_path='config.yaml'):
    """
    Helper to read the training configuration section from yaml.
    """
    if not os.path.exists(config_path):
        config_path = os.path.join(os.path.dirname(__file__), 'config.yaml')
        if not os.path.exists(config_path):
             config_path = '../config.yaml' 
            
    try:
        with open(config_path, 'r') as f:
            full_config = yaml.safe_load(f)
        
        selected_model = full_config.get('selected_model', 'HOM')
        return full_config['trainings'][selected_model]
    except Exception as e:
        logger.warning(f"Could not load config from {config_path}. Using defaults. Error: {e}")
        return {}

class GLORYSDataLoader:
    """Singleton Data Manager."""
    _instance = None
    _file_handles = {}
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(GLORYSDataLoader, cls).__new__(cls)
        return cls._instance
    
    @classmethod
    def get_file_handles(cls, years, base_path, file_prefix, file_suffix='.h5'):
        key = (tuple(years), base_path, file_prefix, file_suffix)
        
        if key not in cls._file_handles:
            cls._file_handles[key] = {}
            for year in years:
                try:
                    filename = f'{file_prefix}_{year}{file_suffix}'
                    pc_file = Path(base_path) / filename
                    
                    if not pc_file.exists():
                        if 'merged' in file_suffix:
                             logger.warning(f"Merged data file not found: {pc_file}")
                             continue
                        raise FileNotFoundError(f"Data file not found: {pc_file}")

                    cls._file_handles[key][year] = h5py.File(pc_file, 'r')['data']
                    logger.info(f"Loaded {filename}")
                    
                except Exception as e:
                    logger.error(f"Failed to load {file_prefix} year {year}: {e}")
                    raise
        
        return cls._file_handles[key]
    
    @classmethod
    def close_handles(cls, years=None):
        keys_to_remove = []
        for key, handles in cls._file_handles.items():
            if years is None or key[0] == tuple(years):
                for handle in handles.values():
                    try:
                        handle.close()
                    except Exception:
                        pass
                keys_to_remove.append(key)
        
        for k in keys_to_remove:
            del cls._file_handles[k]

class GLORYSBaseDataset(data.Dataset):
    """
    Base Dataset class without Sea Ice variables.
    """
    
    def __init__(self, years, day_interval=3, ds_factor=1, 
                 lat_range=(0, 720), lon_range=(0, 1440), time_range=1, base_path=None):
        super(GLORYSBaseDataset, self).__init__()
        
        self.ds_factor = ds_factor
        self.lat_start, self.lat_end = lat_range
        self.lon_start, self.lon_end = lon_range
        self.time_range = time_range 
        self.day_interval = day_interval
        self.years = years
        
        # --- Global Data Path ---
        self.base_path = base_path or _env_or_default("HOM_GLORYS_05_H5_DIR", None)
        if self.base_path is None:
            raise ValueError(
                "Global GLORYS base_path is not set. Pass base_path=... or set env HOM_GLORYS_05_H5_DIR "
                "to the directory containing files like GLORYS_05_<YEAR>.h5."
            )
        self.global_prefix = 'GLORYS_05'

        # --- Regional Data Path ---
        self.regional_base_path = _env_or_default("HOM_GLORYS_REGIONAL_025_H5_DIR", None)
        if self.regional_base_path is None:
            raise ValueError(
                "Regional GLORYS base_path is not set. Set env HOM_GLORYS_REGIONAL_025_H5_DIR "
                "to the directory containing files like GLORYS_pc_025_<YEAR>.h5."
            )
        self.regional_prefix = 'GLORYS_pc_025'

        # --- Merged Forecast Data Path ---
        # Optional: merged forecast results generated by another pipeline.
        self.merged_base_path = _env_or_default("HOM_REGIONAL_MERGED_RESULTS_DIR", None)
        self.merged_prefix = 'GLORYS_regional'
        
        # Indices are capped to the first 360 days of each year.
        self.day_indices = self._generate_day_indices()
        
        self.global_handles = GLORYSDataLoader.get_file_handles(
            self.years, self.base_path, self.global_prefix
        )
        self.regional_handles = GLORYSDataLoader.get_file_handles(
            self.years, self.regional_base_path, self.regional_prefix
        )
        self.merged_handles = {}
        if self.merged_base_path is not None:
            self.merged_handles = GLORYSDataLoader.get_file_handles(
                self.years, self.merged_base_path, self.merged_prefix, file_suffix='_merged.h5'
            )

        train_config = get_training_config()
        self.use_mean_field = train_config.get('mean_field', False)
        self.thetao_mean_full = None
        self.so_mean_full = None
        self.zos_mean_full = None

        if self.use_mean_field:
            self.mean_file_path = _env_or_default("HOM_CLIMATE_MEAN_NPY", None)
            if self.mean_file_path is None:
                raise ValueError(
                    "mean_field=True requires a climate mean npy file. "
                    "Set env HOM_CLIMATE_MEAN_NPY to the path of climate_mean_s_t_ssh.npy."
                )
            try:
                raw_mean_data = np.load(self.mean_file_path)
                self.thetao_mean_full = np.flip(raw_mean_data[:, 69:92:2, 1:, :], 2)
                self.so_mean_full = np.flip(raw_mean_data[:, 0:23:2, 1:, :], 2)
                self.zos_mean_full = np.flip(raw_mean_data[:, 92, 1:, :], 1)
                logger.info(f"Loaded Climate Mean Fields successfully.")
            except Exception as e:
                logger.error(f"Failed to load climate mean fields: {e}")
                raise e
        else:
            logger.info("Skipping Climate Mean Fields loading (mean_field=False).")

        logger.info(f"Dataset Initialized | Years: {years} | Interval: {day_interval} | Time Steps: {time_range} | Samples: {len(self.day_indices)}")
    
    def _generate_day_indices(self):
        """
        Generate indices, limiting each year to the first 360 days.
        """
        day_indices = []
        for year in self.years:
            # Hard cap for consistency with preprocessed files.
            days_in_year = 360 
            
            # Ensure (day_of_year + time_range) stays within the capped window.
            last_valid_day = days_in_year - self.time_range 
            
            # Guard against invalid configuration.
            if last_valid_day <= 0:
                logger.warning(f"Time range {self.time_range} is too large for 360-day limit. Skipping year {year}.")
                continue

            for day_of_year in range(0, last_valid_day, self.day_interval):
                day_indices.append((year, day_of_year))
        return day_indices
    
    def _read_from_handle(self, handle, year, t_start, t_end, variable, lat_slice, lon_slice, is_regional=False):
        """
        Modified to handle different channel mappings for Global vs Regional data.
        :param is_regional: If True, always use the 100-channel map regardless of year.
        """
        
        var_map = {
                'uo':     (slice(0, 23, 2),   True),
                'vo':     (slice(23, 46, 2),  True),
                'zos':    (46,                True),
                'thetao': (slice(47, 70, 2),  True),
                'so':     (slice(70, 93, 2),  True),
                'A':      (slice(97, 100),    False)
            }
        
        if variable not in var_map:
            raise ValueError(f"Unknown variable: {variable}")
            
        channel_idx, apply_zero_mask = var_map[variable]
        data = handle[t_start:t_end, channel_idx, lat_slice, lon_slice]

        return np.nan_to_num(data, nan=0.0)
    
    def __getitem__(self, index):
        year, day_of_year = self.day_indices[index]
        t_start = day_of_year
        t_end = day_of_year + self.time_range + 1

        vars_to_load = ['thetao', 'so', 'uo', 'vo', 'zos', 'A']
        tensors = []
        
        # 1. Load Global Data
        global_handle = self.global_handles[year]
        global_lat_slice = slice(self.lat_start, self.lat_end, self.ds_factor)
        global_lon_slice = slice(self.lon_start, self.lon_end, self.ds_factor)

        for var in vars_to_load:
            data_np = self._read_from_handle(
                global_handle, year, t_start, t_end, var, 
                global_lat_slice, global_lon_slice,
                is_regional=False
            )
            tensors.append(torch.tensor(data_np, dtype=torch.float32))

        # 2. Overwrite with Merged Forecast
        if year in self.merged_handles:
            merged_handle = self.merged_handles[year]
            
            # With the 360-day cap above, day_of_year < 360 is guaranteed.
            forecast_data = merged_handle[day_of_year] 
            forecast_tensor = torch.tensor(forecast_data, dtype=torch.float32)
            
            # The merged forecast contains only a limited set of depth levels; overwrite those only.
            tensors[0][1:, :] = forecast_tensor[0:self.time_range, 0:12]   # thetao
            tensors[1][1:, :] = forecast_tensor[0:self.time_range, 12:24]  # so
            tensors[2][1:, :] = forecast_tensor[0:self.time_range, 24:36] # uo
            tensors[3][1:, :] = forecast_tensor[0:self.time_range, 36:48] # vo
            tensors[4][1:]    = forecast_tensor[0:self.time_range, 48]    # zos

        # 3. Load Mean Fields
        if self.use_mean_field:
            m_idx = [(day_of_year + t) % 365 for t in range(self.time_range + 1)]
            tensors.append(torch.tensor(self.thetao_mean_full[m_idx, ..., global_lat_slice, global_lon_slice], dtype=torch.float32))
            tensors.append(torch.tensor(self.so_mean_full[m_idx, ..., global_lat_slice, global_lon_slice], dtype=torch.float32))
            tensors.append(torch.tensor(self.zos_mean_full[m_idx, global_lat_slice, global_lon_slice], dtype=torch.float32))
        else:
            tensors.extend([torch.zeros(1), torch.zeros(1), torch.zeros(1)])

        # 4. Load Regional Data
        regional_handle = self.regional_handles[year]
        for var in vars_to_load:
            data_np = self._read_from_handle(
                regional_handle, year, t_start, t_end, var,
                slice(None), slice(None),
                is_regional=True
            )
            tensors.append(torch.tensor(data_np, dtype=torch.float32))
        
        return tuple(tensors)
    
    def __len__(self):
        return len(self.day_indices)

# --- Datasets ---
class TrainDataset(GLORYSBaseDataset):
    def __init__(self, ds_factor=1, lat_range=(0, 360), lon_range=(0, 720)):
        config = get_training_config()
        super(TrainDataset, self).__init__(
            years=range(1993, 2019), day_interval=3, ds_factor=ds_factor, 
            lat_range=lat_range, lon_range=lon_range, 
            time_range=config.get('integral_interval', 1)
        )

class ValDataset(GLORYSBaseDataset):
    def __init__(self, ds_factor=1, lat_range=(0, 360), lon_range=(0, 720)):
        config = get_training_config()
        super(ValDataset, self).__init__(
            years=range(2019, 2020), day_interval=6, ds_factor=ds_factor, 
            lat_range=lat_range, lon_range=lon_range, 
            time_range=config.get('integral_interval', 1)
        )
        
class TestDataset(GLORYSBaseDataset):
    def __init__(self, ds_factor=1, lat_range=(0, 360), lon_range=(0, 720)):
        config = get_training_config()
        super(TestDataset, self).__init__(
            years=range(2020, 2021), day_interval=3, ds_factor=ds_factor, 
            lat_range=lat_range, lon_range=lon_range, 
            time_range=config.get('integral_interval_test', 60)
        )