from models import get_model
from utils.dataloader_3d import OptimizedCFDDataset, build_dataloader
from utils import Config
import torch
import numpy as np
import argparse
import yaml
import matplotlib.pyplot as plt


def rollout(config: Config, k: int):
    device = f"cuda:{config.device}" if torch.cuda.is_available() else "cpu"
    print(f"Using {device}.")

    # Load model
    model_class = get_model(config.model.name)
    model = model_class(
        **config.model.params.to_dict()
    ).to(device)
    
    # Load model weights
    model.load_state_dict(torch.load(config.checkpoint_path, map_location=device))
    model.eval()

    
    # Get validation indices
    val_indices = np.array([476, 105, 389, 1, 558, 80, 205, 34, 508, 427, 454, 366, 91, 339, 345, 
                          241, 13, 315, 387, 273, 166, 594, 484, 585, 504, 243, 562, 189, 475, 
                          510, 58, 474, 560, 252, 21, 313, 459, 160, 276, 191, 385, 413, 491, 
                          343, 308, 130, 99, 372, 87, 458, 330, 214, 466, 121, 20, 71, 106, 
                          270, 435, 102], dtype=np.int32)
    
    # Create dataset directly with validation indices
    dataset = OptimizedCFDDataset(val_indices)
    field_names = dataset.fields
    
    # Dictionary to store MSE for each trajectory

    total_mse = 0.0
    num_trajectories = 0
    total_vel_mse = 0.0
    total_pressure_mse = 0.0
    total_density_mse = 0.0
    
    # Process each validation trajectory
    for traj_id in val_indices:
        #print(f"Processing trajectory {traj_id}...")
        
        # Get initial condition (t=0)
        initial_idx = np.where((dataset.indices[:, 0] == traj_id) & (dataset.indices[:, 1] == 0))[0][0]
        initial_data, _ = dataset[initial_idx]
        
        # Move to device
        initial_state = initial_data.to(device)
        
        # List to store all predictions and ground truth
        predictions = [initial_state.detach().cpu()]  # Store the initial condition
        ground_truth = [initial_state.detach().cpu()]  # Initial condition is the same
        
        # Perform rollout prediction for the trajectory
        current_state = initial_state
        trajectory_length = 19  # Since we have 20 timesteps (0-19)
        
        # Collect all indices for this trajectory for ground truth access
        trajectory_indices = []
        for idx, (t_id, time_id) in enumerate(dataset.indices):
            if t_id == traj_id:
                trajectory_indices.append((idx, time_id))
        
        # Sort by time
        trajectory_indices.sort(key=lambda x: x[1])
        
        # Rollout prediction
        for t in range(trajectory_length):
            with torch.no_grad():
                # Predict next state
                next_state = model(current_state.unsqueeze(0),k).squeeze(0)
                
                # Store prediction
                predictions.append(next_state.detach().cpu())
                
                # Get ground truth for this time step
                gt_idx, time_id = trajectory_indices[t+1]  # t+1 because we're predicting the next step
                _, ground_truth_data = dataset[gt_idx]  # Get the target (next state)
                ground_truth.append(ground_truth_data)
                
                # Update current state for next iteration
                current_state = next_state
        
        # Calculate MSE for this trajectory
        mse_per_step = []
        vel_mse_per_step = []
        pressure_mse_per_step = []
        density_mse_per_step = []
        for t in range(1, trajectory_length + 1):  # Skip initial condition (t=0)
            mse = torch.mean((predictions[t] - ground_truth[t]) ** 2).item()
            mse_per_step.append(mse)
            vel_mse = torch.mean((predictions[t][:, :3] - ground_truth[t][:, :3]) ** 2).item()
            vel_mse_per_step.append(vel_mse)
            pressure_mse = torch.mean((predictions[t][:, 3] - ground_truth[t][:, 3]) ** 2).item()
            pressure_mse_per_step.append(pressure_mse)
            density_mse = torch.mean((predictions[t][:, 4] - ground_truth[t][:, 4]) ** 2).item()
            density_mse_per_step.append(density_mse)



        # Calculate overall average MSE
        total_vel_mse += np.mean(vel_mse_per_step)
        total_pressure_mse += np.mean(pressure_mse_per_step)
        total_density_mse += np.mean(density_mse_per_step)
        num_trajectories += 1

    
    # Calculate overall average MSE
    overall_avg_mse = total_mse / num_trajectories
    overall_avg_vel_mse = total_vel_mse / num_trajectories
    overall_avg_pressure_mse = total_pressure_mse / num_trajectories
    overall_avg_density_mse = total_density_mse / num_trajectories
    #print(f"\nOverall Average MSE across all trajectories: {overall_avg_mse:.6f}")
    print(f"\nOverall Average Velocity MSE across all trajectories: {overall_avg_vel_mse:.6f}")
    print(f"\nOverall Average Pressure MSE across all trajectories: {overall_avg_pressure_mse:.6f}")
    print(f"\nOverall Average Density MSE across all trajectories: {overall_avg_density_mse:.6f}")
    

def main(config):
    ks = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
    for k in ks:
        print(f"k: {k}")
        rollout(config, k)
    return 


if __name__ == "__main__":
    # Initialize argument parser for command line interface
    parser = argparse.ArgumentParser()
    # Required: Path to the YAML configuration file
    parser.add_argument("--config", type=str, help="Path to config file.")
    parser.add_argument("--checkpoint_path", type=str, help="Path to model checkpoint.")
    parser.add_argument("--set", metavar="KEY=VAL", action="append",
                        help="Override any config entry, e.g. --set model.params.activation=relu")
    
    args = parser.parse_args()
    
    # Load configuration from YAML file
    config = Config.from_yaml(args.config)
    
    # Add checkpoint path to config
    if args.checkpoint_path:
        config.checkpoint_path = args.checkpoint_path
    
    for item in args.set or []:
        key, raw = item.split("=", 1)
        config.set(key, yaml.safe_load(raw))

    # Start the rollout process with the configured settings
    main(config)