import os
import time
from datetime import datetime
import math
import torch.utils.data
from trainer.utils import move_to, get_inner_model, collate_bl, collate
import torch
import trainer.baseline
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
import torch_geometric.utils
import torch_geometric.transforms
import torch_geometric.loader
import trainer.visualize
import matplotlib.pyplot as plt

class Runner:

    def __init__(self, model, train_data, val_data, baseline, optimizer, lr_scheduler, train_config, neptune_run=None):
        """
        Args:
            model: The model to train. Instance of Attention.
            train_data: train data can be a GraphDataset or a list thereof. ASSUMES THAT ALL DATASETS IN THE LIST HAVE EQUAL LENGTH.
            val_data: The validation data can be a GraphDataset or a list thereof.
            baseline: A RolloutBaseline wrapping the model.
            optimizer: Optimizer to use for optimization.
            lr_scheduler: Learning rate scheduler.
            train_config: TrainConfig that details aspects of the training such as batch size, learning rate. If checkpoint_dir is None, nothing will be written to the filesystem (the model is not stored).
            neptune_run: Used for logging. If None, does nothing.
        """

        self.model = model
        if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel):
            self.inner_model = model.module
        else:
            self.inner_model = model

        if not isinstance(train_data, list):
            self.train_data = [train_data]
        else:
            self.train_data = train_data

        if not isinstance(val_data, list):
            self.val_data = [val_data]
        else:
            self.val_data = val_data

        self.baseline = baseline

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self.tb_logger = None

        self.neptune_run = neptune_run

        self.train_config = train_config

        self.collate_fn = collate_bl if isinstance(self.baseline, trainer.baseline.RolloutBaseline) else collate

    # If num_samples is 0, greedy
    # If num_samples is > 0, greedy plus num_samples-1 samples
    def validate(self, num_samples=0):
        # Validate
        costs = torch.zeros(len(self.val_data), requires_grad=False, device=self.train_config.device)
        runtime = torch.zeros(len(self.val_data), requires_grad=False, device=self.train_config.device)

        for i in range(len(self.val_data)):
            bucket = self.val_data[i]

            tick = time.perf_counter_ns()

            cost = self.inner_model.rollout(bucket, graph_nodes=bucket.graph_nodes,
                                                   batch_size=self.train_config.batch_size,
                                                   device=self.train_config.device, num_samples=0)

            if num_samples > 0:
                cost_sample = self.inner_model.rollout(bucket, graph_nodes=bucket.graph_nodes, batch_size=self.train_config.batch_size, device=self.train_config.device, num_samples=num_samples-1)
                cost = torch.minimum(cost, cost_sample)

            tock = time.perf_counter_ns()

            #print('Bucket {} avg_cost: {} +- {}'.format(bucket.graph_nodes, cost, torch.std(cost) / math.sqrt(len(cost))))

            avg_cost = cost.mean().item()
            costs[i] = avg_cost
            runtime[i] = tock-tick

        return costs, runtime

    def qualitative_validation(self, log_interval=200, save_pdf=True):
        for i in range(len(self.val_data)):
            self._log_qualitative_results_for_bucket_index(i, log_interval=log_interval, save_pdf=save_pdf)


    def train(self):
        """
        Trains the model according to the parameters defined in initialization.
        """
        if self.train_config.ckpt is not None:
            self.__create_checkpoint_directory()

        validation_error_did_decrease = True

        for epoch in range(self.train_config.n_epochs):

            step = epoch * (max(1, len(self.train_data) // self.train_config.batch_size))

            epoch_start_time = time.perf_counter()

            # update blvalues in first iteration or if bl model was updated
            if validation_error_did_decrease:
                # ASSUMES THAT ALL DATASETS HAVE EQUAL LENGTH
                training_dataloaders = [torch.utils.data.DataLoader(self.baseline.wrap_dataset(datum), batch_size=self.train_config.batch_size, shuffle=True, collate_fn=self.collate_fn, drop_last=True) for datum in self.train_data]

            self.model.train()
            self.inner_model.set_choose_randomly(True)

            # train on batches
            # This implements a "bucketing" approach to dealing with graphs of varying sizes.
            # Each dataloader has graphs of the same size

            for batches in zip(*training_dataloaders):
                #loss = torch.zeros(len(batches))
                loss = torch.tensor(0)
                for j in range(len(batches)):
                    batch = batches[j]
                    loss = loss + self.__get_batch_loss(step, batch, self.train_data[j].graph_nodes)

                mean_loss = loss / torch.tensor(len(batches))
                self.__make_train_step(step, mean_loss)

            epoch_duration = time.perf_counter() - epoch_start_time

            start_epoch_callback = time.perf_counter()
            validation_error_did_decrease, val_mean = self.baseline.epoch_callback(self.model, epoch)
            epoch_callback_duration = time.perf_counter() - start_epoch_callback

            # lr_scheduler should be called at end of epoch
            self.lr_scheduler.step()

            if (validation_error_did_decrease and self.train_config.n_epochs - epoch <= self.train_config.checkpoint_last_epochs) or epoch == self.train_config.n_epochs-1:
                if self.train_config.ckpt is not None:
                    self._make_checkpoint(self.save_dir, epoch)

            if self.neptune_run is not None:
                self.neptune_run["val/cost"].log(val_mean)
                self.neptune_run["performance/epoch_time"].log(epoch_duration)
                self.neptune_run["performance/baseline_time"].log(epoch_callback_duration)
                # Hardcoded
                if epoch % 20 == 0:
                    self.qualitative_validation(save_pdf=False)

        return 1

    def __get_batch_loss(self, step, batch, graph_nodes):
        x, bl_val = self.baseline.unwrap_batch(batch)

        x = move_to(x, self.train_config.device)
        bl_val = move_to(bl_val, self.train_config.device) if bl_val is not None else None

        # Evaluate model, get costs and log probabilities
        cost, _, log_likelihood = self.model(x, graph_nodes)

        # Loss sqrt-ing to stabilize loss
        if self.train_config.use_sqrt_cost:
            loss = ((torch.sqrt(cost) - torch.sqrt(bl_val)) * log_likelihood).mean()
        else:
            loss = ((cost - bl_val) * log_likelihood).mean()

        if step % int(self.train_config.log_step) == 0 and self.neptune_run is not None:
            avg_cost = cost.mean().item()
            self.neptune_run["train/cost"].log(avg_cost)
            self.neptune_run["train/neg_log_likelihood"].log(-log_likelihood.mean().item())

        return loss

    def __make_train_step(self, step, loss):
        # Perform backward pass and optimization step
        self.optimizer.zero_grad()
        loss.backward()

        # Clip gradient norms and get (clipped) gradient norms for logging
        grad_norms = clip_grad_norms(self.optimizer.param_groups, self.train_config.clip_grad_norm)

        self.optimizer.step()

        # Logging
        if step % int(self.train_config.log_step) == 0 and self.neptune_run is not None:
            grad_norm_og, grad_norm_clipped = grad_norms

            self.neptune_run["train/grad_norm"].log(grad_norm_og)
            self.neptune_run["train/grad_norm_clip"].log(grad_norm_clipped)
            self.neptune_run["train/loss"].log(loss)

    def _train_batch(self, step, batch):
        loss = self.__get_batch_loss(step, batch)
        self.__make_train_step(step, loss)

    def __create_checkpoint_directory(self):
        # create file for this training run
        decoding_code = {'local': 'L', 'global': 'G', 'static': 'S'}.get(self.inner_model.decoding_type, None)
        assert decoding_code is not None, "The decoding type {} is not supported".format(self.inner_model.decoding_type)

        if self.inner_model.encoder_type == 'GCN':
            decoding_code = decoding_code + '-GCN'

        timestamp = datetime.now().strftime("%d_%m_%Y_%H:%M:%S")
        graph_type = self.train_data[0].graph_type
        model_name = 'P-{}-{}-{}-{}'.format(graph_type, decoding_code, self.train_config.rng_seed, timestamp)

        save_dir = os.path.join(
            self.train_config.ckpt,
            "{}_{}".format(self.inner_model.problem.NAME, self.train_data[0].graph_nodes),
            model_name
        )
        os.makedirs(save_dir)
        self.save_dir = save_dir


    def _log_qualitative_results_for_bucket_index(self, i, log_interval=50, save_pdf=False):
        bucket = self.val_data[i]

        loader = torch_geometric.loader.DataLoader(bucket, batch_size=1, shuffle=False)

        n = bucket.graph_nodes

        # Do not evaluate very large graphs
        if n > 200:
            return

        plt.ioff()
        for i, graph in enumerate(loader):
            if i % log_interval == 0:

                with torch.no_grad():
                    costs, order, _ = self.model(graph, n)

                nx_graph = torch_geometric.utils.to_networkx(graph, to_undirected=True)

                trainer.visualize.visualize_result(costs, order, nx_graph, self.inner_model.problem, i, neptune_run=self.neptune_run, save_pdf=save_pdf)

        plt.close('all')



    def _make_checkpoint(self, save_dir, epoch):
        save_name = os.path.join(save_dir, 'epoch-{}.pt'.format(epoch))
        torch.save(get_inner_model(self.model).state_dict(), save_name)

        if self.neptune_run is not None:
            decoding_code = {'local': 'L', 'global': 'G', 'static': 'S'}.get(self.inner_model.decoding_type, None)
            self.neptune_run['model_checkpoints/{}-{}-epoch-{}'.format(self.train_config.rng_seed, decoding_code, epoch)].upload(save_name)


def clip_grad_norms(param_groups, max_norm=math.inf):
    grad_norms = [
        torch.nn.utils.clip_grad_norm_(
            group['params'],
            max_norm if max_norm > 0 else math.inf,
            norm_type=2
        )
        for group in param_groups
    ]
    grad_norms_clipped = [min(g_norm, max_norm) for g_norm in grad_norms] if max_norm > 0 else grad_norms
    return grad_norms, grad_norms_clipped