import argparse
import json
import os
from collections.abc import Callable
from typing import Any

import distrax
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from flows import flowee, nn
from flows.distribution_embedding import FlowEmbedding


def mask(ndim: int, i: int) -> jax.Array:
    if i % 2 == 0:
        return flowee.checkerboard_mask((ndim,), dtype=jnp.uint8)
    else:
        return flowee.create_mask((ndim,), (3,), dtype=jnp.uint8)


def build_discrete_embedding_net(args: argparse.Namespace, key: jax.Array) -> FlowEmbedding:
    dequant_num_layers = 2
    dequant_num_params = 2
    flow_num_params = 5

    key, emb_key = jax.random.split(key, 2)
    emb_net = eqx.nn.MLP(
        args.ndim,
        args.embedding_size,
        args.embedding_hidden_size,
        args.embedding_nlayers,
        key=emb_key,
    )

    key, dequant_key = jax.random.split(key, 2)
    dequant_layer = flowee.Dequantize(
        max_val=args.max_val,
        in_dtype=jnp.uint8,
        var_flow=flowee.Coupling(
            mask(args.ndim, 0),
            nn.MultiMLP(
                (args.ndim, args.ndim),
                args.ndim * dequant_num_params,
                args.flow_hidden_size,
                dequant_num_layers,
                key=dequant_key,
            ),
            dual=True,
        ),
    )

    key, *flow_keys = jax.random.split(key, 1 + args.ncoupling_layers)
    flow_net = flowee.Sequential(
        [dequant_layer]
        + [
            flowee.Coupling(
                mask(args.ndim, i),
                nn.MLP(
                    args.ndim + args.embedding_size,
                    args.ndim * flow_num_params,
                    args.flow_hidden_size,
                    args.flow_nlayers,
                    key=flow_keys[i],
                ),
                flowee.ParameterizedNLSq((args.ndim,)),
                dual=True,
            )
            for i in range(args.ncoupling_layers)
        ]
        + [flowee.Sigmoid(1e-5) if args.flow_squash_sigmoid else flowee.Identity()]
    )

    prior_dist = (
        distrax.Uniform(-1e-4, 1 + 1e-4) if args.uniform_prior else distrax.Normal(0.0, 1.0)
    )

    flow_net.add_prior(prior_dist, (args.ndim,))

    return FlowEmbedding(emb_net, flow_net)


def build_continuous_embedding_net(args: argparse.Namespace, key: jax.Array) -> FlowEmbedding:
    flow_num_params = 5
    mask = flowee.checkerboard_mask((args.ndim,))

    key, emb_key = jax.random.split(key, 2)
    emb_net = eqx.nn.MLP(
        args.ndim,
        args.embedding_size,
        args.embedding_hidden_size,
        args.embedding_nlayers,
        key=emb_key,
    )

    key, *flow_keys = jax.random.split(key, 1 + args.ncoupling_layers)
    flow_net = flowee.Sequential(
        [
            flowee.Coupling(
                mask if i % 2 == 0 else 1 - mask,
                nn.MLP(
                    args.ndim + args.embedding_size,
                    args.ndim * flow_num_params,
                    args.flow_hidden_size,
                    args.flow_nlayers,
                    key=flow_keys[i],
                ),
                # flowee.ParameterizedAffine((args.ndim,)),
                flowee.ParameterizedNLSq((args.ndim,)),
                dual=True,
            )
            for i in range(args.ncoupling_layers)
        ]
    )
    flow_net.add_prior(distrax.Normal(0.0, 1.0), (args.ndim,))

    return FlowEmbedding(emb_net, flow_net)


def save_ckpt(model: eqx.Module, args: dict[str, Any], step: int) -> str:
    ckpt_dir = os.path.join(args['model_dir'], args['run_name'], str(args['seed']))
    ckpt_name = os.path.join(ckpt_dir, f'{step:0{len(str(args["nsteps"]))}}.eqx')

    with open(ckpt_name, 'wb') as f:
        f.write((json.dumps(args) + '\n').encode())
        eqx.tree_serialise_leaves(f, model)

    return ckpt_name


def load_ckpt(
    ckpt_path: str, model_builder: Callable, key: jax.Array
) -> tuple[eqx.Module, argparse.Namespace]:
    with open(ckpt_path, 'rb') as f:
        args = argparse.Namespace(**json.loads(f.readline().decode()))
        container = model_builder(args, key)

        return eqx.tree_deserialise_leaves(f, container), args


def select_last_ckpt(directory: str) -> str:
    files = [f for f in os.listdir(directory) if f.endswith('.eqx')]
    return os.path.join(directory, sorted(files)[-1])


def plot_legend(directory: str) -> None:
    handles, labels = plt.gca().get_legend_handles_labels()

    fig_legend = plt.figure(figsize=(3, 2))
    ax_legend = fig_legend.add_subplot(111)

    legend = ax_legend.legend(handles, labels, loc='center', ncol=7, fancybox=True, shadow=False)
    ax_legend.axis('off')

    fig_legend.canvas.draw()
    bbox = legend.get_window_extent().transformed(fig_legend.dpi_scale_trans.inverted())
    fig_legend.set_size_inches(bbox.width, bbox.height)

    fig_legend.savefig(os.path.join(directory, 'legend.pdf'), bbox_inches='tight')
