import json
import hashlib
from typing import List

import numpy as np
from pathlib import Path
import pickle
import argparse

import mlds
from train import generate_datasets, run_experiment, to_tf_pipeline, ExperimentData
from losses import LossMode, DataMode, get_base_loss


def flatten_dict(data):
    result = {}
    for key in data:
        if isinstance(data[key], dict):
            for entry in data[key]:
                result[key + "/" + entry] = data[key][entry]
        else:
            result[key] = data[key]
    return result


def run_and_cache_experiment(cache: dict, data, train_mode: str, loss_mode: str, normalized: bool, base_loss: str,
                             l2_reg: float, index: int, num_epochs: int, pretraining: int, **kwargs):
    # TODO next chance, add learning_rate to config
    config = {
        "data": train_mode,
        "mode": loss_mode,
        "loss": base_loss,
        "normalized": normalized,
        "l2_reg": round(l2_reg, 20),
        "index": index
    }

    identifier = hashlib.sha256(json.dumps(config).encode()).hexdigest()
    if identifier in cache:
        return

    results = run_experiment(data=data, train_mode=train_mode, loss_mode=loss_mode, base_loss=base_loss,
                             normalized=normalized, l2_reg=l2_reg, num_epochs=num_epochs, pretraining=pretraining,
                             **kwargs)
    results["config"] = config
    results = flatten_dict(results)
    cache[identifier] = results


def load_dataset(train_file: str, test_file: str, propensities: np.ndarray):
    train_ds = mlds.read_data(train_file)
    test_ds = mlds.read_data(test_file)

    train_ds, val_ds = mlds.split_dataset(train_ds, 0.7)

    train_ds = mlds.to_sparse(train_ds)
    val_ds = mlds.to_sparse(val_ds)
    test_ds = mlds.to_sparse(test_ds)

    return ExperimentData(
        propensities=propensities,
        clean_train=None,
        clean_val=None,
        clean_test=None,
        noisy_train=to_tf_pipeline(train_ds),
        noisy_val=to_tf_pipeline(val_ds),
        noisy_test=to_tf_pipeline(test_ds)
    )


def compute_prop(data_path, prop_path, a=0.55, b=1.5):
    if prop_path.exists():
        print("Loading propensity scores")
        inv_prop = np.load(prop_path)
        return inv_prop

    f = open(data_path, "r")
    header = f.readline().split(" ")
    num_samples = int(header[0])
    num_labels = int(header[2])

    inv_prop = np.zeros(num_labels)
    print("Computing propensity scores")
    for _ in range(num_samples):
        sample = f.readline().rstrip('\n')
        labels = sample.split(" ", 1)[0]
        if labels == "":
            continue
        labels = [int(label) for label in labels.split(",")]
        inv_prop[labels] += 1.0
    f.close()

    c = (np.log(num_samples) - 1) * np.power(b + 1, a)
    inv_prop = 1 + c * np.power(inv_prop + b, -a)
    np.save(prop_path, 1.0 / inv_prop)

    return inv_prop


def run_settings(base_path: Path, result_path: Path, regularization_range: dict, step_size: float = 1, **kwargs):
    source_path = base_path / "train.txt"
    test_path = base_path / "test.txt"
    propensities = compute_prop(source_path, base_path / "prop.npy")

    data = load_dataset(str(source_path), str(test_path), propensities)
    if result_path.exists():
        result_cache = json.loads(result_path.read_text())
    else:
        result_cache = {}

    def rce(loss_mode: str, normalized: bool, base_loss: str):
        start, end = regularization_range[(base_loss, normalized)]
        for reg in 10.0**np.arange(start, end, step=step_size):
            run_and_cache_experiment(result_cache, data, DataMode.NOISY, loss_mode, normalized, base_loss, reg, 1,
                                     pretraining=0, **kwargs)
            result_path.write_text(json.dumps(result_cache, indent=2))

    rce(loss_mode=LossMode.VANILLA,  base_loss="bce", normalized=False)
    rce(loss_mode=LossMode.UNBIASED, base_loss="bce", normalized=False)
    rce(loss_mode=LossMode.BOUND,    base_loss="bce", normalized=False)
    rce(loss_mode=LossMode.VANILLA,  base_loss="bce", normalized=True)
    rce(loss_mode=LossMode.BOUND,    base_loss="bce", normalized=True)

    rce(loss_mode=LossMode.VANILLA,  base_loss="cce", normalized=False)
    rce(loss_mode=LossMode.UNBIASED, base_loss="cce", normalized=False)
    rce(loss_mode=LossMode.VANILLA,  base_loss="cce", normalized=True)
    rce(loss_mode=LossMode.BOUND,    base_loss="cce", normalized=True)
