import jax
from jax import numpy as jnp
from jax import nn
import numpy as np
from flax import linen as nn
from flax.linen import initializers as nni
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import random as jr
from util import *
import optax
import pickle
import argparse

from model import Transformer
from tasks import Task, ModPSeq2SeqTask

class InContextMarkovChain:
    def __init__(self, vocab_size, alpha):
        self.vocab_size = vocab_size
        self.alpha = alpha
    
    # should the target be the true next token, or the bayes (or the desired induction head behavior?)
    def sample(self, length, key):
        pi_key, seq_key = jr.split(key, 2)
        prior = self.alpha * jnp.ones(self.vocab_size)
        pi = jr.dirichlet(pi_key, prior, [self.vocab_size])
        x = jnp.zeros((length,), dtype=int)
        
        def step(i, carry):
            x, k = carry
            k, subkey = jr.split(k)
            p = lax.cond(i == 0, lambda _:mu, lambda _:pi[x[i-1]], operand=None)
            x = x.at[i].set(jr.choice(subkey, pi.shape[0], p=p))
            return x, k

        x, _ = lax.fori_loop(0, length, step, (x, seq_key))
        test_token = x[-1]
        y = self.IH(x)
        return x, y 
    
    def IH(self, seq): # return answer for all lengths
        onehot = jax.nn.one_hot(seq, self.vocab_size)
        tensor = jnp.einsum("ij,ik->ijk", jnp.concatenate([onehot[:1, :], onehot[:-1, :]]), onehot)
        counts = jnp.cumsum(tensor, axis=0)
        total_counts = jnp.sum(counts, axis=2)[..., jnp.newaxis]
        avgs = jnp.where(total_counts > 0, counts/total_counts, 1./self.vocab_size)
        ans = avgs[jnp.arange(len(seq)), seq, :]
        return ans
    
    
    
class InContextkGram:
    def __init__(self, vocab_size, alpha, k):
        self.vocab_size = vocab_size
        self.alpha = alpha
        self.k = k
        
        # base-V weights for turning a length-k context into a single row index
        exps = jnp.arange(k-1, -1, -1, dtype=jnp.int32)
        self._powers = jnp.power(jnp.int32(vocab_size), exps).astype(jnp.int32)
        
    def _ctx_index(self, ctx):
        # ctx: shape (k,), dtype int32
        # returns scalar jnp.int32 in [0, V**k)
        return jnp.dot(ctx.astype(jnp.int32), self._powers)
    
    # target is desired induction head behavior
    def sample(self, length, key):
        pi_key, seq_key = jr.split(key, 2)
        prior = self.alpha * jnp.ones(self.vocab_size)
        pi = jr.dirichlet(pi_key, prior, [self.vocab_size**self.k])
        x = jnp.zeros((length,), dtype=int)
        
        def step(i, carry):
            x, rng = carry
            rng, subkey = jr.split(rng)
            
            def p_first_tokens(_):
                return jnp.ones((self.vocab_size,)) / self.vocab_size

            def p_with_context(_):
                # ctx = x[i-self.k:i]              # (k,)
                ctx = lax.dynamic_slice(x, (i-self.k,), (self.k,))
                idx = self._ctx_index(ctx)  # scalar
                return jnp.take(pi, idx, axis=0)  # (V,)
            
            p = lax.cond(i < self.k, p_first_tokens, p_with_context, operand=None)
            tok = jr.choice(subkey, self.vocab_size, p=p)
            x = x.at[i].set(tok)
            return (x, rng)

        x, out_key = lax.fori_loop(0, length, step, (x, seq_key))
        
        #guarantee that the last k tokens appear at least once more
        max_start = length - self.k + 1                      # upper bound is exclusive
        i = jr.randint(out_key, shape=(), minval=0, maxval=max_start, dtype=jnp.int32)
        tail = lax.dynamic_slice(x, (length - self.k,), (self.k,))     # (k,)
        x   = lax.dynamic_update_slice(x, tail, (i,))                 # write into x[i:i+k]

        y = self.IHk(x)
        return x, y 
    
    def IHk(self, seq):
        
        target_idx = self._ctx_index(seq[-self.k:])
        counts = jnp.zeros((self.vocab_size))
        
        pad = jnp.full((self.k,), seq[0], dtype=seq.dtype)
        seq_pad = jnp.concatenate([pad, seq], axis=0)  # shape: T + k
        
        def step(i, carry):
            counts = carry
            # ctx = seq[i-self.k:i]              # (k,)
            ctx = lax.dynamic_slice(seq_pad, (i,), (self.k,))
            idx = self._ctx_index(ctx)  # scalar
            next_token = seq[i]
            counts = lax.cond(idx == target_idx, lambda c: c.at[next_token].add(1), lambda c: c, counts)
            
            return counts
        
        counts = lax.fori_loop(0, len(seq), step, counts)
        total_counts = counts.sum()
        avgs = jnp.where(total_counts > 0, counts/total_counts, 1./self.vocab_size)
        return avgs
    
class MLP(nn.Module):
    d: int
    width: int 
    
    @nn.compact
    def __call__(self, x):
        #muP init
        z = nn.Dense(self.width, kernel_init=nni.normal(jnp.sqrt(1/self.d)), use_bias=True, name = 'layer1')(x)
        z = nn.relu(z)
        z = nn.Dense(self.d, kernel_init=nni.normal(jnp.sqrt(1/self.width)), use_bias=False, name = 'layer2')(z)
        return z
    
class RPE(nn.Module):
    tau: int
    heads: int
    
    def map_output(self, rpe, T):
        i = jnp.arange(T).reshape(-1, 1)  # (T, 1)
        j = jnp.arange(T).reshape(1, -1)  # (1, T)
    
        # Compute offset = i - j for each element
        offset = i - j  # shape (T, T)

        # Broadcast z over the (T, T) matrix, only keeping values for offset in [0, k-1]
        rpe_full = jnp.where((offset >= 0) & (offset < self.tau), rpe[offset], 0.0)
        return rpe_full
        
        
    
    @nn.compact
    def __call__(self, x):
        rpe = self.param('rpe', nni.zeros, (self.heads, self.tau))
        T = x.shape[-2]
        rpe_full = jax.vmap(self.map_output, in_axes=(0, None), out_axes=0)(rpe, T)
        return rpe_full
        # i = jnp.arange(T).reshape(-1, 1)  # (T, 1)
        # j = jnp.arange(T).reshape(1, -1)  # (1, T)
    
        # # Compute offset = i - j for each element
        # offset = i - j  # shape (T, T)
        # offset = jnp.broadcast_to(offset[None, :, :], (self.heads, T, T))

        # # Broadcast z over the (T, T) matrix, only keeping values for offset in [0, k-1]
        # rpe_full = jnp.where((offset >= 0) & (offset < self.tau), rpe[offset], 0.0)
        # return rpe_full

class SelfAttn(nn.Module):
    d: int 
    heads: int
    tau: int
    
    def attn(self, x, Q, K, rpe):
        T = x.shape[-2]
        attn = jnp.einsum("...ij,jm,km,...lk -> ...il", x, Q, K, x)
        # attn = jnp.log(jnp.arange(1, T+1))[None, :, None] * attn / self.d
        attn = attn/self.d
        attn = attn + jnp.log(T)*rpe
        attn = jnp.where(jnp.tri(T), attn, -jnp.inf)
        attn = nn.softmax(attn)
        attn = jnp.einsum("...ij,...jk->...ik", attn, x)
        return attn
    
    @nn.compact
    def __call__(self, x):
        
        Q = self.param('Q', nni.normal(1./jnp.sqrt(self.d)), (self.heads, self.d, self.d))
        K = self.param('K', nni.normal(1./jnp.sqrt(self.d)), (self.heads, self.d, self.d))
        O = self.param('O', nni.normal(1./jnp.sqrt(self.d)), (self.d, self.d))
        V = self.param('V', nni.normal(1./jnp.sqrt(self.d)), (self.heads*self.d, self.d))
        rpe = RPE(self.tau, self.heads)(x)
        print(rpe.shape)
        attn = jax.vmap(lambda a, b, c, r: self.attn(a, b, c, r), (None, 0, 0, 0), -2)(x, Q, K, rpe)
        attn = attn.reshape(*attn.shape[:-2], -1)
        attn = attn@V@O
        
        return attn

class Transformer(nn.Module):
    vocab_size: int
    max_length: int
    output_size: int
    d: int
    heads: int
    width: int # width of MLP
    
    def embed(self, x, wte):
        out = wte[x]
        return out
    
    @nn.compact
    def __call__(self, x):
        
        B, T = x.shape[:2]
    
        
        wte = self.param('wte', nni.normal(1.), (self.vocab_size, self.d))
        
        unembed = self.param('unembed', nni.normal(1.), (self.d,self.vocab_size))
        
        
        x = self.embed(x, wte)

        x = x + SelfAttn(self.d, self.heads, 4)(x)
        
        # x = x + MLP(self.d, self.width)(x)
        
        x = x + SelfAttn(self.d, self.heads, 4)(x)
        
        # x = x + MLP(self.d, self.width)(x)

        # output should be a length [batch, T] sequence
        return x@unembed/self.d
    
def run(vocab_size, k, seed):

    rng = RNG(seed)
    # problem = InContextMarkovChain(vocab_size, 1)
    problem = InContextkGram(vocab_size, 1., k=k)

    max_length = 80

    d_model = 16
    heads = k
    width = 128

    # initialize model
    model = Transformer(vocab_size, max_length, 1, d_model, heads, width)
    p0 = model.init(rng.next(), vmap(lambda key: problem.sample(max_length, key))(rng.next(2))[0])
    p0 = flatten_dict(p0["params"], sep=".")


    criterion = lambda f, y: vocab_size*jnp.mean((y-f)**2)

    @partial(jit, static_argnames="mutable")
    def f(p, *args, **kwargs):
        p = dict(params=unflatten_dict(p, sep="."))
        return model.apply(p, *args, **kwargs)
        
    @jit
    def loss_fn(p, batch):
        x, y = batch
        # only compute prediction on second token onwards
        # return vmap(criterion)(f(p, x)[:,1:], y).mean()
        return vmap(criterion)(f(p, x)[:,-1], y[:,-1]).mean()

    @jit
    def test_loss_fn(p, batch):
        x, y = batch
        # only compute prediction on last token
        return vmap(criterion)(f(p, x)[:,-1], y[:,-1]).mean()

    models = []

    # maximum training lengths
    train_lengths = [i for i in jnp.arange(5, 26, 5)]
    
    train_losses = []

    # loop over train_lengths
    for train_length in train_lengths:

        p = p0
        lr = 3e-2
        steps = 2**15
        save_every = steps // 128
        batch_size = 2**10
        max_size = 2**20 # 2**24
        # epoch_len = max_size // batch_size
        epoch_len = 2**10
        sample_fn = lambda k: vmap(lambda key: problem.sample(train_length, key))(jr.split(k, epoch_len * batch_size))

        def batch_iterator(key):
            while True:
                key, subkey = jr.split(key)
                batches = sample_fn(subkey)
                for i in range(epoch_len):
                    yield tree_map(lambda x: x[batch_size * i : batch_size * (i + 1)], batches)
                    
        iterator = batch_iterator(rng.next())
                    
        # muP scaling for Adam
        mapping = dict.fromkeys(p0.keys(), 'hidden')
        mapping['wte'] = 'embed'
        mapping['unembed'] = 'embed'
        mapping['SelfAttn_0.RPE_0.rpe'] = 'embed'
        mapping['SelfAttn_1.RPE_0.rpe'] = 'embed'
        opt = optax.multi_transform(
        {
            'embed': optax.adam(learning_rate = lr),
            'hidden': optax.adam(learning_rate = lr/d_model)
        },
        mapping)

        @jit
        def step_fn(p, batch, opt_state):
            loss, g = jax.value_and_grad(loss_fn)(p, batch)
            updates, opt_state = opt.update(g, opt_state, p)
            p = optax.apply_updates(p, updates)
            return p, opt_state, loss


        train_loss = []
        test_loss = []

        opt_state = opt.init(p0)
        for i in trange(steps):
            batch = next(iterator)
            loss = loss_fn(p, batch)
            if loss < 1e-5 or jnp.isinf(loss):
                break
            p, opt_state, loss = step_fn(p, batch, opt_state)
            train_loss.append(loss)

        print("final loss = ", loss)
        
        train_losses.append(loss)
        models.append(p)
        
    # evaluate models on test_lengths
    
    test_samples = 2**10
    test_rng = rng.next()
    test_lengths = [i for i in jnp.arange(20, 201, 10)]

    all_losses = []
    for test_length in test_lengths:
        test_losses = []
        testx, testy = vmap(lambda key: problem.sample(test_length, key))(jr.split(test_rng, test_samples))
        for p in models:
            test_loss = test_loss_fn(p, (testx, testy))
            test_losses.append(test_loss)
        all_losses.append(test_losses)
        
    all_losses = np.array(all_losses)
    # print(all_losses)
        
    # return all_losses, models
    return all_losses

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description='Process some arguments.')
    
    parser.add_argument('--seed', type=int, required=True, help='seed')
    
    parser.add_argument('--vocab', type=int, required=True, help='vocab')
    
    parser.add_argument('--k', type=int, required=True, help='k')
    
    args = parser.parse_args()
    
    seed = args.seed
    vocab_size = args.vocab
    k = args.k
    
        
    filename = 'IH_exact_vocab={}_k={}_seed={}.pkl'.format(vocab_size, k, seed)
    print("Running for vocab = {}, k = {}, seed = {}".format(vocab_size, k, seed))
    losses = run(vocab_size, k, seed)
    with open(filename, 'wb') as f:
        pickle.dump(losses, f)

    
