import torch
from trainer.model import Attention
import numpy as np
from os.path import join
import os.path
import trainer.dataset
import trainer.runner
import trainer.trainer
import problems.GC as GC
import problems.MVC as MVC
from build_memetracker import build_memetracker_graph
from trainer.trainconfig import TrainConfig
import pickle
import time
try:
    import neptune.new as neptune
except ImportError:
    pass


def load_dimacs_graphs(directory):
    files = os.listdir(directory)

    data = [trainer.dataset.GraphDataset(dimacs_filename=join(directory, f)) for f in files]
    names = files

    return data, names


def test_model_file_on_synthetic_graphs(model_path, graph_size, graphs_per_size=2000, graph_types=None, problem=GC,
                                        initialization='positional', decoding_type='local', n_heads=4, throwback=1,
                                        shortcuts=True, normalize=True, encoder_layers=3,
                                        num_samples=0, seed=79146,
                                        neptune_run=None, log_interval=10,
                                        qualitative=True,
                                        log_baseline=False):

    model = trainer.model.make_default_model(problem, initialization=initialization, decoding_type=decoding_type,
                                             n_heads=n_heads, throwback=throwback,
                                             encoder_layers=encoder_layers,
                                             shortcuts=shortcuts, normalize=normalize)

    model.load_state_dict(torch.load(model_path))

    if graph_types is None:
        graph_types = trainer.dataset.default_graph_types()

    test_data = [trainer.dataset.generate_dataset(n, graph_types, graphs_per_size, seed).set_initial_features(embed_dim=model.embed_dim, initialization=initialization) for n in graph_size]

    if log_baseline:
        [trainer.trainer.log_baseline(neptune_run, test_data[i].data, prefix=str(test_data[i].graph_nodes), problem=problem) for i in range(len(graph_size))]

    cost = test_model_on_graphs(model, test_data=test_data, num_samples=num_samples, seed=seed,
                                          neptune_run=neptune_run, log_interval=log_interval, qualitative=qualitative)

    if neptune_run is not None:
        neptune_run["parameters/checkpoint_path"] = model_path
        neptune_run["parameters/problem"] = problem.NAME
        neptune_run["parameters/num_samples"] = num_samples
        neptune_run["parameters/initialization"] = initialization
        neptune_run["parameters/graph_size"] = str(graph_size)

    return cost

def test_memetracker_multiple_models(model_directory, graph_path, problem=MVC,
                                      initialization='positional', decoding_type='local', n_heads=4, throwback=1,
                                      num_samples=0, seed=25930, neptune_run=None):

    files = os.listdir(model_directory)

    results = [test_memetracker(join(model_directory, f), graph_path=graph_path, initialization=initialization,
                                decoding_type=decoding_type, n_heads=n_heads, throwback=throwback,
                                num_samples=num_samples, seed=seed,
                                problem=problem,
                                neptune_run=neptune_run).detach().cpu().numpy() for f in files]

    average_cost = np.asarray([np.mean(c) for c in results])

    if neptune_run is not None:
        neptune_run["result/cost"] = average_cost.mean()
        neptune_run["result/std"] = average_cost.std()
        neptune_run["parameters/model_directory"] = model_directory
        neptune_run["parameters/graph_directory"] = graph_path

    return average_cost, average_cost.mean(), average_cost.std()


def test_memetracker(model_path, graph_path, problem=MVC,
                    initialization='positional', decoding_type='local', n_heads=4, throwback=1,
                    num_samples=0, seed=25930, neptune_run=None):

    model = trainer.model.make_default_model(problem, initialization=initialization,
                                             decoding_type=decoding_type, n_heads=n_heads, throwback=throwback)
    model.load_state_dict(torch.load(model_path))

    memetracker, node_dict = build_memetracker_graph(graph_path, 'undirected')

    test_data = trainer.dataset.GraphDataset(nx_graph=memetracker)
    test_data.set_initial_features(embed_dim=model.embed_dim, initialization=initialization)

    cost = test_model_on_graphs(model, test_data=test_data, num_samples=num_samples, seed=seed,
                                neptune_run=neptune_run, log_interval=1)


    return cost


def test_dimacs_files_multiple_models(model_directory, graph_directory, problem=GC,
                                      initialization='positional', decoding_type='local', n_heads=4, throwback=1,
                                      num_samples=0, seed=25930, neptune_run=None):

    files = os.listdir(model_directory)

    results = [test_dimacs_files(join(model_directory, f), directory=graph_directory, initialization=initialization,
                                 decoding_type=decoding_type, n_heads=n_heads, throwback=throwback,
                                 num_samples=num_samples, seed=seed,
                                 problem=problem,
                                 neptune_run=neptune_run) for f in files]

    results = [c.detach().cpu().numpy() for (c, n) in results]

    average_cost = np.asarray([np.mean(c) for c in results])

    if neptune_run is not None:
        neptune_run["result/cost"] = average_cost.mean()
        neptune_run["result/std"] = average_cost.std()
        neptune_run["parameters/model_directory"] = model_directory
        neptune_run["parameters/graph_directory"] = graph_directory

    return average_cost, average_cost.mean(), average_cost.std()


def test_dimacs_files(model_path, directory, problem=GC,
                      initialization="positional", decoding_type='local', n_heads=4, throwback=1,
                      num_samples=0, seed=234146, neptune_run=None):

    np.random.seed(seed)

    model = trainer.model.make_default_model(problem, initialization=initialization,
                                             decoding_type=decoding_type, n_heads=n_heads, throwback=throwback)

    model.load_state_dict(torch.load(model_path))

    test_data, names = load_dimacs_graphs(directory, initialization=initialization)
    for d in test_data:
        d.set_initial_features(embed_dim=model.embed_dim, initialization=initialization)

    cost = test_model_on_graphs(model, test_data=test_data, num_samples=num_samples, seed=seed,
                                neptune_run=neptune_run, log_interval=1)

    if neptune_run is not None:
        neptune_run["parameters/checkpoint_path"].log(model_path)
        neptune_run["parameters/problem"] = problem.NAME
        neptune_run["parameters/num_samples"] = num_samples
        neptune_run["parameters/initialization"] = initialization
        neptune_run["results/cost"] = str(zip(names, cost.numpy().tolist()))

    return cost, names


def test_pickled_graphs_multiple_models(model_directory, graph_directory, reference_directory,
                                        initialization="positional", num_samples=0, seed=234149,
                                        problem=MVC, neptune_run=None, log_interval=100,
                                        n_graphs=1000,
                                        qualitative=False,
                                        baseline_function=None):

    if model_directory is not None:
        models = os.listdir(model_directory)
    else:
        models = []

    graph_files = [i for i in sorted(os.listdir(graph_directory)) if not i.startswith('.')]
    reference_files = [i for i in sorted(os.listdir(reference_directory)) if not i.startswith('.')]

    costs = []
    ratios = []
    for m in models:
        model_path = join(model_directory, m)

        costs.append([])
        ratios.append([])
        for i in range(len(graph_files)):

            cost, opt = test_pickled_graphs(model_path, join(graph_directory, graph_files[i]), join(reference_directory, reference_files[i]),
                                            initialization=initialization, num_samples=num_samples, seed=seed, problem=problem,
                                            neptune_run=neptune_run, log_interval=log_interval, n_graphs=n_graphs,
                                            qualitative=qualitative)

            costs[len(costs)-1].append(cost)
            ratios[len(ratios)-1].append(opt)

    if baseline_function is not None:
        assert models == []
        costs.append([])
        ratios.append([])
        for i in range(len(graph_files)):

            cost, opt = test_pickled_graph_baseline(baseline_function,
                                                    join(graph_directory, graph_files[i]),
                                                    join(reference_directory, reference_files[i]),
                                                    neptune_run=neptune_run, n_graphs=n_graphs)

            costs[0].append(cost)
            ratios[0].append(opt)

    costs_np = np.array(costs)
    ratios_np = np.array(ratios)

    cost_mean = costs_np.mean(axis=0)
    ratios_mean = ratios_np.mean(axis=0)
    cost_std = costs_np.std(axis=0)

    cost_mean_overall = cost_mean.mean()
    cost_std_overall = costs_np.mean(axis=1).std()
    ratios_mean_total = ratios_mean.mean()

    if neptune_run is not None:
        neptune_run["results/cost_mean_per_size"] = cost_mean.tolist()
        neptune_run["results/cost_std_per_size"] = cost_std.tolist()
        neptune_run["results/opt_ratio_mean_per_size"] = ratios_mean.tolist()
        neptune_run["results/cost_mean_overall"] = cost_mean_overall
        neptune_run["results/cost_std_overall"] = cost_std_overall
        neptune_run["results/opt_ratio_mean_total"] = ratios_mean_total

    return cost_mean, ratios_mean, cost_std, cost_std_overall


def test_pickled_graphs(model_path, graph_path, reference_path=None, initialization="positional", num_samples=0, seed=234149, problem=MVC, neptune_run=None, log_interval=100, n_graphs=1000, qualitative=True):
    model = trainer.model.make_default_model(problem, initialization=initialization)
    model.load_state_dict(torch.load(model_path))

    raw_data = trainer.dataset.read_pickle_graphs(graph_path, n_graphs)
    test_data = [trainer.dataset.GraphDataset(nx_graph=g).set_initial_features(model.embed_dim, initialization) for g in raw_data]

    cost = test_model_on_graphs(model, test_data=test_data, num_samples=num_samples, seed=seed,
                                neptune_run=neptune_run, log_interval=log_interval, qualitative=qualitative)


    # Compute optimality ratio
    avg_cost = cost.numpy().mean()
    opt_ratio, avg_opt = _get_opt_ratio_from_reference(reference_path, cost)

    if neptune_run is not None:
        neptune_run["parameters/model_directory"] = model_path
        neptune_run["parameters/graph_file"] = graph_path
        neptune_run["parameters/problem"] = problem.NAME
        neptune_run["parameters/initialization"] = initialization
        neptune_run["results/opt_ratio_mean"].log(avg_opt)

    return avg_cost, avg_opt


def _get_opt_ratio_from_reference(reference_path: str, cost: list):
    opt_ratio = []
    if reference_path is not None:
        print(reference_path)
        with open(reference_path, 'rb') as f:
            reference = pickle.load(f)

        opt_ratio = [i / j for (i, j) in zip(cost, reference['target'])]
    avg_opt = np.asarray(opt_ratio).mean()

    return opt_ratio, avg_opt


def test_pickled_graph_baseline(baseline_function, graph_path, reference_path=None, neptune_run=None, n_graphs=1000):

    raw_data = trainer.dataset.read_pickle_graphs(graph_path, n_graphs)
    cost = [baseline_function(g) for g in raw_data]
    opt, avg_opt = _get_opt_ratio_from_reference(reference_path, cost)

    avg_cost = np.asarray(cost).mean()
    if neptune_run is not None:
        neptune_run["parameters/baseline_function"] = str(baseline_function)
        neptune_run["parameters/graph_file"] = graph_path
        neptune_run["val/cost"].log(cost)
        neptune_run["val/average_cost"].log(avg_cost)
        neptune_run["results/opt_ratio_mean"].log(avg_opt)

    return avg_cost, avg_opt


def test_model_on_graphs(model, test_data, num_samples=0, seed=79146, neptune_run=None, log_interval=10, qualitative=True):

    if neptune_run is not None:
        neptune_run["parameters/num_samples"] = num_samples

    torch.manual_seed(seed)
    np.random.seed(seed)

    train_config = TrainConfig(0)

    runner = trainer.runner.Runner(model, [], test_data, None, None, None, train_config, neptune_run)

    costs, runtime = runner.validate(num_samples=num_samples)

    if qualitative:
        runner.qualitative_validation(log_interval=log_interval)

    if neptune_run is not None:
        neptune_run["val/cost"].log(costs.numpy().tolist())
        neptune_run["val/average_cost"].log(costs.mean().item())
        neptune_run["parameters/seed"] = seed
        neptune_run["val/n"].log(test_data[0].graph_nodes)

        runtimes = (runtime.numpy() / 1000000.0).tolist()
        for i in range(len(runtimes)):
            #neptune_run["performance/ns/" + str(i)].log(runtimes[i])
            neptune_run["performance/ms/" + str(i)].log(runtimes[i])
        #neptune_run["performance/average_runtime_ms"].log(runtime.mean().item() / 1000.0)
        #neptune_run["runtime/std_runtime_ms"].log(runtime.std().item() / 1000.0)

    return costs


