import time

import torch
import trainer.baseline
from trainer.dataset import GraphDataset, default_graph_types, generate_dataset
import problems.GC as GC
import problems.MVC as MVC
try:
    import neptune.new as neptune
except ImportError:
    pass
from problems.GC.color_utils import get_heuristic_coloring
from problems.MVC.mvc_utils import get_heuristic_cover
import numpy as np
from trainer.runner import Runner


def train_multi_size(train_n: list, val_n: list, train_config, train_samples=10000, val_samples=2000, initialization="degree", graph_types=None, problem=GC, train_seed=146293, val_seed=238, neptune_run=None, rng_seed=123, model=None):

    torch.manual_seed(rng_seed)
    np.random.seed(rng_seed)

    if model is None:
        model = trainer.model.make_default_model(problem)

    if neptune_run is not None:
        neptune_run["parameters/encoder_depth"] = 3
        neptune_run["parameters/n_heads"] = model.n_heads
        neptune_run["parameters/rng_seed"] = rng_seed

    return train_model_multi_size(model, train_n, val_n, train_config, train_samples, val_samples, initialization, graph_types, problem, train_seed, val_seed, neptune_run)


def train_model_multi_size(model, train_n: list, val_n: list, train_config, train_samples=10000, val_samples=2000, initialization="degree", graph_types=None, problem=GC, train_seed=146293, val_seed=238, neptune_run=None):

    lr_decay = 1.0

    optimizer = torch.optim.Adam([{'params': model.parameters(), 'lr': train_config.lr}])

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: lr_decay ** epoch)

    if graph_types is None:
        graph_types = default_graph_types()

    model.to(train_config.device)

    if neptune_run is not None:
        neptune_run["parameters/lr"] = train_config.lr

        neptune_run["parameters/initialization"] = initialization
        neptune_run["parameters/graph_types"] = str(graph_types)
        neptune_run["parameters/train_seed"] = train_seed
        neptune_run["parameters/val_seed"] = val_seed
        neptune_run["parameters/train_samples"] = train_samples
        neptune_run["parameters/val_samples"] = val_samples
        neptune_run["parameters/train/n_epochs"] = train_config.n_epochs
        neptune_run["parameters/train/train_n"] = str(train_n)
        neptune_run["parameters/train/val_n"] = val_n
        neptune_run["parameters/problem"] = str(problem)

        train_config.log(neptune_run)


    validation_data = []
    for n in val_n:
        data_val = generate_dataset(n, graph_types, val_samples, val_seed)
        data_val.set_initial_features(embed_dim=model.embed_dim, initialization=initialization)
        validation_data.append(data_val)

        # Log Baseline
        if n <= 1000:
            log_baseline(neptune_run, data_val, "val", problem)

    data = []
    for n in train_n:

        datum = generate_dataset(n, graph_types, train_samples, train_seed)
        datum.set_initial_features(embed_dim=model.embed_dim, initialization=initialization)
        data.append(datum)

        #log_baseline(neptune_run, datum, "train/"+str(n), problem)

    baseline = trainer.baseline.RolloutBaseline(model, validation_data, device=train_config.device)

    runner = trainer.runner.Runner(model, data, validation_data, baseline=baseline, optimizer=optimizer,
                                       lr_scheduler=lr_scheduler, train_config=train_config, neptune_run=neptune_run)

    runner.train()

    return model


def log_baseline(neptune_run, data, prefix, problem):
    if neptune_run is None:
        return

    if problem == GC:
        strategies = ['smallest_last', 'largest_first', 'DSATUR']
        for strategy in strategies:
            heuristic_coloring = []
            timing = []
            for graph in data:
                tic = time.perf_counter_ns()

                col = get_heuristic_coloring(graph, strategy=strategy)

                toc = time.perf_counter_ns()
                heuristic_coloring.append(col)
                timing.append((toc-tic) / 1000000.0)

            if neptune_run is not None:
                neptune_run[prefix+"/greedy/"+strategy] = np.asarray(heuristic_coloring).mean()
                neptune_run[prefix+"/greedy/"+strategy+"_time_ms"].log(timing)

    elif problem == MVC:
        strategies = ['iterative', 'list_right']
        for strategy in strategies:
            heuristic = [get_heuristic_cover(graph, strategy=strategy) for graph in data]
            if neptune_run is not None:
                neptune_run[prefix+"/greedy/"+strategy] = np.asarray(heuristic).mean()
