import argparse
import augmax
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import matplotlib.pyplot as plt
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import random, value_and_grad
from jax.tree_util import tree_map
import jax.tree_util as tree_util
from tqdm import tqdm
from flax.traverse_util import flatten_dict, unflatten_dict

from src.utils import flatten_params, rngmix, timeblock
from src.datasets import load_cifar10
from .cifar10_vit_train import ViTModel, TransformerEncoderLayer, make_stuff

# Define MoEBlock: Mixture of Experts block with multiple expert MLPs
class MoEBlock(nn.Module):
    num_experts: int
    embedding_dim: int
    hidden_dim: int

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        # Gating network to select experts
        gate_logits = nn.Dense(self.num_experts)(x)  # [batch_size, num_patches + 1, num_experts]
        gate = nn.softmax(gate_logits, axis=-1)
        
        # Compute outputs from each expert
        expert_outputs = []
        for _ in range(self.num_experts):
            expert_output = nn.Dense(self.hidden_dim)(x)
            expert_output = nn.gelu(expert_output)
            expert_output = nn.Dense(self.embedding_dim)(expert_output)
            expert_outputs.append(expert_output)
        expert_outputs = jnp.stack(expert_outputs, axis=-1)  # [batch_size, num_patches + 1, embedding_dim, num_experts]
        
        # Weighted sum of expert outputs based on gate
        output = jnp.einsum('bsn,bsen->bse', gate, expert_outputs)  # [batch_size, num_patches + 1, embedding_dim]
        
        if return_moe_outputs:
            return output, gate_logits, expert_outputs
        else:
            return output

# Define MoETransformerEncoderLayer: Transformer layer with MoE instead of MLP
class MoETransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int
    num_experts: int

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        # Attention block
        x_norm = nn.LayerNorm(use_scale=True, use_bias=True, name='LayerNorm_0')(x)
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.embedding_dim,
            name='MultiHeadDotProductAttention_0'
        )(x_norm, x_norm)
        x = x + attn_output
        
        # MoE block
        x_norm = nn.LayerNorm(use_scale=True, use_bias=True, name='LayerNorm_1')(x)
        if return_moe_outputs:
            moe_output, gate_logits, expert_outputs = MoEBlock(
                num_experts=self.num_experts,
                embedding_dim=self.embedding_dim,
                hidden_dim=self.hidden_dim,
                name='MoEBlock_0'
            )(x_norm, return_moe_outputs=True)
        else:
            moe_output = MoEBlock(
                num_experts=self.num_experts,
                embedding_dim=self.embedding_dim,
                hidden_dim=self.hidden_dim,
                name='MoEBlock_0'
            )(x_norm)
        x = x + moe_output
        
        if return_moe_outputs:
            return x, gate_logits, expert_outputs
        else:
            return x

# Define ViTModelMoE: ViT with MoE in the last transformer layer
class ViTModelMoE(nn.Module):
    patch_size: int = 4
    embedding_dim: int = 192
    num_heads: int = 12
    num_layers: int = 0
    hidden_dim: int = 768
    hidden_dim_expert: int = 768
    num_classes: int = 10
    num_experts: int = 0

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        # Patch embedding
        x = nn.Conv(
            features=self.embedding_dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            padding="VALID",
            name='Conv_0'
        )(x)
        x = x.reshape((x.shape[0], -1, self.embedding_dim))

        # Add CLS token
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)

        # Positional embedding
        pos_embedding = self.param(
            'pos_embedding',
            nn.initializers.normal(stddev=0.02),
            (1, x.shape[1], self.embedding_dim)
        )
        x = x + pos_embedding

        if return_moe_outputs:
            x, gate_logits, expert_outputs = MoETransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim_expert,
                num_experts=self.num_experts,
                name=f'MoETransformerEncoderLayer_0'
            )(x, return_moe_outputs=True)
        else:
            x = MoETransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim_expert,
                num_experts=self.num_experts,
                name=f'MoETransformerEncoderLayer_0'
            )(x)

        # Transformer layers
        for i in range(1, self.num_layers):
            x = TransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim,
                name=f'TransformerEncoderLayer_{i}'
            )(x)
        
        # Classification head
        x = x[:, 0, :]  # Take CLS token output
        x = nn.Dense(self.num_classes, name='Dense_0')(x)
        #x = nn.log_softmax(x)
        
        if return_moe_outputs:
            return x, gate_logits, expert_outputs
        else:
            return x

    def get_moe_outputs(self, params, x):
        variables = {'params': params}
        _, gate_logits, expert_outputs = self.apply(variables, x, return_moe_outputs=True)
        return gate_logits, expert_outputs

    def get_moe_params(self, params):
        """
        Returns a dictionary containing the kernels and biases of the MoE experts and the gating network
        as NumPy arrays.

        The dictionary includes:
        - 'gating_kernel': 2D array of shape (embedding_dim, num_experts)
        - 'gating_bias': 1D array of shape (num_experts)
        - For each expert i (0 to num_experts-1):
            - f'expert_{i}_layer1_kernel': 2D array of shape (embedding_dim, hidden_dim)
            - f'expert_{i}_layer1_bias': 1D array of shape (hidden_dim)
            - f'expert_{i}_layer2_kernel': 2D array of shape (hidden_dim, embedding_dim)
            - f'expert_{i}_layer2_bias': 1D array of shape (embedding_dim)

        Total items: 4 * num_experts + 2
        """
        moe_layer_name = f'MoETransformerEncoderLayer_0'
        moe_block_params = params[moe_layer_name]['MoEBlock_0']
        dict_out = {}
        # Gating parameters
        gating_kernel = moe_block_params['Dense_0']['kernel']
        gating_bias = moe_block_params['Dense_0']['bias']
        dict_out['gating_kernel'] = np.array(gating_kernel)  # Shape: (embedding_dim, num_experts)
        dict_out['gating_bias'] = np.array(gating_bias)      # Shape: (num_experts)
        # Experts' parameters
        for i in range(self.num_experts):
            layer1_kernel = moe_block_params[f'Dense_{2*i + 1}']['kernel']
            layer1_bias = moe_block_params[f'Dense_{2*i + 1}']['bias']
            layer2_kernel = moe_block_params[f'Dense_{2*i + 2}']['kernel']
            layer2_bias = moe_block_params[f'Dense_{2*i + 2}']['bias']
            dict_out[f'expert_{i}_layer1_kernel'] = np.array(layer1_kernel)  # Shape: (embedding_dim, hidden_dim)
            dict_out[f'expert_{i}_layer1_bias'] = np.array(layer1_bias)      # Shape: (hidden_dim)
            dict_out[f'expert_{i}_layer2_kernel'] = np.array(layer2_kernel)  # Shape: (hidden_dim, embedding_dim)
            dict_out[f'expert_{i}_layer2_bias'] = np.array(layer2_bias)      # Shape: (embedding_dim)
        return dict_out

# Main function for fine-tuning
def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Fine-tune ViT with MoE on CIFAR-10")
    parser.add_argument("--test", action="store_true", help="Run in smoke-test mode")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained ViT model checkpoint")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True, help="Optimizer to use")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--num-layers", type=int, default=1, help="Number of transformer layers")
    parser.add_argument("--num-experts", type=int, default=2, help="Number of experts in MoE block")
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()


    # Configuration
    class Config:
        pass
    config = Config()
    config.test = args.test
    config.seed = args.seed
    config.model_path = args.model_path
    config.optimizer = args.optimizer
    config.learning_rate = args.learning_rate
    config.num_epochs = 50
    config.batch_size = 100

    # Load pre-trained model
    num_layers = args.num_layers 
    model = ViTModel(num_layers=num_layers)
    with open(config.model_path, "rb") as f:
        serialized_params = f.read()
    dummy_params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 32, 32, 3)))['params']
    pretrained_params = flax.serialization.from_bytes(dummy_params, serialized_params)

    # Define fine-tuning model
    finetune_model = ViTModelMoE(num_layers=num_layers, num_experts=args.num_experts)
    rng = jax.random.PRNGKey(config.seed)
    new_params = finetune_model.init(rng, jnp.ones((1, 32, 32, 3)))['params']

    # Combine parameters: copy pre-trained weights except for MoE block
    combined_params = new_params.copy()
    for i in range(1, num_layers):
        combined_params[f'TransformerEncoderLayer_{i}'] = pretrained_params[f'TransformerEncoderLayer_{i}']
    first_layer_key = f'TransformerEncoderLayer_0'
    moe_layer_key = f'MoETransformerEncoderLayer_0'
    for subkey in ['MultiHeadDotProductAttention_0']:
        combined_params[moe_layer_key][subkey] = pretrained_params[first_layer_key][subkey]
    for key in ['Conv_0', 'cls_token', 'pos_embedding', 'Dense_0']: #first 3 is before Transformer blocks, Dense_0 is classification head
        combined_params[key] = pretrained_params[key]

    # Set up optimizer to train only MoE block
    if config.optimizer == "sgd":
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=1e-6,
            peak_value=config.learning_rate,
            warmup_steps=10,
            decay_steps=config.num_epochs * (50000 // config.batch_size)
        )
        base_tx = optax.sgd(lr_schedule, momentum=0.9)
    elif config.optimizer == "adam":
        base_tx = optax.adam(config.learning_rate)
    else:  # adamw
        base_tx = optax.adamw(config.learning_rate, weight_decay=1e-3)

    def get_labels(params):
        def label_fn(path, _):
            if ((len(path) >= 3 and 
                path[0].key == f'MoETransformerEncoderLayer_0' and path[1].key == 'MoEBlock_0')):
                return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, params)

    labels = get_labels(combined_params)

    ######sanity check
    # List to store paths of trainable parameters
    trainable_paths = []

    # Function to collect paths where label is 'trainable'
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            # Convert path to a readable string by joining path keys
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)

    # Traverse the label tree and collect trainable paths
    tree_util.tree_map_with_path(collect_trainable_paths, labels)

    # Print the trainable parts
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")
    '''example output
    Trainable parts of the model:
        Dense_0/bias
        Dense_0/kernel
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_0/bias
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_0/kernel
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_1/bias
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_1/kernel
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_2/bias
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_2/kernel
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_3/bias
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_3/kernel
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_4/bias
        MoETransformerEncoderLayer_1/MoEBlock_0/Dense_4/kernel
    Dense_0 is classification head; /Dense_0 is gate; /Dense_1 to /Dense_4 is 2 experts
    '''
    #########

    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    # Create training state
    train_state = TrainState.create(
        apply_fn=finetune_model.apply,
        params=combined_params,
        tx=tx
    )
    stuff = make_stuff(finetune_model)

    # Load CIFAR-10 datasets
    with timeblock("Load datasets"):
        train_ds, test_ds = load_cifar10()
        num_train_examples = train_ds["images_u8"].shape[0]
        num_test_examples = test_ds["images_u8"].shape[0]
        assert num_train_examples % config.batch_size == 0

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    # Early stopping setup
    patience = 10
    best_test_accuracy = 0.0
    best_test_loss = 99999
    best_metric = None
    best_params = None
    best_epoch = -1
    epochs_since_improvement = 0

    # Training loop
    for epoch in tqdm(range(config.num_epochs), desc="Epochs"):
        infos = []
        with timeblock(f"Epoch {epoch}"):
            batch_ix = random.permutation(rngmix(rng, f"epoch-{epoch}"), num_train_examples).reshape(
                (-1, config.batch_size)
            )
        batch_rngs = random.split(rngmix(rng, f"batch_rngs-{epoch}"), batch_ix.shape[0])
        for i in range(batch_ix.shape[0]):
            p = batch_ix[i, :]
            images_u8 = train_ds["images_u8"][p, :, :, :]
            labels = train_ds["labels"][p]
            train_state, info = stuff["step"](batch_rngs[i], train_state, images_u8, labels)
            infos.append(info)

        train_loss = sum(config.batch_size * x["batch_loss"] for x in infos) / num_train_examples
        train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples

        with timeblock("Test set eval"):
            test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](
                train_state.params,
                test_ds,
                2000  # CIFAR-10 test set size
            )

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_params = train_state.params
            best_epoch = epoch
            best_metric = {
                "train_loss": f"{train_loss:.4f}",
                "train_accuracy": f"{train_accuracy:.4f}",
                "test_loss": f"{test_loss:.4f}",
                "test_accuracy": f"{test_accuracy:.4f}",
            }
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break

    # Save fine-tuned model
    with timeblock("Model serialization"):
        weights_file = (
            f"cifar10_vit_finetune_moe_seed{config.seed}_"
            f"opt_{config.optimizer}_lr_{config.learning_rate}_num_layers_{finetune_model.num_layers}_num_experts_{finetune_model.num_experts}_hidden_dim_{finetune_model.hidden_dim}_hidden_dim_expert_{finetune_model.hidden_dim_expert}_embedding_dim_{finetune_model.embedding_dim}_num_heads_{finetune_model.num_heads}_best_epoch{best_epoch}"
        )
        # Write to the text file
        with open(f'{args.ckpt_path}/ckpt.txt', "a") as f:
            f.write(f"\n{weights_file}\n")
            for metric, value in best_metric.items():
                f.write(f"{metric}: {value}\n")

        with open(f'{args.ckpt_path}/{weights_file}', "wb") as f:
            f.write(flax.serialization.to_bytes(best_params))

    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Test Loss over Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/{config.seed}loss_plot.png")
    plt.close()
    plt.figure()
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Test Accuracy over Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/{config.seed}accuracy_plot.png")
    plt.close()

if __name__ == "__main__":
    main()