import numpy as np
import pandas as pd
import os
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import random
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import warnings

from data.utis import generate_weather_data, generate_pca_weather_factors, generate_calendar_features

# For reproducibility
np.random.seed(42)
random.seed(42)

@dataclass(frozen=True)
class ElectricityLoadParams:
    """Parameters defining electricity load patterns for a grid region.
    
    Args:
        base_load: Base electricity consumption (MW)
        daily_pattern: Daily load pattern coefficients
        weekly_pattern: Weekly seasonality coefficients
        yearly_pattern: Yearly seasonality amplitude
        
        # Weather sensitivity
        temp_sensitivity: Non-linear temperature response parameters
        humidity_sensitivity: Humidity impact coefficient
        wind_sensitivity: Wind speed impact coefficient
        solar_sensitivity: Solar irradiance impact coefficient
        
        # Calendar effects
        weekend_effect: Weekend load adjustment (%)
        holiday_effect: Holiday load adjustment (%)
        dst_transition_effect: DST transition adjustment (%)
        
        # Infrastructure effects
        planned_outage_effect: Impact of planned outages (%)
        
        # Renewable integration
        renewable_substitution: Renewable generation substitution effect
        renewable_volatility: Impact of renewable volatility on price
        
        # Price sensitivity (for price modeling)
        price_base: Base price level
        price_peak_multiplier: Peak demand price multiplier
        price_volatility: Base price volatility
        
        # Non-linear effects
        capacity_constraint: System capacity constraint threshold
        capacity_price_exponent: Non-linear price response to capacity constraints
        
        # Random components
        load_noise: Standard deviation of load noise
        price_noise: Standard deviation of price noise
        
        # Trend components
        load_trend: Annual load growth rate
    """
    # Base load parameters
    base_load: float
    daily_pattern: List[float]
    weekly_pattern: List[float]
    yearly_pattern: float
    
    # Weather sensitivity
    temp_sensitivity: Dict[str, float]  # Contains non-linear parameters
    humidity_sensitivity: float
    wind_sensitivity: float
    solar_sensitivity: float
    
    # Calendar effects
    weekend_effect: float
    holiday_effect: float
    dst_transition_effect: float
    
    # Infrastructure effects
    planned_outage_effect: float
    
    # Renewable integration
    renewable_substitution: float
    renewable_volatility: float
    
    # Price sensitivity
    price_base: float
    price_peak_multiplier: float
    price_volatility: float
    
    # Non-linear effects
    capacity_constraint: float
    capacity_price_exponent: float
    
    # Random components
    load_noise: float
    price_noise: float
    
    # Trend components
    load_trend: float


def generate_region_params(num_regions: int = 5) -> Dict[str, ElectricityLoadParams]:
    """Generate parameters for different grid regions with realistic variations.
    
    Args:
        num_regions: Number of regions to generate parameters for
        
    Returns:
        Dictionary mapping region IDs to their load/price pattern parameters
    """
    region_params = {}
    
    # Define typical daily load patterns (24 hourly coefficients)
    # Different regions have different daily load profiles
    daily_patterns = [
        # Residential-heavy region (evening peak)
        [0.65, 0.60, 0.58, 0.56, 0.57, 0.62, 0.75, 0.85, 0.95, 0.93, 0.94, 0.95, 
         0.97, 0.98, 0.97, 0.98, 1.05, 1.15, 1.25, 1.20, 1.10, 0.95, 0.80, 0.70],
        
        # Commercial-heavy region (midday peak)
        [0.60, 0.55, 0.53, 0.52, 0.54, 0.65, 0.80, 0.95, 1.10, 1.20, 1.25, 1.26, 
         1.25, 1.24, 1.25, 1.20, 1.15, 1.05, 0.95, 0.85, 0.80, 0.75, 0.68, 0.63],
        
        # Industrial-heavy region (flat with slight daytime increase)
        [0.85, 0.83, 0.82, 0.82, 0.83, 0.85, 0.90, 0.95, 1.00, 1.05, 1.08, 1.10, 
         1.10, 1.10, 1.08, 1.05, 1.00, 0.98, 0.95, 0.92, 0.90, 0.88, 0.87, 0.86],
        
        # Mixed-use region with dual peaks
        [0.70, 0.65, 0.60, 0.58, 0.60, 0.70, 0.85, 1.00, 1.10, 1.15, 1.10, 1.05, 
         1.05, 1.10, 1.05, 1.00, 1.10, 1.20, 1.15, 1.05, 0.95, 0.85, 0.80, 0.75],
        
        # Tourism/entertainment region (evening-focused)
        [0.75, 0.65, 0.55, 0.50, 0.48, 0.50, 0.60, 0.70, 0.80, 0.85, 0.90, 0.95, 
         1.00, 0.95, 0.90, 0.95, 1.05, 1.15, 1.25, 1.30, 1.35, 1.25, 1.10, 0.90]
    ]
    
    # Weekly patterns (7 daily multipliers, starting with Monday)
    weekly_patterns = [
        # Standard business week pattern
        [1.0, 1.0, 1.0, 1.0, 0.95, 0.85, 0.80],
        
        # Tourist area (weekends busier)
        [0.85, 0.85, 0.90, 0.95, 1.05, 1.15, 1.10],
        
        # Industrial (reduced weekend operations)
        [1.05, 1.05, 1.05, 1.05, 1.05, 0.70, 0.65],
        
        # Balanced region
        [1.0, 1.0, 1.0, 1.0, 1.0, 0.90, 0.85],
        
        # University/education area
        [1.05, 1.10, 1.05, 1.05, 0.95, 0.75, 0.70]
    ]
    
    # Parameter ranges for different types of regions
    base_load_range = (500, 5000)  # MW
    yearly_pattern_range = (0.15, 0.35)  # Amplitude of yearly seasonality
    
    # Temperature sensitivity parameters (non-linear response)
    temp_sensitivity_ranges = {
        'optimal_temp': (18.0, 23.0),      # Temperature with minimal HVAC usage
        'cooling_threshold': (22.0, 27.0),  # Temperature where cooling starts
        'heating_threshold': (12.0, 18.0),  # Temperature where heating starts
        'cooling_slope': (0.02, 0.06),      # Load increase per degree above cooling threshold
        'heating_slope': (0.01, 0.04)       # Load increase per degree below heating threshold
    }
    
    # Other parameter ranges
    parameter_ranges = {
        # Weather sensitivities
        'humidity_sensitivity': (0.001, 0.005),
        'wind_sensitivity': (-0.005, 0.002),
        'solar_sensitivity': (-0.02, -0.005),
        
        # Calendar effects
        'weekend_effect': (-0.25, -0.05),
        'holiday_effect': (-0.3, -0.1),
        'dst_transition_effect': (-0.05, 0.05),
        
        # Infrastructure effects
        'planned_outage_effect': (-0.2, -0.05),
        
        # Renewable integration
        'renewable_substitution': (0.2, 0.8), 
        'renewable_volatility': (0.01, 0.05),
        
        # Price parameters
        'price_base': (20, 60),  # Base price per MWh
        'price_peak_multiplier': (1.5, 4.0),
        'price_volatility': (0.05, 0.2),
        
        # Non-linear effects
        'capacity_constraint': (0.8, 0.95),  # 80-95% of max capacity
        'capacity_price_exponent': (1.5, 3.0),
        
        # Random components
        'load_noise': (0.01, 0.05),
        'price_noise': (0.05, 0.15),
        
        # Trend components
        'load_trend': (-0.02, 0.04)  # -2% to 4% annual growth
    }
    
    for i in range(1, num_regions + 1):
        region_id = f'Region_{i}'
        
        # Randomly select a daily and weekly pattern profile
        daily_pattern_idx = random.randint(0, len(daily_patterns) - 1)
        weekly_pattern_idx = random.randint(0, len(weekly_patterns) - 1)
        
        # Add some random variation to the patterns
        daily_pattern = [max(0.4, min(1.5, v * (0.9 + 0.2 * random.random()))) 
                         for v in daily_patterns[daily_pattern_idx]]
        weekly_pattern = [max(0.6, min(1.2, v * (0.95 + 0.1 * random.random()))) 
                          for v in weekly_patterns[weekly_pattern_idx]]
        
        # Generate the base load and yearly pattern amplitude
        base_load = random.uniform(base_load_range[0], base_load_range[1])
        yearly_pattern = random.uniform(yearly_pattern_range[0], yearly_pattern_range[1])
        
        # Generate temperature sensitivity parameters
        temp_sensitivity = {}
        for param, (min_val, max_val) in temp_sensitivity_ranges.items():
            temp_sensitivity[param] = round(random.uniform(min_val, max_val), 2)
        
        # Ensure logical ordering of temperature thresholds
        if temp_sensitivity['cooling_threshold'] <= temp_sensitivity['optimal_temp']:
            temp_sensitivity['cooling_threshold'] = temp_sensitivity['optimal_temp'] + 2.0
        if temp_sensitivity['heating_threshold'] >= temp_sensitivity['optimal_temp']:
            temp_sensitivity['heating_threshold'] = temp_sensitivity['optimal_temp'] - 2.0
        
        # Generate other parameters
        params = {}
        for param, (min_val, max_val) in parameter_ranges.items():
            params[param] = round(random.uniform(min_val, max_val), 4)
        
        region_params[region_id] = ElectricityLoadParams(
            # Base load parameters
            base_load=round(base_load, 1),
            daily_pattern=daily_pattern,
            weekly_pattern=weekly_pattern,
            yearly_pattern=yearly_pattern,
            
            # Weather sensitivity
            temp_sensitivity=temp_sensitivity,
            humidity_sensitivity=params['humidity_sensitivity'],
            wind_sensitivity=params['wind_sensitivity'],
            solar_sensitivity=params['solar_sensitivity'],
            
            # Calendar effects
            weekend_effect=params['weekend_effect'],
            holiday_effect=params['holiday_effect'],
            dst_transition_effect=params['dst_transition_effect'],
            
            # Infrastructure effects
            planned_outage_effect=params['planned_outage_effect'],
            
            # Renewable integration
            renewable_substitution=params['renewable_substitution'],
            renewable_volatility=params['renewable_volatility'],
            
            # Price sensitivity
            price_base=params['price_base'],
            price_peak_multiplier=params['price_peak_multiplier'],
            price_volatility=params['price_volatility'],
            
            # Non-linear effects
            capacity_constraint=params['capacity_constraint'],
            capacity_price_exponent=params['capacity_price_exponent'],
            
            # Random components
            load_noise=params['load_noise'],
            price_noise=params['price_noise'],
            
            # Trend components
            load_trend=params['load_trend']
        )
    
    return region_params


def generate_dates(start_date: str = '2020-01-01', 
                 end_date: str = '2021-12-31', 
                 freq: str = 'H') -> pd.DatetimeIndex:
    """Generate DatetimeIndex for the simulation period.
    
    Args:
        start_date: Start date string in 'YYYY-MM-DD' format
        end_date: End date string in 'YYYY-MM-DD' format
        freq: Frequency of the time series ('H' for hourly)
        
    Returns:
        DatetimeIndex covering the specified period
    """
    return pd.date_range(start=start_date, end=end_date, freq=freq)


def generate_planned_outages(dates: pd.DatetimeIndex, 
                           num_regions: int = 5,
                           annual_outage_prob: float = 0.04) -> Dict[str, pd.Series]:
    """Generate planned outage schedules for each region.
    
    Args:
        dates: DatetimeIndex for the simulation period
        num_regions: Number of regions
        annual_outage_prob: Probability of having a planned outage in a year
        
    Returns:
        Dictionary of Series with planned outage flags for each region
    """
    outages = {}
    
    # Convert annual probability to daily
    daily_prob = annual_outage_prob / 365
    
    # Group dates by day to ensure outages span full days
    days = pd.Series(dates.date).unique()
    
    for i in range(1, num_regions + 1):
        region_id = f'Region_{i}'
        
        # Initialize outage series (0 = no outage)
        outage_series = pd.Series(0, index=dates)
        
        # Generate outages
        for day in days:
            # Random chance of outage starting on this day
            if random.random() < daily_prob:
                # Outage duration (1-5 days)
                duration = random.randint(1, 5)
                
                # Find all timestamps for this day and subsequent days within duration
                start_date = pd.Timestamp(day)
                end_date = start_date + pd.Timedelta(days=duration)
                
                # Set outage flag for all hours in the outage period
                outage_mask = (dates >= start_date) & (dates < end_date)
                outage_series.loc[outage_mask] = 1
        
        outages[region_id] = outage_series
    
    return outages


def calculate_temperature_effect(temperature: float, params: Dict[str, float]) -> float:
    """Calculate non-linear temperature effect on load.
    
    Args:
        temperature: Temperature in Celsius
        params: Temperature sensitivity parameters
        
    Returns:
        Temperature effect multiplier
    """
    # Extract parameters
    optimal_temp = params['optimal_temp']
    cooling_threshold = params['cooling_threshold']
    heating_threshold = params['heating_threshold']
    cooling_slope = params['cooling_slope']
    heating_slope = params['heating_slope']
    
    # Initialize effect
    effect = 0.0
    
    # Cooling effect (higher temperatures increase load)
    if temperature > cooling_threshold:
        effect += cooling_slope * (temperature - cooling_threshold) ** 1.5
    
    # Heating effect (lower temperatures increase load)
    if temperature < heating_threshold:
        effect += heating_slope * (heating_threshold - temperature) ** 1.2
    
    return effect


def generate_electricity_load(
    region_params: Dict[str, ElectricityLoadParams],
    dates: pd.DatetimeIndex,
    weather_factors: pd.DataFrame,
    calendar_features: pd.DataFrame,
    planned_outages: Dict[str, pd.Series],
    renewable_forecasts: Dict[str, pd.DataFrame] = None
) -> Dict[str, pd.DataFrame]:
    """Generate synthetic electricity load and price data based on parameters and covariates.
    
    Args:
        region_params: Parameters for each region
        dates: DatetimeIndex for the time series
        weather_factors: PCA factors of weather data
        calendar_features: Calendar-related features
        planned_outages: Planned outage schedules
        renewable_forecasts: Renewable generation forecasts
        
    Returns:
        Dictionary of DataFrames with electricity load and price for each region
    """
    results = {}
    
    # Get the number of days since start for trend calculation
    start_date = dates[0]
    days_since_start = [(date - start_date).days / 365 for date in dates]  # In years
    
    for region_id, params in region_params.items():
        # Initialize DataFrame for this region
        region_df = pd.DataFrame(index=dates)
        
        # Get planned outages for this region
        outages = planned_outages[region_id]
        
        # Calculate load components
        base_loads = []
        calendar_effects = []
        weather_effects = []
        outage_effects = []
        price_values = []
        
        # Auxiliary variables for price dynamics
        price_momentum = 0.0  # For price autocorrelation
        
        # Process each timestamp
        for i, date in enumerate(dates):
            # Get timestamp features
            hour = date.hour
            day_of_week = date.dayofweek
            day_of_year = date.dayofyear
            is_weekend = calendar_features.loc[date, 'is_weekend']
            is_holiday = calendar_features.loc[date, 'is_holiday']
            dst_transition = calendar_features.loc[date, 'dst_transition']
            
            # 1. BASE LOAD CALCULATION
            # -----------------------
            
            # Base load with daily pattern
            daily_factor = params.daily_pattern[hour]
            
            # Apply weekly pattern
            weekly_factor = params.weekly_pattern[day_of_week]
            
            # Apply yearly seasonality
            yearly_factor = 1.0 + params.yearly_pattern * np.sin(2 * np.pi * (day_of_year - 15) / 365)
            
            # Apply long-term trend
            trend_factor = (1.0 + params.load_trend) ** days_since_start[i]
            
            # Combined base load
            base_load = params.base_load * daily_factor * weekly_factor * yearly_factor * trend_factor
            
            # 2. CALENDAR EFFECTS
            # ------------------
            
            # Weekend effect
            weekend_effect = base_load * params.weekend_effect * is_weekend
            
            # Holiday effect
            holiday_effect = base_load * params.holiday_effect * is_holiday
            
            # DST transition effect
            dst_effect = base_load * params.dst_transition_effect * dst_transition
            
            # Combined calendar effect
            calendar_effect = weekend_effect + holiday_effect + dst_effect
            
            # 3. WEATHER EFFECTS
            # ----------------
            
            # Use weather factors to approximate the true temperature/weather
            # For simplicity, we use factor_1 primarily for temperature
            approx_temperature = 20 + 10 * weather_factors.loc[date, 'weather_factor_1']
            approx_humidity = 60 + 15 * weather_factors.loc[date, 'weather_factor_2']
            approx_wind = 5 + 5 * weather_factors.loc[date, 'weather_factor_3']
            approx_solar = max(0, 500 + 300 * weather_factors.loc[date, 'weather_factor_4'])
            
            # Non-linear temperature effect
            temp_effect = calculate_temperature_effect(approx_temperature, params.temp_sensitivity)
            temp_effect *= base_load
            
            # Other weather effects
            humidity_effect = base_load * params.humidity_sensitivity * (approx_humidity - 60)
            wind_effect = base_load * params.wind_sensitivity * approx_wind
            solar_effect = base_load * params.solar_sensitivity * (approx_solar / 800)
            
            # Combined weather effect
            weather_effect = temp_effect + humidity_effect + wind_effect + solar_effect
            
            # 4. INFRASTRUCTURE EFFECTS
            # -----------------------
            
            # Planned outage effect
            outage_effect = base_load * params.planned_outage_effect * outages.loc[date]
            
            # 5. CALCULATE TOTAL LOAD
            # ---------------------
            
            # Combine all effects
            deterministic_load = base_load + calendar_effect + weather_effect + outage_effect
            
            # Add random noise
            random_noise = np.random.normal(0, params.load_noise * base_load)
            total_load = max(0, deterministic_load + random_noise)
            
            # 6. CALCULATE ELECTRICITY PRICE
            # ----------------------------
            
            # Calculate load factor (percentage of theoretical maximum)
            max_possible_load = params.base_load * 1.5  # Assume system sized at 150% of base load
            load_factor = total_load / (params.capacity_constraint * max_possible_load)
            
            # Non-linear price response when approaching capacity
            if load_factor >= params.capacity_constraint:
                capacity_scarcity = ((load_factor - params.capacity_constraint) / 
                                   (1 - params.capacity_constraint)) ** params.capacity_price_exponent
            else:
                capacity_scarcity = 0
            
            # Base price component
            base_price = params.price_base * (1 + (load_factor - 0.6) * 0.5)
            
            # Add capacity constraint effect (non-linear)
            scarcity_price = params.price_base * params.price_peak_multiplier * capacity_scarcity
            
            # Add price momentum (autocorrelation)
            price_momentum = 0.7 * price_momentum + 0.3 * np.random.normal(0, params.price_volatility * params.price_base)
            
            # Calculate total price
            price = base_price + scarcity_price + price_momentum
            price = max(1, price)  # Ensure price is positive
            
            # Store all components
            base_loads.append(base_load)
            calendar_effects.append(calendar_effect)
            weather_effects.append(weather_effect)
            outage_effects.append(outage_effect)
            price_values.append(price)
        
        # Populate region DataFrame
        region_df['base_load'] = base_loads
        region_df['calendar_effect'] = calendar_effects
        region_df['weather_effect'] = weather_effects
        region_df['outage_effect'] = outage_effects
        region_df['total_load'] = region_df['base_load'] + region_df['calendar_effect'] + region_df['weather_effect'] + region_df['outage_effect']
        region_df['price'] = price_values
        
        # Add weather approximations as features
        region_df['approx_temperature'] = 20 + 10 * weather_factors['weather_factor_1']
        region_df['approx_humidity'] = 60 + 15 * weather_factors['weather_factor_2']
        region_df['approx_wind'] = 5 + 5 * weather_factors['weather_factor_3']
        region_df['approx_solar'] = np.maximum(0, 500 + 300 * weather_factors['weather_factor_4'])
        
        # Add calendar features
        for col in ['hour', 'day_of_week', 'month', 'is_weekend', 'is_holiday', 'dst_transition']:
            region_df[col] = calendar_features[col]
        
        # Add planned outage indicator
        region_df['planned_outage'] = outages
        
        # Store the complete dataset for this region
        results[region_id] = region_df
    
    return results


def save_data(
    electricity_data: Dict[str, pd.DataFrame],
    weather_factors: pd.DataFrame,
    calendar_features: pd.DataFrame,
    data_dir: str = './data/electricity'
) -> None:
    """Save the generated data to files.
    
    Args:
        electricity_data: Dictionary of region DataFrames
        weather_factors: PCA weather factors
        calendar_features: Calendar features
        data_dir: Directory to save files
    """
    # Create directory if it doesn't exist
    os.makedirs(data_dir, exist_ok=True)
    
    # Save each region's data
    for region_id, region_df in electricity_data.items():
        file_path = os.path.join(data_dir, f"{region_id}_data.csv")
        region_df.to_csv(file_path)
    
    # Save weather factors
    weather_path = os.path.join(data_dir, "weather_factors.csv")
    weather_factors.to_csv(weather_path)
    
    # Save calendar features
    calendar_path = os.path.join(data_dir, "calendar_features.csv")
    calendar_features.to_csv(calendar_path)
    
    print(f"Data saved to directory: {data_dir}")


def plot_electricity_data(
    electricity_data: Dict[str, pd.DataFrame],
    weather_factors: pd.DataFrame,
    plot_dir: str = './plots/electricity',
    num_regions_to_plot: int = 2,
    sample_weeks: List[str] = None
) -> None:
    """Plot electricity load data and its components.
    
    Args:
        electricity_data: Dictionary of region DataFrames
        weather_factors: PCA weather factors
        plot_dir: Directory to save plots
        num_regions_to_plot: Number of regions to include in plots
        sample_weeks: List of start dates for sample weeks to plot
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib as mpl
        from matplotlib.dates import DateFormatter
        
        # Create directory if it doesn't exist
        os.makedirs(plot_dir, exist_ok=True)
        
        # Set up plotting style
        plt.style.use('seaborn-v0_8-whitegrid')
        plt.rcParams.update({
            'font.family': 'serif',
            'font.size': 10,
            'axes.labelsize': 12,
            'axes.titlesize': 14,
            'xtick.labelsize': 10,
            'ytick.labelsize': 10,
            'axes.grid': True,
            'grid.alpha': 0.3,
            'figure.figsize': (10, 6),
            'axes.spines.top': False,
            'axes.spines.right': False
        })
        
        # Define a colorblind-friendly palette
        colors = ['#0173B2', '#DE8F05', '#029E73', '#D55E00', '#CC78BC', 
                  '#CA9161', '#FBAFE4', '#949494', '#ECE133', '#56B4E9']
        
        # Sample weeks to plot (if not provided)
        if sample_weeks is None:
            sample_weeks = ['2020-01-13', '2020-07-13']  # Winter and summer week
        
        # Select regions to plot
        regions_to_plot = list(electricity_data.keys())[:num_regions_to_plot]
        
        # 1. Plot annual load profiles
        # ---------------------------
        plt.figure(figsize=(10, 6))
        
        for i, region_id in enumerate(regions_to_plot):
            region_df = electricity_data[region_id]
            
            # Resample to daily for better visualization
            daily_load = region_df['total_load'].resample('D').mean()
            
            plt.plot(daily_load.index, daily_load, 
                     label=region_id, color=colors[i % len(colors)], linewidth=1.5)
        
        plt.title('Annual Electricity Load Profile by Region')
        plt.xlabel('Date')
        plt.ylabel('Load (MW)')
        plt.legend()
        plt.tight_layout()
        
        annual_path = os.path.join(plot_dir, 'annual_load_profile.png')
        plt.savefig(annual_path, dpi=300)
        
        # 2. Plot load components for each region
        # -------------------------------------
        for region_id in regions_to_plot:
            region_df = electricity_data[region_id]
            
            for week_start in sample_weeks:
                # Extract one week of data
                start_date = pd.Timestamp(week_start)
                end_date = start_date + pd.Timedelta(days=7)
                mask = (region_df.index >= start_date) & (region_df.index < end_date)
                week_data = region_df[mask]
                
                if len(week_data) == 0:
                    continue
                
                # Determine season for the title
                month = start_date.month
                if 3 <= month <= 5:
                    season = "Spring"
                elif 6 <= month <= 8:
                    season = "Summer"
                elif 9 <= month <= 11:
                    season = "Fall"
                else:
                    season = "Winter"
                
                fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True, 
                                             gridspec_kw={'height_ratios': [3, 1]})
                
                # Plot load components on first axis
                ax1.plot(week_data.index, week_data['base_load'], 
                       label='Base Load', color=colors[0], linewidth=2)
                
                ax1.plot(week_data.index, week_data['total_load'], 
                       label='Total Load', color=colors[1], linewidth=2)
                
                # Add shaded areas for effects
                ax1.fill_between(
                    week_data.index, 
                    week_data['base_load'], 
                    week_data['base_load'] + week_data['weather_effect'],
                    color=colors[2], alpha=0.4, label='Weather Effect'
                )
                
                ax1.fill_between(
                    week_data.index, 
                    week_data['base_load'] + week_data['weather_effect'], 
                    week_data['total_load'],
                    color=colors[3], alpha=0.4, label='Other Effects'
                )
                
                # Add outage indicators if any
                if week_data['planned_outage'].sum() > 0:
                    outage_idx = week_data.index[week_data['planned_outage'] > 0]
                    ax1.scatter(outage_idx, week_data.loc[outage_idx, 'total_load'], 
                              color='red', marker='x', s=80, label='Planned Outage')
                
                # Format the plot
                ax1.set_title(f'{region_id} - {season} Week Load Decomposition ({week_start})')
                ax1.set_ylabel('Load (MW)')
                ax1.legend(loc='upper left')
                ax1.xaxis.set_major_formatter(DateFormatter('%a %H:%M'))
                
                # Plot price on second axis
                ax2.plot(week_data.index, week_data['price'], 
                       color=colors[4], linewidth=2)
                ax2.set_ylabel('Price ($/MWh)')
                ax2.set_xlabel('Date')
                
                plt.tight_layout()
                component_path = os.path.join(plot_dir, f'{region_id}_{season.lower()}_week_components.png')
                plt.savefig(component_path, dpi=300)
                plt.close()
        
        # 3. Plot relationship between load and price
        # -----------------------------------------
        plt.figure(figsize=(8, 6))
        
        for i, region_id in enumerate(regions_to_plot):
            region_df = electricity_data[region_id]
            
            # Plot load vs price
            plt.scatter(region_df['total_load'], region_df['price'], 
                      alpha=0.3, color=colors[i % len(colors)], 
                      label=region_id, s=10)
            
            # Add trend line
            from scipy.stats import linregress
            slope, intercept, r_value, _, _ = linregress(region_df['total_load'], region_df['price'])
            x_range = np.linspace(region_df['total_load'].min(), region_df['total_load'].max(), 100)
            plt.plot(x_range, intercept + slope * x_range, 
                   '--', color=colors[i % len(colors)], 
                   label=f'{region_id} trend (r²={r_value**2:.2f})')
        
        plt.title('Electricity Price vs. Load')
        plt.xlabel('Load (MW)')
        plt.ylabel('Price ($/MWh)')
        plt.legend()
        plt.tight_layout()
        
        price_load_path = os.path.join(plot_dir, 'price_vs_load.png')
        plt.savefig(price_load_path, dpi=300)
        
        # 4. Plot weather factors influence
        # -------------------------------
        fig, axes = plt.subplots(2, 2, figsize=(12, 10), sharex=True)
        axes = axes.flatten()
        
        region_id = regions_to_plot[0]  # Use first region for this plot
        region_df = electricity_data[region_id]
        
        # Extract a month of data for clearer visualization
        month_start = pd.Timestamp('2020-07-01')
        month_end = pd.Timestamp('2020-07-31')
        mask = (region_df.index >= month_start) & (region_df.index <= month_end)
        month_data = region_df[mask]
        
        # Resample to 6-hour intervals for clearer plots
        month_data_6h = month_data.resample('6H').mean()
        
        # Plot load vs temperature
        axes[0].scatter(month_data['approx_temperature'], month_data['total_load'], 
                      alpha=0.5, color=colors[0], s=30)
        axes[0].set_title('Load vs. Temperature')
        axes[0].set_xlabel('Temperature (°C)')
        axes[0].set_ylabel('Load (MW)')
        
        # Plot price vs load
        axes[1].scatter(month_data['total_load'], month_data['price'], 
                      alpha=0.5, color=colors[1], s=30)
        axes[1].set_title('Price vs. Load')
        axes[1].set_xlabel('Load (MW)')
        axes[1].set_ylabel('Price ($/MWh)')
        
        # Plot temperature and load time series
        ax3 = axes[2]
        ax3.plot(month_data_6h.index, month_data_6h['total_load'], 
               color=colors[0], label='Load (MW)')
        
        ax3_twin = ax3.twinx()
        ax3_twin.plot(month_data_6h.index, month_data_6h['approx_temperature'], 
                    color=colors[2], label='Temperature (°C)', linestyle='--')
        
        ax3.set_title('Load and Temperature Over Time')
        ax3.set_xlabel('Date')
        ax3.set_ylabel('Load (MW)', color=colors[0])
        ax3_twin.set_ylabel('Temperature (°C)', color=colors[2])
        
        # Add both legends
        lines1, labels1 = ax3.get_legend_handles_labels()
        lines2, labels2 = ax3_twin.get_legend_handles_labels()
        ax3.legend(lines1 + lines2, labels1 + labels2, loc='upper left')
        
        # Plot price time series
        ax4 = axes[3]
        ax4.plot(month_data_6h.index, month_data_6h['price'], 
               color=colors[1], label='Price ($/MWh)')
        
        ax4.set_title('Price Over Time')
        ax4.set_xlabel('Date')
        ax4.set_ylabel('Price ($/MWh)')
        
        plt.tight_layout()
        weather_influence_path = os.path.join(plot_dir, 'weather_influence.png')
        plt.savefig(weather_influence_path, dpi=300)
        
        print(f"Plots saved to directory: {plot_dir}")
        
    except ImportError:
        print("Matplotlib not available for visualization")
    except Exception as e:
        print(f"Error during plotting: {str(e)}")


def convert_to_timeseries_df(electricity_data: Dict[str, pd.DataFrame]) -> pd.DataFrame:
    """Convert the dictionary of region DataFrames to a single TimeSeriesDataFrame format.
    
    Args:
        electricity_data: Dictionary of region DataFrames with electricity data
        
    Returns:
        A single DataFrame in long format suitable for time series analysis,
        with region as the series_id
    """
    # Initialize list to hold all data
    all_data = []
    
    # Process each region
    for region_id, region_df in electricity_data.items():
        # Create a copy of the region dataframe to avoid modifying the original
        df = region_df.copy()
        
        # Add region_id as a column
        df['series_id'] = region_id
        
        # Reset index to make timestamp a column
        df = df.reset_index()
        df.rename(columns={'index': 'timestamp'}, inplace=True)
        
        # Append to our collection
        all_data.append(df)
    
    # Concatenate all regions into a single DataFrame
    combined_df = pd.concat(all_data, ignore_index=True)
    
    # Ensure the columns are in a sensible order
    cols = ['timestamp', 'series_id']
    target_cols = ['total_load', 'price']
    feature_cols = [col for col in combined_df.columns if col not in cols + target_cols]
    
    # Reorder columns: timestamp, series_id, targets, features
    combined_df = combined_df[cols + target_cols + feature_cols]
    
    # Convert to TimeSeriesDataFrame if autogluon.timeseries is available
    try:
        from autogluon.timeseries import TimeSeriesDataFrame
        ts_df = TimeSeriesDataFrame.from_data_frame(
            df=combined_df,
            id_column='series_id',
            timestamp_column='timestamp'
        )
        print(f"Converted to TimeSeriesDataFrame with {len(ts_df.item_ids)} series")
        return ts_df
    except (ImportError, ModuleNotFoundError):
        # Fall back to regular pandas DataFrame in the right format
        print("AutoGluon not available, returning standard pandas DataFrame in time series format")
        combined_df = combined_df.sort_values(['series_id', 'timestamp'])
        return combined_df


def summarize_covariates(timeseries_df: pd.DataFrame) -> None:
    """Print a summary of the covariates in the timeseries dataframe.
    
    Args:
        timeseries_df: The time series dataframe containing all data
    """
    # Identify column types
    if isinstance(timeseries_df, pd.DataFrame):
        # Regular pandas DataFrame
        cols = timeseries_df.columns
    else:
        # Likely AutoGluon TimeSeriesDataFrame
        cols = timeseries_df.columns
    
    # Get target columns (known targets)
    target_cols = ['total_load', 'price']
    
    # Get metadata columns
    metadata_cols = ['timestamp', 'series_id']
    if 'item_id' in cols:
        metadata_cols.append('item_id')
    
    # Real-valued (continuous) covariates
    real_covariates = [
        'approx_temperature', 'approx_humidity', 'approx_wind', 'approx_solar',
        'weather_factor_1', 'weather_factor_2', 'weather_factor_3', 'weather_factor_4',
        'weather_factor_5', 'base_load', 'calendar_effect', 'weather_effect',
        'outage_effect', 'hour_sin', 'hour_cos', 'day_of_week_sin', 'day_of_week_cos',
        'day_of_year_sin', 'day_of_year_cos'
    ]
    real_covariates = [col for col in real_covariates if col in cols]
    
    # Categorical covariates
    categorical_covariates = [
        'hour', 'day_of_week', 'month', 'day_of_year', 'year',
        'is_weekend', 'is_holiday', 'dst_transition', 'is_school_period',
        'planned_outage'
    ]
    categorical_covariates = [col for col in categorical_covariates if col in cols]
    
    # Component covariates (may overlap with real-valued)
    component_covariates = [
        'base_load', 'calendar_effect', 'weather_effect', 'outage_effect'
    ]
    component_covariates = [col for col in component_covariates if col in cols]
    
    # Other covariates (anything left)
    other_covariates = [col for col in cols 
                        if col not in target_cols + metadata_cols + 
                        real_covariates + categorical_covariates + component_covariates]
    
    # Print summary
    print("\n" + "="*80)
    print("COVARIATES SUMMARY")
    print("="*80)
    
    print("\nTarget Variables:")
    for col in target_cols:
        if col in cols:
            print(f"  - {col}")
    
    print("\nReal-Valued (Continuous) Covariates:")
    for col in real_covariates:
        if col in cols and col not in component_covariates:
            print(f"  - {col}")
    
    print("\nCategorical Covariates:")
    for col in categorical_covariates:
        if col in cols:
            print(f"  - {col}")
    
    print("\nComponent Covariates:")
    for col in component_covariates:
        if col in cols:
            print(f"  - {col}")
    
    if other_covariates:
        print("\nOther Covariates:")
        for col in other_covariates:
            print(f"  - {col}")
    
    # Print dimensionality summary
    print("\nDimensionality Summary:")
    print(f"  - Total features: {len(cols)}")
    print(f"  - Real-valued covariates: {len([c for c in real_covariates if c in cols and c not in component_covariates])}")
    print(f"  - Categorical covariates: {len([c for c in categorical_covariates if c in cols])}")
    print(f"  - Component covariates: {len([c for c in component_covariates if c in cols])}")
    print(f"  - Target variables: {len([c for c in target_cols if c in cols])}")
    
    print("\nFuture-Known vs. Past-Only Covariates:")
    print("  - Future-known: weather factors, calendar features, planned outages")
    print("  - Past-only: None (all covariates in this synthetic dataset are future-known)")
    print("="*80)


def main(
    num_regions: int = 5,
    num_weather_stations: int = 10,
    start_date: str = '2020-01-01',
    end_date: str = '2021-12-31',
    data_dir: str = './data/electricity',
    plot_dir: str = './plots/electricity'
) -> None:
    """Generate synthetic electricity load/price data conditioned on covariates.
    
    Args:
        num_regions: Number of regions to generate
        num_weather_stations: Number of weather stations to simulate
        start_date: Start date for simulation
        end_date: End date for simulation
        data_dir: Directory to save data files
        plot_dir: Directory to save plots
    """
    print(f"Generating synthetic electricity data from {start_date} to {end_date}")
    
    # 1. Generate dates for the simulation
    dates = generate_dates(start_date, end_date)
    print(f"Generated {len(dates)} timestamps")
    
    # 2. Generate region parameters
    region_params = generate_region_params(num_regions)
    print(f"Generated parameters for {len(region_params)} regions")
    
    # 3. Generate weather station data
    weather_data = generate_weather_data(dates, num_stations=num_weather_stations)
    print(f"Generated weather data for {len(weather_data)} stations")
    
    # 4. Generate PCA weather factors
    weather_factors = generate_pca_weather_factors(weather_data)
    print(f"Extracted {weather_factors.shape[1]} weather PCA factors")
    
    # 5. Generate calendar features
    calendar_features = generate_calendar_features(dates)
    print(f"Generated {calendar_features.shape[1]} calendar features")
    
    # 6. Generate planned outages
    planned_outages = generate_planned_outages(dates, num_regions)
    print(f"Generated planned outage schedules for {len(planned_outages)} regions")
    
    # 7. Generate electricity load and price data
    electricity_data = generate_electricity_load(
        region_params, dates, weather_factors, calendar_features, planned_outages
    )
    print(f"Generated electricity load and price data for {len(electricity_data)} regions")
    
    # 8. Save the data
    save_data(electricity_data, weather_factors, calendar_features, data_dir)
    
    # 9. Plot the data
    plot_electricity_data(electricity_data, weather_factors, plot_dir)
    
    # 10. Convert to TimeSeriesDataFrame
    timeseries_df = convert_to_timeseries_df(electricity_data)
    
    # 11. Save the timeseries dataframe
    timeseries_path = os.path.join(data_dir, "electricity_timeseries.csv")
    if hasattr(timeseries_df, 'to_csv'):
        timeseries_df.to_csv(timeseries_path)
    else:
        timeseries_df.reset_index().to_csv(timeseries_path)
    print(f"Saved time series data to {timeseries_path}")
    
    # 12. Print covariate summary
    summarize_covariates(timeseries_df)
    
    print("Done!")


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Generate synthetic electricity load/price data')
    parser.add_argument('--num-regions', type=int, default=5,
                      help='Number of regions to generate')
    parser.add_argument('--num-weather-stations', type=int, default=10,
                      help='Number of weather stations to simulate')
    parser.add_argument('--start-date', type=str, default='2020-01-01',
                      help='Start date for simulation (YYYY-MM-DD)')
    parser.add_argument('--end-date', type=str, default='2021-12-31',
                      help='End date for simulation (YYYY-MM-DD)')
    parser.add_argument('--data-dir', type=str, default='./data/electricity/data',
                      help='Directory to save data files')
    parser.add_argument('--plot-dir', type=str, default='./data/electricity/plots',
                      help='Directory to save plots')
    
    args = parser.parse_args()
    
    main(
        num_regions=args.num_regions,
        num_weather_stations=args.num_weather_stations,
        start_date=args.start_date,
        end_date=args.end_date,
        data_dir=args.data_dir,
        plot_dir=args.plot_dir
    )