import argparse
import augmax
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, value_and_grad, vmap
import jax.tree_util as tree_util
from tqdm import tqdm
from flax.traverse_util import flatten_dict, unflatten_dict

from src.utils import flatten_params, rngmix, timeblock
from .mnist_vit_train import ViTModel, load_datasets, make_stuff, TransformerEncoderLayer


# Define MoEBlock: Mixture of Experts block with multiple expert MLPs
class MoEBlock(nn.Module):
    num_experts: int
    embedding_dim: int
    hidden_dim: int

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        # Gating network to select experts
        
        gate_logits = nn.Dense(self.num_experts)(x)  # [batch_size, num_patches + 1, num_experts]    
        gate = nn.softmax(gate_logits, axis=-1)
        
        # Compute outputs from each expert
        expert_outputs = []
        for _ in range(self.num_experts):
            expert_output = nn.Dense(self.hidden_dim)(x)
            expert_output = nn.gelu(expert_output)
            expert_output = nn.Dense(self.embedding_dim)(expert_output)
            expert_outputs.append(expert_output)
        expert_outputs = jnp.stack(expert_outputs, axis=-1)  # [batch_size, num_patches + 1, embedding_dim, num_experts]
        
        # Weighted sum of expert outputs based on gate
        output = jnp.einsum('bsn,bsen->bse', gate, expert_outputs)  # [batch_size, num_patches + 1, embedding_dim]

        if return_moe_outputs:
            return output, gate_logits, expert_outputs
        else:
            return output

# Define MoETransformerEncoderLayer: Transformer layer with MoE instead of MLP
class MoETransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int
    num_experts: int

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        # Attention block
        x_norm = nn.LayerNorm(use_scale=False, use_bias=False, name='LayerNorm_0')(x)
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.embedding_dim,
            name='MultiHeadDotProductAttention_0'
        )(x_norm, x_norm)
        x = x + attn_output
        
        # MoE block
        x_norm = nn.LayerNorm(use_scale=False, use_bias=False, name='LayerNorm_1')(x)
        if return_moe_outputs:
            moe_output, gate_logits, expert_outputs = MoEBlock(
                num_experts=self.num_experts,
                embedding_dim=self.embedding_dim,
                hidden_dim=self.hidden_dim,
                name='MoEBlock_0'
            )(x_norm, return_moe_outputs=True)
        else:
            moe_output = MoEBlock(
                num_experts=self.num_experts,
                embedding_dim=self.embedding_dim,
                hidden_dim=self.hidden_dim,
                name='MoEBlock_0'
            )(x_norm)
        x = x + moe_output
        
        if return_moe_outputs:
            return x, gate_logits, expert_outputs
        else:
            return x

# Define ViTModelMoE: ViT model with MoE in the last transformer layer
class ViTModelMoE(nn.Module):
    patch_size: int = 7
    embedding_dim: int = 32
    num_heads: int = 1
    num_layers: int = 2
    hidden_dim: int = 32
    hidden_dim_expert: int = 32
    num_classes: int = 10
    num_experts: int = 2

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        # Patch embedding
        x = nn.Conv(
            features=self.embedding_dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            padding="VALID",
            name='Conv_0'
        )(x)
        x = x.reshape((x.shape[0], -1, self.embedding_dim))
        
        # Add CLS token
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)
        
        # Positional embedding
        pos_embedding = self.param(
            'pos_embedding',
            nn.initializers.normal(stddev=0.02),
            (1, x.shape[1], self.embedding_dim)
        )
        x = x + pos_embedding
        
        if return_moe_outputs:
            x, gate_logits, expert_outputs = MoETransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim_expert,
                num_experts=self.num_experts,
                name=f'MoETransformerEncoderLayer_0'
            )(x, return_moe_outputs=True)
        else:
            x = MoETransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim_expert,
                num_experts=self.num_experts,
                name=f'MoETransformerEncoderLayer_0'
            )(x)

        # Transformer layers
        for i in range(1, self.num_layers):
            x = TransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim,
                name=f'TransformerEncoderLayer_{i}'
            )(x)

        # Classification head
        x = x[:, 0, :]  # Take CLS token output
        x = nn.Dense(self.num_classes, name='Dense_0')(x)
        
        if return_moe_outputs:
            return x, gate_logits, expert_outputs
        else:
            return x

    def get_moe_outputs(self, params, x):
        variables = {'params': params}
        _, gate_logits, expert_outputs = self.apply(variables, x, return_moe_outputs=True)
        return gate_logits, expert_outputs

    def get_moe_params(self, params):
        """
        Returns a dictionary containing the kernels and biases of the MoE experts and the gating network
        as NumPy arrays.

        The dictionary includes:
        - 'gating_kernel': 2D array of shape (embedding_dim, num_experts)
        - 'gating_bias': 1D array of shape (num_experts)
        - For each expert i (0 to num_experts-1):
            - f'expert_{i}_layer1_kernel': 2D array of shape (embedding_dim, hidden_dim)
            - f'expert_{i}_layer1_bias': 1D array of shape (hidden_dim)
            - f'expert_{i}_layer2_kernel': 2D array of shape (hidden_dim, embedding_dim)
            - f'expert_{i}_layer2_bias': 1D array of shape (embedding_dim)

        Total items: 4 * num_experts + 2
        """
        moe_layer_name = f'MoETransformerEncoderLayer_0'
        moe_block_params = params[moe_layer_name]['MoEBlock_0']
        dict_out = {}
        # Gating parameters
        gating_kernel = moe_block_params['Dense_0']['kernel']
        gating_bias = moe_block_params['Dense_0']['bias']
        dict_out['gating_kernel'] = np.array(gating_kernel)  # Shape: (embedding_dim, num_experts)
        dict_out['gating_bias'] = np.array(gating_bias)      # Shape: (num_experts)
        # Experts' parameters
        for i in range(self.num_experts):
            layer1_kernel = moe_block_params[f'Dense_{2*i + 1}']['kernel']
            layer1_bias = moe_block_params[f'Dense_{2*i + 1}']['bias']
            layer2_kernel = moe_block_params[f'Dense_{2*i + 2}']['kernel']
            layer2_bias = moe_block_params[f'Dense_{2*i + 2}']['bias']
            dict_out[f'expert_{i}_layer1_kernel'] = np.array(layer1_kernel)  # Shape: (embedding_dim, hidden_dim)
            dict_out[f'expert_{i}_layer1_bias'] = np.array(layer1_bias)      # Shape: (hidden_dim)
            dict_out[f'expert_{i}_layer2_kernel'] = np.array(layer2_kernel)  # Shape: (hidden_dim, embedding_dim)
            dict_out[f'expert_{i}_layer2_bias'] = np.array(layer2_bias)      # Shape: (embedding_dim)
        return dict_out

# Main function for fine-tuning
def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Fine-tune ViT with MoE on MNIST")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained ViT model checkpoint")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True, help="Optimizer to use")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--num-layers", type=int, default=1, help="Number of transformer layers")
    parser.add_argument("--num-experts", type=int, default=2, help="Number of experts in MoE block")
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()


    class Config:
        pass
    config = Config()       
    config.seed = args.seed
    config.model_path = args.model_path
    config.optimizer = args.optimizer
    config.learning_rate = args.learning_rate
    config.num_epochs = 100
    config.batch_size = 500

    print("JAX devices:", jax.devices())

    # Infer num_layers from ViTModel
    num_layers = args.num_layers 
    model = ViTModel(num_layers=num_layers)

    # Load pre-trained parameters
    with open(config.model_path, "rb") as f:
        serialized_params = f.read()
    dummy_params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1)))['params']
    pretrained_params = flax.serialization.from_bytes(dummy_params, serialized_params)

    # Define fine-tuning model with the same num_layers
    finetune_model = ViTModelMoE(num_layers=num_layers, num_experts=args.num_experts)

    # Initialize fine-tuning model parameters with new seed
    rng = jax.random.PRNGKey(config.seed)
    new_params = finetune_model.init(rng, jnp.ones((1, 28, 28, 1)))['params']

    # Combine parameters: copy pre-trained weights except for MoE block
    combined_params = new_params.copy()
    for i in range(1, num_layers):
        combined_params[f'TransformerEncoderLayer_{i}'] = pretrained_params[f'TransformerEncoderLayer_{i}']
    first_layer_key = f'TransformerEncoderLayer_0'
    moe_layer_key = f'MoETransformerEncoderLayer_0'
    for subkey in ['MultiHeadDotProductAttention_0']:
        combined_params[moe_layer_key][subkey] = pretrained_params[first_layer_key][subkey]
    for key in ['Conv_0', 'cls_token', 'pos_embedding', 'Dense_0']:
        combined_params[key] = pretrained_params[key]

    # Set up optimizer to train only MoE block
    if config.optimizer == "sgd":
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=1e-6,
            peak_value=config.learning_rate,
            warmup_steps=10,
            decay_steps=config.num_epochs * (60000 // config.batch_size)
        )
        base_tx = optax.sgd(lr_schedule, momentum=0.9)
    elif config.optimizer == "adam":
        base_tx = optax.adam(config.learning_rate)
    else:  # adamw
        base_tx = optax.adamw(config.learning_rate, weight_decay=1e-3)

    def get_labels(params):
        def label_fn(path, _):
            if (len(path) >= 2 and 
                path[0].key == f'MoETransformerEncoderLayer_0' and path[1].key == 'MoEBlock_0'):
                return 'trainable'
            #if path[0].key == "Dense_0": return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, params)

    labels = get_labels(combined_params)
    ######sanity check
    # List to store paths of trainable parameters
    trainable_paths = []

    # Function to collect paths where label is 'trainable'
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            # Convert path to a readable string by joining path keys
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)

    # Traverse the label tree and collect trainable paths
    tree_util.tree_map_with_path(collect_trainable_paths, labels)

    # Print the trainable parts
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")
    ######

    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    # Create training state
    train_state = TrainState.create(
        apply_fn=finetune_model.apply,
        params=combined_params,
        tx=tx
    )
    stuff = make_stuff(finetune_model)

    # Load datasets
    with timeblock("load_datasets"):
        train_ds, test_ds = load_datasets()
        num_train_examples = train_ds["images_u8"].shape[0]
        num_test_examples = test_ds["images_u8"].shape[0]
        assert num_train_examples % config.batch_size == 0

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    # Early stopping setup
    patience = 10
    best_test_loss = 99999999
    best_metric = None
    best_params = None
    best_epoch = -1
    epochs_since_improvement = 0

    # Training loop
    for epoch in tqdm(range(config.num_epochs), desc="Epochs"):
        infos = []
        with timeblock(f"Epoch {epoch}"):
            batch_ix = random.permutation(
                rngmix(rng, f"epoch-{epoch}"),
                num_train_examples
            ).reshape((-1, config.batch_size))
            for i in range(batch_ix.shape[0]):
                p = batch_ix[i, :]
                train_state, info = stuff["step"](
                    train_state,
                    train_ds["images_u8"][p, :, :, :],
                    train_ds["labels"][p]
                )
                infos.append(info)

        # Compute training metrics
        train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples
        train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples

        # Evaluate on test set
        with timeblock("Test set eval"):
            test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](
                train_state.params,
                test_ds,
                10_000
            )

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        # Early stopping logic
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_params = train_state.params
            best_epoch = epoch
            best_metric = {
                "train_loss": f"{train_loss:.4f}",
                "train_accuracy": f"{train_accuracy:.4f}",
                "test_loss": f"{test_loss:.4f}",
                "test_accuracy": f"{test_accuracy:.4f}",
            }
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break

    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Test Loss over Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/{config.seed}loss_plot.png")
    plt.close()
    plt.figure()
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Test Accuracy over Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/{config.seed}accuracy_plot.png")
    plt.close()

    # Save fine-tuned model
    with timeblock("model serialization"):
        
        weights_file = (
            f"mnist_vit_finetune_moe_seed{config.seed}_"
            f"opt_{config.optimizer}_lr_{config.learning_rate}_num_layers_{finetune_model.num_layers}_num_experts_{finetune_model.num_experts}_hidden_dim_{finetune_model.hidden_dim}_hidden_dim_expert_{finetune_model.hidden_dim_expert}_embedding_dim_{finetune_model.embedding_dim}_num_heads_{finetune_model.num_heads}_best_epoch{best_epoch}"
        )
        
        # Write to the text file
        with open(f'{args.ckpt_path}/ckpt.txt', "a") as f:
            f.write(f"\n{weights_file}\n")
            for metric, value in best_metric.items():
                f.write(f"{metric}: {value}\n")

        with open(f'{args.ckpt_path}/{weights_file}', "wb") as f:
            f.write(flax.serialization.to_bytes(best_params))


if __name__ == "__main__":
    main()