import argparse
import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
import flax
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, value_and_grad
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import time
from .datamodule import load_agnews_data, pad_sequences
from .utils import flatten_params, rngmix, timeblock

# --- Model Definitions ---

class TransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, mask=None, deterministic: bool = True):
        # Attention block
        x_norm = nn.LayerNorm(use_scale=True, use_bias=True)(x)
        self.sow('intermediates', 'mha_input', x_norm)
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.embedding_dim,
            dropout_rate=self.dropout_rate
        )(x_norm, x_norm, mask=mask, deterministic=deterministic)
        
        # Residual connection
        post_attention = x + attn_output

        # Feed-forward block
        ffn_norm = nn.LayerNorm(use_scale=True, use_bias=True)(post_attention)
        ffn_hidden = nn.Dense(self.hidden_dim)(ffn_norm)
        ffn_post_activation = nn.gelu(ffn_hidden)
        ffn_post_activation = nn.Dropout(rate=self.dropout_rate)(
            ffn_post_activation, deterministic=deterministic
        )
        ffn_output = nn.Dense(self.embedding_dim)(ffn_post_activation)

        # Residual connection
        post_ffn = post_attention + ffn_output
        return post_ffn

class BertModel(nn.Module):
    embedding_dim: int = 48
    num_heads: int = 0
    num_layers: int = 0
    hidden_dim: int = 192
    num_classes: int = 4
    vocab_size: int = 15000
    max_seq_len: int = 100
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, pad_mask=None, deterministic: bool = True):
        # Input x: (batch_size, seq_len)
        x = nn.Embed(self.vocab_size, self.embedding_dim)(x)
        
        # Prepend CLS token
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)
        
        # Add positional embeddings
        pos_embedding = self.param(
            'pos_embedding',
            nn.initializers.normal(stddev=0.02),
            (1, self.max_seq_len + 1, self.embedding_dim)
        )
        x = x + pos_embedding[:, :x.shape[1], :]
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

        # --- CRITICAL FIX: Correct Attention Masking ---
        attn_mask = None
        if pad_mask is not None:
            cls_mask = jnp.ones((pad_mask.shape[0], 1, 1, 1), dtype=bool)
            attn_mask = jnp.concatenate([cls_mask, pad_mask], axis=-1)
        
        # Transformer Encoder Stack
        for _ in range(self.num_layers):
            x = TransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim,
                dropout_rate=self.dropout_rate
            )(x, mask=attn_mask, deterministic=deterministic)
            
        # Extract CLS token representation for classification
        x = x[:, 0, :]
        x = nn.Dense(self.num_classes)(x)
        return x

# --- Training and Evaluation Logic ---

def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
    return optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels).mean()

def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return {'loss': loss, 'accuracy': accuracy}

@jit
def train_step(state, batch, rng):
    dropout_rng = random.fold_in(rng, state.step)
    
    def loss_fn(params):
        logits = state.apply_fn(
            {'params': params},
            batch['inputs'],
            pad_mask=batch['pad_mask'],
            deterministic=False,
            rngs={'dropout': dropout_rng}
        )
        loss = cross_entropy_loss(logits, batch['labels'])
        return loss, logits

    grad_fn = value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['labels'])
    
    return state, metrics

@jit
def eval_step(state, batch):
    logits = state.apply_fn(
        {'params': state.params},
        batch['inputs'],
        pad_mask=batch['pad_mask'],
        deterministic=True,
    )
    return compute_metrics(logits, batch['labels'])

def get_mha_inputs(model, params, dataset, rng, batch_size: int, vocab, pad_value):
    num_examples = len(dataset["token_ids"])
    indices = random.choice(rng, num_examples, shape=(batch_size,), replace=False)
    token_ids = [dataset["token_ids"][int(i)] for i in indices]
    inputs = pad_sequences(token_ids, pad_value=pad_value)
    pad_mask = (inputs != pad_value)[:, None, None, :]
    
    _, variables = model.apply(
        {"params": params},
        inputs,
        pad_mask=pad_mask,
        deterministic=True,
        mutable=['intermediates']
    )
    
    intermediate_vars = variables.get('intermediates', {})
    activations = []
    for i in range(model.num_layers):
        layer_key = f'TransformerEncoderLayer_{i}'
        if layer_key in intermediate_vars and 'mha_input' in intermediate_vars[layer_key]:
            activation = intermediate_vars[layer_key]['mha_input'][0]
            activations.append(activation)
        else:
            raise KeyError(f"Could not find 'mha_input' for layer {i} ('{layer_key}'). "
                          f"Available intermediates: {list(intermediate_vars.keys())}")

    if len(activations) != model.num_layers:
        raise ValueError(f"Expected to get activations for {model.num_layers} layers, "
                        f"but found {len(activations)}.")

    return activations

def main():
    parser = argparse.ArgumentParser(description="Train a fixed BERT-style model on AGNews")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], default="adam")
    parser.add_argument("--learning-rate", type=float, default=1e-3)
    parser.add_argument("--num-layers", type=int, required=True)
    parser.add_argument("--num-heads", type=int, required=True)
    parser.add_argument("--batch-size", type=int, default=512)
    parser.add_argument("--epochs", type=int, default=6)
    parser.add_argument("--ckpt-path", type=str, default="./checkpoints", help="Path to ckpt directory")
    parser.add_argument("--data-path", type=str, default="/data/agnews", help="Path to AGNews data")
    args = parser.parse_args()

    os.makedirs(args.ckpt_path, exist_ok=True)
    rng = random.PRNGKey(args.seed)

    # --- Data Loading and Preprocessing ---
    (x_train, y_train), (x_test, y_test), num_classes, PAD = load_agnews_data(args.data_path)
    train_ds = {"token_ids": x_train, "labels": y_train}
    test_ds = {"token_ids": x_test, "labels": y_test}
    
    # Build vocabulary from training data
    print("Building vocabulary...")
    all_words = [word for seq in train_ds["token_ids"] for word in seq]
    word_counts = Counter(all_words)
    vocab_size = 15000 - 2  # Reserve 0 for PAD, 1 for UNK
    most_common = word_counts.most_common(vocab_size)
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for idx, (word, _) in enumerate(most_common, 2):
        vocab[word] = idx
    
    def to_ids(seq):
        return [vocab.get(word, 1) for word in seq]

    train_ds["token_ids"] = [to_ids(seq) for seq in train_ds["token_ids"]]
    test_ds["token_ids"] = [to_ids(seq) for seq in test_ds["token_ids"]]
    train_ds["labels"] = np.array(train_ds["labels"], dtype=np.int32)
    test_ds["labels"] = np.array(test_ds["labels"], dtype=np.int32)
    
    # --- Model and Optimizer Initialization ---
    model = BertModel(
        num_layers=args.num_layers, 
        num_heads=args.num_heads,
        num_classes=num_classes,
        dropout_rate=0.1
    )

    if args.optimizer == "sgd":
        tx = optax.sgd(args.learning_rate, momentum=0.9)
    elif args.optimizer == "adamw":
        tx = optax.adamw(args.learning_rate, weight_decay=1e-3)
    else: # adam
        tx = optax.adam(args.learning_rate)

    rng, init_rng = random.split(rng)
    
    dummy_input = jnp.zeros((1, model.max_seq_len), dtype=jnp.int32)
    init_params = model.init({'params': init_rng, 'dropout': init_rng}, dummy_input)['params']
    '''
    #print all layers
    layer_paths = []
    def collect_layer_paths(path, value):
        # Convert path to a readable string by joining path keys
        path_str = '/'.join([str(p.key) for p in path])
        shape = value.shape
        layer_paths.append((path_str, shape))

    jax.tree_util.tree_map_with_path(collect_layer_paths, init_params)
    print("Layers of the model:")
    for path, shape in layer_paths:  
        print(f"  {path} {shape}")

    Layers of the model:
    Dense_0/bias (4,)
    Dense_0/kernel (48, 4)
    Embed_0/embedding (15000, 48)
    TransformerEncoderLayer_0/Dense_0/bias (192,)
    TransformerEncoderLayer_0/Dense_0/kernel (48, 192)
    TransformerEncoderLayer_0/Dense_1/bias (48,)
    TransformerEncoderLayer_0/Dense_1/kernel (192, 48)
    TransformerEncoderLayer_0/LayerNorm_0/bias (48,)
    TransformerEncoderLayer_0/LayerNorm_0/scale (48,)
    TransformerEncoderLayer_0/LayerNorm_1/bias (48,)
    TransformerEncoderLayer_0/LayerNorm_1/scale (48,)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/key/bias (4, 12)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/key/kernel (48, 4, 12)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/out/bias (48,)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/out/kernel (4, 12, 48)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/query/bias (4, 12)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/query/kernel (48, 4, 12)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/value/bias (4, 12)
    TransformerEncoderLayer_0/MultiHeadDotProductAttention_0/value/kernel (48, 4, 12)
    TransformerEncoderLayer_1/Dense_0/bias (192,)
    TransformerEncoderLayer_1/Dense_0/kernel (48, 192)
    TransformerEncoderLayer_1/Dense_1/bias (48,)
    TransformerEncoderLayer_1/Dense_1/kernel (192, 48)
    TransformerEncoderLayer_1/LayerNorm_0/bias (48,)
    TransformerEncoderLayer_1/LayerNorm_0/scale (48,)
    TransformerEncoderLayer_1/LayerNorm_1/bias (48,)
    TransformerEncoderLayer_1/LayerNorm_1/scale (48,)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/key/bias (4, 12)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/key/kernel (48, 4, 12)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/out/bias (48,)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/out/kernel (4, 12, 48)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/query/bias (4, 12)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/query/kernel (48, 4, 12)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/value/bias (4, 12)
    TransformerEncoderLayer_1/MultiHeadDotProductAttention_0/value/kernel (48, 4, 12)
    cls_token (1, 1, 48)
    pos_embedding (1, 101, 48)
    '''
    train_state = TrainState.create(apply_fn=model.apply, params=init_params, tx=tx)
    
    # --- Training Loop ---
    train_losses, train_accuracies = [], []
    test_losses, test_accuracies = [], []

    num_train_examples = len(train_ds["labels"])
    num_test_examples = len(test_ds["labels"])
    
    for epoch in range(args.epochs):
        print(f"\n--- Epoch {epoch+1}/{args.epochs} ---")
        rng, epoch_rng = random.split(rng)
        
        # Shuffle training data
        perm = jax.random.permutation(epoch_rng, num_train_examples)
        
        # Train
        batch_metrics = []
        pbar = tqdm(range(0, num_train_examples, args.batch_size), desc="Training")
        for i in pbar:
            batch_perm = perm[i:i+args.batch_size]
            token_ids = [train_ds["token_ids"][int(j)] for j in batch_perm]
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            
            batch = {
                'inputs': x_batch,
                'pad_mask': (x_batch != PAD)[:, None, None, :],
                'labels': train_ds["labels"][batch_perm]
            }
            
            rng, step_rng = random.split(rng)
            train_state, metrics = train_step(train_state, batch, step_rng)
            batch_metrics.append(metrics)
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "acc": f"{metrics['accuracy']:.4f}"
            })

        epoch_train_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_train_acc = np.mean([m['accuracy'] for m in batch_metrics])
        train_losses.append(epoch_train_loss)
        train_accuracies.append(epoch_train_acc)

        # Evaluate
        batch_metrics = []
        pbar = tqdm(range(0, num_test_examples, args.batch_size), desc="Evaluating")
        for i in pbar:
            token_ids = test_ds["token_ids"][i:i+args.batch_size]
            x_batch = pad_sequences(token_ids, pad_value=PAD)
            
            batch = {
                'inputs': x_batch,
                'pad_mask': (x_batch != PAD)[:, None, None, :],
                'labels': test_ds["labels"][i:i+args.batch_size]
            }
            
            metrics = eval_step(train_state, batch)
            batch_metrics.append(metrics)
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "acc": f"{metrics['accuracy']:.4f}"
            })

        epoch_test_loss = np.mean([m['loss'] for m in batch_metrics])
        epoch_test_acc = np.mean([m['accuracy'] for m in batch_metrics])
        test_losses.append(epoch_test_loss)
        test_accuracies.append(epoch_test_acc)

        print(f"Epoch {epoch+1} Summary: "
              f"Train Loss={epoch_train_loss:.4f}, Train Acc={epoch_train_acc:.4f} | "
              f"Test Loss={epoch_test_loss:.4f}, Test Acc={epoch_test_acc:.4f}")

        # Save checkpoint
        ckpt_name = (
            f"agnews_bert_seed{args.seed}_"
            f"opt_{args.optimizer}_lr_{args.learning_rate}_L{args.num_layers}_H{args.num_heads}_"
            f"epoch{epoch+1}_trainloss_{epoch_train_loss:.4f}_testloss_{epoch_test_loss:.4f}_trainacc_{epoch_train_acc:.4f}_testacc_{epoch_test_acc:.4f}.flax"
        )
        with open(os.path.join(args.ckpt_path, ckpt_name), "wb") as f:
            f.write(flax.serialization.to_bytes(train_state.params))

    # --- Plot and Save Metrics ---
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss over Epochs'); plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Accuracy over Epochs'); plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.ckpt_path, "metrics_plot.png"))
    plt.close()

if __name__ == "__main__":
    main()