import os
import time
from datetime import datetime

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm.autonotebook import tqdm

import utils


def train_minimum(model, dataloader, lr=1e-4, steps=1000):
    steps_til_summary = 200
    optim = torch.optim.AdamW(lr=lr, params=model.parameters())

    model_input, ground_truth = next(iter(dataloader))
    model_input = {key: value.cuda() for key, value in model_input.items()}
    ground_truth = {key: value.cuda() for key, value in ground_truth.items()}

    start_time = time.time()

    for step in tqdm(range(steps + 1)):
        model_output = model(model_input)
        loss = ((model_output['model_out'] - ground_truth['img']) ** 2).mean()
        if not step % steps_til_summary:
            print("Step %d, Total loss %0.6f" % (step, loss))

        optim.zero_grad()
        loss.backward()
        optim.step()

    print(f"Total time: {time.time() - start_time}(s)")


def train(model, train_dataloader, epochs, lr, steps_til_summary, epochs_til_checkpoint, model_dir, loss_fn,
          summary_fn, val_dataloader=None, double_precision=False, clip_grad=False, loss_schedules=None,
          global_weights_lr_factor=None, summary_values_compare_fn=None):

    if global_weights_lr_factor is not None:
        print("Splitting lr")
        optim = torch.optim.Adam([{'params': [p[1] for p in model.named_parameters() if 'global' not in p[0]], 'lr': lr},
                                  {'params': [p[1] for p in model.named_parameters() if 'global' in p[0]],
                                   'lr': lr / global_weights_lr_factor}], lr=lr)
    else:
        optim = torch.optim.Adam(lr=lr, params=model.parameters())

    summaries_dir = os.path.join(model_dir, 'summaries')
    utils.cond_mkdir(summaries_dir)

    checkpoints_dir = os.path.join(model_dir, 'checkpoints')
    utils.cond_mkdir(checkpoints_dir)

    writer = SummaryWriter(summaries_dir)

    total_steps = 0
    all_summary_values = []

    total_start_time = time.time()
    with tqdm(total=len(train_dataloader) * epochs) as pbar:
        train_losses = []
        for epoch in range(epochs):
            if not epoch % epochs_til_checkpoint and epoch:
                torch.save(model.state_dict(),
                           os.path.join(checkpoints_dir, 'model_epoch_%04d.pth' % epoch))
                np.savetxt(os.path.join(checkpoints_dir, 'train_losses_epoch_%04d.txt' % epoch),
                           np.array(train_losses))

            for step, (model_input, gt) in enumerate(train_dataloader):
                start_time = time.time()

                model_input = {key: value.cuda() for key, value in model_input.items()}
                gt = {key: value.cuda() for key, value in gt.items()}

                if double_precision:
                    model_input = {key: value.double() for key, value in model_input.items()}
                    gt = {key: value.double() for key, value in gt.items()}

                model_output = model(model_input)
                losses = loss_fn(model_output, gt)

                train_loss = 0.
                for loss_name, loss in losses.items():
                    single_loss = loss.mean()

                    if loss_schedules is not None and loss_name in loss_schedules:
                        writer.add_scalar(loss_name + "_weight", loss_schedules[loss_name](total_steps), total_steps)
                        single_loss *= loss_schedules[loss_name](total_steps)

                    writer.add_scalar(loss_name, single_loss, total_steps)
                    train_loss += single_loss

                train_losses.append(train_loss.item())
                writer.add_scalar("total_train_loss", train_loss, total_steps)

                if not total_steps % steps_til_summary:
                    torch.save(model.state_dict(),
                               os.path.join(checkpoints_dir, 'model_current.pth'))
                    summary_values = summary_fn(model, model_input, gt, model_output, writer, total_steps)
                    if summary_values is not None:
                        all_summary_values.append(summary_values)
                        if summary_values_compare_fn is not None:
                            if all(summary_values_compare_fn(summary_values, old_summary_values) for old_summary_values in
                                      all_summary_values[:-1]):
                                  torch.save(model.state_dict(), os.path.join(checkpoints_dir, 'model_best.pth'))

                optim.zero_grad()
                train_loss.backward()

                if clip_grad:
                    if isinstance(clip_grad, bool):
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)

                optim.step()

                pbar.update(1)

                if not total_steps % steps_til_summary:
                    tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time))

                    if val_dataloader is not None:
                        print("Running validation set...")
                        model.eval()
                        with torch.no_grad():
                            val_losses = []
                            for (model_input, gt) in val_dataloader:
                                model_output = model(model_input)
                                val_loss = loss_fn(model_output, gt)
                                val_losses.append(val_loss)

                            writer.add_scalar("val_loss", np.mean(val_losses), total_steps)
                        model.train()

                total_steps += 1

        total_time = time.time() - total_start_time
        print(f"Total time: {total_time:.2f}")
        torch.save(model.state_dict(),
                   os.path.join(checkpoints_dir, 'model_final.pth'))
        np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final.txt'),
                   np.array(train_losses))

        return total_time, all_summary_values


class LinearDecaySchedule():
    def __init__(self, start_val, final_val, num_steps):
        self.start_val = start_val
        self.final_val = final_val
        self.num_steps = num_steps

    def __call__(self, iter):
        return self.start_val + (self.final_val - self.start_val) * min(iter / self.num_steps, 1.)
