import argparse
import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
import flax
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, value_and_grad
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import time
from datasets import Dataset
from flax.jax_utils import replicate, unreplicate
from jax import lax
from typing import List

def pad_sequences(sequences: List[np.ndarray], pad_value: int = 0) -> np.ndarray:
    max_len = max(len(seq) for seq in sequences)
    return np.stack([
        np.pad(seq, (0, max_len - len(seq)), constant_values=pad_value)
        for seq in sequences
    ])

# --- Model Definitions ---

def rope_dot_product_attention(query, key, value,
                              bias=None, dropout_rng=None, dropout_rate=0.0,
                              deterministic=False, dtype=jnp.float32, precision=None):
    """
    Computes dot-product attention after applying Rotary Position Embeddings (RoPE)
    to the query and key.
    """
    # Flax MHA expects inputs as (batch, num_heads, seq_len, head_dim)
    # We transpose to (batch, seq_len, num_heads, head_dim) for easier RoPE application
    query_t = jnp.transpose(query, (0, 2, 1, 3))
    key_t = jnp.transpose(key, (0, 2, 1, 3))
    
    seq_len = query_t.shape[1]
    head_dim = query_t.shape[-1]
    
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"
    
    # Create Rotary Embeddings
    # theta_i = 10000**(-2*(i-1)/d) for i=1..d/2 -> 10000**(-2j/d) for j=0..d/2-1
    freqs = 10000.0 ** (-jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    positions = jnp.arange(seq_len, dtype=jnp.float32)
    # freqs_grid shape: (seq_len, head_dim/2)
    freqs_grid = jnp.einsum('i,j->ij', positions, freqs)
    # emb shape: (seq_len, head_dim)
    emb = jnp.repeat(freqs_grid, 2, axis=-1)

    # Expand dims for broadcasting to (batch, seq_len, num_heads, head_dim)
    # cos_pos/sin_pos shape: (1, seq_len, 1, head_dim)
    cos_pos = jnp.cos(emb)[None, :, None, :]
    sin_pos = jnp.sin(emb)[None, :, None, :]
    
    # Helper to apply RoPE, consistent with row vector convention
    def _apply_rope(x, cos, sin):
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        x_rotated = jnp.stack([-x2, x1], axis=-1).reshape(x.shape)
        return x * cos + x_rotated * sin
        
    query_rope = _apply_rope(query_t, cos_pos, sin_pos)
    key_rope = _apply_rope(key_t, cos_pos, sin_pos)
    
    # Transpose back to (batch, num_heads, seq_len, head_dim)
    query_rope = jnp.transpose(query_rope, (0, 2, 1, 3))
    key_rope = jnp.transpose(key_rope, (0, 2, 1, 3))
    
    # Call the original dot_product_attention with the rotated q,k
    return nn.dot_product_attention(
        query_rope, key_rope, value, bias=bias, dropout_rng=dropout_rng,
        dropout_rate=dropout_rate, deterministic=deterministic, dtype=dtype, precision=precision
    )

class TransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, mask=None, deterministic: bool = True):
        # Attention block
        x_norm = nn.LayerNorm(use_scale=False, use_bias=False)(x)
        self.sow('intermediates', 'mha_input', x_norm)
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.embedding_dim,
            dropout_rate=self.dropout_rate,
            attention_fn=rope_dot_product_attention
        )(x_norm, x_norm, mask=mask, deterministic=deterministic)
        
        # Residual connection
        post_attention = x + attn_output

        # Feed-forward block
        ffn_norm = nn.LayerNorm(use_scale=False, use_bias=False)(post_attention)
        ffn_hidden = nn.Dense(self.hidden_dim)(ffn_norm)
        ffn_post_activation = nn.gelu(ffn_hidden)
        # Add dropout to the FFN block
        ffn_post_activation = nn.Dropout(rate=self.dropout_rate)(
            ffn_post_activation, deterministic=deterministic
        )
        ffn_output = nn.Dense(self.embedding_dim)(ffn_post_activation)

        # Residual connection
        post_ffn = post_attention + ffn_output
        return post_ffn

class BertModel(nn.Module):
    embedding_dim: int = 64
    num_heads: int = 0
    num_layers: int = 0
    hidden_dim: int = 256
    num_classes: int = 219
    vocab_size: int = 30522
    max_seq_len: int = 256
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, pad_mask=None, deterministic: bool = True):
        # Input x: (batch_size, seq_len)
        x = nn.Embed(self.vocab_size, self.embedding_dim)(x)
        
        # Prepend CLS token
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)
        
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

        # --- CRITICAL FIX: Correct Attention Masking ---
        attn_mask = None
        if pad_mask is not None:
            # `pad_mask` is a boolean mask: (batch, 1, 1, seq_len), True for valid tokens
            # Create a boolean mask for the [CLS] token, which is always attended to.
            cls_mask = jnp.ones((pad_mask.shape[0], 1, 1, 1), dtype=bool) 
            # Concatenate the masks. The final mask will be boolean.
            attn_mask = jnp.concatenate([cls_mask, pad_mask], axis=-1)
            # DO NOT convert to float. `nn.MultiHeadDotProductAttention` expects a boolean mask.
        
        # Transformer Encoder Stack
        for _ in range(self.num_layers):
            x = TransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim,
                dropout_rate=self.dropout_rate
            )(x, mask=attn_mask, deterministic=deterministic)
            
        # Extract CLS token representation for classification
        x = x[:, 0, :]
        x = nn.Dense(self.num_classes)(x)
        return x

def make_stuff(model):
  def get_mha_inputs(params, dataset, rng, batch_size: int):
      num_examples = len(dataset["token_ids"])
      indices = random.choice(rng, num_examples, shape=(batch_size,), replace=False)
      token_ids = [dataset["token_ids"][int(i)] for i in indices]
      
      inputs = pad_sequences(token_ids, pad_value=0)
      inputs = jnp.array(inputs)
      
      pad_mask = (inputs != 0)[:, None, None, :]
      
      _, variables = model.apply({"params": params}, inputs, pad_mask=pad_mask, mutable=['intermediates'], deterministic=True)
      
      intermediate_vars = variables.get('intermediates', {})
      activations = []
      # When submodules are created in a loop, `sow` namespaces the variables
      # with the submodule's auto-generated name (e.g., 'TransformerEncoderLayer_0').
      for i in range(model.num_layers):
          layer_key = f'TransformerEncoderLayer_{i}'
          if layer_key in intermediate_vars and 'mha_input' in intermediate_vars[layer_key]:
              # Flax stores sown values in a tuple, even if there's only one.
              activation = intermediate_vars[layer_key]['mha_input'][0]
              activations.append(activation)
          else:
              raise KeyError(f"Could not find 'mha_input' for layer {i} ('{layer_key}'). "
                           f"Available intermediates: {list(intermediate_vars.keys())}")

      if len(activations) != model.num_layers:
          raise ValueError(f"Expected to get activations for {model.num_layers} layers, "
                           f"but found {len(activations)}.")

      return activations

  return {
      "get_mha_inputs": get_mha_inputs,
  }

# Modified train_step for pmap compatibility
def train_step(state, batch, rng_key):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['inputs'], batch['pad_mask'],
                                rngs={'dropout': rng_key}, deterministic=False)
        one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])
        per_example_loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
        numerator = lax.psum(jnp.sum(per_example_loss * batch['valid_mask']), axis_name='batch')
        denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
        loss = numerator / denominator
        return loss, logits

    grad_fn = value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    grads = lax.pmean(grads, axis_name='batch')
    state = state.apply_gradients(grads=grads)

    per_example_acc = (jnp.argmax(logits, -1) == batch['labels']).astype(jnp.float32)
    acc_numerator = lax.psum(jnp.sum(per_example_acc * batch['valid_mask']), axis_name='batch')
    acc_denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    accuracy = acc_numerator / acc_denominator

    metrics = {'loss': loss, 'accuracy': accuracy}
    return state, metrics

# Modified eval_step for pmap compatibility
def eval_step(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['inputs'], batch['pad_mask'],
                            deterministic=True)
    one_hot_labels = jax.nn.one_hot(batch['labels'], num_classes=logits.shape[-1])
    per_example_loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
    numerator = lax.psum(jnp.sum(per_example_loss * batch['valid_mask']), axis_name='batch')
    denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    loss = numerator / denominator

    per_example_acc = (jnp.argmax(logits, -1) == batch['labels']).astype(jnp.float32)
    acc_numerator = lax.psum(jnp.sum(per_example_acc * batch['valid_mask']), axis_name='batch')
    acc_denominator = lax.psum(jnp.sum(batch['valid_mask']), axis_name='batch')
    accuracy = acc_numerator / acc_denominator

    metrics = {'loss': loss, 'accuracy': accuracy}
    return metrics

def main():
    parser = argparse.ArgumentParser(description="Train a BERT-style model on DBpedia")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], default="adamw")
    parser.add_argument("--learning-rate", type=float, default=5e-4) #2e-5
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--num-layers", type=int, required=True)
    parser.add_argument("--num-heads", type=int, required=True)
    parser.add_argument("--embedding-dim", type=int, default=48)
    parser.add_argument("--hidden-dim", type=int, default=192)
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--ckpt-path", type=str, default="./checkpoints", help="Path to ckpt directory")
    parser.add_argument("--train-dataset-path", type=str, required=True, help="Path to train parquet")
    parser.add_argument("--test-dataset-path", type=str, required=True, help="Path to test parquet")
    args = parser.parse_args()

    num_devices = jax.local_device_count()
    print(f"Using {num_devices} devices")
    assert args.batch_size % num_devices == 0, "Batch size must be divisible by the number of devices"
    per_device_batch_size = args.batch_size // num_devices

    os.makedirs(args.ckpt_path, exist_ok=True)
    rng = random.PRNGKey(args.seed)

    # --- Data Loading and Preprocessing ---
    train_dataset = Dataset.from_parquet(args.train_dataset_path)
    test_dataset = Dataset.from_parquet(args.test_dataset_path)
    x_train = train_dataset['input_ids']
    y_train = train_dataset['label']
    x_test = test_dataset['input_ids']
    y_test = test_dataset['label']
    num_classes = 219
    PAD = 0
    train_ds = {"token_ids": x_train, "labels": np.array(y_train, dtype=np.int32)}
    test_ds = {"token_ids": x_test, "labels": np.array(y_test, dtype=np.int32)}
    
    # --- Model and Optimizer Initialization ---
    model = BertModel(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        num_layers=args.num_layers, 
        num_heads=args.num_heads,
        num_classes=num_classes,
        vocab_size=30522,
        max_seq_len=256,
        dropout_rate=0.1
    )
    stuff = make_stuff(model)

    num_train_examples = len(train_ds["labels"])
    steps_per_epoch = (num_train_examples + args.batch_size - 1) // args.batch_size
    num_train_steps = steps_per_epoch * args.epochs
    lr_schedule_fn = optax.linear_schedule(
        init_value=args.learning_rate, end_value=0, transition_steps=num_train_steps
    )

    if args.optimizer == "sgd":
        tx = optax.sgd(learning_rate=lr_schedule_fn, momentum=0.9)
    elif args.optimizer == "adamw":
        tx = optax.adamw(learning_rate=lr_schedule_fn, weight_decay=args.weight_decay)
    else: # adam
        tx = optax.adam(learning_rate=lr_schedule_fn)

    rng, init_rng = random.split(rng)
    
    dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
    init_params = model.init({'params': init_rng, 'dropout': init_rng}, dummy_input)['params']
    train_state = TrainState.create(apply_fn=model.apply, params=init_params, tx=tx)
    
    # Replicate across devices
    train_state = replicate(train_state)

    # Pmap the steps
    p_train_step = jax.pmap(train_step, axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    # --- Training Loop ---
    train_losses, train_accuracies = [], []
    test_losses, test_accuracies = [], []

    num_test_examples = len(test_ds["labels"])

    def shard(array):
        return jnp.reshape(array, (num_devices, per_device_batch_size) + array.shape[1:])

    for epoch in range(args.epochs):
        print(f"\n--- Epoch {epoch+1}/{args.epochs} ---")
        rng, epoch_rng = random.split(rng)
        
        # Shuffle training data
        perm = jax.random.permutation(epoch_rng, num_train_examples)
        
        # Train
        batch_metrics = []
        pbar = tqdm(range(0, num_train_examples, args.batch_size), desc="Training")
        for i in pbar:
            batch_perm = perm[i:i+args.batch_size]
            current_batch_size = len(batch_perm)
            pad_size = args.batch_size - current_batch_size
            if pad_size > 0:
                token_ids = [train_ds["token_ids"][int(j)] for j in batch_perm] + [[PAD]] * pad_size
                labels = train_ds["labels"][batch_perm].tolist() + [0] * pad_size
            else:
                token_ids = [train_ds["token_ids"][int(j)] for j in batch_perm]
                labels = train_ds["labels"][batch_perm]
            
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)
            valid_mask = jnp.array([1.0] * current_batch_size + [0.0] * pad_size)
            
            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels),
                'valid_mask': shard(valid_mask)
            }
            
            rng, step_rng = random.split(rng)
            rngs = random.split(step_rng, num_devices)
            train_state, metrics = p_train_step(train_state, batch, rngs)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "acc": f"{metrics['accuracy']:.4f}"
            })

        epoch_train_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_train_acc = np.mean([m['accuracy'] for m in batch_metrics])
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)

        # Evaluate
        batch_metrics = []
        pbar = tqdm(range(0, num_test_examples, args.batch_size), desc="Evaluating")
        for i in pbar:
            end = min(i + args.batch_size, num_test_examples)
            current_batch_size = end - i
            pad_size = args.batch_size - current_batch_size
            if pad_size > 0:
                token_ids = test_ds["token_ids"][i:end] + [[PAD]] * pad_size
                labels = test_ds["labels"][i:end].tolist() + [0] * pad_size
            else:
                token_ids = test_ds["token_ids"][i:end]
                labels = test_ds["labels"][i:end]
            
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            x_batch = jnp.array(x_batch)
            pad_mask = (x_batch != PAD)[:, None, None, :]
            labels = jnp.array(labels)
            valid_mask = jnp.array([1.0] * current_batch_size + [0.0] * pad_size)
            
            batch = {
                'inputs': shard(x_batch),
                'pad_mask': shard(pad_mask),
                'labels': shard(labels),
                'valid_mask': shard(valid_mask)
            }
            metrics = p_eval_step(train_state, batch)
            metrics = unreplicate(metrics)
            batch_metrics.append(metrics)
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "acc": f"{metrics['accuracy']:.4f}"
            })

        epoch_test_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_test_acc = np.mean([m['accuracy'] for m in batch_metrics])
        test_losses.append(epoch_test_loss)
        test_accuracies.append(epoch_test_acc)

        print(f"Epoch {epoch+1} Summary: "
              f"Train Loss={epoch_train_loss:.4f}, Train Acc={epoch_train_acc:.4f} | "
              f"Test Loss={epoch_test_loss:.4f}, Test Acc={epoch_test_acc:.4f}")
        
        # Save checkpoint
        ckpt_name = (
            f"dbpedia_bert_seed{args.seed}_"
            f"opt_{args.optimizer}_lr_{args.learning_rate}_num_layers_{model.num_layers}_num_heads_{model.num_heads}_hidden_dim_{model.hidden_dim}_embedding_dim_{model.embedding_dim}_"
            f"epoch{epoch+1}_trainloss_{epoch_train_loss:.4f}_testloss_{epoch_test_loss:.4f}_trainacc_{epoch_train_acc:.4f}_testacc_{epoch_test_acc:.4f}"
        )
        with open(os.path.join(args.ckpt_path, ckpt_name), "wb") as f:
            f.write(flax.serialization.to_bytes(unreplicate(train_state).params))

        # Plot and save metrics
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        # Loss subplot
        axs[0].plot(train_losses, label='Train Loss')
        axs[0].plot(test_losses, label='Test Loss')
        axs[0].set_xlabel('Epoch')
        axs[0].set_ylabel('Loss')
        axs[0].set_title('Training and Test Loss over Epochs')
        axs[0].legend()
        # Accuracy subplot
        axs[1].plot(train_accuracies, label='Train Accuracy')
        axs[1].plot(test_accuracies, label='Test Accuracy')
        axs[1].set_xlabel('Epoch')
        axs[1].set_ylabel('Accuracy')
        axs[1].set_title('Training and Test Accuracy over Epochs')
        axs[1].legend()
        plt.tight_layout()
        plt.savefig(f"{args.ckpt_path}/metrics_plot.png")
        plt.close()

if __name__ == "__main__":
    main()