import time
import logging
import copy
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from utils.train_loop import inner_training_loop


# Copied from https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
def train_model(model, loss_type, device, dataloaders, optimizer, lr_scheduler, metrics_trackers,
                logger=None, ex=None, learning_rate=0.01, print_iter=np.inf, num_epochs=25, config_str='',
                virtual_batch_size=None, test=False):

    if ex is not None and ex.current_run:
        writer = SummaryWriter(log_dir='logs', flush_secs=5)
        writer.add_text('config', config_str)
        tb_dump_embeddings = False
        tb_dump_weights = True
        stats = ex.current_run.info
    else:
        writer = None
        stats = {}

    if test:
        phases = ['train', 'val', 'test']
    else:
        phases = ['train', 'val']

    for phase in phases:
        stats[phase] = {}
    stats['best_epoch'] = -1

    logger = logger or logging.getLogger(__name__)

    if virtual_batch_size is None:
        applygrad_iter = 1
    else:
        assert virtual_batch_size % dataloaders['train'].batch_size == 0
        applygrad_iter = virtual_batch_size // dataloaders['train'].batch_size

    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):

        epoch_title = f"Epoch {epoch:2d}/{num_epochs - 1:2d}"
        if print_iter < np.inf:
            logger.info('')
            logger.info(epoch_title)
            logger.info('-' * len(epoch_title))

        # Each epoch has a training and validation phase
        for phase in phases:
            if not phase in dataloaders:
                logger.warn(f"Skipping {phase} due to missing dataloader")
                continue
            epoch_start = time.time()

            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            if print_iter != np.inf: iter_metrics = metrics_trackers['iter'][phase]
            epoch_metrics = metrics_trackers['epoch'][phase]

            if print_iter != np.inf: iter_metrics.reset()
            epoch_metrics.reset()

            # Iterate over data.
            if print_iter != np.inf: iter_time_start = time.time()
            curr_iter = 0
            next_print = print_iter

            if isinstance(dataloaders[phase], list):

                hits_1_list = []
                loss_list = []

                for dataloader in dataloaders[phase]:

                    epoch_metrics.reset()
                    for ibatch, (graphs1, graphs2, data_labels) in enumerate(dataloader):

                        apply_grad = ((ibatch + 1) % applygrad_iter == 0)

                        outputs, labels, loss, batch_size = inner_training_loop(
                            model, optimizer, dataloader, phase, loss_type, device,
                            graphs1, graphs2, data_labels, applygrad_iter, apply_grad)

                        # statistics
                        with torch.set_grad_enabled(False):
                            epoch_metrics.update(outputs, labels, loss)

                    loss, hits_1 = epoch_metrics.readout(['loss', 'hits@1'])
                    hits_1_list.append(hits_1)
                    loss_list.append(loss)
                    if print_iter != np.inf:
                        cat = dataloader.dataset.category
                        logger.info(f"{cat:15}{hits_1:.4f}")

                final_hits = np.asarray(hits_1_list).mean()
                final_loss = np.asarray(loss_list).mean()
                epoch_metrics.set_values(final_loss, final_hits)

                if print_iter == np.inf:
                    temp = f"{epoch}/{num_epochs - 1},"
                    epoch_str = f"Epoch {temp:10}{phase:8} "
                    for hits_1 in hits_1_list:
                        epoch_str += f"{hits_1:.4f} "
                    logger.info(epoch_str)

            else:
                for ibatch, (graphs1, graphs2, data_labels) in enumerate(dataloaders[phase]):

                    apply_grad = ((ibatch + 1) % applygrad_iter == 0)

                    outputs, labels, loss, batch_size = inner_training_loop(
                        model, optimizer, dataloaders[phase], phase, loss_type, device,
                        graphs1, graphs2, data_labels, applygrad_iter, apply_grad)

                    # statistics
                    with torch.set_grad_enabled(False):
                        if print_iter != np.inf:
                            iter_metrics.update(outputs, labels, loss)
                        epoch_metrics.update(outputs, labels, loss)

                    curr_iter += batch_size
                    if curr_iter >= next_print:
                        next_print += print_iter
                        iter_str  = f"{phase:5} {curr_iter:8d} "
                        iter_str += iter_metrics.get_string()
                        iter_str += f" ({time.time() - iter_time_start:.2f}s)"
                        logger.info(iter_str)
                        iter_metrics.reset()
                        iter_time_start = time.time()

            stat = epoch_metrics.get_log_dict()
            stats[phase][str(epoch)] = stat

            #############
            # Tensorboard
            if writer:
                for name, value in stat.items():
                    writer.add_scalar(str(phase) + '/' + name.replace('@', '_'), value, epoch)

                if phase == 'val':
                    writer.add_scalar('zisc/learning_rate', optimizer.param_groups[0]['lr'], epoch)
                    writer.add_scalar('zisc/reg_sinkhorn', model.sinkhorn_reg, epoch)

                    if loss_type == 'margin_pairwise':
                        y = epoch_metrics.epoch_labels
                        x = 1. - epoch_metrics.epoch_outputs / epoch_metrics.epoch_outputs.max()
                        writer.add_pr_curve('validation_pr', y, x, epoch)

                    if epoch % 5 == 0:

                        if tb_dump_weights:
                            for name, param in model.named_parameters():
                                writer.add_histogram(name, param, epoch)

                        if tb_dump_embeddings:
                            _, max_nodes, _ = model.node_embeddings[0].shape
                            meta1 = meta2 = np.zeros(max_nodes)
                            metadata = np.concatenate([meta1, meta2 + 1])

                            # Could maybe even run a fixed set of test graphs
                            for idx in range(5):
                                g1 = model.node_embeddings[0][idx,:,:]
                                g2 = model.node_embeddings[1][idx,:,:]
                                mat = torch.cat([g1, g2])
                                writer.add_embedding(mat, metadata=metadata, tag=f'ep_{epoch:03d}_l_{labels[idx].item()}_o_{outputs[idx].item():.2f}_{idx}')
            # END Tensorboard
            #################

            epoch_str = "Epoch "
            if print_iter == np.inf:
                temp = f"{epoch}/{num_epochs - 1},"
                epoch_str += f"{temp:10}"
            epoch_str += f"{phase:8} {epoch_metrics.get_string()} ({time.time() - epoch_start:.2f}s)"
            logger.info(epoch_str)

            if phase == 'train':
                lr_scheduler.step()
                model.step_sinkhorn_reg()

            if phase == 'val' and epoch_metrics.is_best_val():
                stats['best_epoch'] = epoch
                best_model_wts = copy.deepcopy(model.state_dict())

        if epoch_metrics.is_patience_over():
            break

    time_elapsed = time.time() - start
    logger.info('')
    logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    logger.info(f"Best validation result: {metrics_trackers['epoch']['val'].get_best_values_string()} (epoch {stats['best_epoch']})")

    if ex is not None and ex.current_run:
        temp_path = '/tmp/graph_distance_model_weights_' + str(ex.current_run.config['seed'])
        torch.save(best_model_wts, temp_path)
        ex.current_run.add_artifact(temp_path, 'model_wts')

    if 'val' not in dataloaders:
        best_model_wts = copy.deepcopy(model.state_dict())

    return {'rmse': {phase: stats[phase][str(stats['best_epoch'])]['rmse'] for phase in phases},
            'time': time_elapsed,
            'best_epoch': stats['best_epoch'],
            'last_epoch': epoch,
            'best_model_wts': best_model_wts}
