import numpy as np
import pandas as pd
import os
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import random
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import warnings


def generate_weather_data(
    dates: pd.DatetimeIndex,
    num_stations: int = 10,
    base_temp: float = 15.0,
    temp_amplitude: float = 15.0
) -> Dict[str, pd.DataFrame]:
    """Generate realistic weather data across multiple stations with spatial correlation.
    
    Args:
        dates: DatetimeIndex covering the simulation period at hourly frequency
        num_stations: Number of weather stations to simulate
        base_temp: Annual average temperature in Celsius
        temp_amplitude: Annual temperature swing amplitude
        
    Returns:
        Dictionary of DataFrames containing weather variables for each station
    """
    # Define station locations (in abstract space for correlation)
    station_locations = []
    for i in range(num_stations):
        # Generate in unit square with some clustering
        if i < num_stations * 0.7:  # 70% of stations in central cluster
            x = 0.5 + 0.3 * (random.random() - 0.5)
            y = 0.5 + 0.3 * (random.random() - 0.5)
        else:  # 30% of stations spread out
            x = random.random()
            y = random.random()
        station_locations.append((x, y))
    
    # Generate correlation matrix based on distances
    distance_matrix = np.zeros((num_stations, num_stations))
    for i in range(num_stations):
        for j in range(num_stations):
            x1, y1 = station_locations[i]
            x2, y2 = station_locations[j]
            distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
            # Convert distance to correlation (closer = higher correlation)
            if i == j:
                distance_matrix[i, j] = 1.0
            else:
                # Correlation decreases with distance (exponential decay)
                distance_matrix[i, j] = np.exp(-3 * distance)
    
    # Generate temperature data for each station
    weather_data = {}
    
    for station_idx in range(num_stations):
        station_id = f'Station_{station_idx+1}'
        
        # For this station
        station_data = pd.DataFrame(index=dates)
        
        # Station-specific parameters (add variation)
        station_base_temp = base_temp + 5 * (random.random() - 0.5)  # Vary by +/- 2.5C
        station_temp_amplitude = temp_amplitude * (0.8 + 0.4 * random.random())  # Vary by 80-120%
        
        # Generate hourly temperatures with yearly and daily cycles
        temps = []
        for timestamp in dates:
            # Yearly component (seasonal cycle)
            day_of_year = timestamp.day_of_year
            yearly_component = station_temp_amplitude * np.sin(2 * np.pi * (day_of_year - 15) / 365)
            
            # Daily component (warmer afternoon, cooler night)
            hour = timestamp.hour
            daily_amplitude = 4 + 2 * random.random()  # 4-6C daily swing
            daily_component = daily_amplitude * np.sin(2 * np.pi * (hour - 3) / 24)
            
            # Combined temperature
            temp = station_base_temp + yearly_component + daily_component
            temps.append(temp)
        
        # Convert to numpy array for vector operations
        temps = np.array(temps)
        
        # Add station data to the collection
        station_data['temperature'] = temps
        
        # Generate correlated humidity based on temperature
        humidity_base = 60 + 20 * (random.random() - 0.5)  # Base around 60%
        
        # Higher humidity when temperature is moderate, lower at extremes
        humidity = humidity_base - 0.5 * (temps - station_base_temp)**2
        
        # Add noise and clip to valid range
        humidity = humidity + 10 * np.random.randn(len(dates))
        humidity = np.clip(humidity, 20, 100)
        station_data['humidity'] = humidity
        
        # Generate wind speed data (more complex with episodes of high/low wind)
        wind_speeds = []
        wind_state = random.random()  # Initial state
        for i in range(len(dates)):
            # State transitions (persistence in wind patterns)
            if random.random() < 0.05:  # 5% chance of state change each hour
                wind_state = random.random()
            
            # Base wind speed depends on state (0 to 15 m/s)
            base_wind = 15 * wind_state
            
            # Add daily pattern (typically windier in afternoon)
            hour = dates[i].hour
            hour_effect = 2 * np.sin(2 * np.pi * (hour - 1) / 24)
            
            # Add noise
            wind = base_wind + hour_effect + 1.5 * np.random.randn()
            wind = max(0, wind)  # No negative wind speeds
            
            wind_speeds.append(wind)
        
        station_data['wind_speed'] = wind_speeds
        
        # Generate solar irradiance (0 at night, bell curve during day)
        solar = []
        for timestamp in dates:
            hour = timestamp.hour
            day_of_year = timestamp.day_of_year
            
            # Seasonal solar strength
            seasonal_factor = 0.7 + 0.6 * np.sin(2 * np.pi * (day_of_year - 15) / 365)
            
            # No sun at night (approximate sunrise/sunset)
            if 6 <= hour <= 18:  # Daylight hours
                # Bell curve pattern during the day
                daytime_hour = hour - 6  # 0 to 12 scale
                hour_factor = np.sin(np.pi * daytime_hour / 12)
                
                # Cloud effect (derived from humidity as proxy)
                cloud_factor = max(0.1, 1 - (humidity[i] - 40) / 100)
                
                # Combined solar irradiance
                irradiance = 1000 * seasonal_factor * hour_factor * cloud_factor
            else:
                irradiance = 0
                
            solar.append(max(0, irradiance))
        
        station_data['solar_irradiance'] = solar
        
        # Store the complete weather dataset for this station
        weather_data[station_id] = station_data
    
    return weather_data


def generate_pca_weather_factors(weather_data: Dict[str, pd.DataFrame], 
                               n_components: int = 5) -> pd.DataFrame:
    """Generate PCA factors from multiple weather station data.
    
    Args:
        weather_data: Dictionary of weather dataframes for each station
        n_components: Number of PCA components to extract
        
    Returns:
        DataFrame containing PCA weather factors
    """
    # Consolidate all weather variables across all stations
    all_features = []
    feature_names = []
    
    # Identify the first station's index to use for the output DataFrame
    first_station = list(weather_data.keys())[0]
    dates_index = weather_data[first_station].index
    
    # Stack all station data into a single matrix
    for station_id, station_df in weather_data.items():
        for column in ['temperature', 'humidity', 'wind_speed', 'solar_irradiance']:
            all_features.append(station_df[column].values)
            feature_names.append(f"{station_id}_{column}")
    
    # Convert to a numpy array, shape: (n_features, n_samples)
    X = np.array(all_features).T
    
    # Standardize the data
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Apply PCA
    n_components = min(n_components, X.shape[1])
    pca = PCA(n_components=n_components)
    pca_factors = pca.fit_transform(X_scaled)
    
    # Convert to DataFrame
    pca_df = pd.DataFrame(
        pca_factors, 
        index=dates_index,
        columns=[f'weather_factor_{i+1}' for i in range(n_components)]
    )
    
    # Print explained variance to understand information capture
    explained_variance = pca.explained_variance_ratio_
    cumulative_variance = np.cumsum(explained_variance)
    print(f"PCA Explained variance ratio: {explained_variance}")
    print(f"Cumulative explained variance: {cumulative_variance}")
    
    return pca_df


def generate_calendar_features(dates: pd.DatetimeIndex) -> pd.DataFrame:
    """Generate calendar-related features for the given dates.
    
    Args:
        dates: DatetimeIndex covering the simulation period
        
    Returns:
        DataFrame containing calendar features
    """
    # Initialize DataFrame
    calendar_df = pd.DataFrame(index=dates)
    
    # US holidays (simplified set for the example)
    us_holidays = {
        # 2020 holidays
        "2020-01-01": "New Year's Day",
        "2020-01-20": "Martin Luther King Jr. Day",
        "2020-05-25": "Memorial Day",
        "2020-07-04": "Independence Day",
        "2020-09-07": "Labor Day",
        "2020-11-11": "Veterans Day",
        "2020-11-26": "Thanksgiving",
        "2020-12-25": "Christmas",
        
        # 2021 holidays
        "2021-01-01": "New Year's Day",
        "2021-01-18": "Martin Luther King Jr. Day",
        "2021-05-31": "Memorial Day",
        "2021-07-04": "Independence Day",
        "2021-09-06": "Labor Day",
        "2021-11-11": "Veterans Day",
        "2021-11-25": "Thanksgiving",
        "2021-12-25": "Christmas",
        
        # 2022 holidays
        "2022-01-01": "New Year's Day",
        "2022-01-17": "Martin Luther King Jr. Day",
        "2022-05-30": "Memorial Day",
        "2022-07-04": "Independence Day",
        "2022-09-05": "Labor Day",
        "2022-11-11": "Veterans Day",
        "2022-11-24": "Thanksgiving",
        "2022-12-25": "Christmas"
    }
    
    # Convert dictionary keys to datetime
    holidays_datetime = {pd.Timestamp(date): name for date, name in us_holidays.items()}
    
    # Basic time features
    calendar_df['hour'] = dates.hour
    calendar_df['day_of_week'] = dates.dayofweek
    calendar_df['month'] = dates.month
    calendar_df['day_of_year'] = dates.dayofyear
    calendar_df['year'] = dates.year
    
    # Is weekend
    calendar_df['is_weekend'] = (calendar_df['day_of_week'] >= 5).astype(int)
    
    # Is holiday
    calendar_df['is_holiday'] = 0
    for timestamp in dates:
        check_date = timestamp.floor('D')  # Get just the date part
        if check_date in holidays_datetime:
            # Mark the entire day as a holiday
            calendar_df.loc[timestamp, 'is_holiday'] = 1
    
    # DST transitions
    calendar_df['dst_transition'] = 0
    
    # Function to identify DST transitions
    def is_dst_transition(date):
        # Check if this date's hour does not match the same hour one day ago
        # This works based on the fact that local UTC offset changes during DST transitions
        if date.hour == 0:  # Skip first hour of day to avoid issues
            return 0
        try:
            yesterday = date - pd.Timedelta(days=1)
            time_diff = (date - yesterday).total_seconds() / 3600
            # If not close to 24 hours, it's a DST transition
            if abs(time_diff - 24) > 0.5:
                return 1 if time_diff > 24 else -1  # 1 for spring forward, -1 for fall back
        except:
            pass
        return 0
    
    # Apply DST detection to each date (ignore errors which may occur at beginning of series)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for i, date in enumerate(dates):
            if i > 24:  # Skip first day to have comparison data
                calendar_df.at[date, 'dst_transition'] = is_dst_transition(date)
    
    # Special periods (e.g., school year)
    calendar_df['is_school_period'] = 0
    
    for date_idx, timestamp in enumerate(dates):
        month = timestamp.month
        # Summer break (approximately)
        if 6 <= month <= 8:
            calendar_df.loc[timestamp, 'is_school_period'] = 0
        else:
            calendar_df.loc[timestamp, 'is_school_period'] = 1
    
    # Hour of day - encoded as sin/cos for cyclical nature
    hours_in_day = 24
    calendar_df['hour_sin'] = np.sin(2 * np.pi * calendar_df['hour'] / hours_in_day)
    calendar_df['hour_cos'] = np.cos(2 * np.pi * calendar_df['hour'] / hours_in_day)
    
    # Day of week - encoded as sin/cos
    days_in_week = 7
    calendar_df['day_of_week_sin'] = np.sin(2 * np.pi * calendar_df['day_of_week'] / days_in_week)
    calendar_df['day_of_week_cos'] = np.cos(2 * np.pi * calendar_df['day_of_week'] / days_in_week)
    
    # Day of year - encoded as sin/cos
    days_in_year = 365.25
    calendar_df['day_of_year_sin'] = np.sin(2 * np.pi * calendar_df['day_of_year'] / days_in_year)
    calendar_df['day_of_year_cos'] = np.cos(2 * np.pi * calendar_df['day_of_year'] / days_in_year)
    
    return calendar_df