import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
# tf.config.run_functions_eagerly(True)
from typing import Optional
import jax
import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn
import optax
import tensorflow_datasets as tfds
from flax.training.train_state import TrainState
from flax.training.common_utils import shard
# import sys
# sys.path.append('.')
from tqdm.auto import tqdm
import time
import functools
# import wandb
from jax_smi import initialise_tracking
initialise_tracking()

@jax.jit
def loglu(x: jnp.ndarray):
    return jnp.maximum(x,-jnp.log(-x+1.))

@functools.partial(jax.jit, static_argnames=['channel_axis','variant','eps','num_groups','dim','share_axis'])
def colu(input: jnp.ndarray,
         channel_axis: int = -1,
         variant: str = "hard",
         eps: float = 1e-7,
         num_groups: Optional[int] = None,
         dim: Optional[int] = 4,
         share_axis: bool = False
         ):
    """project the input x onto the axes dimension"""
    """G=number of cones, S=dim of cones"""
    """output dimension = S = axes + cone sections = [len=(G or 1)] + G * [len=(S-1)]"""
    """jnp.moveaxis is avoided to optimize speed on TPU"""
    shape = input.shape
    if len(shape) == 0:
        return input # edge case
    assert (dim is not None) ^ (num_groups is not None) # specify one of both, infer the other

    if share_axis:
        if dim is None:
            assert (shape[channel_axis] - 1) % num_groups == 0
            dim = (shape[channel_axis] - 1) // num_groups + 1
        if num_groups is None:
            assert (shape[channel_axis] - 1) % (dim - 1) == 0
            num_groups = (shape[channel_axis] - 1) // (dim - 1)
    else:
        if dim is None:
            assert shape[channel_axis] % num_groups == 0
            dim = shape[channel_axis] // num_groups
        if num_groups is None:
            assert shape[channel_axis] % dim == 0
            num_groups = shape[channel_axis] // dim

    if dim == 2: # pointwise case
        return nn.silu(input) if variant == "soft" else nn.relu(input)

    # y = axes, x = cone sections
    if share_axis:
        y, x = jnp.split(input, [1], axis=channel_axis)
    else:
        y, x = jnp.split(input, [num_groups], axis=channel_axis)

    assert channel_axis < 0, "channel_axis must be negative" # Comply with broadcasting on first dimensions
    x_old_shape = x.shape
    y_old_shape = y.shape
    x_shape = x.shape[:channel_axis] + (num_groups, dim - 1) # NG(S-1)
    if share_axis:
        y_shape = y.shape[:channel_axis] + (1, 1) # N11
    else:
        y_shape = y.shape[:channel_axis] + (num_groups, 1) # NG1
    if channel_axis < -1:
        x_shape += x.shape[(channel_axis+1):] # NGSHW if channel_axis = -3
        y_shape += y.shape[(channel_axis+1):] # NG1HW
    x = x.reshape(x_shape)
    y = y.reshape(y_shape)

    xn = jnp.linalg.norm(x,axis=channel_axis,keepdims=True) # NG1HW

    mask = y / (xn + eps) # NG1HW
    if variant == "sqrt":
        mask = jnp.sqrt(mask)
    elif variant == "log":
        mask = jnp.log(jnp.max(mask,0)+1)
    elif variant == "soft":
        mask = nn.sigmoid(mask - .5)
    elif variant == "hard":
        mask = mask.clip(0,1)
    else:
        raise NotImplementedError("variant must be soft or hard.")

    x = mask * x # NGSHW
    x = x.reshape(x_old_shape)
    y = y.reshape(y_old_shape)
    output = jnp.concatenate([y,x],axis=channel_axis)

    return output

@functools.partial(jax.jit, static_argnames=['scaling','eps'])
def rcolu_(x, scaling="constant",eps=1e-8):
    """x = w + v, v || e"""
    C = x.shape[-1]
    # e = jnp.ones(C) / jnp.sqrt(C)
    vn = jnp.sum(x,axis=-1,keepdims=True) / jnp.sqrt(C) # dot(x, e)
    v = jnp.repeat(vn,C,axis=-1) / jnp.sqrt(C) # outer(v, e)
    w = x - v
    wn = jnp.linalg.norm(w, x=-1, keepdims=True)
    if scaling == 'constant':
        m = jnp.maximum(vn, 0.) / (wn + eps)
        m = jnp.minimum(m, 1.)
    else:
        m = nn.sigmoid(vn - .5)
    w_ = w * m # project onto cone
    x = v + w_
    return x

@functools.partial(jax.jit, static_argnames=['scaling','eps'])
def rcolu_normal(x: jnp.ndarray, scaling: str = "constant", eps: float = 1e-8) -> jnp.ndarray:
    """Updated version using normal projection (experimental)"""
    """x = w + v, v || e"""
    C = x.shape[-1]
    vn = jnp.sum(x, axis=-1, keepdims=True) / jnp.sqrt(C)  # dot(x, e)
    v = jnp.broadcast_to(vn, (*vn.shape[:-1], C)) / jnp.sqrt(C)  # outer(vn, e)
    w = x - v
    wn = jnp.linalg.norm(w, axis=-1, keepdims=True)
    m = vn / (wn + eps)
    
    if scaling == 'constant':
        m_ = jnp.clip(m, -1, 1)
        r = m - m_
    else:
        raise NotImplementedError("Soft scaling is implemented.")
    
    w_ = w * m_  # project onto cone
    r1 = 1 + r / 2.  # relative increment
    x = r1 * (v + w_)
    return x

@functools.partial(jax.jit, static_argnames=['dim','num_groups','axis','scaling','normal','eps'])
def rcolu(x,
          dim=4,
          num_groups=None,
          scaling='constant',
          normal=False,
          axis=-1,
          eps=1e-7
          ):
    """dim=S, num_groups=S"""
    if len(x.shape) == 0:
        return x
    assert (dim is not None) ^ (num_groups is not None) # specify one of both
    shape = x.shape
    if dim is None:
        assert shape[-1] % num_groups == 0
        dim = shape[-1] // num_groups
    if num_groups is None:
        assert shape[-1] % dim == 0
        num_groups = shape[-1] // dim
    if axis != -1:
        x = jnp.moveaxis(x, axis, -1)
    new_shape = x.shape[:-1] + (num_groups, dim)
    x = x.reshape(new_shape)
    if normal:
        x = rcolu_normal(x,scaling,eps)
    else:
        x = rcolu_(x,scaling,eps)
    x = x.reshape(shape)
    if axis != -1:
        x = jnp.moveaxis(x, -1, axis)
    return x

# # some test
# x = jnp.zeros(6).at[0].set(1)
# y = rcolu(x,dim=3)
# y.shape
# # some assertion
# y1 = jnp.sum(y,axis=-1,keepdims=True) / jnp.sqrt(3) # dot(y, e)
# jnp.linalg.norm(y) / y1


def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits, one_hot_labels).mean()

def accuracy(logits, labels):
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == labels)

@jax.pmap 
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits, batch['label'])
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    
    # Calculate the gradient norm
    grad_norm = jnp.linalg.norm(jax.tree_util.tree_flatten(grads)[0][0])  # Flatten and compute norm

    state = state.apply_gradients(grads=grads)
    acc = accuracy(logits, batch['label'])
    
    return state, loss, acc, grad_norm  # Return gradient norm as well

@jax.pmap
def eval_step(state, batch):
    logits = state.apply_fn({'params': state.params}, batch['image'])
    loss = cross_entropy_loss(logits, batch['label'])
    acc = accuracy(logits, batch['label'])
    return loss, acc

def prepare_data(data_name):
    ds_builder = tfds.builder(data_name)
    ds_builder.download_and_prepare()
    # train_ds = tfds.as_numpy(tfds.load(data_name, split='train', batch_size=128, shuffle_files=True))
    # test_ds = tfds.as_numpy(tfds.load(data_name, split='test', batch_size=1000))

    # Define normalization function using TensorFlow operations
    def normalize_img(example):
        image = example['image']
        image = tf.cast(image, tf.float32) / 255.0  # Convert to float and scale to [0, 1]
        # image = (image - 0.1307) / 0.3081            # Normalize with mean and std
        example['image'] = image
        return example

    # Prepare training and testing datasets
    train_ds = tfds.load(
        data_name, split='train', batch_size=512, shuffle_files=True
    ).map(normalize_img).prefetch(tf.data.experimental.AUTOTUNE)

    test_ds = tfds.load(
        data_name, split='test', batch_size=10000
    ).map(normalize_img).prefetch(tf.data.experimental.AUTOTUNE)

    # Convert the datasets to NumPy format to be compatible with JAX
    train_ds = tfds.as_numpy(train_ds)
    test_ds = tfds.as_numpy(test_ds)

    # Get number of iterations for training and testing sets
    num_iters = len(train_ds)
    num_test_iters = len(test_ds)
    return train_ds, test_ds, num_iters, num_test_iters

def create_train_state(model, rng, learning_rate, weight_decay=.01):
    params = model.init(rng, jnp.ones([1, 28, 28, 1]))['params']

    def weight_decay_mask(params):
        return {k: 'bias' not in k and 'scale' not in k for k in params.keys()}

    tx = optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay, mask=weight_decay_mask(params))
    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def train_and_evaluate(model, num_epochs, learning_rate):
    rng = jax.random.PRNGKey(int(time.time()+jax.process_index()))
    rngs = {'params': rng}
    train_ds, test_ds, num_iters, num_test_iters = prepare_data('mnist')
    state = create_train_state(model, rngs, learning_rate)
    # print(jax.tree.map(lambda x:x.shape,state.params))

    # Replicate state across devices
    # state = jax.device_put_replicated(state, jax.local_devices()) # this only works on single-host machine
    devices = jax.local_devices()
    state = flax.jax_utils.replicate(state, devices)
    
    # bar = range(num_epochs)
    bar = tqdm(range(num_epochs), leave=False, desc="Epochs")
    best_test_acc = list()
    for _ in bar:
        # pbar = tqdm(range(num_iters),leave=1)
        # Training loop
        for batch in train_ds:
            # pbar.update(1)
            batch = shard(batch)
            state, train_loss, train_acc, grad_norm = train_step(state, batch)
            train_loss, train_acc, grad_norm = flax.jax_utils.unreplicate((train_loss, train_acc, grad_norm))
            # pbar.set_postfix(dict(train_loss=train_loss.mean(), train_acc=train_acc.mean()))  # set_postfix within the inner loop

        # Evaluation loop
        test_loss, test_acc = 0, 0
        for batch in test_ds:
            batch = shard(batch)
            loss, acc = eval_step(state, batch)
            test_loss += loss.mean()
            test_acc += acc.mean()

        test_loss /= num_test_iters
        test_acc /= num_test_iters
        best_test_acc.append(test_acc)

        bar.set_postfix({'Test Loss':test_loss,'Test Acc':test_acc,'grad_norm':grad_norm})
        # pbar.close()

    bar.close()
    best_test_acc = max(best_test_acc)
    # print('Train Loss',train_loss,'Test Loss',test_loss,'Test Acc',test_acc,'Best Test Acc',best_test_acc)
    return best_test_acc, train_loss

DIM = 49
# validate different activation functions
def validate_activation(fn,C=512,method='default',num_epochs=50,name='relu',num_exp = 1):
    class MNISTModel(nn.Module):
        @nn.compact
        def __call__(self, x):
            if method == 'extrapolate-aggregate':
                x = jnp.repeat(x.reshape(-1, 1, 28 * 28),DIM,axis=1)
                x = nn.Conv(kernel_size=(3,),padding="CIRCULAR",features=C)(x)
                x = jnp.mean(x,axis=1,keepdims=False)
            elif method == 'extrapolate':
                x = jnp.repeat(x.reshape(-1, 1, 28 * 28),DIM,axis=1)
                x = nn.Conv(kernel_size=(3,),padding="CIRCULAR",features=C//DIM)(x)
                x = x.reshape(-1, C)
            elif method == 'split':
                x = x.reshape(-1, DIM, 28 * 28//DIM)
                x = nn.Conv(kernel_size=(3,),padding="CIRCULAR",features=C//DIM)(x)
                x = x.reshape(-1, C)
            elif method == 'split-dense':
                x = x.reshape(-1, DIM, 28 * 28//DIM)
                x = nn.Dense(features=C//DIM)(x)
                x = x.reshape(-1, C)
            elif method == "id":
                x = x.reshape(-1, 28 * 28)
            elif method == "nobias":
                x = x.reshape(-1, 28 * 28)
                x = nn.Dense(features=C,use_bias=False)(x)
            else:
                x = x.reshape(-1, 28 * 28)  # Flatten the input
                x = nn.Dense(features=C)(x)
            x = fn(x)
            x = nn.Dense(features=10)(x)
            return x

    model = MNISTModel()
    best_test_acc = []
    train_loss = []

    for _ in range(num_exp): # replaced by parallelization
        acc, loss = train_and_evaluate(model, num_epochs=num_epochs, learning_rate=1e-3)
        best_test_acc.append(acc)
        train_loss.append(loss)
        # print(f'Test Accuracy: {acc}')
        # print(f'Train Loss: {loss}')
        # best_test_acc = jax.experimental.multihost_utils.process_allgather(acc)
        # train_loss = jax.experimental.multihost_utils.process_allgather(loss)

    ma = np.mean(best_test_acc).item()
    sa = np.std(best_test_acc).item()
    ml = np.mean(train_loss).item()
    sl = np.std(train_loss).item()
    print(f'Test Accuracy: {ma:.4f} ± {sa:.4f}')
    print(f'Train Loss: {ml:.4f} ± {sl:.4f}')
    return {'name': name, 'Acc': ma, 'StdAcc': sa, 'Loss': ml, 'StdLoss': sl, 'C': C, }  # Include C and method


def run():
    scolu = functools.partial(colu,variant='soft')
    rncolu = functools.partial(rcolu,normal=True)
    colu2 = functools.partial(colu,dim=2)
    colu1200 = functools.partial(colu,dim=1200)
    colu2400 = functools.partial(colu,dim=2400)
    results = []  # Accumulate results for all experiments
    candidates = [8,64,512,4096]#[4,8,16,32]#[32, 128, 512, 2048, 8192]
    num_epochs = 100
    method = 'default'
    num_exp = 5
    for C in candidates:
        print(f"######### C={C} ##########")
        print("######### colu ##########")
        results.append(validate_activation(colu, C=C, name='colu', method=method, num_epochs=num_epochs, num_exp = num_exp))
        print("######### relu ##########")
        results.append(validate_activation(nn.relu, C=C, name='relu', method=method, num_epochs=num_epochs, num_exp = num_exp))
        print("######### silu ##########")
        results.append(validate_activation(nn.silu, C=C, name='silu', method=method, num_epochs=num_epochs, num_exp = num_exp))
        print("######### gelu ##########")
        results.append(validate_activation(nn.gelu, C=C, name='gelu', method=method, num_epochs=num_epochs, num_exp = num_exp))
        print("######### colu2 ##########")
        results.append(validate_activation(colu2,C=C,name='colu2', method=method, num_epochs=num_epochs, num_exp = num_exp))

    # Create DataFrame and save to CSV
    df = pd.DataFrame(results)
    df.to_csv("results.csv", index=False)

def plot():
    if jax.process_index() == 0:
        import pandas as pd
        import matplotlib.pyplot as plt
        def pretty_plt():
            plt.rcParams.update({
                "text.usetex": False,
                "font.family": "serif",
                "font.sans-serif": ["Computer Modern Roman"]})
            plt.rcParams['axes.spines.right'] = False
            plt.rcParams['axes.spines.top'] = False
            plt.rcParams['axes.spines.left'] = True
            plt.rcParams['axes.spines.bottom'] = True
            plt.rcParams['axes.grid'] = True
            plt.rcParams['grid.alpha'] = 0.5
            plt.rcParams['font.size'] = 17
            plt.rcParams['legend.framealpha'] = 0.7
            plt.rcParams['xtick.labelsize'] = 14
            plt.rcParams['ytick.labelsize'] = 14
            # plt.rcParams['xaxis.labellocation'] = 'center'
            # plt.rcParams['yaxis.labellocation'] = 'top'
            plt.rcParams['legend.fontsize'] = 'small'

        pretty_plt()

        df = pd.read_csv("results.csv")
        # Generate and save the PDF figure
        fig, ax = plt.subplots(figsize=(4,3))

        for activation in df['name'].unique():  # Iterate through activation functions
            subset = df[df['name'] == activation]
            ax.plot(subset['C'], subset['Acc'], label=activation)
            ax.fill_between(
                subset['C'],
                subset['Acc']+subset['StdAcc'],
                subset['Acc']-subset['StdAcc'],
                alpha=0.3
            )

        ax.set_xlabel("Num Hidden Neurons $C$")  # Use set_xlabel() method
        ax.set_ylabel("Accuracy")  # Use set_ylabel() method
        # ax.set_title("Accuracy vs. C for Different Activation Functions")  # Use set_title() method
        ax.legend()
        plt.tight_layout()
        plt.savefig("accuracy_plot.pdf",)

if __name__ == '__main__':
    # run()
    plot()