import os

from absl import app
from absl import flags
from absl import logging

import jax
from jax import random, grad, jit, vmap
from jax import core
import jax.numpy as np
from jax.config import config
from jax.nn import relu, gelu, sigmoid
from jax.nn.initializers import glorot_normal, normal, constant
from jax.example_libraries import optimizers

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

import numpy as onp
from functools import partial
from typing import Any, Callable, Sequence, Optional, Sequence, Union
from tqdm.notebook import tqdm as tqdm
from tqdm import trange

import imageio
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm

import wandb

RES = 512


def load_dataset(filename, id):
    npz_data = np.load(filename)
    out = {
        "data_grid_search": npz_data['train_data'] / 255.,
        "data_test": npz_data['test_data'] / 255.,
        }
    return out


def make_dataset(dataset, img_id, res=512):
    y_test = dataset['data_test'][img_id, :, :, :]
    y_train = y_test[::2, ::2, :]

    x_train = np.linspace(0, 1, res // 2 + 1)[:-1]
    x_train = np.stack(np.meshgrid(x_train, x_train), axis=-1)

    x_test = np.linspace(0, 1, res + 1)[:-1]
    x_test = np.stack(np.meshgrid(x_test, x_test), axis=-1)

    x_train = x_train.reshape(-1, 2)
    x_test = x_test.reshape(-1, 2)

    y_train = y_train.reshape(-1, 3)
    y_test = y_test.reshape(-1, 3)

    return x_train, y_train, x_test, y_test


def weight_fact_glorot_normal(mean=1.0, stddev=0.01):
    # use mean=2.0, stddev=0.01 for natural dataset
    def init(key, shape):
        key1, key2 = random.split(key)
        w = glorot_normal()(key1, shape)
        g = mean + normal(stddev)(key2, (shape[-1],))
        g = np.exp(g)
        v = w / g
        return g, v

    return init


def weight_norm_glorot_normal():
    def init(key, shape):
        w = glorot_normal()(key, shape)
        g = np.linalg.norm(w, 2, axis=0)
        v = w / g
        return g, v

    return init


class PosEmbs(nn.Module):
    embed_type: str
    embed_scale: float
    embed_dim: int

    def setup(self):
        if self.embed_type == 'posenc':
            kernel = 2. ** np.linspace(0, self.embed_scale, self.embed_dim) - 1
            kernel = np.stack([kernel, np.zeros_like(kernel)], -1)
            self.kernel = np.concatenate([kernel, np.roll(kernel, 1, axis=-1)], 0).T

        elif self.embed_type == 'gaussian':
            kernel = self.param('kernel', normal(self.embed_scale), (2, self.embed_dim))
            self.kernel = jax.lax.stop_gradient(kernel)

    #             self.kernel = nn.initializers.normal(self.embed_scale)(random.PRNGKey(1234), (2, self.embed_dim))

    @nn.compact
    def __call__(self, x):
        y = np.concatenate([np.cos(np.dot(2 * np.pi * x, self.kernel)),
                            np.sin(np.dot(2 * np.pi * x, self.kernel))], axis=-1)
        return y


class FactorizedDense(nn.Module):
    features: int
    mode: Optional = None  # [None, 'weight_norm', 'weight_fact']

    @nn.compact
    def __call__(self, x):

        if self.mode == 'None':
            kernel = self.param('kernel', glorot_normal(), (x.shape[-1], self.features))

        elif self.mode == 'adaptive_activation':
            kernel = self.param('kernel', glorot_normal(), (x.shape[-1], self.features))
            n = 1.0
            a = self.param('activation_params', constant(1.0 / n), (self.features,))

        elif self.mode == 'weight_norm':
            g, v = self.param('kernel',
                              weight_norm_glorot_normal(),
                              (x.shape[-1], self.features))
            kernel = g * v / np.linalg.norm(v, 2, axis=0)

        elif self.mode == 'weight_fact':
            g, v = self.param('kernel',
                              weight_fact_glorot_normal(),
                              (x.shape[-1], self.features))
            kernel = g * v

        bias = self.param('bias', nn.initializers.zeros, (self.features,))
        y = np.dot(x, kernel) + bias

        if self.mode == 'adaptive_activation':
            y = n * a * y

        return y


class Mlp(nn.Module):
    num_layers: int = 3
    layer_size: int = 256
    out_dim: int = 3
    activation: Callable = relu
    embedding: Optional = None  # [None, ('posenc', scale), ('gaussian', scale)]
    mode: Optional = None  # [None, 'weight_fact', 'weight_norm']

    @nn.compact
    def __call__(self, x):
        # Embedding
        if self.embedding == None:
            x = x
        else:
            embed_type, embed_scale = self.embedding
            x = PosEmbs(embed_type, embed_scale, self.layer_size)(x)

        # hidden layers
        for k in range(self.num_layers):
            x = FactorizedDense(self.layer_size, mode=self.mode)(x)
            x = self.activation(x)
        # last layer
        x = FactorizedDense(self.out_dim, mode=self.mode)(x)
        x = sigmoid(x)
        return x


activation_fn = {
    'relu': nn.relu,
    'gelu': nn.gelu,
    'swish': nn.swish,
    'tanh': np.tanh,
    'sin': np.sin
    }


def get_activation(str):
    if str in activation_fn:
        return activation_fn[str]


def with_warmup(step_size, num_warmup_steps):
    def schedule(i):
        return np.where(i <= num_warmup_steps,
                        step_size * i / num_warmup_steps,
                        step_size)

    return schedule


def train(key, lr, num_layers, layer_size, activation, embedding, mode, x_train, y_train, x_test, y_test):


    wandb.init(project='image_sweep')
    # init
    activation_fn = get_activation(activation)
    out_dim = 3
    arch = Mlp(num_layers, layer_size, out_dim, activation_fn, embedding, mode=mode)

    # model loss
    model_loss = jit(lambda params, x, y: .5 * np.mean((arch.apply(params, x) - y) ** 2))
    model_psnr = jit(lambda params, x, y: -10 * np.log10(2. * model_loss(params, x, y)))
    model_grad_loss = jit(lambda params, x, y: jax.grad(model_loss)(params, x, y))

    # parameters initialization
    x = np.ones((2,))
    params = arch.init(key, x)

    # optimizer initialization
    iters = 2000
    lr = with_warmup(lr, num_warmup_steps=200)
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(params)
    opt_update = jit(opt_update)

    # loggers
    train_psnrs = []
    test_psnrs = []

    # train loop
    pbar = trange(iters)
    for it in pbar:
        opt_state = opt_update(it, model_grad_loss(get_params(opt_state), x_train, y_train), opt_state)

        if it % 10 == 0:
            params = get_params(opt_state)

            train_psnr = model_psnr(params, x_train, y_train)
            test_psnr = model_psnr(params, x_test, y_test)
            train_psnrs.append(train_psnr)
            test_psnrs.append(test_psnr)

            pbar.set_postfix({'train psnr': '{:.3f}'.format(train_psnr),
                              'test psnr': '{:.3f}'.format(test_psnr)})
            log_dict = {}
            log_dict['train_psnr'] = train_psnr
            log_dict['test_psnr'] = test_psnr

            wandb.log(log_dict, it)

    # model prediction after training
    y_pred = arch.apply(params, x_test)

    results = {
        'params': params,
        'y_pred': y_pred,
        'train_psnrs': train_psnrs,
        'test_psnrs': test_psnrs,
        }

    return results


FLAGS = flags.FLAGS

flags.DEFINE_string('dataset_type', 'natural', 'determine the type of the dataset for image regression')


def main(argv):
    dataset_type = FLAGS.dataset_type

    if dataset_type == 'natural':
        dataset = load_dataset('data_div2k.npz', '1TtwlEDArhOMoH18aUyjIMSZ3WODFmUab')
    elif dataset_type == 'text':
        dataset = load_dataset('data_2d_text.npz', '1V-RQJcMuk9GD4JCUn70o7nwQE0hEzHoT')

    project_name = 'image_' + dataset_type + '_sweep'

    sweep_config = {
        'method': 'grid',
        'name': 'sweep',
        'metric': {
            'goal': 'maximize',
            'name': 'test_psnr'
            },
        }

    parameters_dict = {'seed': {'values': [2, 3, 5, 7, 11]},
                       'learning_rate': {'values': [1e-3, 1e-4]},
                       'num_layers': {'values': [2, 3, 4]},
                       'layer_size': {'values': [64, 128, 256]},
                       'activation': {'values': ['relu', 'gelu', 'sin']},
                       'gaussian_scale': {'values': [1, 5, 10]},
                       'mode': {'values': ['None', 'adaptive_activation', 'weight_norm', 'weight_fact']}
                       }


    sweep_config['parameters'] = parameters_dict

    def train_sweep():

        wandb.init(project=project_name)

        sweep_config = wandb.config

        seed = sweep_config.seed
        lr = sweep_config.learning_rate
        num_layers = sweep_config.num_layers
        layer_size = sweep_config.layer_size
        activation = sweep_config.activation
        gaussian_scale = sweep_config.gaussian_scale
        mode = sweep_config.mode

        key = random.PRNGKey(seed)
        embedding = ('gaussian', gaussian_scale)

        x_train, y_train, x_test, y_test = make_dataset(dataset, img_id=0)

        train(key, lr, num_layers, layer_size, activation, embedding, mode, x_train, y_train, x_test, y_test)




    sweep_id = wandb.sweep(
        sweep_config,
        project=project_name,
        )

    wandb.agent(sweep_id, function=train_sweep)


if __name__ == "__main__":
    flags.mark_flags_as_required(['dataset_type'])
    app.run(main)