# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer Decoder-only model."""

# pylint: disable=g-importing-member
# pylint: disable=invalid-name
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
import dataclasses
from functools import partial

from flax import linen as nn
import jax
import jax.numpy as jnp
from .mu_task_base import MuTask
from learned_optimization.tasks import base
from typing import Any, Callable, Dict, Optional, Tuple, Union, Mapping
from jax.random import PRNGKey

Params = Any
ModelState = Any
Batch = Any
import functools
Mesh = jax.sharding.Mesh
from .mu_moe_mlp import MoeBlock, nd_dense_init
MoeConfig = Any

@dataclasses.dataclass
class DoConfig:
  """Hyper-parameters for Transformer decoder-only."""
  D: int  # model/embed dim  = qkv dim
  H: int  # num attention heads
  L: int  # max context/sequence length (move out of config?)
  N: int  # number of transformer block layers
  V: int  # vocab size
  F: int  # FF inner dimension
  use_kv_norm: bool = False
  ffn_type: str = 'regular_ffn'
  moe_config: MoeConfig = None
  kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform()
  embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling(
      1.0, 'fan_in', 'normal', out_axis=0)
  dtype: jnp.dtype = jnp.float32
  fsdp_enabled: bool = False

  # Transformer block rematerialization / gradient checkpointing to save memory.
  remat: bool = False


class TransformerDo(nn.Module):
  """Transformer decoder-only."""
  docfg: DoConfig



  def get_mup_lrs(self, params, device):
    """Returns the MuP learning rate multipliers that match the parameter structure."""
    def get_dense(v, fan_in, is_first_or_last=False):
      if is_first_or_last:  # first and last layers
        return jax.tree_util.tree_map(lambda x: jnp.array(1.0, dtype=jnp.float32, device=device), v)
      else:  # hidden layers
        return jax.tree_util.tree_map(lambda x: jnp.array(1.0/fan_in if 'kernel' in v else 1.0, 
                                               dtype=jnp.float32, device=device), v)
    
    def get_attention(v):
      # For attention layers, use 1/fan_in for all kernels
      result = {}
      for k, val in v.items():
        if k == 'attn_out_proj':
          # For output projection, fan_in is H*head_dim
          fan_in = val['kernel'].shape[0] * val['kernel'].shape[1]
          result[k] = {'kernel': jnp.array(1.0/fan_in, dtype=jnp.float32, device=device)}
        elif k in ['k_layer_norm', 'q_layer_norm']:
          # Layer norm parameters for key and query normalization
          # TODO(Ben) check if this is correct
          result[k] = get_layer_norm(val)
        else:  # query, key, value
          fan_in = val['kernel'].shape[0]  # embedding dimension
          result[k] = {'kernel': jnp.array(1.0/fan_in, dtype=jnp.float32, device=device)}
      return result
    
    def get_moe(v):
      # For MoE blocks
      result = {}
      
      # Gate kernel - use 1/fan_in
      if 'gate' in v:
        fan_in = v['gate']['kernel'].value.shape[0]
        result['gate'] = {'kernel': jax.tree_util.tree_map(
          lambda x: jnp.array(1.0/fan_in, dtype=jnp.float32, device=device), 
          v['gate']['kernel'])}
      
      # Expert weights - use 1/fan_in
      for key in ['wi_0', 'wi_1', 'wo']:
        if key in v:
          if key.startswith('wi'):
            fan_in = v[key].value.shape[1]  # embedding dimension
          else:  # wo
            fan_in = v[key].value.shape[1]  # mlp dimension
          result[key] = jax.tree_util.tree_map(
            lambda x: jnp.array(1.0/fan_in, dtype=jnp.float32, device=device), 
            v[key])
      
      return result
    
    def get_layer_norm(v):
      # Layer norm parameters always have lr 1.0
      return jax.tree_util.tree_map(lambda x: jnp.array(1.0, dtype=jnp.float32, device=device), v)
    
    # Start with all learning rates at 1.0
    lr_tree = jax.tree_util.tree_map(lambda x: jnp.array(1.0, dtype=jnp.float32, device=device), 
                           params['params'])
    
    # Process each parameter group
    for k, v in params['params'].items():
      if k == 'embed' or k == 'pos_embed' or k == 'output_proj':
        # Embeddings always have lr 1.0
        lr_tree[k] = jax.tree_util.tree_map(lambda x: jnp.array(1.0, dtype=jnp.float32, device=device), v)
      
      elif k.startswith('blocks_'):
        # Process transformer blocks
        for block_k, block_v in v.items():
          if 'CausalAttn' in block_k:
            lr_tree[k][block_k] = get_attention(block_v)
          elif 'LayerNorm' in block_k:
            lr_tree[k][block_k] = get_layer_norm(block_v)
          elif 'MoeBlock' in block_k:
            lr_tree[k][block_k] = get_moe(block_v)
          elif 'MlpSwiGLU' in block_k or 'Mlp' in block_k:
            # Handle SwiGLU MLP blocks
            result = {}
            for mlp_k, mlp_v in block_v.items():
              if 'Dense' in mlp_k:
                fan_in = mlp_v['kernel'].shape[0]
                result[mlp_k] = get_dense(mlp_v, fan_in)
            lr_tree[k][block_k] = result
          elif 'Dense' in block_k:
            fan_in = block_v['kernel'].shape[0]
            lr_tree[k][block_k] = get_dense(block_v, fan_in)
      
      elif k == 'out_ln':
        # Output layer norm
        lr_tree[k] = get_layer_norm(v)
    
    return {"params": lr_tree}


  def setup(self):
    cfg = self.docfg
    self.embed = nn.Embed(
        num_embeddings=cfg.V,
        features=cfg.D,
        embedding_init=nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"),
    )
    self.pos_embed = nn.Embed(
        num_embeddings=cfg.L,
        features=cfg.D,
        embedding_init=nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"),
    )

    # Optional output projection for untied weights
    self.tie_weights = getattr(cfg, 'tie_weights', False)
    if not self.tie_weights:
      self.output_proj = nn.Dense(
          features=cfg.V,
          use_bias=False,
          kernel_init=nn.initializers.zeros,
          dtype=cfg.dtype,
          name='output_proj'
      )

    block = nn.remat(TBlock) if cfg.remat else TBlock
    self.blocks = [block(cfg) for _ in range(cfg.N)]
    self.out_ln = nn.LayerNorm(dtype=cfg.dtype, use_bias=False)

  # def get_mup_lrs(self, params, device):
  #   return jax.tree_util.tree_map(lambda x: jnp.array([1.0], device=device), params)

  def __call__(self, y_BxL: jax.Array):
    # For training on concatenated examples.
    y_BxLxD = self.embed(y_BxL)
    y_BxLxD += self.pos_embed(jnp.arange(0, y_BxL.shape[1])[None, ...])
    
    # Track load balance losses
    load_balance_losses = []
    
    for i, block in enumerate(self.blocks):
      y_BxLxD, lbl_loss = block(y_BxLxD)

      load_balance_losses.append(lbl_loss)
      self.sow("intermediates", f"layer_{i}_load_balance_loss", lbl_loss)
    
    # Calculate total load balance loss if any
    total_load_balance_loss = sum(load_balance_losses)
    self.sow("intermediates", "load_balance_loss", total_load_balance_loss)
    
    y_BxLxD = self.out_ln(y_BxLxD)
    
    # Use either tied weights (embedding.attend) or separate output projection
    if self.tie_weights:
      logits_BxLxV = self.embed.attend(y_BxLxD.astype(jnp.float32))
    else:
      logits_BxLxV = self.output_proj(y_BxLxD)
      
    return logits_BxLxV


class Mlp(nn.Module):
  """Multilayer perceptron."""
  cfg: DoConfig

  @nn.compact
  def __call__(self, x_BxLxD: jax.Array):
    cfg = self.cfg
    linear = partial(
        nn.Dense, kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), use_bias=False,
        dtype=cfg.dtype
    )
    x_BxLxF = linear(cfg.F)(x_BxLxD)
    x_BxLxF = jax.nn.gelu(x_BxLxF)
    x_BxLxD = linear(cfg.D)(x_BxLxF)
    return x_BxLxD


class MlpSwiGLU(nn.Module):
  """Multilayer perceptron with SwiGLU activation."""
  cfg: DoConfig

  @nn.compact
  def __call__(self, x_BxLxD: jax.Array):
    cfg = self.cfg
    linear = partial(
        nn.Dense, kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"), use_bias=False,
        dtype=cfg.dtype
    )
    # SwiGLU uses two parallel projections
    x_BxLxF = linear(cfg.F)(x_BxLxD)
    gate_BxLxF = linear(cfg.F)(x_BxLxD)
    
    # Apply SwiGLU activation: x * sigmoid(gate * beta)
    x_BxLxF = x_BxLxF * jax.nn.sigmoid(gate_BxLxF * 1.0)
    
    # Project back to original dimension
    x_BxLxD = linear(cfg.D)(x_BxLxF)
    return x_BxLxD


class TBlock(nn.Module):
  """Transformer Block."""
  docfg: DoConfig

  @nn.compact
  def __call__(self, in_BxLxD: jax.Array):
    cfg = self.docfg

    # "pre-layernorm"
    x_BxLxD = nn.LayerNorm(dtype=cfg.dtype, use_bias=False)(in_BxLxD)
    x_BxLxD = CausalAttn(cfg)(x_BxLxD)
    x_BxLxD += in_BxLxD

    z_BxLxD = nn.LayerNorm(dtype=cfg.dtype, use_bias=False)(x_BxLxD)
    
    # Choose FFN type based on config
    if cfg.ffn_type == 'swiglu':
      z_BxLxD = MlpSwiGLU(cfg)(z_BxLxD)
      loss = 0.0
    elif cfg.ffn_type == 'moe':
      # Create and apply MoeBlock
      moe_layer = MoeBlock(
        config=cfg.moe_config,
        num_experts=cfg.moe_config.num_experts,
        num_experts_per_tok=cfg.moe_config.num_experts_per_tok,
        mesh=cfg.moe_config.mesh,
        kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"),
        kernel_axes=("embed", "experts"),
        name="MoeBlock"
      )
      # Apply MoE layer and get output and load balancing loss
      moe_output, loss = moe_layer(z_BxLxD)

    elif cfg.ffn_type == 'regular_ffn':
      # Default to standard MLP
      z_BxLxD = Mlp(cfg)(z_BxLxD)
      loss = 0.0
    else:
      raise ValueError(f"Invalid FFN type: {cfg.ffn_type}") 

    return x_BxLxD + z_BxLxD, loss


class CausalAttn(nn.Module):
  """Causal attention layer."""
  cfg: DoConfig

  @nn.compact
  def __call__(self, x_BxLxD: jax.Array):
    cfg = self.cfg

    assert cfg.D % cfg.H == 0, f'D {cfg.D} not divisible by H {cfg.H}'
    Dh = cfg.D // cfg.H

    # Maps D -> (H, Dh)
    multilinear = partial(
        nn.DenseGeneral,
        axis=-1,
        features=(cfg.H, Dh),
        kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"),
        use_bias=False,
        dtype=cfg.dtype,
    )

    q_BxLxHxDh, k_BxLxHxDh, v_BxLxHxDh = (
        multilinear(name='query')(x_BxLxD),
        multilinear(name='key')(x_BxLxD),
        multilinear(name='value')(x_BxLxD),
    )
    
    # Apply RMS normalization to k and v if enabled
    if cfg.use_kv_norm:
      # Use Flax's LayerNorm with separate instances for k and q
      k_layer_norm = nn.LayerNorm(epsilon=1e-9/cfg.H, use_bias=False, dtype=cfg.dtype, name='k_layer_norm')
      q_layer_norm = nn.LayerNorm(epsilon=1e-9/cfg.H, use_bias=False, dtype=cfg.dtype, name='q_layer_norm')
      
      # Apply separate LayerNorm instances along the last dimension
      k_BxLxHxDh = k_layer_norm(k_BxLxHxDh)
      q_BxLxHxDh = q_layer_norm(q_BxLxHxDh)
    
      q_BxLxHxDh /= Dh ** 0.5 # todo(Ben) check if this is correct
    else:
      q_BxLxHxDh /= Dh 

    
    att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh)
    # cast to fp32 for softmax
    att_BxHxLxL = att_BxHxLxL.astype(jnp.float32)

    # causal attention mask
    L = x_BxLxD.shape[1]
    mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_))

    _NEG_INF = jnp.finfo(cfg.dtype).min
    att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF)
    att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1)
    att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype)
    out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh)
    # Output projection followed by contraction back to original dims
    out_BxLxD = nn.DenseGeneral(
        features=cfg.D,
        name='attn_out_proj',
        axis=(-2, -1),
        kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal"),
        use_bias=False,
        dtype=cfg.dtype,
    )(out_BxLxHxDh)
    return out_BxLxD



class MoeConfig:
    def __init__(self):
      self.dtype = jnp.float32
      self.weight_dtype = jnp.float32
      self.emb_dim = 128
      self.mlp_dim = 256
      self.num_experts = 8
      self.num_experts_per_tok = 2
      self.capacity_factor = 1.0
      self.matmul_precision = jax.lax.Precision.DEFAULT
      self.model_call_mode = "train"
      self.load_balance_loss_weight = 0.01
      self.use_moe_linears = True
      self.ici_expert_parallelism = 1
      self.dcn_expert_parallelism = 1
      self.activations_in_float32 = True
      self.megablox = False
      self.activation = nn.swish
      self.mlp_activations = ["swish"]
      devices = jax.devices()
      self.mesh = Mesh(devices, ("data",))

class MuTransformerMoETask(base.Task, MuTask):
  """Transformer-based next token prediction task with MoE layers."""
  
  def __init__(self, datasets, name, cfg,
               mup_multipliers=dict(input_mult=1.0,
                                    output_mult=1.0,
                                    hidden_lr_mult=1.0)):
    cfg['vocab_size'] = datasets.extra_info['vocab_size']
    cfg['input_mult'] = mup_multipliers['input_mult']
    cfg['output_mult'] = mup_multipliers['output_mult']
    self.hidden_lr_mult = mup_multipliers['hidden_lr_mult']
    self.task_name = name
    
    # Create mesh for MoE
    devices = jax.devices()
    self.mesh = Mesh(devices, ("data",))
    
    # Configure MoE parameters
    self.num_experts = cfg.get('num_experts', 8)
    self.num_experts_per_tok = cfg.get('num_experts_per_tok', 2)
    self.capacity_factor = cfg.get('capacity_factor', 1.5)
    self.load_balance_loss_weight = cfg.get('load_balance_loss_weight', 0.01)
    
    # Create MoE config
    moe_config = MoeConfig()
    moe_config.num_experts = self.num_experts
    moe_config.num_experts_per_tok = self.num_experts_per_tok
    moe_config.capacity_factor = self.capacity_factor
    moe_config.load_balance_loss_weight = self.load_balance_loss_weight
    moe_config.mesh = self.mesh
    moe_config.emb_dim = cfg['model_dim']
    moe_config.mlp_dim = cfg['ffn_dim']

    import pprint
    pprint.pprint(cfg)
    # Create transformer config
    self.transformer_config = DoConfig(
        D=cfg.get('model_dim', 128),
        H=cfg.get('num_heads', 4),
        L=cfg.get('max_seq_len', 64),
        N=cfg.get('num_layers', 2),
        V=cfg.get('vocab_size', 1000),
        F=cfg.get('ffn_dim', 256),
        use_kv_norm=cfg.get('use_kv_norm', True),
        ffn_type=cfg.get('ffn_type', 'moe'),
        moe_config=moe_config,
        dtype=jnp.float32,
        fsdp_enabled=False,
        remat=cfg.get('remat', False)
    )
    pprint.pprint(self.transformer_config)
    # exit(0)
    
    # Create the Flax module
    self.flax_module = TransformerDo(docfg=self.transformer_config)
    
    self.datasets = datasets
    self.dropout_rate = cfg.get('dropout_rate', 0.0)
    self.mup_lrs = None
    self.mup_state = None

    self.eps_mult = 1 / cfg.get('model_dim', 128)
    self.init_mup_state()

  def init(self, key: PRNGKey):
    batch = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape, x.dtype),
                                   self.datasets.abstract_batch)
    
    return self.flax_module.init({"params": key}, batch["image"])

  def init_with_state(self, key: PRNGKey) -> Tuple[Params, ModelState]:
    params = self.init(key)
    if self.mup_lrs is None:
      self.mup_lrs = self.flax_module.get_mup_lrs(params, jax.devices()[jax.process_index()])
    state = {'flax_mup_lrs': self.mup_lrs}
    return params, self.get_mup_state(state, eps_mult=self.eps_mult)

  def get_loss(self, logits, data):
    """Compute cross entropy loss for next token prediction."""
    targets = data["label"]
    mask = (data["image"] != 0)
    loss = base.softmax_cross_entropy(
        logits=logits, labels=jax.nn.one_hot(targets, self.transformer_config.V))
    return jnp.sum(loss * mask) / jnp.sum(mask)
  
  @functools.partial(jax.jit, static_argnums=(0,))
  def loss(self, params: Any, key: PRNGKey, data: Any):
    # Forward pass
    logits, intermediates = self.flax_module.apply(
        params,
        data["image"],
        mutable=['intermediates']
    )
    
    # Calculate next token prediction loss
    task_loss = self.get_loss(logits, data)
    
    # Add load balance loss if present
    total_loss = task_loss
    if 'intermediates' in intermediates and 'load_balance_loss' in intermediates['intermediates']:
        load_balance_loss = intermediates['intermediates']['load_balance_loss'][0]
        total_loss = task_loss + load_balance_loss
    
    return total_loss

  @functools.partial(jax.jit, static_argnums=(0,))
  def loss_with_state(self, params: Any, state: Any, key: PRNGKey, data: Any):
    # Forward pass with intermediates to capture load balance losses
    logits, intermediates = self.flax_module.apply(
        params,
        data["image"],
        mutable=['intermediates']
    )
    
    # Calculate next token prediction loss
    task_loss = self.get_loss(logits, data)
    
    # Add load balance loss if present
    total_loss = task_loss
    if 'intermediates' in intermediates and 'load_balance_loss' in intermediates['intermediates']:
        load_balance_loss = intermediates['intermediates']['load_balance_loss'][0]
        total_loss = task_loss + load_balance_loss
    
    return total_loss, self.get_mup_state(state)

  @functools.partial(jax.jit, static_argnums=(0,))
  def loss_with_state_and_aux(
      self, params: Params, state: ModelState, key: PRNGKey,
      data: Batch) -> Tuple[jnp.ndarray, ModelState, Mapping[str, jnp.ndarray]]:
    
    # Forward pass with intermediates to capture load balance losses
    logits, intermediates = self.flax_module.apply(
        params,
        data["image"],
        mutable=['intermediates']
    )
    
    # Calculate next token prediction loss
    task_loss = self.get_loss(logits, data)
    
    # Prepare aux dict with task loss and load balance losses
    aux = {'train loss': task_loss}

    
    # Add load balance losses if present
    total_loss = task_loss
    if 'intermediates' in intermediates and 'load_balance_loss' in intermediates['intermediates']:
        load_balance_loss = intermediates['intermediates']['load_balance_loss'][0]
        total_loss = task_loss + load_balance_loss
        aux['load_balance_loss'] = load_balance_loss
        
        # Add individual layer load balance losses if available
        for k, v in intermediates['intermediates'].items():
            if k.endswith('_load_balance_loss'):
                aux[k] = v[0]

    

    
    return total_loss, {'state': self.get_mup_state(state), 'aux': aux}


if __name__ == "__main__":

  


  # Create a small transformer configuration
  config = DoConfig(
      D=128,  # model dimension
      H=4,    # number of attention heads
      L=64,   # max sequence length
      N=2,    # number of transformer blocks
      V=1000, # vocabulary size
      F=256,  # feed-forward dimension
      use_kv_norm=True,
      ffn_type='regular_ffn',
      moe_config=MoeConfig()
  )
  
  # Initialize model
  model = TransformerDo(docfg=config,)
  
  # Create random input data
  batch_size = 2
  seq_length = 32
  rng = jax.random.PRNGKey(0)
  input_tokens = jax.random.randint(
      rng, shape=(batch_size, seq_length), minval=0, maxval=config.V
  )
  
  # Initialize parameters
  params = model.init(rng, input_tokens)
  import pprint
  pprint.pprint(jax.tree_util.tree_map(lambda x: x.shape, params))


  pprint.pprint(jax.tree_util.tree_map(lambda x: x, model.get_mup_lrs(params, jax.devices()[jax.process_index()])))
  exit(0)
  
  # Run a forward pass
  logits, intermediates = model.apply(params, 
                                      input_tokens,
                                      mutable=['intermediates'])
  
  # Print shapes and summary
  print(f"Input shape: {input_tokens.shape}")
  print(f"Output logits shape: {logits.shape}")
  print(f"Model parameter count: {sum(x.size for x in jax.tree_util.tree_leaves(params))}")
  
  # Test a prediction
  predicted_tokens = jnp.argmax(logits, axis=-1)
  print(f"Predicted tokens shape: {predicted_tokens.shape}")
  print("Test successful!")

  # Extract and print the load balance losses from intermediates
  if intermediates and 'intermediates' in intermediates:
    if 'load_balance_loss' in intermediates['intermediates']:
      total_load_balance_loss = intermediates['intermediates']['load_balance_loss']
      print(f"Total load balance loss: {total_load_balance_loss}")
