import torch
import numpy as np
from typing import Optional, List, Tuple, Dict, Any
from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset
import time
from torch import optim
from utils import get_fdprs, get_fdprs_uniform
from CMLR import CMLR

def train_CMLR(
    model: torch.nn.Module,
    optimizer: Optimizer,
    data_loader: DataLoader,
    num_epochs: Optional[int] = None,
    num_samples: Optional[int] = None,
    eps_samples: Optional[torch.Tensor] = None,
    tol: Optional[float] = None,
    patience: Optional[int] = None
) -> Tuple[List[float], List[float]]:
    """
    Train the CMLR model by maximizing the ELBO via stochastic variational inference.

    Args:
        model: CMLR instance whose forward() returns the ELBO.
        optimizer: torch Optimizer (e.g., Adam).
        data_loader: yields (indices, x_batch) pairs for mini-batches.
        num_epochs: maximum training epochs (default 10000).
        num_samples: Monte Carlo samples per ELBO estimate.
        eps_samples: pre-drawn noise of shape [num_samples, Bdim, D].
        tol: minimum ELBO improvement to reset patience (default 1e-4).
        patience: epochs to wait without improvement before early stopping (default 100).

    Returns:
        loss_history: list of -ELBO from final batch of each epoch.
        elbo_history: list of average ELBO per epoch.
    """
    # --- Initialize history and early-stopping trackers ---   
    loss_history: List[float] = []
    elbo_history: List[float] = []
    best_elbo = -float('inf')  # Initialize best ELBO (we want to maximize ELBO)
    patience_counter = 0
 
    # --- Set default values for optional parameters ---
    if tol is None:
        tol = 1e-4
    if patience is None:
        patience = 100
    if num_epochs is None:
        num_epochs = 10000

    # --- Main training loop over epochs ---
    for epoch in range(num_epochs):
        epoch_elbo = 0.0
        num_batches = 0
        # Iterate over mini-batches.
        for indices, x_batch in data_loader:
            # Convert to numpy array and ensure float32 type
            x_batch = np.array(x_batch, dtype=np.float32)
            # Convert to a PyTorch tensor
            x_batch = torch.tensor(x_batch).squeeze(0)
            optimizer.zero_grad()
            # Forward pass: compute the ELBO for this batch
            elbo = model(
                x_batch=x_batch,
                indices=indices,
                num_samples=num_samples,
                eps_samples=eps_samples
            ) 
            # We minimize negative ELBO
            loss = -elbo  
            loss.backward()
            optimizer.step()

            epoch_elbo += elbo.item()
            num_batches += 1

        # --- End of epoch: compute average ELBO ---
        current_elbo = epoch_elbo / num_batches
        loss_history.append(loss.item())
        elbo_history.append(current_elbo)

        # --- Early-stopping logic ---
        # Check if the current ELBO shows significant improvement over the best so far.
        if current_elbo > best_elbo + tol:
            best_elbo = current_elbo
            patience_counter = 0  # Reset the counter if there's improvement
        else:
            patience_counter += 1

            # --- Periodic logging ---
        if epoch % 200 == 0:
            print(
                f"Epoch {epoch:4d} | "
                f"Avg ELBO: {current_elbo:.6f} | "
                f"Best ELBO: {best_elbo:.6f} | "
                f"Patience: {patience_counter}/{patience}"
            )

        # Early stopping: if no improvement for 'patience' consecutive epochs, break.
        if patience_counter >= patience:
            print(
                f"Early stopping at epoch {epoch}: "
                f"no ELBO improvement ≥ {tol} for {patience} epochs."
            )
            break

    return loss_history, elbo_history


def fit_CMLR_model(
    xsamples: torch.Tensor,
    labels: torch.Tensor,
    CMLR_fixed_params: Dict[str, Any]
) -> Tuple[torch.nn.Module, list, float]:
    """
    Fit the Continuous MLR (CMLR) model to simulated data.

    This function:
      1. Builds Fourier‐domain bases for the provided labels and a uniform grid.
      2. Wraps the dataset in an indexed DataLoader for mini‐batch SVI.
      3. Initializes the CMLR model and trains it with early stopping.
      4. Returns the trained model, ELBO history, and total training time.

    Args
    ----
    xsamples : torch.Tensor, shape [L, D]
        Simulated neural responses (features) for L samples and D neurons.
    labels : torch.Tensor, shape [L] or [L, 1]
        Corresponding labels for each sample (could be indices or real values).
    CMLR_fixed_params : dict
        Training parameters and model configuration, including:
          - 'num_samples_MC'      : int, Monte Carlo samples per ELBO eval.
          - 'learing_rate'        : float, learning rate for Adam.
          - 'label_params'        : dict, label‐space config for basis.
          - 'T'                   : int, grid size for uniform basis.
          - 'batch_size_SVI'      : int, mini‐batch size.
          - 'maximum_num_epochs'  : int, max training epochs.
          - 'tolerance'           : float, ELBO improvement threshold.
          - 'patience'            : int, early‐stopping patience.

    Returns
    -------
    model : torch.nn.Module
        Trained CMLR model.
    elbo_history : list of float
        ELBO values recorded at each epoch.
    training_time : float
        Total training time in seconds.
    """

    # --- 1) Unpack training parameters ---
    num_samples_MC = CMLR_fixed_params['num_samples_MC']
    learing_rate = CMLR_fixed_params['learing_rate']
    label_params = CMLR_fixed_params['label_params']
    T = CMLR_fixed_params['T'] 
    batch_size_SVI = min(CMLR_fixed_params['batch_size_SVI'], xsamples.shape[0])  # Ensure batch size does not exceed number of samples
    max_num_epochs = CMLR_fixed_params['maximum_num_epochs']
    tolerance = CMLR_fixed_params['tolerance']
    patience = CMLR_fixed_params['patience']

    # --- 2) Build Fourier‐domain parameters for labels and uniform grid ---
    fdprs_labels = get_fdprs(labels, label_params)
    fdprs_uniform_grid = get_fdprs_uniform(T, label_params)
    print("Number of fourier coefficients:", fdprs_labels['Bmat'].shape[1])

    # --- 3) Prepare an indexed dataset for mini‐batch SVI ---
    class IndexedTensorDataset(TensorDataset):
        def __getitem__(self, index):
            # Retrieve the original data (e.g., xsamp)
            data = super().__getitem__(index)
            # Return a tuple (index, data)
            return index, data
    indexed_dataset = IndexedTensorDataset(xsamples)
    # Create a DataLoader that shuffles data and returns mini-batches.
    data_loader = DataLoader(indexed_dataset, batch_size=batch_size_SVI, shuffle=True,drop_last=True)

    # --- 4) Initialize model and optimizer ---
    CMLR_model = CMLR(
        xsamples=xsamples,   # shape [K, D]
        fdprs_labels=fdprs_labels,
        fdprs_uniform_grid = fdprs_uniform_grid,
        label_params=label_params,
    )
    # Train the model using Adam optimizer
    optimizer = optim.Adam(CMLR_model.parameters(), lr=learing_rate)

    # --- 5) Pre-generate Monte Carlo noise for reproducibility ---
    Bdim = fdprs_labels['Bmat'].shape[1]  # Number of Fourier basis functions
    D = xsamples.shape[1]                # Number of features
    fixed_random_samples_MC = torch.randn(num_samples_MC, Bdim, D)

    print("Training the model using Adam optimizer...")
    # Record the start time.
    start_time = time.time()
    # Train the model with early stopping
    print("Starting training...")
    loss_history, elbo_history = train_CMLR(CMLR_model, optimizer, data_loader, num_epochs=max_num_epochs,
                                                    num_samples=num_samples_MC, eps_samples=fixed_random_samples_MC, tol=tolerance, patience=patience)
    # Record the end time.
    end_time = time.time()
    # Calculate the elapsed time in seconds.
    training_time = end_time - start_time
    print(f"Training completed in {training_time:.2f} seconds.")
    
    return CMLR_model,elbo_history,training_time
