import argparse
import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
import flax
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, value_and_grad
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import time

# Import data loading functions to be compatible with agnews_bert_train.py
from .datamodule import load_agnews_data, pad_sequences
from .utils import flatten_params, rngmix, timeblock


# --- Model Definitions ---

def rope_dot_product_attention(query, key, value,
                               bias=None, dropout_rng=None, dropout_rate=0.0,
                               deterministic=False, dtype=jnp.float32, precision=None):
    """
    Computes dot-product attention after applying Rotary Position Embeddings (RoPE)
    to the query and key.
    """
    # Flax MHA expects inputs as (batch, num_heads, seq_len, head_dim)
    # We transpose to (batch, seq_len, num_heads, head_dim) for easier RoPE application
    query_t = jnp.transpose(query, (0, 2, 1, 3))
    key_t = jnp.transpose(key, (0, 2, 1, 3))

    seq_len = query_t.shape[1]
    head_dim = query_t.shape[-1]
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"

    # Create Rotary Embeddings
    freqs = 10000.0 ** (-jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    positions = jnp.arange(seq_len, dtype=jnp.float32)
    freqs_grid = jnp.einsum('i,j->ij', positions, freqs)
    emb = jnp.repeat(freqs_grid, 2, axis=-1)

    # Expand dims for broadcasting to (batch, seq_len, num_heads, head_dim)
    cos_pos = jnp.cos(emb)[None, :, None, :]
    sin_pos = jnp.sin(emb)[None, :, None, :]

    def _apply_rope(x, cos, sin):
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        x_rotated = jnp.stack([-x2, x1], axis=-1).reshape(x.shape)
        return x * cos + x_rotated * sin

    query_rope = _apply_rope(query_t, cos_pos, sin_pos)
    key_rope = _apply_rope(key_t, cos_pos, sin_pos)

    # Transpose back to (batch, num_heads, seq_len, head_dim)
    query_rope = jnp.transpose(query_rope, (0, 2, 1, 3))
    key_rope = jnp.transpose(key_rope, (0, 2, 1, 3))

    # Call the original dot_product_attention with the rotated q,k
    return nn.dot_product_attention(
        query_rope, key_rope, value, bias=bias, dropout_rng=dropout_rng,
        dropout_rate=dropout_rate, deterministic=deterministic, dtype=dtype, precision=precision
    )


class TransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, mask=None, deterministic: bool = True):
        # Attention block
        x_norm = nn.LayerNorm(use_scale=True, use_bias=True)(x)
        self.sow('intermediates', 'mha_input', x_norm)
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.embedding_dim,
            dropout_rate=self.dropout_rate,
            attention_fn=rope_dot_product_attention
        )(x_norm, x_norm, mask=mask, deterministic=deterministic)
        # Residual connection
        post_attention = x + attn_output

        # Feed-forward block
        ffn_norm = nn.LayerNorm(use_scale=True, use_bias=True)(post_attention)
        ffn_hidden = nn.Dense(self.hidden_dim)(ffn_norm)
        ffn_post_activation = nn.gelu(ffn_hidden)
        ffn_post_activation = nn.Dropout(rate=self.dropout_rate)(
            ffn_post_activation, deterministic=deterministic
        )
        ffn_output = nn.Dense(self.embedding_dim)(ffn_post_activation)
        # Residual connection
        post_ffn = post_attention + ffn_output
        return post_ffn


class BertModel(nn.Module):
    embedding_dim: int = 48
    num_heads: int = 0
    num_layers: int = 0
    hidden_dim: int = 192
    num_classes: int = 4
    vocab_size: int = 15000
    max_seq_len: int = 100
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, pad_mask=None, deterministic: bool = True):
        # Input x: (batch_size, seq_len)
        x = nn.Embed(self.vocab_size, self.embedding_dim)(x)

        # Prepend CLS token
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)

        # RoPE replaces absolute positional embeddings, so they are omitted.
        # Add dropout after embeddings.
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

        attn_mask = None
        if pad_mask is not None:
            cls_mask = jnp.ones((pad_mask.shape[0], 1, 1, 1), dtype=bool)
            attn_mask = jnp.concatenate([cls_mask, pad_mask], axis=-1)

        # Transformer Encoder Stack
        for _ in range(self.num_layers):
            x = TransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim,
                dropout_rate=self.dropout_rate
            )(x, mask=attn_mask, deterministic=deterministic)

        # Extract CLS token representation for classification
        x = x[:, 0, :]
        x = nn.Dense(self.num_classes)(x)
        return x


# --- Training and Evaluation Logic ---

def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
    return optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels).mean()


def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return {'loss': loss, 'accuracy': accuracy}


@jit
def train_step(state, batch, rng):
    dropout_rng = random.fold_in(rng, state.step)
    def loss_fn(params):
        logits = state.apply_fn(
            {'params': params},
            batch['inputs'],
            pad_mask=batch['pad_mask'],
            deterministic=False,
            rngs={'dropout': dropout_rng}
        )
        loss = cross_entropy_loss(logits, batch['labels'])
        return loss, logits

    grad_fn = value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['labels'])
    return state, metrics


@jit
def eval_step(state, batch):
    logits = state.apply_fn(
        {'params': state.params},
        batch['inputs'],
        pad_mask=batch['pad_mask'],
        deterministic=True,
    )
    return compute_metrics(logits, batch['labels'])


def get_mha_inputs(model, params, dataset, rng, batch_size: int, vocab, pad_value):
    num_examples = len(dataset["token_ids"])
    indices = random.choice(rng, num_examples, shape=(batch_size,), replace=False)
    token_ids = [dataset["token_ids"][int(i)] for i in indices]
    inputs = pad_sequences(token_ids, pad_value=pad_value)
    pad_mask = (inputs != pad_value)[:, None, None, :]

    _, variables = model.apply(
        {"params": params},
        inputs,
        pad_mask=pad_mask,
        deterministic=True,
        mutable=['intermediates']
    )
    intermediate_vars = variables.get('intermediates', {})
    activations = []
    for i in range(model.num_layers):
        layer_key = f'TransformerEncoderLayer_{i}'
        if layer_key in intermediate_vars and 'mha_input' in intermediate_vars[layer_key]:
            activation = intermediate_vars[layer_key]['mha_input'][0]
            activations.append(activation)
        else:
            raise KeyError(f"Could not find 'mha_input' for layer {i} ('{layer_key}'). "
                           f"Available intermediates: {list(intermediate_vars.keys())}")
    if len(activations) != model.num_layers:
        raise ValueError(f"Expected to get activations for {model.num_layers} layers, "
                         f"but found {len(activations)}.")
    return activations


def main():
    parser = argparse.ArgumentParser(description="Train a RoPE BERT-style model on AGNews")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], default="adam")
    parser.add_argument("--learning-rate", type=float, default=1e-3)
    parser.add_argument("--num-layers", type=int, required=True)
    parser.add_argument("--num-heads", type=int, required=True)
    parser.add_argument("--batch-size", type=int, default=512)
    parser.add_argument("--epochs", type=int, default=6)
    parser.add_argument("--ckpt-path", type=str, default="./checkpoints", help="Path to ckpt directory")
    parser.add_argument("--data-path", type=str, default="/data/agnews", help="Path to AGNews data")
    args = parser.parse_args()

    os.makedirs(args.ckpt_path, exist_ok=True)
    rng = random.PRNGKey(args.seed)

    # --- Data Loading and Preprocessing ---
    (x_train, y_train), (x_test, y_test), num_classes, PAD = load_agnews_data(args.data_path)
    train_ds = {"token_ids": x_train, "labels": y_train}
    test_ds = {"token_ids": x_test, "labels": y_test}

    # Build vocabulary from training data
    print("Building vocabulary...")
    all_words = [word for seq in train_ds["token_ids"] for word in seq]
    word_counts = Counter(all_words)
    vocab_size = 15000 - 2 # Reserve 0 for PAD, 1 for UNK
    most_common = word_counts.most_common(vocab_size)
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for idx, (word, _) in enumerate(most_common, 2):
        vocab[word] = idx

    def to_ids(seq):
        return [vocab.get(word, 1) for word in seq]

    train_ds["token_ids"] = [to_ids(seq) for seq in train_ds["token_ids"]]
    test_ds["token_ids"] = [to_ids(seq) for seq in test_ds["token_ids"]]
    train_ds["labels"] = np.array(train_ds["labels"], dtype=np.int32)
    test_ds["labels"] = np.array(test_ds["labels"], dtype=np.int32)

    # --- Model and Optimizer Initialization ---
    model = BertModel(
        num_layers=args.num_layers,
        num_heads=args.num_heads,
        num_classes=num_classes,
        vocab_size=len(vocab), # Use actual vocab size
        dropout_rate=0.1
    )

    if args.optimizer == "sgd":
        tx = optax.sgd(args.learning_rate, momentum=0.9)
    elif args.optimizer == "adamw":
        tx = optax.adamw(args.learning_rate, weight_decay=1e-3)
    else: # adam
        tx = optax.adam(args.learning_rate)

    rng, init_rng = random.split(rng)
    dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
    init_params = model.init({'params': init_rng, 'dropout': init_rng}, dummy_input)['params']
    train_state = TrainState.create(apply_fn=model.apply, params=init_params, tx=tx)

    # --- Training Loop ---
    train_losses, train_accuracies = [], []
    test_losses, test_accuracies = [], []

    num_train_examples = len(train_ds["labels"])
    num_test_examples = len(test_ds["labels"])

    for epoch in range(args.epochs):
        print(f"\n--- Epoch {epoch+1}/{args.epochs} ---")
        rng, epoch_rng = random.split(rng)

        # Shuffle training data
        perm = jax.random.permutation(epoch_rng, num_train_examples)

        # Train
        batch_metrics = []
        pbar = tqdm(range(0, num_train_examples, args.batch_size), desc="Training")
        for i in pbar:
            batch_perm = perm[i:i+args.batch_size]
            token_ids = [train_ds["token_ids"][int(j)] for j in batch_perm]
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            batch = {
                'inputs': x_batch,
                'pad_mask': (x_batch != PAD)[:, None, None, :],
                'labels': train_ds["labels"][batch_perm]
            }
            rng, step_rng = random.split(rng)
            train_state, metrics = train_step(train_state, batch, step_rng)
            batch_metrics.append(metrics)
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "acc": f"{metrics['accuracy']:.4f}"
            })

        epoch_train_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_train_acc = np.mean([m['accuracy'] for m in batch_metrics])
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)

        # Evaluate
        batch_metrics = []
        pbar = tqdm(range(0, num_test_examples, args.batch_size), desc="Evaluating")
        for i in pbar:
            token_ids = test_ds["token_ids"][i:i+args.batch_size]
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            batch = {
                'inputs': x_batch,
                'pad_mask': (x_batch != PAD)[:, None, None, :],
                'labels': test_ds["labels"][i:i+args.batch_size]
            }
            metrics = eval_step(train_state, batch)
            batch_metrics.append(metrics)
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "acc": f"{metrics['accuracy']:.4f}"
            })
        
        epoch_test_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_test_acc = np.mean([m['accuracy'] for m in batch_metrics])
        test_losses.append(epoch_test_loss)
        test_accuracies.append(epoch_test_acc)

        print(f"Epoch {epoch+1} Summary: "
              f"Train Loss={epoch_train_loss:.4f}, Train Acc={epoch_train_acc:.4f} | "
              f"Test Loss={epoch_test_loss:.4f}, Test Acc={epoch_test_acc:.4f}")

        # Save checkpoint
        ckpt_name = (
            f"agnews_bert_rope_seed{args.seed}_" # Added _rope to distinguish checkpoints
            f"opt_{args.optimizer}_lr_{args.learning_rate}_L{args.num_layers}_H{args.num_heads}_"
            f"epoch{epoch+1}_trainloss_{epoch_train_loss:.4f}_testloss_{epoch_test_loss:.4f}_trainacc_{epoch_train_acc:.4f}_testacc_{epoch_test_acc:.4f}.flax"
        )
        with open(os.path.join(args.ckpt_path, ckpt_name), "wb") as f:
            f.write(flax.serialization.to_bytes(train_state.params))

    # --- Plot and Save Metrics ---
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss over Epochs'); plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Accuracy over Epochs'); plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(args.ckpt_path, "metrics_plot_rope.png")) # Added _rope to distinguish plots
    plt.close()


if __name__ == "__main__":
    main()