import operator
import re
import time
from contextlib import contextmanager

import jax.numpy as jnp
from flax import traverse_util
from flax.core import freeze, unfreeze
from jax import random
from jax.tree_util import tree_reduce, tree_map

rngmix = lambda rng, x: random.fold_in(rng, hash(x) % 2**16)

@contextmanager
def timeblock(name):
  start = time.time()
  try:
    yield
  finally:
    end = time.time()
    print(f"{name} took {end - start:.5f} seconds")

class RngPooper:
  """A stateful wrapper around stateless random.PRNGKey's."""

  def __init__(self, init_rng):
    self.rng = init_rng

  def poop(self):
    self.rng, rng_key = random.split(self.rng)
    return rng_key

def l1prox(x, alpha):
  return jnp.sign(x) * jnp.maximum(0, jnp.abs(x) - alpha)

def ec2_get_instance_type():
  # See also https://stackoverflow.com/questions/51486405/aws-ec2-command-line-display-instance-type/51486782
  return open("/sys/devices/virtual/dmi/id/product_name").read().strip()

# Utilities for dealing with flax model parameters
def partition(pred, iterable):
  trues = []
  falses = []
  for item in iterable:
    if pred(item):
      trues.append(item)
    else:
      falses.append(item)
  return trues, falses

def partition_dict(pred, d):
  trues = {}
  falses = {}
  for k, v in d.items():
    if pred(k):
      trues[k] = v
    else:
      falses[k] = v
  return trues, falses

def flatten_params(params):
  return {"/".join(k): v for k, v in traverse_util.flatten_dict(unfreeze(params)).items()}

def unflatten_params(flat_params):
  return freeze(
      traverse_util.unflatten_dict({tuple(k.split("/")): v
                                    for k, v in flat_params.items()}))

def merge_params(a, b):
  return unflatten_params({**a, **b})

def kmatch(pattern, key):
  regex = "^"
  i = 0
  while i < len(pattern):
    if pattern[i] == "*":
      if i + 1 < len(pattern) and pattern[i + 1] == "*":
        regex += "(.*)"
        i += 2
      else:
        regex += "([^\/]*)"
        i += 1
    else:
      regex += pattern[i]
      i += 1
  regex += "$"
  return re.fullmatch(regex, key)

assert kmatch("*", "a") is not None
assert kmatch("*", "a").group(0) == "a"
assert kmatch("*", "a").group(1) == "a"
assert kmatch("abc", "def") is None
assert kmatch("abc/*/ghi", "abc/def/ghi").group(1) == "def"
assert kmatch("abc/**/jkl", "abc/def/ghi/jkl").group(1) == "def/ghi"
assert kmatch("abc/*/jkl", "abc/def/ghi/jkl") is None
assert kmatch("**/*", "abc/def/ghi/jkl").group(1) == "abc/def/ghi"
assert kmatch("**/*", "abc/def/ghi/jkl").group(2) == "jkl"

def lerp(lam, t1, t2):
  return tree_map(lambda a, b: (1 - lam) * a + lam * b, t1, t2)

def tree_norm(t):
  return jnp.sqrt(tree_reduce(operator.add, tree_map(lambda x: jnp.sum(x**2), t)))

def tree_l2(t1, t2):
  return tree_norm(tree_map(lambda x, y: x - y, t1, t2))

def slerp(lam, t1, t2):
  # See https://en.wikipedia.org/wiki/Slerp
  om = jnp.arccos(
      tree_reduce(operator.add, tree_map(lambda x, y: jnp.sum(x * y), t1, t2)) /
      (tree_norm(t1) * tree_norm(t2)))
  sinom = jnp.sin(om)
  return tree_map(
      lambda x, y: jnp.sin((1 - lam) * om) / sinom * x + jnp.sin(lam * om) / sinom * y,
      t1,
      t2,
  )

def compute_weights_cost_matrices(model, params_a, params_b):
    """
    Compute weight-based cost matrices for two MoE models based on gating weights and expert parameters.
    
    Args:
        model: ViTModelMoE instance with a get_moe_params method and attributes num_experts, embedding_dim, hidden_dim
        params_a: Parameters of the first model
        params_b: Parameters of the second model
    
    Returns:
        D: Distance matrix [num_experts, num_experts] based on expert parameters using Frobenius norm
        S: Distance matrix [num_experts, num_experts] based on gating weights
    """
    moe_params_a = model.get_moe_params(params_a) 
    moe_params_b = model.get_moe_params(params_b)
    num_experts = model.num_experts

    # Gating-based similarity matrix S
    gating_kernel_a = moe_params_a['gating_kernel']
    gating_bias_a = moe_params_a['gating_bias']
    gating_kernel_b = moe_params_b['gating_kernel']
    gating_bias_b = moe_params_b['gating_bias']
    
    # Center the gating kernels and biases
    mean_gating_kernel_a = np.mean(gating_kernel_a, axis=0)
    mean_gating_bias_a = np.mean(gating_bias_a)
    centered_gating_kernel_a = gating_kernel_a - mean_gating_kernel_a
    centered_gating_bias_a = gating_bias_a - mean_gating_bias_a

    mean_gating_kernel_b = np.mean(gating_kernel_b, axis=0)
    mean_gating_bias_b = np.mean(gating_bias_b)
    centered_gating_kernel_b = gating_kernel_b - mean_gating_kernel_b
    centered_gating_bias_b = gating_bias_b - mean_gating_bias_b

    # Construct gating vectors with centered kernels and biases
    gating_vectors_a = np.hstack([centered_gating_kernel_a.T, centered_gating_bias_a[:, np.newaxis]])
    gating_vectors_b = np.hstack([centered_gating_kernel_b.T, centered_gating_bias_b[:, np.newaxis]])

    # Compute the difference between gating vectors
    diff_vectors = gating_vectors_a[:, np.newaxis, :] - gating_vectors_b[np.newaxis, :, :]

    # Compute the Euclidean distance
    S = np.sqrt(np.sum(diff_vectors ** 2, axis=2))

    # Expert parameters-based distance matrix D
    D = np.zeros((num_experts, num_experts))
    for i in range(num_experts):
        W1_a = moe_params_a[f'expert_{i}_layer1_kernel']
        b1_a = moe_params_a[f'expert_{i}_layer1_bias']
        W1p_a = moe_params_a[f'expert_{i}_layer2_kernel']
        b1p_a = moe_params_a[f'expert_{i}_layer2_bias']
        
        W_tilde1_a = np.vstack([W1_a, b1_a[np.newaxis, :]])
        W_tilde1p_a = np.vstack([W1p_a, b1p_a[np.newaxis, :]])
        
        gram1_a = W_tilde1_a.T @ W_tilde1_a
        gram1p_a = W_tilde1p_a @ W_tilde1p_a.T
        
        for j in range(num_experts):
            W1_b = moe_params_b[f'expert_{j}_layer1_kernel']
            b1_b = moe_params_b[f'expert_{j}_layer1_bias']
            W1p_b = moe_params_b[f'expert_{j}_layer2_kernel']
            b1p_b = moe_params_b[f'expert_{j}_layer2_bias']
            
            W_tilde1_b = np.vstack([W1_b, b1_b[np.newaxis, :]])
            W_tilde1p_b = np.vstack([W1p_b, b1p_b[np.newaxis, :]])
            
            gram1_b = W_tilde1_b.T @ W_tilde1_b
            gram1p_b = W_tilde1p_b @ W_tilde1p_b.T
            
            diff1 = gram1_a - gram1_b
            diff1p = gram1p_a - gram1p_b
            norm_diff1 = np.linalg.norm(diff1, 'fro')
            norm_diff1p = np.linalg.norm(diff1p, 'fro')
            D[i, j] = np.sqrt(norm_diff1**2 + norm_diff1p**2)

    return D, S
