from dataclasses import dataclass
from functools import partial
import pickle
import os

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training import train_state
from flax import serialization
from flax.training.common_utils import shard

import optax
import wandb
import time
from tqdm import tqdm
from act_flax import colu, rcolu
import jax_smi
jax_smi.initialise_tracking()
import sys
import numpy as np

ACT = {'relu':nn.relu,'colu':colu}
NAME = sys.argv[1]
act = ACT.get(NAME,nn.relu)
try:
    SIZE = sys.argv[2]
except:
    SIZE = 'small'

@dataclass
class Config():
    seed = int(time.time())
    num_iterations = 50000
    block_size = 64
    learning_rate = 1e-4
    if SIZE == 'small':
        batch_size = 4096
        embed_size = 256 #768
        num_heads = 8 #12
        head_size = 32 #64
        num_layers = 6 #12
        dropout = 0.2 #0.1
    elif SIZE == 'large':
        batch_size = 512
        embed_size = 768
        num_heads = 12
        head_size = 64
        num_layers = 12
        dropout = 0.
    else:
        raise NotImplementedError

config = Config()

with open("inputs/input.txt", "r", encoding="utf-8") as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: "".join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Let's now split up the data into train and validation sets
data = jnp.array(encode(text))
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
eval_data = data[n:]

dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

@jax.jit
def get_batch(random_key, data):
    # generate a small batch of data of inputs x and targets y
    ix = jax.random.randint(random_key, shape=(config.batch_size, 1), minval=0, maxval=len(data)-config.block_size)
    x = dynamic_slice_vmap(data, ix, (config.block_size,))
    y = dynamic_slice_vmap(data, ix+1, (config.block_size,))
    return x, y

class LayerNorm(nn.Module):
    epsilon: float = 1e-6
    reduction_axes = -1

    @nn.compact
    def __call__(self, x):
        """Applies layer normalization on the input."""
        # compute statistics
        mean2 = jnp.mean(jax.lax.square(x), self.reduction_axes, keepdims=True)
        mean = jnp.mean(x, self.reduction_axes, keepdims=True)
        var = jnp.maximum(0., mean2 - jax.lax.square(mean))

        # compute normalized inputs
        x_norm = (x - mean) * jax.lax.rsqrt(var + self.epsilon)
        return x_norm * self.param("scale", nn.initializers.ones, x.shape[-1]) + self.param("bias", nn.initializers.zeros, x.shape[-1])

CONV=False
DIM=32
class Attention(nn.Module):
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        if CONV:
            x = x.reshape(x.shape[:-1]+(DIM, 256//DIM,))
            key = nn.Conv(features=self.head_size//DIM, kernel_size=(3,), padding='CIRCULAR', use_bias=False)(x).reshape(x.shape[:-2]+(self.head_size,))
            query = nn.Conv(features=self.head_size//DIM, kernel_size=(3,), padding='CIRCULAR', use_bias=False)(x).reshape(x.shape[:-2]+(self.head_size,))
            value = nn.Conv(features=self.head_size//DIM, kernel_size=(3,), padding='CIRCULAR', use_bias=False)(x).reshape(x.shape[:-2]+(self.head_size,))
        else:
            key = nn.Dense(self.head_size, use_bias=False)(x)
            query = nn.Dense(self.head_size, use_bias=False)(x)
            value = nn.Dense(self.head_size, use_bias=False)(x)
        
        tril = jnp.tril(jnp.ones((x.shape[1], x.shape[1])))
        attention_weights = nn.softmax(jnp.where(tril == 0, -jnp.inf, query @ jnp.moveaxis(key, -1, -2)), axis=-1) # replaces jnp.transpose(key, axes=(0, 2, 1)) with moveaxis to shard dimensions
        attention_weights = nn.Dropout(config.dropout)(attention_weights, deterministic=not training)
        return attention_weights @ value

class MultiHeadAttention(nn.Module):
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = jnp.concatenate([Attention(self.head_size)(x, training) for _ in range(self.num_heads)], axis=-1)
        return nn.Dropout(config.dropout)(nn.Dense(self.num_heads*self.head_size)(x), deterministic=not training)

class FeedFoward(nn.Module):

    @nn.compact
    def __call__(self, x, training: bool):
        return nn.Dropout(config.dropout)(nn.Dense(config.embed_size)(act(nn.Dense(4*config.embed_size)(x))), deterministic=not training)
        # Replace with the following line:
        # return nn.Dropout(config.dropout)(nn.Dense(config.embed_size)(colu(nn.Dense(4*config.embed_size)(x))), deterministic=not training)

class Block(nn.Module):
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        x = x + MultiHeadAttention(self.num_heads, self.head_size)(LayerNorm()(x), training)
        return x + FeedFoward()(LayerNorm()(x), training)

class Model(nn.Module):
    num_layers: int
    num_heads: int
    head_size: int

    @nn.compact
    def __call__(self, x, training: bool):
        B, T = x.shape
        x = nn.Embed(num_embeddings=vocab_size, features=config.embed_size)(x) + \
            nn.Embed(num_embeddings=config.block_size, features=config.embed_size)(jnp.arange(T))
        for _ in range(self.num_layers):
            x = Block(self.num_heads, self.head_size)(x, training)
        x = nn.LayerNorm()(x)
        return nn.Dense(vocab_size)(x)

    def generate(self, random_key, params, context, length=50):
        for _ in range(length):
            logits = self.apply(params, context[:, -config.block_size:], training=False)
            random_key, random_subkey = jax.random.split(random_key)
            new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=(1, 1))
            context = jnp.concatenate([context, new_token], axis=1)
        return context

    @partial(jax.jit, static_argnames=("self", "length"))
    def generate_jit(self, random_key, params, length):
        def scan_generate(carry, x):
            key, context = carry
            logits = self.apply(params, context, training=False)
            random_key, random_subkey = jax.random.split(key)
            new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=(1, 1))
            context = jnp.concatenate([context[:, 1:], new_token], axis=1)
            return (random_key, context), new_token
        
        _, new_tokens = jax.lax.scan(
            scan_generate,
            (random_key, jnp.zeros((1, config.block_size), dtype=jnp.int32)),
            (),
            length=length,
        )
        return new_tokens

class TrainState(train_state.TrainState):
  key: jax.Array

def create_train_state(random_key, config):
    model = Model(num_layers=config.num_layers, num_heads=config.num_heads, head_size=config.head_size)
    params = model.init(random_key, jnp.ones((config.batch_size, config.block_size), dtype=jnp.int32), training=False)
    tx = optax.adamw(config.learning_rate)
    return TrainState.create(
        apply_fn=model.apply, params=params, key=random_key, tx=tx)


# Replace the train_step function with:
@partial(jax.pmap, axis_name='batch')
def train_step(state, x, y, dropout_key):
    dropout_key = jax.random.fold_in(key=dropout_key, data=state.step)
    def loss_fn(params):
        logits = state.apply_fn(params, x, training=True, rngs={'dropout': dropout_key})
        one_hot_encoded_labels = jax.nn.one_hot(y, num_classes=vocab_size)
        return optax.softmax_cross_entropy(
            logits=logits, labels=one_hot_encoded_labels
        ).mean()
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    loss = jax.lax.pmean(loss, axis_name='batch')
    state = state.apply_gradients(grads=grads)
    return state, loss

@jax.jit
def eval_step(state, x, y):
    logits = state.apply_fn(state.params, x, training=False)
    logits = flax.jax_utils.unreplicate(logits)
    one_hot_encoded_labels = jax.nn.one_hot(y, num_classes=vocab_size)
    return optax.softmax_cross_entropy(
        logits=logits, labels=one_hot_encoded_labels
    ).mean()

random_key = jax.random.PRNGKey(config.seed)
random_key, random_subkey = jax.random.split(random_key)

state = create_train_state(random_subkey, config)  
state = flax.jax_utils.replicate(state)
wandb_run = wandb.init(project="nanoGPT",name=NAME+'_'+SIZE)

# Add after model definition
def compute_flops(config):
    # Compute FLOPs for one forward pass
    seq_len = config.block_size
    hidden_dim = config.embed_size
    vocab_size = len(chars)
    
    # Embedding layer FLOPs
    embedding_flops = seq_len * hidden_dim
    
    # Per transformer block FLOPs
    # Self attention
    attention_flops = 4 * seq_len * hidden_dim * hidden_dim  # Q,K,V projections and output projection
    attention_flops += 2 * seq_len * seq_len * hidden_dim    # QK attention matrix multiply and attention * V
    
    # MLP
    mlp_flops = 8 * seq_len * hidden_dim * hidden_dim  # Factor of 8 because of 4x hidden dim in feedforward
    # Activation function in MLP (1 FLOP per element)
    activation_flops = seq_len * (4 * hidden_dim)  # Activation applied to 4x hidden dim
    if NAME == 'colu':
        activation_flops = activation_flops * 7
    mlp_flops += activation_flops
    
    # Layer norm
    ln_flops = 5 * seq_len * hidden_dim  # mean, variance, normalize, scale, bias
    
    # Total FLOPs per block
    block_flops = attention_flops + mlp_flops + 2 * ln_flops
    
    # Total FLOPs for all blocks + final layer norm + final projection
    total_flops = (config.num_layers * block_flops) + ln_flops + (seq_len * hidden_dim * vocab_size)
    
    return total_flops

# Calculate and print FLOPs before training
flops = compute_flops(config)
print(f"Approximate FLOPs per forward pass: {flops/1e9:.2f}B")

# Modify training loop
pbar = tqdm(range(config.num_iterations))
start_time = time.time()
for i in pbar:
    step_start = time.time()
    random_key, random_subkey = jax.random.split(random_key)
    # Reshape batch for multiple devices
    x, y = get_batch(random_subkey, train_data)
    x, y = shard(x), shard(y)
    random_subkey = jax.random.split(random_subkey, num=jax.device_count())
    state, loss = train_step(state, x, y, random_subkey)
    loss = flax.jax_utils.unreplicate(loss)
    
    step_time = time.time() - step_start
    steps_per_sec = 1.0 / step_time
    flops_per_sec = flops * config.batch_size * steps_per_sec
    
    pbar.set_postfix({
        "train_loss": f"{loss:.4f}",
        "steps/sec": f"{steps_per_sec:.2f}",
        "TFLOP/s": f"{flops_per_sec/1e12:.2f}"
    })
    
    if i % 100 == 0:
        random_key, random_subkey = jax.random.split(random_key)
        eval_loss = eval_step(flax.jax_utils.unreplicate(state), *get_batch(random_subkey, eval_data))
        wandb.log({
            "train_loss": loss,
            "eval_loss": eval_loss,
            "steps_per_sec": steps_per_sec,
            "tflops_per_sec": flops_per_sec/1e12
        })

total_time = time.time() - start_time
print(f"\nTraining completed in {total_time/60:.2f} minutes")
print(f"Average steps/sec: {config.num_iterations/total_time:.2f}")
print(f"Average TFLOP/s: {(flops * config.batch_size * config.num_iterations)/(total_time*1e12):.2f}")

# params_state_dict = serialization.to_state_dict(state.params)
params_state_dict = serialization.to_state_dict(flax.jax_utils.unreplicate(state.params))

os.makedirs("./outputs", exist_ok=True)
with open("./outputs/params.pickle", "wb") as params_file:
    pickle.dump(params_state_dict, params_file)

# Let's now generate some text
model = Model(num_layers=config.num_layers, num_heads=config.num_heads, head_size=config.head_size)
params = model.init(
    random_key, jnp.ones((config.batch_size, config.block_size), dtype=jnp.int32), training=False
)
with open("./outputs/params.pickle", "rb") as params_file:
    params_state_dict = pickle.load(params_file)
params = serialization.from_state_dict(params, params_state_dict)

text = model.generate_jit(random_key, params, 1000)[:, 0, 0].tolist()
print(decode(text))
