import pandas as pd

from time import time
from typing import Literal, Union
from flax.training.train_state import TrainState
from jax._src.random import PRNGKey
from jax.experimental.shard_map import shard_map
from omegaconf import DictConfig
from functools import partial
import numpy as np
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P

from fair_dp_sgd.utils.jax_utils import get_inference_function
from fair_dp_sgd.training.loss import get_loss
from fair_dp_sgd.algorithm import get_update_rule, get_algorithm_artifacts
from jax import jit
import jax


def train_and_evaluate(
        cfg: DictConfig,
        state: TrainState,
        train_stream,
        test_data,
        val_data,
        rng: PRNGKey,
        return_: Literal["state", "metrics"] = "metrics"
) -> Union[TrainState, pd.DataFrame]:

    update_rule = get_update_rule(cfg)
    per_sample_loss = get_loss(cfg, state)
    per_sample_training_loss = jax.jit(partial(per_sample_loss, train=True))
    inference_function = get_inference_function(state)

    key, rng = jax.random.split(rng)
    metadata = (cfg, per_sample_training_loss, inference_function)
    _train_stream = iter(train_stream())
    artifacts = get_algorithm_artifacts(cfg, key)

    device_array = np.array(jax.devices())
    mesh = Mesh(device_array, ("batch",))
    in_specs = (P(), P("batch"), P(), P(), P(), P())
    out_specs = (P(), P("batch"), P())

    times = []
    update_rule = partial(update_rule, metadata)
    sharded_map = shard_map(update_rule, mesh, in_specs, out_specs)
    update_rule = jit(sharded_map)

    for i in range(1, cfg.training_params.number_of_steps + 1):
        try:
            batch, number_of_samples = next(_train_stream)
        except StopIteration:
            _train_stream = iter(train_stream())
            batch, number_of_samples = next(_train_stream)

        start = time()
        next_rng, rng = jax.random.split(rng, 2)
        state, train_metadata, artifacts = update_rule(
            state, batch, next_rng, i, artifacts, number_of_samples
        )
        state.params["Dense_0"]["kernel"].block_until_ready()
        end = time()

        times.append(end - start)
    return times