#!/usr/bin/env python3

import hydra
from hydra.utils import instantiate

import numpy as np

import time
import csv
import os
import pickle as pkl

import ot
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state
import optax

from ott.core import quad_problems, linear_problems, sinkhorn
from ott.geometry import PointCloud

import matplotlib.pyplot as plt
plt.style.use('bmh')

import functools

from jax.config import config
config.update("jax_enable_x64", True)

from meta_ot import utils
from meta_ot import data


import sys
from IPython.core import ultratb
sys.excepthook = ultratb.FormattedTB(
    mode='Plain', color_scheme='Neutral', call_pdb=1)

DIR = os.path.dirname(os.path.realpath(__file__))

class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg
        self.train_iter = 0
        self.output = 0

        if self.cfg.data in ['mnist', 'random', 'doodles', 'usps28']:
            na = nb = 784
            self.n_output = na
        elif self.cfg.data == 'world':
            na = 100
            nb = 10000
            self.n_output = na
        else:
            assert False

        self.key = jax.random.PRNGKey(self.cfg.seed)
        self.potential_model = instantiate(
            self.cfg.potential_model, n_output=self.n_output)
        init_key, self.key = jax.random.split(self.key)
        a_placeholder = jnp.zeros(na)
        b_placeholder = jnp.zeros(nb)
        self.params = self.potential_model.init(
            init_key, a_placeholder, b_placeholder)['params']


    def run(self):
        logf, writer = self._init_logging()

        opt = instantiate(self.cfg.optim)
        if self.cfg.max_grad_norm:
            opt = optax.chain(optax.clip_by_global_norm(self.cfg.max_grad_norm), opt)
        state = train_state.TrainState.create(
            apply_fn=self.potential_model.apply, params=self.params, tx=opt)

        if self.cfg.data == 'mnist':
            train_sampler = data.MNISTPairSampler(train=True, batch_size=self.cfg.batch_size, debug=False)
            self.geom = train_sampler.geom
        elif self.cfg.data == 'world':
            train_sampler = data.WorldPairSampler(batch_size=self.cfg.batch_size, debug=False)
            self.geom = train_sampler.geom
        elif self.cfg.data == 'usps28':
            train_sampler = data.USPSPairSampler(train=True, batch_size=self.cfg.batch_size, debug=False, reshape=True)
            self.geom = train_sampler.geom
        elif self.cfg.data == 'doodles':
            train_sampler = data.DoodlePairSampler(train=True, batch_size=self.cfg.batch_size, debug=False)
            self.geom = train_sampler.geom
        elif self.cfg.data == 'random':
            train_sampler = data.RandomSampler(batch_size=self.cfg.batch_size, debug=False)
            self.geom = train_sampler.geom
        else:
            assert False

        def loss_batch(params, batch):
            loss_fn = functools.partial(self.dual_obj_loss, params=params)
            loss = jax.vmap(loss_fn)(a=batch.a, b=batch.b)
            return jnp.mean(loss)

        @jax.jit
        def update(state, key):
            batch = train_sampler(key)
            grad_fn = jax.value_and_grad(loss_batch)
            loss, grads = grad_fn(state.params, batch)
            return loss, state.apply_gradients(grads=grads)

        pred_error_vmap = jax.vmap(self.pred_error, in_axes=[0, 0, None])

        @jax.jit
        def compute_train_error(params, key):
            batch = train_sampler(key)
            err = pred_error_vmap(batch.a, batch.b, params)
            return jnp.mean(err)

        start_time = time.time()
        loss_meter = utils.RunningAverageMeter()
        while self.train_iter < self.cfg.num_train_iter:
            k1, self.key = jax.random.split(self.key)
            loss, state = update(state, k1)
            loss_meter.update(loss.item())
            if self.train_iter % 100 == 0:
                self.params = state.params
                k1, self.key = jax.random.split(self.key)
                train_err = compute_train_error(state.params, k1)
                print(f'[{self.train_iter}] train_loss={loss_meter.val:.2e} train_err={train_err:.2e}')
                writer.writerow({
                    'iter': self.train_iter,
                    'train_loss': loss_meter.avg,
                    'train_err': train_err,
                    'time': time.time()-start_time,
                })
                logf.flush()
                self.save()
                losses = []
            self.train_iter += 1

    def pred_error(self, a, b, params):
        f_pred = self.potential_model.apply(
            {'params': params}, a, b)
        # Apply one sinkhorn itertaion updating both potential
        # so the error reported here is the same as the first
        # error reported from OTT after the first iteration.
        g_pred = self.g_from_f(a, b, f_pred)
        f_pred = self.geom.update_potential(
            f_pred, g_pred,
            jnp.log(a), 0, axis=1)
        # Marginal error of the a measure is ~0, no need to add
        err = sinkhorn.marginal_error(
            f_pred, g_pred, target=b,
            geom=self.geom, axis=0,
            norm_error = [1],
            lse_mode = True
        )
        return err

    def g_from_f(self, a, b, f):
        g = self.geom.update_potential(
            f, jnp.zeros_like(b),
            jnp.log(b), 0, axis=0)
        return g

    def pred_transport(self, a, b):
        f_pred = self.potential_model.apply({'params': self.params}, a, b)
        g_pred = self.g_from_f(a, b, f_pred)
        P = self.geom.transport_from_potentials(f_pred, g_pred)
        return P

    def dual_obj_from_f(self, a, b, f):
        g = self.g_from_f(a, b, f)
        g = jnp.where(jnp.isfinite(g), g, 0.)

        supp_a = a > 0
        supp_b = b > 0
        fa = supp_a * self.geom.potential_from_scaling(a)
        div_a = jnp.sum(jnp.where(supp_a, a * (f - fa), 0.0))

        gb = supp_b * self.geom.potential_from_scaling(b)
        div_b = jnp.sum(jnp.where(supp_b, b * (g - gb), 0.0))

        total_sum = jnp.sum(self.geom.marginal_from_potentials(f, g))
        dual_obj = div_a + div_b + self.geom.epsilon * (
            jnp.sum(a) * jnp.sum(b) - total_sum
        )

        return dual_obj

    def dual_obj_loss(self, a, b, params):
        f_pred = self.potential_model.apply({'params': params}, a, b)
        dual_value = self.dual_obj_from_f(a, b, f_pred)
        loss = -dual_value
        return loss

    def save(self, tag='latest'):
        path = os.path.join(self.work_dir, f'{tag}.pkl')
        with open(path, 'wb') as f:
            pkl.dump(self, f)

    def _init_logging(self):
        logf = open('log.csv', 'a')
        fieldnames = ['iter', 'time', 'train_loss', 'train_err']
        writer = csv.DictWriter(logf, fieldnames=fieldnames)
        if os.stat('log.csv').st_size == 0:
            writer.writeheader()
            logf.flush()
        return logf, writer


@hydra.main(config_path='conf', config_name='train_discrete')
def main(cfg):
    SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
    from train_discrete import Workspace as W # For pickling
    fname = os.getcwd() + '/latest.pkl'
    # TODO: Fix resuming
    # if os.path.exists(fname):
    #     print(f'Resuming fom {fname}')
    #     with open(fname, 'rb') as f:
    #         workspace = pkl.load(f)
    # else:
    #     workspace = W(cfg)
    workspace = W(cfg)

    workspace.run()


if __name__ == '__main__':
    main()
