import argparse
import io
import os
import random  # Add random module
import torch
import wandb

import matplotlib.pyplot as plt
import numpy as np
import pyarrow.parquet as pq



from models import MODEL_REGISTRY  # Import model registry
from utils import *  # Import all utilities

def set_seed(seed):
    """Set all seeds for reproducibility"""
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy
    torch.manual_seed(seed)  # PyTorch CPU
    torch.cuda.manual_seed(seed)  # PyTorch GPU
    torch.cuda.manual_seed_all(seed)  # PyTorch multi-GPU
    torch.backends.cudnn.deterministic = True  # CUDNN
    torch.backends.cudnn.benchmark = False  # CUDNN

'''
Data
'''
def get_data(path):
    # Check if the data file exists
    if not os.path.exists(path):
        raise FileNotFoundError(f"Data file not found: {path}.")
    
    # Load data from parquet file:
    # 1. Read the 'u' column from the parquet file
    # 2. Convert each entry to a numpy array using BytesIO
    # 3. Stack all arrays together
    # 4. Convert to float32 for better memory efficiency
    data = np.stack([np.load(io.BytesIO(x.as_buffer())) for x in pq.read_table(path)['u']]).astype(np.float32)
    
    # Convert numpy array to PyTorch tensor
    data = torch.tensor(data)
    return data  # Shape: (batch_size, seq_len, nx) where nx is the spatial dimension

def process_data(data, time_jump=1,residual=False, norm_res=1):
    # Create overlapping windows of size 2 along the sequence dimension
    # unfold(dim=1, size=2, step=1) creates windows along dimension 1
    # permute reorganizes dimensions to (batch, seq, window, spatial)
    index = list(range(0, data.shape[1], time_jump))
    data = data[:,index,:]
    data = data.unfold(1, 2, 1).permute(0, 1, 3, 2)  # Shape: (batch_size, (seq_len-1), 2, nx)
    

    # Flatten batch and sequence dimensions to create independent samples
    data = data.reshape(-1, 2, data.shape[-1])  # Shape: (batch_size * (seq_len-1), 2, nx)

    # Split into inputs (first timestep) and targets (second timestep)
    inputs = data[:, :-1, :]  # Take first timestep
    
    if not residual:
        # For direct prediction: target is the actual value at t+1
        targets = data[:, -1:, :]
        #targets = targets[:,index,:]
    else:
        # For residual prediction: target is the change between t and t+1
        targets = (data[:, -1:, :] - data[:, -2:-1, :])/norm_res
    
    return inputs, targets

def process_traj(traj,time_jump=1):
    # Create a copy of the trajectory for validation
    # This keeps the original trajectory unmodified
    data = traj.clone()
    index = list(range(0, data.shape[1], time_jump))
    data = data[:,index,:]
    return data

'''
Train
'''
def train(dataloader, model, loss_fn, optimizer, scheduler):
    ## TODO: Add scaler parameter and modify loss computation to log in original scale when using scaler
    # Set model to training mode (enables dropout, batch norm, etc.)
    model.train()

    for batch, (x, y) in enumerate(dataloader):
        # Forward pass: compute model predictions
        if loss_fn is not None:
            pred = model(x)

            # Compute loss between predictions and targets
            loss = loss_fn(pred, y)
        else:
            loss = model(x,y)

        # Backward pass:
        # 1. Clear accumulated gradients
        # 2. Compute gradients of loss w.r.t. model parameters
        # 3. Update model parameters using optimizer
        # 4. Update learning rate using scheduler
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) ## New! Gradient clipping.
        optimizer.step()
        scheduler.step()
        
        # Log training metrics to wandb
        wandb.log({'learning rate': scheduler.get_last_lr()[0], 'train/loss': loss.item()})
        # Print progress every 10% of the dataset
        #if len(dataloader) >= 10 and (batch == 0 or batch % (len(dataloader) // 10) == 0):
            #print(f'current: {batch * len(x):>5d}/{len(dataloader.dataset):>5d}, train/loss: {loss.item():>7f}')
        
'''
Valid
'''
def valid_onestep(dataloader, model, loss_fn):
    ## TODO: Modify loss computation to inverse transform predictions and targets when using scaler
    # Set model to evaluation mode (disables dropout, batch norm, etc.)
    model.eval()

    # Disable gradient computation for validation
    with torch.inference_mode():
        loss = 0
        for x, y in dataloader:
            # Compute one-step-ahead predictions
            pred = model(x)
            # Accumulate loss over all batches
            loss += loss_fn(pred, y).item()  # Note: Last batch might be smaller!
        # Calculate average loss across all batches
        loss /= len(dataloader)
    
    # Log validation metrics to wandb
    wandb.log({'valid/onestep_loss': loss})
    print(f'valid/onestep_loss: {loss:.7f}')

def batched_corrcoef(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
    """
    Compute Pearson correlation coefficient for batched inputs.
    
    Args:
        pred: Tensor of shape (batch_size, n_features)
        targ: Tensor of shape (batch_size, n_features)
        
    Returns:
        Tensor of shape (batch_size, 1) containing correlation coefficients
    """
    # Center the variables
    pred_centered = pred - pred.mean(dim=1, keepdim=True)
    targ_centered = targ - targ.mean(dim=1, keepdim=True)
    
    # Compute numerator (covariance)
    numerator = (pred_centered * targ_centered).sum(dim=1, keepdim=True)
    
    # Compute denominator (product of standard deviations)
    pred_std = torch.sqrt((pred_centered ** 2).sum(dim=1, keepdim=True))
    targ_std = torch.sqrt((targ_centered ** 2).sum(dim=1, keepdim=True))
    denominator = pred_std * targ_std
    
    # Handle zero division
    correlation = torch.where(
        denominator > 0,
        numerator / denominator,
        torch.zeros_like(numerator)
    )
    
    return correlation

def find_corr_threshold_index(corr, threshold):
    """Find last index where correlation stays above the threshold for each trajectory.
    Uses vectorized operations for efficient computation.
    
    Args:
        corr: Correlation tensor of shape (n_trajectories, n_timesteps)
        threshold: Correlation threshold value
    Returns:
        Tensor of indices (as float type) indicating:
        - Last position where correlation >= threshold for each trajectory
        - sequence length if correlation never drops below threshold
        - 0 if correlation starts below threshold
    """
    # Create mask where True indicates correlation < threshold
    mask = (corr < threshold).float()
    
    # If first element is below threshold, return 0
    starts_below = mask[:, 0] == 1
    
    # Find first occurrence where correlation drops below threshold
    first_below = mask.argmax(dim=1) - 1
    
    # Handle special cases:
    # 1. If correlation never drops below threshold (argmax returns 0)
    # 2. If correlation starts below threshold
    first_below[first_below == -1] = corr.shape[1]  # Never drops case
    first_below[starts_below] = 0  # Starts below case
    
    return first_below.float()

def valid_rollout(traj, model, loss_fn, residual, time_jump=1 ,scaler=None,norm_res=1):
    # Set model to evaluation mode (disables dropout, batch norm, etc.)
    model.eval()

    # Disable gradient computation for validation
    with torch.inference_mode():
        # Initialize prediction tensor with zeros, matching trajectory shape
        pred = torch.zeros_like(traj).to(traj.device)
        # Set initial condition (t=0) from true trajectory
        pred[:, 0:1, :] = traj[:, 0:1, :]

        # Perform autoregressive prediction for the entire sequence
        for i in range(1, traj.shape[1]):
            # Get previous prediction as input (clone to ensure no gradient sharing)
            x = pred[:, i-1:i, :].clone()
            # Predict next step
            pred_step = model(x)

            # If using a scaler, convert prediction back to original scale
            if scaler is not None:
                pred_step = scaler.inverse_transform(pred_step)

            # Update prediction based on model type:
            if not residual:
                # Direct prediction: use model output directly
                pred[:, i:i+1, :] = pred_step
            else:
                # Residual prediction: add predicted change to previous state
                pred[:, i:i+1, :] = pred[:, i-1:i, :] + pred_step*norm_res

        # Compute loss between predictions and true trajectory (excluding initial condition)
        loss = loss_fn(pred[:, 1:, :], traj[:, 1:, :]).item()

        # Log rollout loss to wandb
        wandb.log({'valid/rollout_loss': loss}, commit=False)
        print(f'valid/rollout_loss: {loss:>7f}')

        ## Compute correlation
        pred_flat = pred.reshape(-1, pred.shape[-1])
        traj_flat = traj.reshape(-1, traj.shape[-1])

        corr_flat = batched_corrcoef(pred_flat, traj_flat)
        corr = corr_flat.reshape(traj.shape[0], traj.shape[1])

        ## Find threshold indices
        rollout_corr_08 = find_corr_threshold_index(corr, 0.8).mean() * 0.2 * time_jump
        rollout_corr_09 = find_corr_threshold_index(corr, 0.9).mean() * 0.2 * time_jump

        # Log rollout correlation to wandb
        wandb.log({'valid/rollout_corr_08': rollout_corr_08, 'valid/rollout_corr_09': rollout_corr_09}, commit=False)
        print(f'valid/rollout_corr_08: {rollout_corr_08:.7f}, valid/rollout_corr_09: {rollout_corr_09:.7f}')

        # Visualize results for up to 10 trajectories
        preds = []
        for i in range(min(traj.shape[0], 10)):
            # Create figure with two subplots side by side
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

            # Plot true trajectory
            im1 = ax1.imshow(traj[i].cpu().numpy().T, aspect='auto', cmap='Spectral')
            ax1.set_title('True', fontfamily='monospace', fontsize=24)
            ax1.tick_params(labelsize=20, labelfontfamily='monospace')
            cbar1 = plt.colorbar(im1, ax=ax1)
            cbar1.ax.tick_params(labelsize=20, labelfontfamily='monospace')

            # Plot predicted trajectory
            im2 = ax2.imshow(pred[i].cpu().numpy().T, aspect='auto', cmap='Spectral')
            ax2.set_title('Pred', fontfamily='monospace', fontsize=24)
            ax2.tick_params(labelsize=20, labelfontfamily='monospace')
            cbar2 = plt.colorbar(im2, ax=ax2)
            cbar2.ax.tick_params(labelsize=20, labelfontfamily='monospace')
            
            plt.tight_layout()
            
            # Save figure and clean up
            preds.append(wandb.Image(fig))
            plt.close(fig)

        # Log visualizations to wandb
        wandb.log({'valid/rollout_pred': preds})

    return loss

'''
Main
'''
def main(config: str):
    # Initialize wandb for experiment tracking
    wandb.login()
    wandb.init(
        project=config.project,
        config=config.__dict__,
        tags=config.tags,
        entity="",
    )


    # Set random seeds for reproducibility
    set_seed(config.seed)

    # Set up device (GPU/CPU)
    device = f'cuda:{config.device}' if torch.cuda.is_available() else 'cpu'
    print(f'Using {device}.')

    # Set up data paths
    data_path = os.path.dirname(os.path.abspath(__file__)) + '/data/parquet/'
    
    # Load and process training data
    train_path = config.train_path
    train_data = get_data(train_path)
    train_data_input, train_data_output = process_data(train_data, time_jump=config.time_jump,residual=config.residual, norm_res=config.norm_res)
    
    # Initialize and fit data scaler if enabled
    scaler = None
    if config.scale:
        scaler = MaxScaler()
        scaler.fit(train_data_output)
        
        # Log scaling factor to wandb
        wandb.log({
            'scaler/scale': scaler.scale
        })
        
        # Apply scaling to training data
        train_data_output = scaler.transform(train_data_output)
    
    # Move training data to device and create dataloader
    train_data_input, train_data_output = train_data_input.to(device), train_data_output.to(device)
    train_dataloader = torch.utils.data.DataLoader(Dataset(train_data_input, train_data_output), 
                                                 batch_size=config.batch_size, 
                                                 shuffle=True)

    # Load and process validation data if specified
    if config.valid_path:
        valid_path = config.valid_path
        valid_data = get_data(valid_path)
        valid_data_input, valid_data_output = process_data(valid_data,time_jump=config.time_jump,residual=config.residual,norm_res=config.norm_res)
        
        # Scale validation data if scaling is enabled
        if config.scale:
            valid_data_output = scaler.transform(valid_data_output)
        
        # Move validation data to device and create dataloader
        valid_data_input, valid_data_output = valid_data_input.to(device), valid_data_output.to(device)
        valid_dataloader = torch.utils.data.DataLoader(Dataset(valid_data_input, valid_data_output), 
                                                     batch_size=config.batch_size, 
                                                     shuffle=False)

        # Prepare trajectory for rollout validation
        valid_traj = process_traj(valid_data,time_jump=config.time_jump)
        valid_traj = valid_traj.to(device)

    # Placeholder for test data processing
    if config.test_path:
        pass
    
    # Initialize model based on configuration
    if config.model not in MODEL_REGISTRY:
        raise ValueError(f"Model {config.model} not found in registry. Available models: {list(MODEL_REGISTRY.keys())}")
    
    model_class = MODEL_REGISTRY[config.model]
    model = model_class(
        **{k: v for k, v in config.__dict__.items() if k in model_class.__init__.__code__.co_varnames}
    ).to(device)
    #model = torch.compile(model, backend="inductor", options={"max_autotune": True})
    total_params = sum(p.numel() for p in model.parameters())
    wandb.log({'train/number_of_paramenter': total_params})
    print(f"Total number of parameters: {total_params}")

    # Enable model parameter tracking in wandb
    wandb.watch(model)

    # Initialize loss function (Mean Squared Error) for training
    loss_fn = torch.nn.MSELoss() if config.model != 'PDERefiner' and config.model != 'PDERefiner_2' and config.model != 'ddpm' and config.model != 'alpha_deblending' else None
    loss_fn_val = torch.nn.MSELoss() 
    
    # Initialize AdamW optimizer with configured learning rate
    # AdamW is a variant of Adam that implements weight decay correction
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)#, weight_decay=config.weight_decay)

    # Set up learning rate scheduler based on configuration
    if config.scheduler == 'constant':
        # ConstantLR maintains a fixed learning rate throughout training
        # factor=1.0 means no change to the initial learning rate
        scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
    elif config.scheduler == 'cosine':
        # CosineAnnealingLR gradually reduces learning rate following a cosine curve
        # Calculate total number of iterations (batches * epochs)
        iters = len(train_dataloader) * config.epochs
        # T_max: total number of iterations for one cosine cycle
        # eta_min: minimum learning rate at the bottom of the cosine curve
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=config.learning_rate_min)

    # Training loop
    print('Training...')
    for epoch in range(config.epochs):
        print(f'Epoch: {epoch + 1}/{config.epochs}')

        train(train_dataloader, model, loss_fn, optimizer, scheduler)
        
        # Run validation if validation data is provided
        if config.valid_path:
            valid_onestep(valid_dataloader, model, loss_fn_val)

        if (epoch+1)%config.roll_every==0:
             valid_rollout(valid_traj, model, loss_fn_val, config.residual, time_jump=config.time_jump,scaler=scaler,norm_res=config.norm_res)
             #valid_onestep(valid_dataloader, model, loss_fn_val)

    print('Training complete.')

    # Perform rollout validation if validation data is provided
    if config.valid_path:
        valid_rollout(valid_traj, model, loss_fn_val, config.residual, time_jump=config.time_jump,scaler=scaler,norm_res=config.norm_res)

    # Save model and scaler
    model_path = ''
    os.makedirs(model_path, exist_ok=True)

    torch.save(model.state_dict(), f'{model_path}/model.pt')
    if config.scale:
        scaler.save(f'{model_path}/scaler.json')

    print(f'Model saved as {model_path}/model.pt.')

    # Clean up wandb
    wandb.finish()

    

if __name__ == '__main__':
    # Initialize argument parser for command line interface
    parser = argparse.ArgumentParser()
    # Required: Path to the YAML configuration file
    parser.add_argument('--config', type=str, help='Path to config file.')
    
    # Optional: Allow some common parameters to be overridden via CLI
    parser.add_argument('--device', type=int, help='Override GPU device number.')
    # Parse command line arguments
    args = parser.parse_args()
    
    # Load configuration from YAML file
    config = Config.from_yaml(args.config)
    # Override configuration values with command line arguments if provided
    if args.device is not None:
        config.device = args.device

    # Start the training process with the configured settings
    main(config)