import argparse
import augmax
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from flax import linen as nn
from flax.training.train_state import TrainState
from jax import jit, random, value_and_grad, vmap
import jax.tree_util as tree_util
from tqdm import tqdm
from flax.traverse_util import flatten_dict, unflatten_dict

from src.utils import flatten_params, rngmix, timeblock
from .mnist_vit_train import ViTModel, load_datasets, TransformerEncoderLayer

shared = True

def make_stuff(model):
    normalize_transform = augmax.ByteToFloat()
    num_experts = model.num_experts
    topk = model.topk
    alpha = 0.01  # Hyperparameter for auxiliary loss weight

    @jit
    def compute_gate(params, images_f32):
        _, gate_logits, _ = model.apply({"params": params}, images_f32, return_moe_outputs=True)
        top_k_values, _ = jax.lax.top_k(gate_logits, topk)
        kth_values = top_k_values[..., -1]
        gate_logits_masked = jnp.where(
            gate_logits >= kth_values[..., None],
            gate_logits,
            -1e9
        )
        gate = nn.softmax(gate_logits_masked, axis=-1)
        return gate

    @jit
    def compute_aux_loss(gate):
        P_e = jnp.sum(gate, axis=(0, 1))  # Sum over batch and sequence
        P_mean = jnp.sum(P_e) / num_experts
        sigma = jnp.sqrt(jnp.sum((P_e - P_mean) ** 2) / num_experts)
        cv = sigma / P_mean
        return cv ** 2

    #@jit
    def batch_eval(params, images_u8, labels, training=False):
        images_f32 = vmap(normalize_transform)(None, images_u8)
        logits = model.apply({"params": params}, images_f32)
        y_onehot = jax.nn.one_hot(labels, 10)
        ce_loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y_onehot))
        if training:
            gate = compute_gate(params, images_f32)
            aux_loss = compute_aux_loss(gate)
            total_loss = ce_loss + alpha * aux_loss
        else:
            total_loss = ce_loss
            aux_loss = jnp.array(0.0)
        num_correct = jnp.sum(jnp.argmax(logits, axis=-1) == labels)
        return total_loss, {"num_correct": num_correct, "ce_loss": ce_loss, "aux_loss": aux_loss}
    from functools import partial
    batch_eval = partial(jit, static_argnums=(3,))(batch_eval)

    @jit
    def step(train_state, images_u8, labels):
        (total_loss, info), g = value_and_grad(batch_eval, has_aux=True)(train_state.params, images_u8, labels, training=True)
        return train_state.apply_gradients(grads=g), {"batch_loss": total_loss, **info}

    def dataset_loss_and_accuracy(params, dataset, batch_size: int):
        num_examples = dataset["images_u8"].shape[0]
        assert num_examples % batch_size == 0
        num_batches = num_examples // batch_size
        batch_ix = jnp.arange(num_examples).reshape((num_batches, batch_size))
        losses, infos = zip(*[batch_eval(params, dataset["images_u8"][batch_ix[i, :], :, :, :], dataset["labels"][batch_ix[i, :]], training=False) for i in range(num_batches)])
        total_loss = jnp.sum(batch_size * jnp.array([x["ce_loss"] for x in infos])) / num_examples
        total_accuracy = sum(x["num_correct"] for x in infos) / num_examples
        return total_loss, total_accuracy

    return {"normalize_transform": normalize_transform, "batch_eval": batch_eval, "step": step, "dataset_loss_and_accuracy": dataset_loss_and_accuracy}

class SharedExpert(nn.Module):
    hidden_dim: int
    embedding_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.embedding_dim)(x)
        return x

class MoEBlock(nn.Module):
    num_experts: int
    embedding_dim: int
    hidden_dim: int
    topk: int

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        gate_logits = nn.Dense(self.num_experts)(x)
        top_k_values, _ = jax.lax.top_k(gate_logits, self.topk)
        kth_values = top_k_values[..., -1]
        gate_logits_masked = jnp.where(
            gate_logits >= kth_values[..., None],
            gate_logits,
            -1e9
        )
        gate = nn.softmax(gate_logits, axis=-1)
        
        expert_outputs = []
        for _ in range(self.num_experts):
            expert_output = nn.Dense(self.hidden_dim)(x)
            expert_output = nn.gelu(expert_output)
            expert_output = nn.Dense(self.embedding_dim)(expert_output)
            expert_outputs.append(expert_output)
        expert_outputs = jnp.stack(expert_outputs, axis=-1)
        
        output = jnp.einsum('bsn,bsen->bse', gate, expert_outputs)

        ####shared expert
        # Compute shared expert output and add it to the routed experts' output
        if shared:
            shared_expert = SharedExpert(self.hidden_dim, self.embedding_dim, name='shared_expert')
            shared_output = shared_expert(x)
            output = output + shared_output
        ####

        if return_moe_outputs:
            return output, gate_logits, expert_outputs
        else:
            return output

class MoETransformerEncoderLayer(nn.Module):
    embedding_dim: int
    num_heads: int
    hidden_dim: int
    num_experts: int
    topk: int

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        x_norm = nn.LayerNorm(use_scale=False, use_bias=False, name='LayerNorm_0')(x)
        attn_output = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.embedding_dim,
            name='MultiHeadDotProductAttention_0'
        )(x_norm, x_norm)
        x = x + attn_output
        
        x_norm = nn.LayerNorm(use_scale=False, use_bias=False, name='LayerNorm_1')(x)
        if return_moe_outputs:
            moe_output, gate_logits, expert_outputs = MoEBlock(
                num_experts=self.num_experts,
                embedding_dim=self.embedding_dim,
                hidden_dim=self.hidden_dim,
                topk=self.topk,
                name='MoEBlock_0'
            )(x_norm, return_moe_outputs=True)
        else:
            moe_output = MoEBlock(
                num_experts=self.num_experts,
                embedding_dim=self.embedding_dim,
                hidden_dim=self.hidden_dim,
                topk=self.topk,
                name='MoEBlock_0'
            )(x_norm)
        x = x + moe_output
        
        if return_moe_outputs:
            return x, gate_logits, expert_outputs
        else:
            return x

class ViTModelMoE(nn.Module):
    patch_size: int = 7
    embedding_dim: int = 32
    num_heads: int = 1
    num_layers: int = 1
    hidden_dim: int = 32
    hidden_dim_expert: int = 32
    num_classes: int = 10
    num_experts: int = 4
    topk: int = 2

    @nn.compact
    def __call__(self, x, return_moe_outputs=False):
        x = nn.Conv(
            features=self.embedding_dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            padding="VALID",
            name='Conv_0'
        )(x)
        x = x.reshape((x.shape[0], -1, self.embedding_dim))
        
        cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.embedding_dim))
        cls_token = jnp.tile(cls_token, (x.shape[0], 1, 1))
        x = jnp.concatenate([cls_token, x], axis=1)
        
        pos_embedding = self.param(
            'pos_embedding',
            nn.initializers.normal(stddev=0.02),
            (1, x.shape[1], self.embedding_dim)
        )
        x = x + pos_embedding
        
        if return_moe_outputs:
            x, gate_logits, expert_outputs = MoETransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim_expert,
                num_experts=self.num_experts,
                topk=self.topk,
                name=f'MoETransformerEncoderLayer_0'
            )(x, return_moe_outputs=True)
        else:
            x = MoETransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim_expert,
                num_experts=self.num_experts,
                topk=self.topk,
                name=f'MoETransformerEncoderLayer_0'
            )(x)

        for i in range(1, self.num_layers):
            x = TransformerEncoderLayer(
                embedding_dim=self.embedding_dim,
                num_heads=self.num_heads,
                hidden_dim=self.hidden_dim,
                name=f'TransformerEncoderLayer_{i}'
            )(x)

        x = x[:, 0, :]
        x = nn.Dense(self.num_classes, name='Dense_0')(x)
        
        if return_moe_outputs:
            return x, gate_logits, expert_outputs
        else:
            return x

    def get_moe_outputs(self, params, x):
        variables = {'params': params}
        _, gate_logits, expert_outputs = self.apply(variables, x, return_moe_outputs=True)
        return gate_logits, expert_outputs

    def get_moe_params(self, params):
        moe_layer_name = f'MoETransformerEncoderLayer_0'
        moe_block_params = params[moe_layer_name]['MoEBlock_0']
        dict_out = {}
        gating_kernel = moe_block_params['Dense_0']['kernel']
        gating_bias = moe_block_params['Dense_0']['bias']
        dict_out['gating_kernel'] = np.array(gating_kernel)
        dict_out['gating_bias'] = np.array(gating_bias)
        for i in range(self.num_experts):
            layer1_kernel = moe_block_params[f'Dense_{2*i + 1}']['kernel']
            layer1_bias = moe_block_params[f'Dense_{2*i + 1}']['bias']
            layer2_kernel = moe_block_params[f'Dense_{2*i + 2}']['kernel']
            layer2_bias = moe_block_params[f'Dense_{2*i + 2}']['bias']
            dict_out[f'expert_{i}_layer1_kernel'] = np.array(layer1_kernel)
            dict_out[f'expert_{i}_layer1_bias'] = np.array(layer1_bias)
            dict_out[f'expert_{i}_layer2_kernel'] = np.array(layer2_kernel)
            dict_out[f'expert_{i}_layer2_bias'] = np.array(layer2_bias)
        return dict_out

def main():
    parser = argparse.ArgumentParser(description="Fine-tune ViT with MoE on MNIST")
    parser.add_argument("--seed", type=int, required=True, help="Random seed for fine-tuning")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pre-trained ViT model checkpoint")
    parser.add_argument("--optimizer", choices=["sgd", "adam", "adamw"], required=True, help="Optimizer to use")
    parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
    parser.add_argument("--num-layers", type=int, default=1, help="Number of transformer layers")
    parser.add_argument("--num-experts", type=int, default=2, help="Number of experts in MoE block")
    parser.add_argument("--ckpt-path", type=str, default="/", help="Path to ckpt directory")
    args = parser.parse_args()

    class Config:
        pass
    config = Config()
    config.seed = args.seed
    config.model_path = args.model_path
    config.optimizer = args.optimizer
    config.learning_rate = args.learning_rate
    config.num_epochs = 100
    config.batch_size = 500

    print("JAX devices:", jax.devices())

    num_layers = args.num_layers
    model = ViTModel(num_layers=num_layers)

    with open(config.model_path, "rb") as f:
        serialized_params = f.read()
    dummy_params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1)))['params']
    pretrained_params = flax.serialization.from_bytes(dummy_params, serialized_params)

    finetune_model = ViTModelMoE(num_layers=num_layers, num_experts=args.num_experts)

    rng = jax.random.PRNGKey(config.seed)
    new_params = finetune_model.init(rng, jnp.ones((1, 28, 28, 1)))['params']
    
    combined_params = new_params.copy()
    for i in range(1, num_layers):
        combined_params[f'TransformerEncoderLayer_{i}'] = pretrained_params[f'TransformerEncoderLayer_{i}']
    first_layer_key = f'TransformerEncoderLayer_0'
    moe_layer_key = f'MoETransformerEncoderLayer_0'
    for subkey in ['MultiHeadDotProductAttention_0']:
        combined_params[moe_layer_key][subkey] = pretrained_params[first_layer_key][subkey]
    for key in ['Conv_0', 'cls_token', 'pos_embedding', 'Dense_0']:
        combined_params[key] = pretrained_params[key]

    if config.optimizer == "sgd":
        lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=1e-6,
            peak_value=config.learning_rate,
            warmup_steps=10,
            decay_steps=config.num_epochs * (60000 // config.batch_size)
        )
        base_tx = optax.sgd(lr_schedule, momentum=0.9)
    elif config.optimizer == "adam":
        base_tx = optax.adam(config.learning_rate)
    else:
        base_tx = optax.adamw(config.learning_rate, weight_decay=1e-3)

    def get_labels(params):
        def label_fn(path, _):
            if (len(path) >= 2 and
                path[0].key == f'MoETransformerEncoderLayer_0' and path[1].key == 'MoEBlock_0'):
                return 'trainable'
            return 'frozen'
        return jax.tree_util.tree_map_with_path(label_fn, params)

    labels = get_labels(combined_params)
    trainable_paths = []
    def collect_trainable_paths(path, label):
        if label == 'trainable':
            path_str = '/'.join([str(p.key) for p in path])
            trainable_paths.append(path_str)
    tree_util.tree_map_with_path(collect_trainable_paths, labels)
    print("Trainable parts of the model:")
    for path in trainable_paths:
        print(f"  {path}")

    tx = optax.multi_transform(
        {'trainable': base_tx, 'frozen': optax.set_to_zero()},
        labels
    )

    train_state = TrainState.create(
        apply_fn=finetune_model.apply,
        params=combined_params,
        tx=tx
    )
    stuff = make_stuff(finetune_model)

    with timeblock("load_datasets"):
        train_ds, test_ds = load_datasets()
        num_train_examples = train_ds["images_u8"].shape[0]
        num_test_examples = test_ds["images_u8"].shape[0]
        assert num_train_examples % config.batch_size == 0

    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    patience = 10
    best_test_loss = 99999999
    best_metric = None
    best_params = None
    best_epoch = -1
    epochs_since_improvement = 0

    for epoch in tqdm(range(config.num_epochs), desc="Epochs"):
        infos = []
        with timeblock(f"Epoch {epoch}"):
            batch_ix = random.permutation(
                rngmix(rng, f"epoch-{epoch}"),
                num_train_examples
            ).reshape((-1, config.batch_size))
            for i in range(batch_ix.shape[0]):
                p = batch_ix[i, :]
                train_state, info = stuff["step"](
                    train_state,
                    train_ds["images_u8"][p, :, :, :],
                    train_ds["labels"][p]
                )
                infos.append(info)

        train_loss = sum(config.batch_size * x["ce_loss"] for x in infos) / num_train_examples
        train_accuracy = sum(x["num_correct"] for x in infos) / num_train_examples

        with timeblock("Test set eval"):
            test_loss, test_accuracy = stuff["dataset_loss_and_accuracy"](
                train_state.params,
                test_ds,
                10_000
            )

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_params = train_state.params
            best_epoch = epoch
            best_metric = {
                "train_loss": f"{train_loss:.4f}",
                "train_accuracy": f"{train_accuracy:.4f}",
                "test_loss": f"{test_loss:.4f}",
                "test_accuracy": f"{test_accuracy:.4f}",
            }
            epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break

    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Test Loss over Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/{config.seed}loss_plot.png")
    plt.close()
    plt.figure()
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(test_accuracies, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Test Accuracy over Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{args.ckpt_path}/{config.seed}accuracy_plot.png")
    plt.close()

    with timeblock("model serialization"):
        weights_file = (
            f"mnist_vit_finetune_moe_seed{config.seed}_"
            f"opt_{config.optimizer}_lr_{config.learning_rate}_num_layers_{finetune_model.num_layers}_num_experts_{finetune_model.num_experts}_hidden_dim_{finetune_model.hidden_dim}_hidden_dim_expert_{finetune_model.hidden_dim_expert}_embedding_dim_{finetune_model.embedding_dim}_num_heads_{finetune_model.num_heads}_best_epoch{best_epoch}"
        )
        with open(f'{args.ckpt_path}/ckpt.txt', "a") as f:
            f.write(f"\n{weights_file}\n")
            for metric, value in best_metric.items():
                f.write(f"{metric}: {value}\n")
        with open(f'{args.ckpt_path}/{weights_file}', "wb") as f:
            f.write(flax.serialization.to_bytes(best_params))

if __name__ == "__main__":
    main()