import jax
import jax.random as random
import jax.numpy as jnp
from sklearn.metrics import mean_absolute_error
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Iterable, Literal
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import optax
from diffusion_crf import AbstractBatchableObject, auto_vmap
import numpy as np
from tqdm.auto import tqdm

###############################################################################
# Utility Functions
###############################################################################
def extract_time(data):
  """Extract time information from data.

  Args:
    - data: time-series data with shape [n_samples, T, features]

  Returns:
    - time: array of sequence lengths for each sample
    - max_seq_len: maximum sequence length
  """
  time = []
  for i in range(len(data)):
    seq_len = data[i].shape[0]
    time.append(seq_len)

  max_seq_len = max(time)
  return jnp.array(time), max_seq_len
def batch_generator(data, time, batch_size, dim, key=None):
  """Fully vectorized mini-batch generator for predictive model.

  Args:
    - data: time-series data with shape [n_samples, T, features]
    - time: sequence lengths array of shape [n_samples]
    - batch_size: the number of samples in each batch
    - dim: dimensionality of data
    - key: random key for shuffling

  Returns:
    - X_batch: features (all except last dimension) with shape [batch_size, max_seq_len-1, dim-1]
    - T_batch: sequence lengths with shape [batch_size]
    - Y_batch: target values with shape [batch_size, max_seq_len-1, 1]
    - mask: boolean mask with shape [batch_size, max_seq_len-1] (True for valid positions)
  """
  if key is None:
    key = random.PRNGKey(0)

  # Get random indices for batch
  no = len(data)
  idx = random.permutation(key, jnp.arange(no))
  batch_idx = idx[:batch_size]

  # Get the selected data and sequence lengths for this batch
  batch_data = jnp.take(data, batch_idx, axis=0)
  batch_seq_lengths = jnp.take(time, batch_idx)

  # Find max sequence length in this batch (minus 1 for prediction)
  max_seq_len = jnp.max(batch_seq_lengths) - 1

  # Create sequence indices for each batch element [0, 1, 2, ..., max_seq_len-1]
  seq_indices = jnp.arange(max_seq_len)

  # Create mask: indices < (sequence_length - 1) are valid
  mask = seq_indices[None, :] < (batch_seq_lengths - 1)[:, None]

  # Create range indices for input and target
  input_indices = jnp.arange(max_seq_len)
  target_indices = jnp.arange(1, max_seq_len + 1)

  # Create batch_size x max_seq_len x 1 matrices of batch indices
  batch_indices = jnp.arange(batch_size)[:, None, None]
  batch_indices = jnp.broadcast_to(batch_indices, (batch_size, max_seq_len, 1))

  # Create sequence index matrices for inputs and targets
  in_seq_indices = input_indices[None, :, None]
  in_seq_indices = jnp.broadcast_to(in_seq_indices, (batch_size, max_seq_len, 1))

  tgt_seq_indices = target_indices[None, :, None]
  tgt_seq_indices = jnp.broadcast_to(tgt_seq_indices, (batch_size, max_seq_len, 1))

  # Extract features (mask will zero out invalid positions)
  # Use advanced indexing with safe handling of out-of-bounds indices
  safe_max_len = batch_data.shape[1]

  # Prepare X_batch (all features except last from t=0 to T-1)
  # Clamp indices to valid range to avoid out-of-bounds
  safe_in_indices = jnp.clip(input_indices, 0, safe_max_len - 1)
  X_batch = batch_data[:, safe_in_indices, :(dim-1)]
  X_batch = X_batch * mask[:, :, None]  # Apply mask

  # Prepare Y_batch (last feature only from t=1 to T)
  # Clamp indices to valid range to avoid out-of-bounds
  safe_tgt_indices = jnp.clip(target_indices, 0, safe_max_len - 1)
  Y_batch = batch_data[:, safe_tgt_indices, dim-1:dim]
  Y_batch = Y_batch * mask[:, :, None]  # Apply mask

  # Return sequence lengths minus 1 (for prediction)
  T_batch = batch_seq_lengths - 1

  return X_batch, T_batch, Y_batch, mask

###############################################################################
# Main Predictive Function
###############################################################################
def predictive_score_metrics(
    ori_data: Float[Array, 'N T D'],
    generated_data: Float[Array, 'N T D'],
    key: Optional[PRNGKeyArray] = None,
    hidden_dim: Optional[int] = None,
    iterations: int = 5000,
    batch_size: int = 128
) -> float:
  """Report the performance of Post-hoc RNN one-step ahead prediction.

  Args:
    - ori_data: original data with shape (N, T, D)
    - generated_data: generated synthetic data with shape (N, T, D)
    - key: random key for initialization (optional)
    - hidden_dim: hidden dimensions for the GRU (optional)
    - iterations: number of training iterations
    - batch_size: batch size for training

  Returns:
    - predictive_score: MAE of the predictions on the original data
  """
  raise NotImplementedError
  # Check data before running prediction
  if len(ori_data) == 0 or len(generated_data) == 0:
    print("Warning: Empty data provided!")
    return float('nan')

  # -----------------------------
  # 1. Basic parameters
  # -----------------------------
  no, seq_len, dim = jnp.asarray(ori_data).shape

  if hidden_dim is None:
    hidden_dim = max(dim // 2, 1)  # Ensure hidden_dim is at least 1

  if key is None:
    key = random.PRNGKey(0)

  # Set maximum sequence length and each sequence length
  ori_time, ori_max_seq_len = extract_time(ori_data)
  generated_time, generated_max_seq_len = extract_time(generated_data)
  max_seq_len = max(ori_max_seq_len, generated_max_seq_len)

  # -----------------------------
  # 2. Build the Predictor
  # -----------------------------
  class Predictor(AbstractBatchableObject):
    gru: eqx.nn.GRUCell
    dense: eqx.nn.Linear

    @property
    def batch_size(self):
      if self.dense.weight.ndim == 2:
        return None
      elif self.dense.weight.ndim == 3:
        return self.dense.weight.shape[0]
      elif self.dense.weight.ndim > 3:
        return self.dense.weight.shape[:-2]
      else:
        assert 0

    def __init__(self, input_dim: int, hidden_dim: int, key: PRNGKeyArray):
      k1, k2 = random.split(key)
      self.gru = eqx.nn.GRUCell(input_dim, hidden_dim, key=k1)
      self.dense = eqx.nn.Linear(hidden_dim, 1, key=k2)

    @auto_vmap
    def __call__(self, xs: Float[Array, 'T D']):
      """Process sequence and predict next timestep.

      Args:
        xs: Input sequence with shape [T, D] (excluding last feature)

      Returns:
        Predictions with shape [T, 1] (one value per timestep)
      """
      # Process each timestep with the GRU and make predictions
      def scan_fn(hidden, x):
        new_hidden = self.gru(x, hidden)
        prediction = self.dense(new_hidden)
        # Apply sigmoid to map output to [0, 1]
        prediction = jax.nn.sigmoid(prediction)
        return new_hidden, prediction

      # Initialize hidden state with zeros
      init_hidden = jnp.zeros(self.gru.hidden_size)

      # Scan through the sequence
      _, predictions = jax.lax.scan(scan_fn, init_hidden, xs)

      return predictions

  # -----------------------------
  # 3. Training and Prediction
  # -----------------------------

  # Create the predictor
  model_key, training_key = random.split(key)
  model = Predictor(dim-1, hidden_dim, key=model_key)
  optimizer = optax.adam(learning_rate=1e-3)

  # Create optimizer state
  params, static = eqx.partition(model, eqx.is_array)
  opt_state = optimizer.init(params)

  # Define training step
  @jax.jit
  def train_step(model, opt_state, x_batch, y_batch):

    params, static = eqx.partition(model, eqx.is_array)

    def loss_fn(params):
      # Combine params with static parts
      cur_model = eqx.combine(params, static)
      # Get predictions
      preds = jax.vmap(cur_model)(x_batch)
      # Calculate MAE loss
      loss = jnp.mean(jnp.abs(preds - y_batch))
      return loss, preds

    # Calculate gradients and update parameters
    (loss, preds), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    # Return updated params and optimizer state
    model = eqx.combine(params, static)
    return model, opt_state, loss

  # Training loop
  for i in tqdm(range(iterations), desc='Training predictor'):
    # Generate batch from synthetic data
    batch_key, training_key = random.split(training_key)
    X_batch, T_batch, Y_batch = batch_generator(generated_data, generated_time, batch_size, dim, key=batch_key)

    # Process batch data for JAX
    # Convert lists to arrays with padding
    max_batch_len = max(len(x) for x in X_batch)

    # Pad sequences for batch processing
    X_padded = jnp.zeros((batch_size, max_batch_len, dim-1))
    Y_padded = jnp.zeros((batch_size, max_batch_len, 1))

    for j, (x, y) in enumerate(zip(X_batch, Y_batch)):
      X_padded = X_padded.at[j, :len(x)].set(x)
      Y_padded = Y_padded.at[j, :len(y)].set(y)

    # Update model with batch
    model, opt_state, loss = train_step(model, opt_state, X_padded, Y_padded)

  # -----------------------------
  # 4. Test on original data
  # -----------------------------
  # Test the trained model on the original data
  test_key = random.PRNGKey(1)  # Fixed seed for consistent evaluation

  # Get all original data for testing
  X_test, T_test, Y_test = batch_generator(ori_data, ori_time, no, dim, key=test_key)

  # Convert to numpy for sklearn compatibility
  Y_true = [np.array(y) for y in Y_test]

  # Make predictions on each sequence
  Y_pred = []
  for x in X_test:
    if len(x) > 0:  # Skip empty sequences
      # Convert to JAX array
      x_jax = jnp.array(x)
      # Get predictions
      preds = model(x_jax)
      # Convert to numpy for sklearn compatibility
      Y_pred.append(np.array(preds))
    else:
      Y_pred.append(np.array([]))

  # Compute MAE for each sequence
  mae_scores = []
  for i in range(len(Y_true)):
    if len(Y_true[i]) > 0 and len(Y_pred[i]) > 0:
      # Ensure dimensions match
      y_true = Y_true[i].reshape(-1)
      y_pred = Y_pred[i].reshape(-1)[:len(y_true)]  # Trim if needed
      # Calculate MAE
      mae = mean_absolute_error(y_true, y_pred)
      mae_scores.append(mae)

  # Return average MAE
  predictive_score = jnp.mean(jnp.array(mae_scores))

  return predictive_score

###############################################################################
# Testing Code
###############################################################################
def run_test_cases():
  """Run test cases for predictive score metrics."""
  import matplotlib.pyplot as plt

  # Set random seed for reproducibility
  master_key = random.PRNGKey(42)

  # Create test cases
  test_cases = {}

  # Basic parameters
  N, T, D = 100, 24, 3  # Number of samples, timesteps, features

  # Case 1: Identical data (should give very low MAE)
  key, master_key = random.split(master_key)
  original_data = random.normal(key, (N, T, D))
  # Scale to [0,1] for sigmoid compatibility
  original_data = (original_data - jnp.min(original_data)) / (jnp.max(original_data) - jnp.min(original_data))

  test_cases["identical"] = {
    "original": original_data,
    "generated": original_data,
    "expected": "~0.0-0.05 (very low)"
  }

  # Case 2: Slight noise (should give low MAE)
  key, master_key = random.split(master_key)
  noise_small = 0.1 * random.normal(key, (N, T, D))
  test_cases["small_noise"] = {
    "original": original_data,
    "generated": jnp.clip(original_data + noise_small, 0, 1),  # Keep in [0,1] range
    "expected": "~0.05-0.15 (low)"
  }

  # Case 3: Moderate noise (should give moderate MAE)
  key, master_key = random.split(master_key)
  noise_medium = 0.3 * random.normal(key, (N, T, D))
  test_cases["medium_noise"] = {
    "original": original_data,
    "generated": jnp.clip(original_data + noise_medium, 0, 1),
    "expected": "~0.15-0.25 (moderate)"
  }

  # Case 4: Unrelated data (should give high MAE)
  key, master_key = random.split(master_key)
  unrelated_data = random.normal(key, (N, T, D))
  unrelated_data = (unrelated_data - jnp.min(unrelated_data)) / (jnp.max(unrelated_data) - jnp.min(unrelated_data))
  test_cases["unrelated"] = {
    "original": original_data,
    "generated": unrelated_data,
    "expected": "~0.25+ (high)"
  }

  # Create a time series with patterns for more realistic tests
  key, master_key = random.split(master_key)
  t = jnp.linspace(0, 4*jnp.pi, T)
  pattern_data = jnp.zeros((N, T, D))

  # Create patterns with sine waves and phase shifts
  for i in range(N):
    key, subkey = random.split(key)
    pattern_data = pattern_data.at[i, :, -1].set(jnp.sin(t) + 0.1*random.normal(subkey, (T,)))

    # Add correlations between dimensions
    for d in range(D-1):
      key, subkey = random.split(key)
      pattern_data = pattern_data.at[i, :, d].set(0.5*jnp.sin(t + d*0.5*jnp.pi) + 0.1*random.normal(subkey, (T,)))

  # Scale to [0,1] range
  pattern_data = (pattern_data - jnp.min(pattern_data)) / (jnp.max(pattern_data) - jnp.min(pattern_data))

  # Case 5: Patterned data with small noise
  key, master_key = random.split(master_key)
  pattern_noise_small = 0.1 * random.normal(key, (N, T, D))
  test_cases["pattern_small_noise"] = {
    "original": pattern_data,
    "generated": jnp.clip(pattern_data + pattern_noise_small, 0, 1),
    "expected": "~0.05-0.15 (low)"
  }

  # Case 6: Patterned data with larger noise
  key, master_key = random.split(master_key)
  pattern_noise_large = 0.3 * random.normal(key, (N, T, D))
  test_cases["pattern_large_noise"] = {
    "original": pattern_data,
    "generated": jnp.clip(pattern_data + pattern_noise_large, 0, 1),
    "expected": "~0.15-0.25 (moderate)"
  }

  # Run all test cases
  results = {}
  print("Running predictive score test cases...")

  for name, case in test_cases.items():
    print(f"\nTest case: {name}")
    print(f"Expected outcome: {case['expected']}")

    key, master_key = random.split(master_key)

    # Use fewer iterations for testing
    test_iterations = 500

    # Run the predictive score metrics
    score = predictive_score_metrics(
      case["original"],
      case["generated"],
      key=key,
      iterations=test_iterations,
      batch_size=min(32, case["original"].shape[0])
    )

    results[name] = score
    print(f"Predictive score (MAE): {score:.4f}")

    # Interpret results
    if score < 0.05:
      interpretation = "Excellent prediction (very similar data)"
    elif score < 0.15:
      interpretation = "Good prediction (similar data)"
    elif score < 0.25:
      interpretation = "Moderate prediction (somewhat similar data)"
    else:
      interpretation = "Poor prediction (different data patterns)"

    print(f"Interpretation: {interpretation}")

  print("\n==== Summary of Results ====")
  for name, score in results.items():
    print(f"{name}: MAE={score:.4f}")

  print("\nInterpretation Guide:")
  print("- Lower scores indicate better quality of generated data")
  print("- The predictor was trained on generated data and tested on real data")
  print("- Low MAE means generated data captures patterns that generalize to real data")

  return results

if __name__ == '__main__':
  # Run test cases
  results = run_test_cases()

  # If matplotlib is available, visualize results
  try:
    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 6))
    names = list(results.keys())
    scores = [results[name] for name in names]

    plt.bar(names, scores)
    plt.ylabel('Predictive Score (MAE)')
    plt.title('Predictive Scores for Different Test Cases')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

    print("\nRemember: LOWER scores mean BETTER generation quality!")

  except ImportError:
    print("Matplotlib not available for visualization")