import pandas as pd
from autogluon.timeseries import TimeSeriesDataFrame
from residual_chronos.Regressor import TimeSeriesRegressor


def create_special_test_cases():
    """Create and test special time series cases to test robustness"""
    print("\n\n=============== SPECIAL TEST CASES ===============\n")
    
    # Case 1: Two items with same length (10) but different date ranges
    print("\n--- Case 1: Same length, different date ranges ---")
    # Create dates for the two series
    dates1 = pd.date_range('2020-01-01', periods=10, freq='D')
    dates2 = pd.date_range('2020-02-01', periods=10, freq='D')
    
    # Create the data
    data1 = []
    for i, date in enumerate(dates1):
        data1.append({
            'item_id': 'item_A',
            'timestamp': date,
            'target': 100 + i * 10  # Simple trend
        })
    
    for i, date in enumerate(dates2):
        data1.append({
            'item_id': 'item_B',
            'timestamp': date,
            'target': 200 + i * 5   # Different trend
        })
    
    # Convert to TimeSeriesDataFrame
    test_df1 = pd.DataFrame(data1)
    test_tsdf1 = TimeSeriesDataFrame(test_df1)
    
    # Train a simple model and get predictions
    pred1 = TimeSeriesRegressor(
        model_name="AutoARIMA",
        prediction_length=2,  # Short prediction length
        target="target"
    ).fit(
        test_tsdf1,
        hyperparameters={"AutoARIMA": {}},
        random_seed=42,
        time_limit=30
    )

    # for debugging
    # pred1 = TimeSeriesRegressor(
    #     model_name="Chronos",
    #     prediction_length=2,  # Short prediction length
    #     target="target"
    # ).fit(
    #     test_tsdf1,
    #     hyperparameters={"Chronos": {'fine_tune': False}},
    #     random_seed=42,
    #     time_limit=60
    # )
    
    
    # Test with full window
    print("\n--- Full window predictions ---")
    fitted1 = pred1.predict_all(test_tsdf1, horizon=pred1.prediction_length, min_ctx=5, window_size=None)
    assert fitted1['mean'].count() == 10, "Expected 10 predictions, got %d" % fitted1['mean'].count()
    assert fitted1.index.equals(test_tsdf1.index), "Index of fitted1 and test_tsdf1 must be the same"

    print("\n--- Full window predictions with min_ctx=1 ---")
    fitted2 = pred1.predict_all(test_tsdf1, horizon=pred1.prediction_length, min_ctx=1, window_size=None)
    assert fitted2['mean'].count() == 18, "Expected 10 predictions, got %d" % fitted2['mean'].count()
    assert fitted2.index.equals(test_tsdf1.index), "Index of fitted2 and test_tsdf1 must be the same"

    # Test with limited window
    print("\n--- Limited window predictions (window_size=3) ---")
    fitted1_limited = pred1.predict_all(test_tsdf1, horizon=pred1.prediction_length, min_ctx=5, window_size=3)
    assert fitted1_limited['mean'].count() == 6, "Expected 6 predictions, got %d" % fitted1_limited['mean'].count()
    
    # Case 2: Two items with same start date but different lengths (7 vs 12)
    print("\n--- Case 2: Same start date, different lengths ---")
    # Create dates
    dates_common = pd.date_range('2020-03-01', periods=7, freq='D')
    dates_longer = pd.date_range('2020-03-01', periods=12, freq='D')
    
    # Create the data
    data2 = []
    for i, date in enumerate(dates_common):
        data2.append({
            'item_id': 'short_item',
            'timestamp': date,
            'target': 50 + i * 8  # Simple trend
        })
    
    for i, date in enumerate(dates_longer):
        data2.append({
            'item_id': 'long_item',
            'timestamp': date,
            'target': 150 + i * 7  # Different trend
        })
    
    # Convert to TimeSeriesDataFrame
    test_df2 = pd.DataFrame(data2)
    test_tsdf2 = TimeSeriesDataFrame(test_df2)
    
    # Train a simple model and get predictions
    pred2 = TimeSeriesRegressor(
        model_name="AutoARIMA",
        prediction_length=3,  # Medium prediction length
        target="target"
    ).fit(
        test_tsdf2,
        hyperparameters={"AutoARIMA": {}},
        random_seed=42,
        time_limit=30
    )
    
    print("\n--- Testing with different horizon values ---")
    # Test with different horizon values
    result2 = pred2.predict_all(test_tsdf2, horizon=pred2.prediction_length, min_ctx=3, window_size=5)
    print(f"Horizon=3, Window=5 shape: {result2.shape}")
    assert result2['mean'].count() == 10, "Expected 10 predictions, got %d" % result2['mean'].count()
    
    result2_h1 = pred2.predict_all(test_tsdf2, horizon=1, min_ctx=3, window_size=5)
    print(f"Horizon=1, Window=5 shape: {result2_h1.shape}")
    assert result2_h1['mean'].count() == 10, "Expected 10 predictions, got %d" % result2_h1['mean'].count()

# Run the special test cases
create_special_test_cases()
