import os
import time
import logging
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict

import numpy as np

import torch
from torch.utils.tensorboard import SummaryWriter

from .prepare import prepare, prepare_classifier

from evaluate.evaluate import evaluate, evaluate_loss
from utils.general_utils import initialize_logger, ljson, log_commit_hash_and_diff
from utils.data_utils import move_tensors

def train(args, prepare=prepare, evaluate=evaluate):
    """
    Trains a VCI model
    """
    state_dict = None
    if args["checkpoint"] is not None:
        # preserved keys are keys that are preserved in the current run in spite of a checkpoint load
        # for example, if "artifact_path" is in preserved_keys, it will be loaded from the current run environ instead of loaded from the checkpoint args
        preserved_keys = {"data_path", "artifact_path", "omega0", "omega1", "omega2", "lr", "generator_lr", "discriminator_lr", "classifier_lr"}
        preserved_args = {key: args[key] for key in preserved_keys if key in args.keys()}
        state_dict, args = torch.load(args["checkpoint"], map_location="cpu")
        for key in preserved_args.keys():
            args[key] = preserved_args[key]
    device = (
        "cuda:" + str(args["gpu"])
            if (not args["cpu"]) 
                and torch.cuda.is_available() 
            else 
        "cpu"
    )

    if args["seed"] is not None:
        np.random.seed(args["seed"])
        torch.manual_seed(args["seed"])

    ema_model, model, datasets = prepare(args, state_dict=state_dict, device=device)

    dt = datetime.now().strftime("%Y.%m.%d_%H:%M:%S")
    writer = SummaryWriter(log_dir=os.path.join(args["artifact_path"], "runs/" + args["name"] + "_" + dt))
    save_dir = os.path.join(args["artifact_path"], "saves/" + args["name"] + "_" + dt)
    os.makedirs(save_dir, exist_ok=True)

    initialize_logger(save_dir)
    ljson({"training_args": args})
    ljson({"model_params": model.hparams})
    logging.info("")
    
    log_commit_hash_and_diff(save_dir)

    start_time = time.time()

    for epoch in range(args["max_epochs"]):
        epoch_training_stats = defaultdict(float)

        for batch_idx, batch in enumerate(datasets["train_loader"]):
            global_step = len(datasets["train_loader"]) * epoch + batch_idx
            minibatch_training_stats = model.update(
                move_tensors(*batch, device=device), batch_idx, ema_model
            )

            g_did_skip_step = minibatch_training_stats.pop("g_did_skip_step")
            g_grad_norm = minibatch_training_stats["g_grad_norm"]
            writer.add_scalar('G Grad Norm', g_grad_norm, global_step=global_step)

            if g_did_skip_step:
                epoch_training_stats["g_skipped_steps"] += 1
                writer.add_scalar('G Skipped_Gradient_Metric', 1, global_step=global_step)
            else: 
                writer.add_scalar('G Skipped_Gradient_Metric', 0, global_step=global_step)


            if "d_did_skip_step" in minibatch_training_stats.keys(): 
                d_did_skip_step = minibatch_training_stats.pop("d_did_skip_step")
                d_grad_norm = minibatch_training_stats["d_grad_norm"]
                writer.add_scalar('D Grad Norm', d_grad_norm, global_step=global_step)
                if d_did_skip_step:
                    epoch_training_stats["d_skipped_steps"] += 1
                    writer.add_scalar('D Skipped_Gradient_Metric', 1, global_step=global_step)
                else: 
                    writer.add_scalar('D Skipped_Gradient_Metric', 0, global_step=global_step)

            for key, val in minibatch_training_stats.items():
                epoch_training_stats[key] += val

            writer.add_scalar("d_skipped_steps", epoch_training_stats["d_skipped_steps"], global_step=global_step)
            print(f"Done with batch {batch_idx}")
        model.step()

        ellapsed_minutes = (time.time() - start_time) / 60

        for key, val in epoch_training_stats.items():
            if key != "d_grad_norm": 
                epoch_training_stats[key] = val / len(datasets["train_loader"])
            else: 
                epoch_training_stats[key] = val / (len(datasets["train_loader"]) / model.hparams["discriminator_freq"])

        # decay learning rate if necessary
        # also check stopping condition: 
        # patience ran out OR max epochs reached
        stop = (epoch == args["max_epochs"] - 1)
        for key, val in epoch_training_stats.items():
            writer.add_scalar(key, val, epoch)

        if (epoch % args["checkpoint_freq"]) == 0 or stop:
            evaluation_stats, early_stop = evaluate(ema_model, datasets,
                epoch=epoch, save_dir=save_dir, **args
            )

            ljson(
                {
                    "epoch": epoch,
                    "training_stats": epoch_training_stats,
                    "evaluation_stats": evaluation_stats,
                    "ellapsed_minutes": ellapsed_minutes,
                }
            )

            torch.save(
                (ema_model.state_dict(), args),
                os.path.join(
                    save_dir,
                    "model_seed={}_epoch={}.pt".format(args["seed"], epoch),
                ),
            )
            ljson(
                {
                    "model_saved": "model_seed={}_epoch={}.pt\n".format(
                        args["seed"], epoch
                    )
                }
            )

            if stop:
                ljson({"stop": epoch})
                break
            if early_stop:
                ljson({"early_stop": epoch})
                break

    writer.close()
    return ema_model

def train_classifier(args, prepare=prepare_classifier, evaluate=evaluate_loss):
    return train(args, prepare=prepare, evaluate=evaluate)