import argparse
from pathlib import Path

import numpy as np
import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.axes as axes
from typing import List
from losses import LossMode, DataMode, get_base_loss
import json
import hashlib

from train import ExperimentData, run_experiment


def flatten_dict(data):
    result = {}
    for key in data:
        if isinstance(data[key], dict):
            for entry in data[key]:
                result[key + "/" + entry] = data[key][entry]
        else:
            result[key] = data[key]
    return result


def run_and_cache_experiment(cache: dict, data, train_mode: str, loss_mode: str, normalized: bool, base_loss: str,
                             l2_reg: float, index: int, num_epochs: int, pretraining: int, learning_rate: float):
    config = {
        "data": train_mode,
        "mode": loss_mode,
        "loss": base_loss,
        "normalized": normalized,
        "l2_reg": l2_reg,
        "index": index,
        "lr": learning_rate
    }

    identifier = hashlib.sha256(json.dumps(config).encode()).hexdigest()
    if identifier in cache:
        return

    results = run_experiment(data=data, train_mode=train_mode, loss_mode=loss_mode, base_loss=base_loss,
                             normalized=normalized, l2_reg=l2_reg, num_epochs=num_epochs, pretraining=pretraining,
                             learning_rate=learning_rate,
                             sparse=False, batch_size=64, check_k=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), hidden_layer=-1,
                             shuffle_buffer=1000000)
    results["config"] = config
    results = flatten_dict(results)
    cache[identifier] = results


def make_menon_fake(num_samples):
    x_1 = np.random.normal((1, 1), 1.0, (num_samples//2, 2))
    y_1 = np.zeros((num_samples//2, 10))
    select = np.random.randint(0, 2, num_samples)
    for i in range(num_samples//2):
        if select[i] == 0:
            y_1[i, :] = [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
        else:
            y_1[i, :] = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]

    x_2 = np.random.normal((-1, -1), 1.0, (num_samples//2, 2))
    y_2 = np.zeros((num_samples, 10))
    select = np.random.randint(0, 2, num_samples)
    for i in range(num_samples // 2):
        if select[i] == 0:
            y_2[i, :] = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
        else:
            y_2[i, :] = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]

    x = np.concatenate([x_1, x_2], axis=0)
    y = np.concatenate([y_1, y_2], axis=0)
    return x, y


def to_tf_pipeline(features, labels, propensities):
    features = tf.data.Dataset.from_tensor_slices(features)
    if propensities is not None:
        labels = labels * tf.cast(tf.less(tf.random.uniform(shape=tf.shape(labels)), propensities[None, :]), tf.float32)

    labels = tf.data.Dataset.from_tensor_slices(labels)
    dataset = tf.data.Dataset.zip((features, labels))
    return dataset


def make_dataset(num_train, num_val, num_test, propensities):
    train_ds = make_menon_fake(num_train)
    val_ds = make_menon_fake(num_val)
    test_ds = make_menon_fake(num_test)

    def clean(x):
        return to_tf_pipeline(x[0], x[1], None)

    def noisy(x):
        return to_tf_pipeline(x[0], x[1], propensities)

    return ExperimentData(
        propensities=propensities,
        clean_train=clean(train_ds).repeat(10),
        clean_val=clean(val_ds),
        clean_test=clean(test_ds),
        noisy_train=noisy(train_ds).repeat(10),
        noisy_val=noisy(val_ds),
        noisy_test=noisy(test_ds)
    )


def run_all_settings(index: int, losses: List[str]):
    propensities = 1.0 / np.linspace(2.0, 20.0, num=10)
    data = make_dataset(10000, 1000, 1000, propensities)

    def rce(train_mode: str, loss_mode: str, normalized: bool, base_loss: str):
        kind, _ = get_base_loss(base_loss)
        # multiclass bound == multiclass unbiased, so no need to run twice
        if kind == "multiclass" and loss_mode == LossMode.BOUND and not normalized:
            return

        run_and_cache_experiment(result_cache, data, train_mode, loss_mode, normalized, base_loss, 0.0, index,
                                 num_epochs=20, pretraining=0, learning_rate=1e-2)
        result_path.write_text(json.dumps(result_cache, indent=2))

    def run_all(base_loss, normalized):
        rce(train_mode=DataMode.CLEAN, loss_mode=LossMode.VANILLA, base_loss=base_loss, normalized=normalized)
        #rce(train_mode=DataMode.NOISY, loss_mode=LossMode.VANILLA, base_loss=base_loss, normalized=normalized)
        #rce(train_mode=DataMode.NOISY, loss_mode=LossMode.UNBIASED, base_loss=base_loss, normalized=normalized)
        #rce(train_mode=DataMode.NOISY, loss_mode=LossMode.BOUND, base_loss=base_loss, normalized=normalized)

    for loss in losses:
        run_all(base_loss=loss, normalized=False)
        run_all(base_loss=loss, normalized=True)


def plot_result(data, p_ax: axes.Axes, r_ax: axes.Axes, label):
    ks = np.arange(1, 10)
    p_at_k = []
    r_at_k = []

    for k in ks:
        p_at_k.append(data[f"clean-test/P@{k}"])
        r_at_k.append(data[f"clean-test/R@{k}"])

    p_ax.plot(ks, p_at_k, label=label)
    r_ax.plot(ks, r_at_k, label=label)


if __name__ == "__main__":
    parser = argparse.ArgumentParser("runner")
    parser.add_argument("--result-path", type=str, default="fake-menon.json")
    parser.add_argument("--losses", type=str, default="cce")
    args = parser.parse_args()

    losses = [x.strip() for x in args.losses.split(",")]

    result_path = Path(args.result_path)
    if result_path.exists():
        result_cache = json.loads(result_path.read_text())
    else:
        result_cache = {}
    for i in range(1):
        run_all_settings(index=i, losses=losses)

    _, (p_ax, r_ax) = plt.subplots(1, 2)
    data = result_cache["920fdbae6ea19de85743c76d65062a2969099e1bc24793c379f8b674db020834"]
    plot_result(data, p_ax, r_ax, label="cce")
    data = result_cache["cb97a63c4ff558f6d9dadd6c6b33c21b594f13eb0ca3be27a14406529d43dcd2"]
    plot_result(data, p_ax, r_ax, label="cce-n")
    p_ax.legend()
    plt.show()



