import jax
import jax.random as random
import jax.numpy as jnp
from sklearn.metrics import accuracy_score
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Iterable, Literal
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import optax
from diffusion_crf import AbstractBatchableObject, auto_vmap

###############################################################################
# Utility Functions
###############################################################################
def train_test_divide(original_data: Float[Array, 'N T D'],
                      generated_data: Float[Array, 'N T D'],
                      train_rate=0.8,
                      key: PRNGKeyArray = None):
  """Split data into training and test sets with proper randomization."""
  if key is None:
    key = random.PRNGKey(0)

  # Get dataset size
  N = original_data.shape[0]

  # Create shuffled indices (without replacement)
  indices = jnp.arange(N)
  shuffled_indices = random.permutation(key, indices)

  # Split into train and test
  train_size = int(N * train_rate)
  train_idx = shuffled_indices[:train_size]
  test_idx = shuffled_indices[train_size:]

  # Get train and test data
  train_x = original_data[train_idx]
  train_x_hat = generated_data[train_idx]
  test_x = original_data[test_idx]
  test_x_hat = generated_data[test_idx]

  return train_x, train_x_hat, test_x, test_x_hat

###############################################################################
# Main Discriminator Function
###############################################################################
def discriminative_score_metrics(original_data: Float[Array, 'N T D'],
                                 generated_data: Float[Array, 'N T D'],
                                 key: Optional[PRNGKeyArray] = None,
                                 hidden_dim: Optional[int] = None,
                                 train_iterations: int = 2000,
                                 batch_size: int = 128) -> Tuple[Float, Float, Float]:
  """
  Calculate discriminative score between original and generated data.

  Args:
    - original_data: original data with shape (N, T, D)
    - generated_data: generated synthetic data with shape (N, T, D)
    - key: random key for initialization (optional)
    - hidden_dim: hidden dimensions for the GRU (optional)
    - train_iterations: number of training iterations
    - batch_size: batch size for training

  Returns:
    Tuple of (discriminative_score, fake_accuracy, real_accuracy)

    - discriminative_score: how distinguishable the two datasets are (0 = indistinguishable, 0.5 = completely distinct)
    - fake_accuracy: accuracy of detecting generated data
    - real_accuracy: accuracy of detecting real data
  """
  # Check data before running prediction
  if len(original_data) == 0 or len(generated_data) == 0:
    print("Warning: Empty data provided!")
    return float('nan'), float('nan'), float('nan')

  # -----------------------------
  # 1. Basic parameters
  # -----------------------------
  N, T, dim = jnp.asarray(original_data).shape

  if hidden_dim is None:
    hidden_dim = max(dim // 2, 1)  # Ensure hidden_dim is at least 1

  if key is None:
    key = random.PRNGKey(0)

  # Split seeds for different random operations
  keys = random.split(key, 3)
  split_key, model_key, train_key = keys

  # -----------------------------
  # 2. Train/Test Split
  # -----------------------------
  train_x, train_x_hat, test_x, test_x_hat = train_test_divide(
    original_data, generated_data, train_rate=0.8, key=split_key
  )

  # -----------------------------
  # 3. Build the Discriminator
  # -----------------------------
  class Discriminator(AbstractBatchableObject):
    gru: eqx.nn.GRUCell
    dense: eqx.nn.Linear

    @property
    def batch_size(self):
      if self.dense.weight.ndim == 2:
        return None
      elif self.dense.weight.ndim == 3:
        return self.dense.weight.shape[0]
      elif self.dense.weight.ndim > 3:
        return self.dense.weight.shape[:-2]
      else:
        assert 0

    def __init__(self, input_dim: int, hidden_dim: int, key: PRNGKeyArray):
      k1, k2 = random.split(key)
      self.gru = eqx.nn.GRUCell(input_dim, hidden_dim, key=k1)
      self.dense = eqx.nn.Linear(hidden_dim, 1, key=k2)

    @auto_vmap
    def __call__(self, xs: Float[Array, 'T D']):
      scan_fn = lambda state, input: (self.gru(input, state), state)  # Fixed to return state for all steps
      init_state = jnp.zeros(self.gru.hidden_size)
      final_state, _ = jax.lax.scan(scan_fn, init_state, xs)
      logits = self.dense(final_state)
      return logits

  model = Discriminator(dim, hidden_dim, key=model_key)
  optimizer = optax.adam(learning_rate=1e-3)

  # -----------------------------
  # 4. Loss Function and Training
  # -----------------------------
  def loss_fn(model, x_real_batch, x_fake_batch):
    logits_real = jax.vmap(model)(x_real_batch)
    loss_real = optax.sigmoid_binary_cross_entropy(
      logits=logits_real,
      labels=jnp.ones_like(logits_real, dtype=jnp.float32)
    )

    logits_fake = jax.vmap(model)(x_fake_batch)
    loss_fake = optax.sigmoid_binary_cross_entropy(
      logits=logits_fake,
      labels=jnp.zeros_like(logits_fake, dtype=jnp.float32)
    )

    d_loss = loss_real + loss_fake
    return d_loss.mean()

  @jax.jit
  def train_step(model, opt_state, x_real_batch, x_fake_batch):
    loss_value, grads = jax.value_and_grad(lambda m: loss_fn(m, x_real_batch, x_fake_batch))(model)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value

  opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

  # Training loop with fresh batches each iteration
  batch_key = train_key
  for i in range(train_iterations):
    batch_key, sample_key = random.split(batch_key)

    # Sample random batches
    batch_indices = random.choice(
      sample_key,
      jnp.arange(train_x.shape[0]),
      shape=(batch_size,),
      replace=True
    )

    x_real_batch = train_x[batch_indices]
    x_fake_batch = train_x_hat[batch_indices]

    # Update model
    model, opt_state, loss = train_step(model, opt_state, x_real_batch, x_fake_batch)

  # -----------------------------
  # 5. Testing and Accuracy
  # -----------------------------
  logits_real_test = jax.vmap(model)(test_x)
  logits_fake_test = jax.vmap(model)(test_x_hat)

  y_pred_real = jax.nn.sigmoid(logits_real_test)
  y_pred_fake = jax.nn.sigmoid(logits_fake_test)

  y_pred_real = jnp.squeeze(y_pred_real, axis=-1)
  y_pred_fake = jnp.squeeze(y_pred_fake, axis=-1)

  y_true_real = jnp.ones_like(y_pred_real, dtype=jnp.int32)
  y_true_fake = jnp.zeros_like(y_pred_fake, dtype=jnp.int32)

  y_pred_all = jnp.concatenate([y_pred_real, y_pred_fake], axis=0)
  y_true_all = jnp.concatenate([y_true_real, y_true_fake], axis=0)

  # Convert to binary predictions
  y_pred_binary = (y_pred_all > 0.5).astype(jnp.int32)

  acc = accuracy_score(y_true_all, y_pred_binary)
  fake_acc = accuracy_score(y_true_fake, (y_pred_fake > 0.5).astype(jnp.int32))
  real_acc = accuracy_score(y_true_real, (y_pred_real > 0.5).astype(jnp.int32))

  # Discriminative score: 0 = indistinguishable, 0.5 = completely distinguishable
  discriminative_score = jnp.abs(acc - 0.5) * 2  # Scale to 0-1 range for easier interpretation

  return discriminative_score, fake_acc, real_acc

###############################################################################
# Test Cases
###############################################################################
def run_test_cases():
    """Run test cases for the discriminative score metrics."""
    # Set a master seed for reproducibility
    master_key = random.PRNGKey(42)

    # Set dimensions for test data
    N, T, D = 100, 24, 5  # 100 sequences, 24 timesteps, 5 dimensions

    # Create test cases
    test_cases = {}

    # Case 1: Identical data - Should be indistinguishable (score near 0)
    key, master_key = random.split(master_key)
    original_data = random.normal(key, (N, T, D))
    test_cases["identical"] = {
        "original": original_data,
        "generated": original_data,  # Exactly the same
        "expected": "~0.0 (very low score = indistinguishable)"
    }

    # Case 2: Similar data (small noise) - Should be somewhat distinguishable
    key, master_key = random.split(master_key)
    noise_small = 0.1 * random.normal(key, (N, T, D))
    test_cases["similar"] = {
        "original": original_data,
        "generated": original_data + noise_small,
        "expected": "~0.2-0.4 (low score = somewhat distinguishable)"
    }

    # Case 3: Moderately different data
    key, master_key = random.split(master_key)
    noise_medium = 0.5 * random.normal(key, (N, T, D))
    # Add some trend differences too
    trend = jnp.expand_dims(jnp.linspace(0, 1, T), (0, 2)) * 0.3
    test_cases["moderately_different"] = {
        "original": original_data,
        "generated": original_data + noise_medium + trend,
        "expected": "~0.5-0.7 (moderate score = quite distinguishable)"
    }

    # Case 4: Completely different data - Should be highly distinguishable
    key, master_key = random.split(master_key)
    unrelated_data = random.normal(key, (N, T, D))
    test_cases["completely_different"] = {
        "original": original_data,
        "generated": unrelated_data,
        "expected": "~0.8-1.0 (high score = very distinguishable)"
    }

    # Case 5: Pattern data with different levels of noise
    key, master_key = random.split(master_key)
    t = jnp.linspace(0, 4*jnp.pi, T)
    pattern_data = jnp.zeros((N, T, D))

    # Create pattern data with sine waves and phase shifts
    for i in range(N):
        key, subkey = random.split(key)
        pattern_data = pattern_data.at[i, :, 0].set(jnp.sin(t) + 0.1*random.normal(subkey, (T,)))

        # Add different patterns in other dimensions
        for d in range(1, D):
            key, subkey = random.split(key)
            pattern_data = pattern_data.at[i, :, d].set(
                0.5*jnp.sin(t + d*0.5*jnp.pi) + 0.1*random.normal(subkey, (T,))
            )

    # Normalize pattern data
    pattern_data = (pattern_data - jnp.min(pattern_data)) / (jnp.max(pattern_data) - jnp.min(pattern_data))

    # Create different noise levels for pattern data
    key, master_key = random.split(master_key)
    pattern_noise_small = 0.1 * random.normal(key, (N, T, D))
    key, master_key = random.split(master_key)
    pattern_noise_large = 0.5 * random.normal(key, (N, T, D))

    # Add test cases for pattern data
    test_cases["pattern_small_noise"] = {
        "original": pattern_data,
        "generated": pattern_data + pattern_noise_small,
        "expected": "~0.2-0.4 (low score = somewhat distinguishable)"
    }

    test_cases["pattern_large_noise"] = {
        "original": pattern_data,
        "generated": pattern_data + pattern_noise_large,
        "expected": "~0.6-0.8 (high score = quite distinguishable)"
    }

    # Case 6: Structure-preserving transformation (scaled data)
    # This tests if the discriminator can detect changes in scale while preserving patterns
    test_cases["scaled_data"] = {
        "original": pattern_data,
        "generated": pattern_data * 2.0,  # Same pattern, different scale
        "expected": "~0.4-0.6 (moderate score = distinguishable but related)"
    }

    # Run all test cases
    results = {}
    print("Running discriminative metric test cases...")

    for name, case in test_cases.items():
        print(f"\nTest case: {name}")
        print(f"Expected outcome: {case['expected']}")

        key, master_key = random.split(master_key)

        # Use fewer iterations for testing speed
        test_iterations = 100 if "single_sample" in name else 500

        # Run the discriminative score metrics
        disc_score, fake_acc, real_acc = discriminative_score_metrics(
            case["original"],
            case["generated"],
            key=key,
            train_iterations=test_iterations,
            batch_size=min(32, case["original"].shape[0])  # Adjust batch size for small samples
        )

        results[name] = {"score": disc_score, "fake_acc": fake_acc, "real_acc": real_acc}

        print(f"Discriminative score: {disc_score:.4f}")
        print(f"Fake accuracy: {fake_acc:.4f}, Real accuracy: {real_acc:.4f}")
        print(f"Interpretation: {'Easy' if disc_score > 0.8 else 'Moderate' if disc_score > 0.4 else 'Hard'} to distinguish")

    print("\n==== Summary of Results ====")
    for name, metrics in results.items():
        print(f"{name}: Score={metrics['score']:.4f}, Fake Acc={metrics['fake_acc']:.4f}, Real Acc={metrics['real_acc']:.4f}")

    return results

################################################################################################################

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import numpy as np

    # If matplotlib is available, uncomment for visualizations
    try:
        # Run the test cases
        results = run_test_cases()

        # Plot the results
        plt.figure(figsize=(10, 6))
        names = list(results.keys())
        scores = [results[name]["score"] for name in names]
        fake_accs = [results[name]["fake_acc"] for name in names]
        real_accs = [results[name]["real_acc"] for name in names]

        x = np.arange(len(names))
        width = 0.25

        plt.bar(x - width, scores, width, label='Discriminative Score')
        plt.bar(x, fake_accs, width, label='Fake Accuracy')
        plt.bar(x + width, real_accs, width, label='Real Accuracy')

        plt.ylabel('Score')
        plt.title('Discriminative Score Metrics')
        plt.xticks(x, names, rotation=45, ha='right')
        plt.legend()
        plt.tight_layout()
        plt.show()

        print("\nInterpretation Guide:")
        print("- Discriminative score ranges from 0 to 1:")
        print("  - 0.0-0.2: Very similar/indistinguishable data")
        print("  - 0.2-0.4: Somewhat distinguishable")
        print("  - 0.4-0.6: Moderately distinguishable")
        print("  - 0.6-0.8: Quite distinguishable")
        print("  - 0.8-1.0: Very different/easily distinguishable")
        print("\n- When evaluating generated time series:")
        print("  - Lower scores indicate better generation quality (real vs. generated is harder to distinguish)")
        print("  - Higher scores indicate the model is not capturing the real data distribution well")

    except ImportError as e:
        print(f"Visualization error: {e}")
        # Run without visualization
        results = run_test_cases()

    # Example of how to use with real data
    print("\nTo use with your own data:")
    print("disc_score, fake_acc, real_acc = discriminative_score_metrics(original_data, generated_data)")
    print("print(f'Discriminative Score: {disc_score:.4f}')")