"""Utilities related to the networks I'm using."""
import dataclasses
import h5py
import json
import os

import numpy as np
import tensorflow as tf

from xoid.util import basics
from xoid.util import hdf5_util

_np = basics.to_np
set_h5_ds = hdf5_util.set_h5_ds


@dataclasses.dataclass()
class ModelParams:
    """Represents the parameters of a shallow ReLU network."""
    w: np.ndarray
    b: np.ndarray
    v: np.ndarray
    c: np.ndarray

    @property
    def d(self) -> int:
        """Dimension of the input."""
        d, _ = self.w.shape
        return d

    @property
    def m(self) -> int:
        """Number of hidden units."""
        _, m = self.w.shape
        return m

    def astype(self, dtype):
        return self.__class__(
            w=np.copy(self.w).astype(dtype),
            b=np.copy(self.b).astype(dtype),
            v=np.copy(self.v).astype(dtype),
            c=np.copy(self.c).astype(dtype),
        )

    def to_keras_format(self):
        d, m = self.d, self.m
        return self.__class__(
            w=np.copy(self.w).reshape([d, m]),
            b=np.copy(self.b).reshape([m]),
            v=np.copy(self.v).reshape([m, 1]),
            c=np.copy(self.c).reshape([1]),
        )

    def to_walker_format(self, *, normalize_v=False):
        d, m = self.d, self.m
        w = np.copy(self.w).reshape([d, m])
        b = np.copy(self.b).reshape([1, m])
        v = np.copy(self.v).reshape([m])
        c = np.copy(self.c).reshape([])

        if normalize_v:
            w *= v
            b *= v
            v = np.sign(v)
            # Default 0 to +1.
            v[v == 0.0] = 1.0

        return self.__class__(w=w, b=b, v=v, c=c)

    def to_keras_model(self, activation='relu', dtype=np.float32, kernel_regularizer=None):
        # dtype must be a numpy dtype.
        p = self.to_keras_format().astype(dtype)
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(p.m, kernel_regularizer=kernel_regularizer, activation=activation),
            tf.keras.layers.Dense(1, kernel_regularizer=kernel_regularizer),

        ])
        model(tf.keras.Input([p.d], dtype=tf.float32))
        l1, l2 = model.layers

        l1.kernel.assign(p.w)
        l1.bias.assign(p.b)
        l2.kernel.assign(p.v)
        l2.bias.assign(p.c)

        return model

    @classmethod
    def from_keras(cls, model: tf.keras.Sequential):
        l1, l2 = model.layers

        # TODO: Should I reshape these to match my conventions?
        w = l1.kernel.numpy()
        b = l1.bias.numpy()
        v = l2.kernel.numpy()
        c = l2.bias.numpy()

        return cls(w=w, b=b, v=v, c=c)


def network_from_parameters(w, b, v, c) -> tf.keras.Sequential:
    d, m = w.shape
    if len(v.shape) == 1:
        assert v.shape == (m,)
        o = 1
    else:
        _, o = v.shape

    # w should already be in the correct shape.
    b = tf.reshape(b, [m])
    v = tf.reshape(v, [m, o])
    c = tf.reshape(c, [o])

    l1 = tf.keras.layers.Dense(m, activation='relu')
    l2 = tf.keras.layers.Dense(o, activation=None)

    model = tf.keras.Sequential([l1, l2])
    model(tf.keras.Input([d], dtype=tf.float32))

    cast = lambda x: tf.cast(x, tf.float32)
    l1.kernel.assign(cast(w))
    l1.bias.assign(cast(b))
    l2.kernel.assign(cast(v))
    l2.bias.assign(cast(c))

    return model


def save_to_hdf5(filepath, metadata, params):
    filepath = os.path.expanduser(filepath)

    w = _np(params.w)
    b = _np(params.b)
    v = _np(params.v)
    c = _np(params.c)

    with h5py.File(filepath, "w") as f:
        # Make metadata.
        dt = h5py.special_dtype(vlen=str)
        f.create_dataset(
            'metadata', data=json.dumps(metadata), dtype=dt)

        group = f.create_group('weights')

        ds = group.create_dataset('w', w.shape, dtype=w.dtype)
        set_h5_ds(ds, w)

        ds = group.create_dataset('b', b.shape, dtype=b.dtype)
        set_h5_ds(ds, b)

        ds = group.create_dataset('v', v.shape, dtype=v.dtype)
        set_h5_ds(ds, v)

        ds = group.create_dataset('c', c.shape, dtype=c.dtype)
        set_h5_ds(ds, c)


def load_from_hdf5(filepath, dtype=None):
    # dtype, if not None, must be a numpy dtype.
    filepath = os.path.expanduser(filepath)

    with h5py.File(filepath, "r") as f:
        metadata = np.array(f['metadata'])
        metadata = str(metadata)
        # Annoying bug on longleaf.
        if metadata.startswith("b'") and metadata.endswith("'"):
            metadata = metadata[2:-1]
        metadata = json.loads(metadata)
        
        w = np.array(f['weights/w'])
        b = np.array(f['weights/b'])
        v = np.array(f['weights/v'])
        c = np.array(f['weights/c'])

    if dtype is not None:
        w = w.astype(dtype)
        b = b.astype(dtype)
        v = v.astype(dtype)
        c = c.astype(dtype)

    params = ModelParams(w=w, b=b, v=v, c=c)

    return metadata, params


class BiasOnlyLayer(tf.keras.layers.Layer):

    def build(self, shape):
        self.bias = tf.Variable(tf.zeros([shape[-1]]), dtype=tf.float32)

    def call(self, x):
        return x + self.bias
