from __future__ import print_function

import os
import copy
import tqdm
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from collections import defaultdict
from scipy.signal import medfilt

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from train.pytorch_wrapper.eval_hook import EvalHook
from train.pytorch_wrapper.utils import BColors, EpochLogger, is_better

# init color printer
col = BColors()


class Network(object):
    """
    Network Class
    """

    def __init__(self, model):
        self.net = model
        self.best_model = None

    def fit(self, train_loader, validation_loader, train_strategy,
            train_validation_loader=None, dump_file=None, log_file=None, resume_training=False,
            eval_hook: EvalHook = None):

        # put stuff to device
        train_strategy = self._to_device(train_strategy)

        # initialize epoch logger
        logger = EpochLogger()

        # check if out_path exists
        if dump_file is not None:
            out_path = os.path.dirname(dump_file)
            if out_path != '' and not os.path.exists(out_path):
                os.mkdir(out_path)

        # log model evolution
        if log_file is not None:
            out_path = os.path.dirname(log_file)
            if out_path != '' and not os.path.exists(out_path):
                os.mkdir(out_path)

        # prepare for resuming training
        if resume_training:
            print(col.print_colored("Resuming training from previous model state ...", col.WARNING))
            logger.load(log_file)
            self.load(dump_file)

            # set learning rate
            previous_lr = logger.epoch_stats["tr_learningrate"][-1]
            for param_group in train_strategy.optimizer.param_groups:
                param_group['lr'] = previous_lr

            previous_epoch = len(logger.epoch_stats["tr_learningrate"])
        else:
            previous_epoch = 0

        # iterate epochs
        best_params = None
        last_improvement = 0
        for epoch in range(previous_epoch, train_strategy.num_epochs):

            # lr scheduler step
            if train_strategy.lr_scheduler and not isinstance(train_strategy.lr_scheduler, ReduceLROnPlateau) and \
                    epoch > previous_epoch:
                train_strategy.lr_scheduler.step()

            # get epoch learning rate
            epoch_lr = train_strategy.optimizer.param_groups[0]['lr']

            # training
            # --------
            self.net.train()

            for i, data in enumerate(tqdm.tqdm(train_loader,
                                               desc='Tr-Epoch %d/%d' % (epoch + 1, train_strategy.num_epochs))):
                inputs, targets = data['inputs'], data['targets']

                # put data to appropriate device
                for key in inputs.keys():
                    if isinstance(inputs[key], torch.Tensor):
                        inputs[key] = inputs[key].to(train_strategy.device)

                for key in targets.keys():
                    if isinstance(targets[key], torch.Tensor):
                        targets[key] = targets[key].to(train_strategy.device)

                # zero the parameter gradients
                train_strategy.optimizer.zero_grad()

                # forward pass
                outputs = self.net(inputs)

                loss = 0
                for name, weight, criterion in train_strategy.criterion:
                    # compute weighted local training loss
                    current_loss = weight * criterion(outputs, targets)
                    logger.append("tr_%s_running" % name, current_loss.item())

                    # accumulate total loss
                    loss += current_loss

                # perform backward pass
                loss.backward()

                # apply gradient clipping
                if train_strategy.clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.net.parameters(), train_strategy.clip_grad_norm)

                # perform gradient step
                train_strategy.optimizer.step()

                # collect training loss
                if len(train_strategy.criterion) > 1:
                    logger.append("tr_loss_total_running", loss.item())

                # don't train in first epoch to have a plain evaluation round
                if epoch == previous_epoch:
                    break

            # evaluating
            # ----------
            if validation_loader is not None:
                self.net.eval()
                with torch.no_grad():

                    # collect validation loader
                    validation_sets = [("va", validation_loader)]
                    if train_validation_loader:
                        validation_sets.append(("tr", train_validation_loader))

                    # init lists for model outputs and targets
                    all_outputs, all_targets = dict(), dict()
                    for set_name, _ in validation_sets:
                        all_outputs[set_name], all_targets[set_name] = defaultdict(list), defaultdict(list)

                    # iterate validation sets
                    for set_name, data_loader in validation_sets:

                        # iterate validation data
                        for i, data in enumerate(tqdm.tqdm(data_loader,
                                                           desc='Va-Epoch %d/%d' % (
                                                                   epoch + 1, train_strategy.num_epochs))):
                            inputs, targets = data['inputs'], data['targets']

                            # put data to appropriate device
                            for key in inputs.keys():
                                if isinstance(inputs[key], torch.Tensor):
                                    inputs[key] = inputs[key].to(train_strategy.device)

                            for key in targets.keys():
                                if isinstance(targets[key], torch.Tensor):
                                    targets[key] = targets[key].to(train_strategy.device)

                            outputs = self.net(inputs)

                            loss = 0
                            for name, weight, criterion in train_strategy.criterion:
                                # compute weighted local validation loss
                                current_loss = weight * criterion(outputs, targets)
                                logger.append("%s_%s" % (set_name, name), current_loss.item())

                                # accumulate total loss
                                loss += current_loss

                            # collect validation loss
                            if len(train_strategy.criterion) > 1:
                                logger.append("%s_loss_total" % set_name, loss.item())

                            # compute evaluation measures on entire validation set
                            if train_strategy.full_set_eval:
                                for key in outputs.keys():
                                    all_outputs[set_name][key].append(outputs[key].to("cpu"))
                                for key in targets.keys():
                                    all_targets[set_name][key].append(targets[key].to("cpu"))
                            # compute evaluation measures batch wise
                            else:
                                for key in sorted(list(train_strategy.eval_criteria.keys())):
                                    if train_strategy.eval_criteria[key].silent and i > 0:
                                        continue
                                    args = train_strategy.eval_criteria[key].args
                                    args["epoch"] = epoch
                                    args["set_name"] = set_name
                                    value = train_strategy.eval_criteria[key].criterion(outputs, targets, args)
                                    logger.append("%s_%s" % (set_name, key), value)

                # full set evaluation
                if train_strategy.full_set_eval:

                    # iterate validation sets
                    for set_name, _ in validation_sets:
                        outputs, targets = all_outputs[set_name], all_targets[set_name]

                        # collect batch data
                        for key in outputs.keys():
                            outputs[key] = torch.cat(outputs[key], 0)
                        for key in targets.keys():
                            targets[key] = torch.cat(targets[key], 0)

                        # compute eval criteria
                        for key in sorted(list(train_strategy.eval_criteria.keys())):

                            # get current evaluation criterion
                            criterion = train_strategy.eval_criteria[key]

                            # check if it should be computed for the current set
                            if criterion.sets is None or set_name in criterion.sets:
                                args = criterion.args
                                args["epoch"] = epoch
                                args["set_name"] = set_name
                                value = criterion.criterion(outputs, targets, args)
                                logger.append("%s_%s" % (set_name, key), value)

                # check for improvement
                if train_strategy.best_model_by:
                    eval_key = train_strategy.best_model_by[0]
                    eval_value = logger.epoch_stats[eval_key][-1]
                    mode = train_strategy.best_model_by[1]
                    if is_better(eval_value, logger.epoch_stats[eval_key], mode):
                        last_improvement = 0
                    else:
                        last_improvement += 1

                # best model selection
                if train_strategy.best_model_by and is_better(eval_value, logger.epoch_stats[eval_key], mode):
                    best_params = copy.deepcopy(self.net.state_dict())

                    if dump_file:
                        print(col.print_colored("Dumping new best model by %s!" % eval_key, col.WARNING))
                        self.save(dump_file)

            # add learning rate to logger
            logger.append("tr_learningrate", epoch_lr)



            # call eval hook with current model
            if eval_hook and np.mod(epoch, eval_hook.eval_every_k()) == 0:
                with torch.no_grad():
                    eval_dict = eval_hook(self.net)
                    for k, v in eval_dict.items():
                        logger.append("va_%s" % k, v)
                        print(col.print_colored("va_%s: %.3f" % (k, v), col.OKBLUE))
                    # summarize logged data
                    logger.summarize_epoch()
                    if train_strategy.best_model_by:
                        print(logger.epoch_stats)
                        eval_key = train_strategy.best_model_by[0]
                        eval_value = logger.epoch_stats[eval_key][-1]
                        mode = train_strategy.best_model_by[1]
                        if is_better(eval_value, logger.epoch_stats[eval_key], mode):
                            last_improvement = 0
                            self.best_model = self.net
                        else:
                            last_improvement += 1
            else:
                self.best_model = self.net
                # summarize logged data
                logger.summarize_epoch()

            # lr scheduler step
            if isinstance(train_strategy.lr_scheduler, ReduceLROnPlateau):
                train_strategy.lr_scheduler.step(eval_value)
                if epoch_lr != train_strategy.optimizer.param_groups[0]['lr']:
                    print(col.print_colored("Resetting model to previous best state and refining with reduced lr!",
                                            col.WARNING))
                    self.net.load_state_dict(best_params)

            # dump latest model
            elif dump_file:
                self.save(dump_file)
            else:
                pass

            # dump model checkpoints
            if train_strategy.checkpoint_every_k and np.mod(epoch, train_strategy.checkpoint_every_k) == 0:
                suffix = Path(dump_file).suffix
                checkpoint_file = dump_file.replace(suffix, "_cp%06d%s" % (epoch, suffix))
                self.save(checkpoint_file)

            # dump log
            if log_file is not None:
                logger.dump(log_file)

            # print epoch stats
            # -----------------
            txts = []

            # check if training validation set was provided
            losses = []
            for name, _, _ in train_strategy.criterion:
                losses.append("tr_%s_running" % name)
                if validation_loader is not None:
                    losses.append("va_%s" % name)

                if train_validation_loader:
                    losses.append("tr_%s" % name)

            if len(train_strategy.criterion) > 1:
                losses.append("tr_loss_total_running")
                if validation_loader is not None:
                    losses.append("va_loss_total")
                if train_validation_loader:
                    losses.append("tr_loss_total")

            losses = list(np.sort(losses))

            for key in losses:
                txt = "%s: %.7f" % (key, logger.epoch_stats[key][-1])
                if is_better(logger.epoch_stats[key][-1], logger.epoch_stats[key], mode="min"):
                    txt = col.print_colored(txt, col.OKGREEN)
                txts.append(txt)

            tr_txts = [t for t in txts if "tr_" in t]
            if validation_loader is not None:
                va_txts = [t for t in txts if "va_" in t]

            # print current learning rate
            tr_txts.append("lr: %.7f" % epoch_lr)

            # print patients stats
            tr_txts.append("patience: %d" % (train_strategy.patience - last_improvement))

            print(" | ".join(tr_txts))
            if validation_loader is not None:
                print(" | ".join(va_txts))

            # print epoch evaluation criteria
            tr_keys = [k for k in list(logger.epoch_stats.keys()) if "tr_" in k]
            if validation_loader is not None:
                va_keys = [k for k in list(logger.epoch_stats.keys()) if "va_" in k]
            else:
                va_keys = []

            # remove losses for detailed logging
            for loss_key in losses:
                if loss_key in tr_keys:
                    tr_keys.remove(loss_key)
                if validation_loader is not None and loss_key in va_keys:
                    va_keys.remove(loss_key)
            tr_keys.remove("tr_learningrate")

            for key_sets in [tr_keys, va_keys]:

                txts = []
                for key in key_sets:
                    criterion_name = key.replace("tr_", "").replace("va_", "")

                    if criterion_name not in train_strategy.eval_criteria or train_strategy.eval_criteria[
                        criterion_name].silent:
                        continue

                    txt = "%s: %s" % (key, train_strategy.eval_criteria[criterion_name].format)
                    txt = txt % logger.epoch_stats[key][-1]
                    mode = train_strategy.eval_criteria[criterion_name].mode
                    if is_better(logger.epoch_stats[key][-1], logger.epoch_stats[key], mode):
                        txt = col.print_colored(txt, col.OKGREEN)
                    txts.append(txt)

                if len(txts) > 0:
                    print(" | ".join(txts))

            # early stopping
            if train_strategy.patience and last_improvement > train_strategy.patience:
                print(col.print_colored('Patience expired!', col.WARNING))
                break

            if epoch_lr <= 0.0:
                print(col.print_colored('Learning rate is zero!', col.WARNING))
                break

            # insert empty line for readability
            print("")

        # load best model
        if dump_file:
            self.load(dump_file)

        print(col.print_colored('Finished Training!', col.OKBLUE))

    def find_lr(self, train_loader, train_strategy, start_lr=1e-7, end_lr=10, num_it=100, plot=True):

        # put stuff to device
        train_strategy = self._to_device(train_strategy)

        # lr scheduler step
        if train_strategy.lr_scheduler and not isinstance(train_strategy.lr_scheduler, ReduceLROnPlateau):
            train_strategy.lr_scheduler.step()

        # training
        # --------
        self.net.train()

        # prepare schedule
        def update_lr(iter_count):
            # return start_lr + iter_count * ((end_lr - start_lr) / num_it)
            f = np.power(end_lr / start_lr, 1. / num_it)
            return start_lr * f ** iter_count

        lr_rates, lr_losses = [], []
        done, iteration = False, 0
        while True:

            for i, data in enumerate(tqdm.tqdm(train_loader)):
                inputs, targets = data['inputs'], data['targets']

                # set iter learning rate
                lr = update_lr(iteration)
                for g in train_strategy.optimizer.param_groups:
                    g['lr'] = lr

                # put data to appropriate device
                for key in inputs.keys():
                    inputs[key] = inputs[key].to(train_strategy.device)

                for key in targets.keys():
                    targets[key] = targets[key].to(train_strategy.device)

                # zero the parameter gradients
                train_strategy.optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self.net(inputs)

                loss = 0
                for name, weight, criterion in train_strategy.criterion:
                    # compute weighted local training loss
                    current_loss = weight * criterion(outputs, targets)

                    # accumulate total loss
                    loss += current_loss

                # perform backward pass
                loss.backward()

                # apply gradient clipping
                if train_strategy.clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.net.parameters(), train_strategy.clip_grad_norm)

                # perform gradient step
                train_strategy.optimizer.step()

                # collect training loss
                lr_rates.append(np.log10(lr))
                lr_losses.append(loss.item())
                iteration += 1

                if iteration >= num_it:
                    done = True
                    break
            if done:
                break

        # show result
        if plot:
            plt.figure("LR Plot")
            plt.clf()
            plt.plot(lr_rates, lr_losses, "bo-", alpha=0.5)
            plt.plot(lr_rates, medfilt(lr_losses, 11), "m--")
            plt.grid(True)
            plt.ylabel("loss")
            plt.xlabel("$log_{10}$(lr)")
            plt.title("Learning Rate Finder")
            plt.show()

        return lr_rates, lr_losses

    def predict_proba(self, x, device="cuda"):

        # set model to evaluation mode
        self.net.eval()

        # no gradients required
        with torch.no_grad():

            # put data to appropriate device
            for key in x.keys():
                x[key] = x[key].to(device)

            # compute model output
            out = self.net(x)

            # convert output to numpy
            for key in out.keys():
                out[key] = out[key].cpu().numpy()

            return out

    def predict(self, x):
        self.predict_proba(x).argmax(axis=1)

    def save(self, file_path):
        torch.save(self.net.state_dict(), file_path)

    def load(self, file_path):
        self.net.load_state_dict(torch.load(file_path))

    def _to_device(self, train_strategy):
        print("Putting stuff to %s ..." % train_strategy.device)
        self.net.to(train_strategy.device)
        for _, _, criterion in train_strategy.criterion:
            criterion.to(train_strategy.device)
        return train_strategy
