import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import os
import time
import matplotlib.pyplot as plt
import seaborn as sns


# ==============================================================================
# 0. Setup
# ==============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define model structure (must match training)
class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=64):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.layers(x)

# ==============================================================================
# 1. Load Pre-trained Models
# ==============================================================================
print("\n--- Loading pre-trained neural network models ---")

# Load transition model
try:
    TRANSITION_MODEL_PATH = 'checkpoints/transition_model_with_epoch_100.pth'
    transition_model = MLP(input_size=4, output_size=2).to(device)
    transition_model.load_state_dict(torch.load(TRANSITION_MODEL_PATH, map_location=device, weights_only=True))
    transition_model.eval()
    print(f"Successfully loaded transition model from '{TRANSITION_MODEL_PATH}'.")
except FileNotFoundError:
    print(f"Error: Transition model file '{TRANSITION_MODEL_PATH}' not found.")
    exit()

# Load revenue model
try:
    REVENUE_MODEL_PATH = 'checkpoints/revenue_model_with_epoch_100.pth'
    revenue_model = MLP(input_size=4, output_size=1).to(device)
    revenue_model.load_state_dict(torch.load(REVENUE_MODEL_PATH, map_location=device, weights_only=True))
    revenue_model.eval()
    print(f"Successfully loaded revenue model from '{REVENUE_MODEL_PATH}'.")
except FileNotFoundError:
    print(f"Error: Revenue model file '{REVENUE_MODEL_PATH}' not found.")
    exit()

# ==============================================================================
# 2. Load Data for Analysis
# ==============================================================================
print("\n--- Loading analysis data ---")
try:
    DATA_FILENAME = 'simulation.csv'
    df_analysis = pd.read_csv(DATA_FILENAME)
    print(f"Successfully loaded {len(df_analysis)} rows from '{DATA_FILENAME}'.")
except FileNotFoundError:
    print(f"Warning: Data file '{DATA_FILENAME}' not found.")
    print("Using randomly generated synthetic data for demonstration.")
    df_analysis = pd.DataFrame(np.random.rand(20000, 9),
                               columns=['orders', 'drivers', 'A', 'T', 'ordersNext', 'driversNext',
                                        'revenue', 'simu_time', 'n'])
    df_analysis['T'] = np.random.randint(0, 20, size=20000)
    df_analysis[['orders', 'drivers', 'ordersNext', 'driversNext']] *= 50
    df_analysis['revenue'] *= 50

# ==============================================================================
# 3. Core Analysis
# ==============================================================================
df_t0 = df_analysis[df_analysis['T'] == 0].copy()
print(f"At T=0, there are {len(df_t0)} records.")

print("\n--- 1. Descriptive Statistics ---")
order_stats_t0 = df_t0['orders'].describe()
print(order_stats_t0)

print("\n--- 2. Frequency Distribution ---")
order_counts_t0 = df_t0['orders'].value_counts().sort_index()
order_proportions_t0 = df_t0['orders'].value_counts(normalize=True).sort_index()
output_filename_distri = "initial_state_distribution.csv"

order_proportions_df = order_proportions_t0.reset_index()
order_proportions_df.columns = ['orders', 'probability']
order_proportions_df.to_csv(output_filename_distri, index=False)

print(f"\n--- Distribution data successfully saved to: {output_filename_distri} ---")

# ==============================================================================
# 3. Residual Analysis
# ==============================================================================
residual_analysis_results = []
timesteps = sorted(df_analysis['T'].unique())

print("\n--- Computing residuals for each timestep ---")

for t in timesteps:
    df_t = df_analysis[df_analysis['T'] == t]
    if len(df_t) < 2:
        continue

    features = ['orders', 'drivers', 'A', 'T']
    X_t = torch.tensor(df_t[features].values, dtype=torch.float32).to(device)
    Y_revenue_true_t = torch.tensor(df_t[['revenue']].values, dtype=torch.float32).to(device)
    Y_transition_true_t = torch.tensor(df_t[['ordersNext', 'driversNext']].values, dtype=torch.float32).to(device)

    with torch.no_grad():
        Y_revenue_pred_t = revenue_model(X_t)
        Y_transition_pred_t = transition_model(X_t)

    residuals_revenue = Y_revenue_true_t - Y_revenue_pred_t
    residuals_transition = Y_transition_true_t - Y_transition_pred_t

    residual_analysis_results.append({
        'T': t,
        'revenue_resid_mean': residuals_revenue.mean().item(),
        'revenue_resid_var': residuals_revenue.var().item(),
        'ordersNext_resid_mean': residuals_transition[:, 0].mean().item(),
        'ordersNext_resid_var': residuals_transition[:, 0].var().item(),
        'driversNext_resid_mean': residuals_transition[:, 1].mean().item(),
        'driversNext_resid_var': residuals_transition[:, 1].var().item()
    })

print("Residual analysis completed.")

# ==============================================================================
# 4. Reporting
# ==============================================================================
residuals_df_nn = pd.DataFrame(residual_analysis_results)

print("\n" + "="*50)
print("--- Residual Means and Variances at Each Timestep ---")
print("="*50)
print(residuals_df_nn.to_string())

output_filename = "nn_residual_analysis.csv"
residuals_df_nn.to_csv(output_filename, index=False)
print(f"\nResidual analysis results successfully saved to: {output_filename}")