import os

from absl import app
from absl import flags
from absl import logging

import jax
from jax import random, grad, jit, vmap
from jax import core
import jax.numpy as np
from jax.config import config
from jax.scipy import ndimage
from jax.nn import relu, gelu, sigmoid
from jax.nn.initializers import glorot_normal, normal, constant
from jax.example_libraries import optimizers

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

import numpy as onp
from functools import partial
from typing import Any, Callable, Sequence, Optional, Sequence, Union

from livelossplot import PlotLosses
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm
import os
import requests
from io import BytesIO

import cv2
import scipy.ndimage
from scipy.special import binom

from tqdm.notebook import tqdm as tqdm
from tqdm import trange

from phantominator import shepp_logan, ct_shepp_logan, ct_modified_shepp_logan_params_2d

import wandb


def weight_fact_glorot_normal(mean=1.0, stddev=0.1):
    def init(key, shape):
        key1, key2 = random.split(key)
        w = glorot_normal()(key1, shape)
        g = mean + normal(stddev)(key2, (shape[-1],))
        g = np.exp(g)
        v = w / g
        return g, v

    return init


def weight_norm_glorot_normal():
    def init(key, shape):
        w = glorot_normal()(key, shape)
        g = np.linalg.norm(w, 2, axis=0)
        v = w / g
        return g, v

    return init


class PosEmbs(nn.Module):
    embed_type: str
    embed_scale: float
    embed_dim: int

    def setup(self):
        if self.embed_type == 'posenc':
            kernel = 2. ** np.linspace(0, self.embed_scale, self.embed_dim) - 1
            kernel = np.stack([kernel, np.zeros_like(kernel)], -1)
            self.kernel = np.concatenate([kernel, np.roll(kernel, 1, axis=-1)], 0).T

        elif self.embed_type == 'gaussian':
            kernel = self.param('kernel', normal(self.embed_scale), (2, self.embed_dim))
            self.kernel = jax.lax.stop_gradient(kernel)

    #             self.kernel = nn.initializers.normal(self.embed_scale)(random.PRNGKey(1234), (2, self.embed_dim))

    @nn.compact
    def __call__(self, x):
        y = np.concatenate([np.cos(np.dot(2 * np.pi * x, self.kernel)),
                            np.sin(np.dot(2 * np.pi * x, self.kernel))], axis=-1)
        return y


class FactorizedDense(nn.Module):
    features: int
    mode: Optional = None  # [None, 'weight_norm', 'weight_fact']

    @nn.compact
    def __call__(self, x):

        if self.mode == 'None':
            kernel = self.param('kernel', glorot_normal(), (x.shape[-1], self.features))

        elif self.mode == 'adaptive_activation':
            kernel = self.param('kernel', glorot_normal(), (x.shape[-1], self.features))
            n = 1.0
            a = self.param('activation_params', constant(1.0 / n), (self.features,))

        elif self.mode == 'weight_norm':
            g, v = self.param('kernel',
                              weight_norm_glorot_normal(),
                              (x.shape[-1], self.features))
            kernel = g * v / np.linalg.norm(v, 2, axis=0)

        elif self.mode == 'weight_fact':
            g, v = self.param('kernel',
                              weight_fact_glorot_normal(),
                              (x.shape[-1], self.features))
            kernel = g * v

        bias = self.param('bias', nn.initializers.zeros, (self.features,))
        y = np.dot(x, kernel) + bias

        if self.mode == 'adaptive_activation':
            y = n * a * y

        return y


class Mlp(nn.Module):
    num_layers: int = 3
    layer_size: int = 256
    out_dim: int = 1
    activation: Callable = relu
    embedding: Optional = None  # [None, ('posenc', scale), ('gaussian', scale)]
    mode: Optional = None  # [None, 'weight_fact', 'weight_norm']

    @nn.compact
    def __call__(self, x):
        # Embedding
        if self.embedding == None:
            x = x
        else:
            embed_type, embed_scale = self.embedding
            x = PosEmbs(embed_type, embed_scale, self.layer_size)(x)

        # hidden layers
        for k in range(self.num_layers):
            x = FactorizedDense(self.layer_size, mode=self.mode)(x)
            x = self.activation(x)
        # last layer
        x = FactorizedDense(self.out_dim, mode=self.mode)(x)
        x = sigmoid(x)
        return x


# @title NP Area Resize Code
# from https://gist.github.com/shoyer/c0f1ddf409667650a076c058f9a17276

def _reflect_breaks(size: int) -> np.ndarray:
    """Calculate cell boundaries with reflecting boundary conditions."""
    result = np.concatenate([[0], 0.5 + np.arange(size - 1), [size - 1]])
    assert len(result) == size + 1
    return result


def _interval_overlap(first_breaks: np.ndarray,
                      second_breaks: np.ndarray) -> np.ndarray:
    """Return the overlap distance between all pairs of intervals.

    Args:
      first_breaks: breaks between entries in the first set of intervals, with
        shape (N+1,). Must be a non-decreasing sequence.
      second_breaks: breaks between entries in the second set of intervals, with
        shape (M+1,). Must be a non-decreasing sequence.

    Returns:
      Array with shape (N, M) giving the size of the overlapping region between
      each pair of intervals.
    """
    first_upper = first_breaks[1:]
    second_upper = second_breaks[1:]
    upper = np.minimum(first_upper[:, np.newaxis], second_upper[np.newaxis, :])

    first_lower = first_breaks[:-1]
    second_lower = second_breaks[:-1]
    lower = np.maximum(first_lower[:, np.newaxis], second_lower[np.newaxis, :])

    return np.maximum(upper - lower, 0)


def _resize_weights(
        old_size: int, new_size: int, reflect: bool = False) -> np.ndarray:
    """Create a weight matrix for resizing with the local mean along an axis.

    Args:
      old_size: old size.
      new_size: new size.
      reflect: whether or not there are reflecting boundary conditions.

    Returns:
      NumPy array with shape (new_size, old_size). Rows sum to 1.
    """
    if not reflect:
        old_breaks = np.linspace(0, old_size, num=old_size + 1)
        new_breaks = np.linspace(0, old_size, num=new_size + 1)
    else:
        old_breaks = _reflect_breaks(old_size)
        new_breaks = (old_size - 1) / (new_size - 1) * _reflect_breaks(new_size)

    weights = _interval_overlap(new_breaks, old_breaks)
    weights /= np.sum(weights, axis=1, keepdims=True)
    assert weights.shape == (new_size, old_size)
    return weights


def resize(array: np.ndarray,
           shape: [int, ...],
           reflect_axes: [int] = ()) -> np.ndarray:
    """Resize an array with the local mean / bilinear scaling.

    Works for both upsampling and downsampling in a fashion equivalent to
    block_mean and zoom, but allows for resizing by non-integer multiples. Prefer
    block_mean and zoom when possible, as this implementation is probably slower.

    Args:
      array: array to resize.
      shape: shape of the resized array.
      reflect_axes: iterable of axis numbers with reflecting boundary conditions,
        mirrored over the center of the first and last cell.

    Returns:
      Array resized to shape.

    Raises:
      ValueError: if any values in reflect_axes fall outside the interval
        [-array.ndim, array.ndim).
    """
    reflect_axes_set = set()
    for axis in reflect_axes:
        if not -array.ndim <= axis < array.ndim:
            raise ValueError('invalid axis: {}'.format(axis))
        reflect_axes_set.add(axis % array.ndim)

    output = array
    for axis, (old_size, new_size) in enumerate(zip(array.shape, shape)):
        reflect = axis in reflect_axes_set
        weights = _resize_weights(old_size, new_size, reflect=reflect)
        product = np.tensordot(output, weights, [[axis], [-1]])
        output = np.moveaxis(product, -1, axis)
    return output


RES = 512

# Shepp Data Gen
def get_shepp_dataset(rand_key, num_samples):
    ct_params = np.array(ct_modified_shepp_logan_params_2d())
    shepps = []
    for i in range(num_samples):
        rand_key, subkey = random.split(rand_key)
        i_ct_params = ct_params + random.normal(subkey, shape=ct_params.shape) / 20.0
        shepps.append(np.clip(ct_shepp_logan((RES, RES), E=i_ct_params), 0.0, 1.0))

    samples = np.stack(shepps, axis=0)

    return samples


# ATLAS Data Gen
def get_atlas_dataset(rand_key, num_samples):
    id = '1SLejANPHTA_eSJhIjCk9WGeFsKSEMZMx'
    filename = 'atlas_3d.npz'
    #   if not os.path.exists(filename):
    #     !gdown --id $id
    data = np.load(filename)['data'] / 255.0

    scan_id = 1  # random.randint(rand_key, [], 0, data.shape[0])
    samples = resize(data[scan_id, ...], (RES, RES, RES))
    samples = samples[50:-50, :, :]

    rand_samples = []
    for i in range(num_samples):
        rand_key, subkey = random.split(rand_key)
        i_samp = random.randint(subkey, [], 0, samples.shape[0])
        i_slice = samples[i_samp, :, :]
        i_slice /= np.amax(i_slice)
        rand_samples.append(i_slice)
    new_samples = np.stack(rand_samples, axis=0)
    return new_samples






@jit
def ct_project(img, theta):
    y, x = np.meshgrid(np.arange(int(img.shape[0]), dtype=np.float32) / int(img.shape[0]) - 0.5,
                       np.arange(int(img.shape[1]), dtype=np.float32) / int(img.shape[1]) - 0.5, indexing='ij')
    x_rot = x * np.cos(theta) - y * np.sin(theta)
    y_rot = x * np.sin(theta) + y * np.cos(theta)
    x_rot = (x_rot + 0.5) * img.shape[1]
    y_rot = (y_rot + 0.5) * img.shape[0]
    sample_coords = np.stack([y_rot, x_rot], axis=0)
    resampled = ndimage.map_coordinates(img, sample_coords, 0).reshape(img.shape)
    return resampled.mean(axis=0)[:, None, ...]


ct_project_batch = vmap(ct_project, (None, 0), 0)



def make_training_data(dataset, img_id, res=512, num_thetas=20):
    x_train = np.linspace(0, 1, res + 1)[:-1]
    x_train = np.stack(np.meshgrid(x_train, x_train), axis=-1)

    img = dataset[img_id, :, :]
    thetas = np.linspace(0.0, np.pi, num_thetas + 1)[:-1]
    y_train = ct_project_batch(img, thetas)

    x_train = x_train.reshape(-1, 2)

    x_test = x_train
    y_test = y_train

    return x_train, y_train, thetas, img



activation_fn = {
    'relu': nn.relu,
    'gelu': nn.gelu,
    'swish': nn.swish,
    'tanh': np.tanh,
    'sin': np.sin
    }


def get_activation(str):
    if str in activation_fn:
        return activation_fn[str]




def with_warmup(step_size, num_warmup_steps):
    def schedule(i):
        return np.where(i <= num_warmup_steps,
                        step_size * i / num_warmup_steps,
                        step_size)

    return schedule


def train(key, lr, num_layers, layer_size, activation, embedding, mode, training_data):


    wandb.init(project='ct_sweep')

    run_model = jit(lambda params, x: np.reshape(arch.apply(params, x), (RES, RES)))
    compute_projs = jit(lambda params, x, thetas: ct_project_batch(run_model(params, x), thetas))
    model_loss_proj = jit(lambda params, x, y, thetas: .5 * np.mean((compute_projs(params, x, thetas) - y) ** 2))
    model_loss = jit(lambda params, x, y, thetas, image: .5 * np.abs(
        np.mean((np.clip(run_model(params, x), 0.0, 1.0) - image) ** 2)))
    model_psnr = jit(lambda params, x, y, thetas, image: -10 * np.log10(2. * model_loss(params, x, y, thetas, image)))
    model_grad_loss = jit(lambda params, x, y, thetas, image: grad(model_loss_proj)(params, x, y, thetas))


    # init
    activation_fn = get_activation(activation)
    out_dim = 1
    arch = Mlp(num_layers, layer_size, out_dim, activation_fn, embedding, mode=mode)

    # parameters initialization
    x = np.ones((2,))
    params = arch.init(key, x)

    iters = 2000
    lr = with_warmup(lr, num_warmup_steps=iters // 10)
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_update = jit(opt_update)
    opt_state = opt_init(params)

    train_psnrs = []
    test_psnrs = []
    xs = []

    pbar = trange(iters)
    for it in pbar:
        opt_state = opt_update(it, model_grad_loss(get_params(opt_state), *training_data), opt_state)
        if it % 10 == 0:
            train_psnr = model_psnr(get_params(opt_state), *training_data)
            test_psnr = model_psnr(get_params(opt_state), *training_data)
            train_psnrs.append(train_psnr)
            test_psnrs.append(test_psnr)
            xs.append(it)

            pbar.set_postfix({'train_psnr': '{:.3f}'.format(train_psnr),
                              'test_psnr': '{:.3f}'.format(test_psnr)})

            log_dict = {}
            log_dict['train_psnr'] = train_psnr
            log_dict['test_psnr'] = test_psnr

            wandb.log(log_dict, it)

    results = {
        'state': get_params(opt_state),
        'train_psnrs': train_psnrs,
        'test_psnrs': test_psnrs,
        'xs': xs,
        'final_test': run_model(get_params(opt_state), training_data[0])
        }
    return results


FLAGS = flags.FLAGS

flags.DEFINE_string('dataset_type', 'natural', 'determine the type of the dataset for image regression')


def main(argv):
    dataset_type = FLAGS.dataset_type

    project_name = 'ct_' + dataset_type + '_sweep'

    num_samples = 8
    if dataset_type == 'shepp':
        print('Loading Shepp Dataset')
        dataset = get_shepp_dataset(random.PRNGKey(0), num_samples)
        print('Shepp Dataset Loaded')
    elif dataset_type == 'atlas':
        print('Loading ATLAS Dataset')
        dataset = get_atlas_dataset(random.PRNGKey(0), num_samples)
        print('ATLAS Dataset Loaded')


    sweep_config = {
        'method': 'grid',
        'name': 'sweep',
        'metric': {
            'goal': 'maximize',
            'name': 'test_psnr'
            },
        }

    parameters_dict = {'seed': {'values': [2, 3, 5, 7, 11]},
                       'learning_rate': {'values': [1e-3, 1e-4]},
                       'num_layers': {'values': [2, 3, 4]},
                       'layer_size': {'values': [64, 128, 256]},
                       'activation': {'values': ['relu', 'gelu', 'sin']},
                       'gaussian_scale': {'values': [1, 5, 10]},
                       'mode': {'values': ['None', 'adaptive_activation', 'weight_norm', 'weight_fact']}
                       }

    sweep_config['parameters'] = parameters_dict



    def train_sweep():
        wandb.init(project=project_name)

        sweep_config = wandb.config

        seed = sweep_config.seed
        lr = sweep_config.learning_rate
        num_layers = sweep_config.num_layers
        layer_size = sweep_config.layer_size
        activation = sweep_config.activation
        gaussian_scale = sweep_config.gaussian_scale
        mode = sweep_config.mode

        key = random.PRNGKey(seed)
        embedding = ('gaussian', gaussian_scale)

        training_data = make_training_data(dataset, img_id=0, num_thetas=40)

        train(key, lr, num_layers, layer_size, activation, embedding, mode, training_data)


    sweep_id = wandb.sweep(
        sweep_config,
        project=project_name,
        )

    wandb.agent(sweep_id, function=train_sweep)


if __name__ == "__main__":
    flags.mark_flags_as_required(['dataset_type'])
    app.run(main)