import torch
import torch.nn as nn
import torch.nn.functional as F
import datetime as dt
import time
from tensorboardX import SummaryWriter

writer = None

from model import ModelCifar
from eval import evaluate
from loss import CustomLoss, KLLoss, KLLossRev, SupKLLoss, RewardLoss
from utils import *
from hyper_params import load_hyper_params
import argparse
import yaml
import numpy as np
import optuna
from functools import partial
import os
from copy import deepcopy
from optuna.trial import TrialState
from torch.utils.data import DataLoader, TensorDataset
from utils import create_tensors
from tqdm import tqdm
from data import load_data_fast


def add_subpath(path, id):
    return path + "/" + str(id)


def dirname(path):
    return "/".join(path.split("/")[:-1])


STOP_THRESHOLD = 10


def train(trial, model, criterion, optimizer, scheduler, reader, hyper_params, device):
    ignore_unlabeled = hyper_params["ignore_unlabeled"]
    # add_unlabeled = hyper_params["add_unlabeled"]
    # print("--------> Added unlabeled data =", add_unlabeled)
    model.train()

    metrics = {}
    total_batches = 0.0
    total_loss = FloatTensor([0.0])
    correct, total = LongTensor([0]), 0.0
    control_variate = FloatTensor([0.0])
    avg_correct = FloatTensor([0.0])
    ips = FloatTensor([0.0])
    main_loss = FloatTensor([0.0])
    N = len(reader.dataset)
    N = len(reader.dataset)

    for x, y, action, delta, prop, labeled in tqdm(reader):
        # Empty the gradients
        model.zero_grad()
        optimizer.zero_grad()

        x, y, action, delta, prop = (
            x.to(device),
            y.to(device),
            action.to(device),
            delta.to(device),
            prop.to(device),
        )
        # print(x.shape)
        # Forward pass
        output = model(x)
        output = F.softmax(output, dim=1)
        # print(output[:10].detach().cpu().numpy())
        # print(delta[:10].cpu().numpy())
        # print(action[:10].cpu().numpy())
        output_labeled = output[labeled == 1]
        y_labeled = y[labeled == 1]
        delta_labeled = delta[labeled == 1]
        prop_labeled = prop[labeled == 1]
        action_labeled = action[labeled == 1]
        su = (labeled == 1).sum()
        if su > 0:
            if hyper_params.experiment.feedback == "supervised":
                loss = criterion(output_labeled, y_labeled)
            elif hyper_params.experiment.feedback == "bandit":
                if hyper_params.as_reward:
                    loss = criterion(output_labeled, action_labeled, delta_labeled)
                else:
                    loss = criterion(
                        output_labeled, action_labeled, delta_labeled, prop_labeled
                    )
            elif hyper_params.experiment.feedback is None:
                loss = torch.tensor(0).float().to(x.device)
            else:
                raise ValueError(
                    f"Feedback type {hyper_params.experiment.feedback} is not valid."
                )
        else:
            loss = torch.tensor(0).float().to(x.device)
        # print(delta.mean().item(), y.cpu().numpy(), action.cpu().numpy(), prop.mean().item())
        # print("IPS Loss value =", loss.item())
        main_loss += loss.item()
        reg_output = output[labeled > 0] if ignore_unlabeled else output
        reg_action = action[labeled > 0] if ignore_unlabeled else action
        reg_prop = prop[labeled > 0] if ignore_unlabeled else prop
        # print("len data = ", len(reg_output))
        if hyper_params.experiment.regularizers:
            if len(reg_output) > 0:
                if "KL" in hyper_params.experiment.regularizers:
                    loss += (
                        KLLoss(
                            reg_output,
                            reg_action,
                            reg_prop,
                            action_size=hyper_params["dataset"]["num_classes"],
                        )
                        * hyper_params.experiment.regularizers.KL
                    )
                if "KL2" in hyper_params.experiment.regularizers:
                    loss += (
                        KLLossRev(
                            reg_output,
                            reg_action,
                            reg_prop,
                            action_size=hyper_params["dataset"]["num_classes"],
                        )
                        * hyper_params.experiment.regularizers.KL2
                    )
            if "SupKL" in hyper_params.experiment.regularizers:
                if su > 0:
                    loss += (
                        SupKLLoss(
                            output_labeled,
                            action_labeled,
                            delta_labeled,
                            prop_labeled,
                            hyper_params.experiment.regularizers.eps,
                            action_size=hyper_params["dataset"]["num_classes"],
                        )
                        * hyper_params.experiment.regularizers.SupKL
                    )
        # print("IPS+REG Loss value =", loss.item(), "\n\n")
        # print(loss.requires_grad)
        if loss.requires_grad:
            loss.backward()
            optimizer.step()
            if "lr_sch" in hyper_params and hyper_params["lr_sch"] == "OneCycle":
                scheduler.step()
        # else:
        #     print("No grad to optimize!!!")

        # Log to tensorboard
        writer.add_scalar("train loss", loss.item(), total_batches)

        # Metrics evaluation
        total_loss += loss.item()
        control_variate += torch.mean(
            output[range(action.size(0)), action] / prop
        ).item()
        ips += torch.mean((delta * output[range(action.size(0)), action]) / prop).item()
        predicted = torch.argmax(output, dim=1)
        # print(predicted, y)
        total += y.size(0)
        correct += (predicted == y).sum().item()
        avg_correct += output[range(action.size(0)), y].sum().item()
        total_batches += 1.0
    if "lr_sch" not in hyper_params or hyper_params["lr_sch"] != "OneCycle":
        scheduler.step()

    metrics["main_loss"] = round(float(main_loss) / total_batches, 4)
    metrics["loss"] = round(float(total_loss) / total_batches, 4)
    metrics["Acc"] = round(100.0 * float(correct) / float(total), 4)
    metrics["AvgAcc"] = round(100.0 * float(avg_correct) / float(total), 4)
    metrics["CV"] = round(float(control_variate) / total_batches, 4)
    metrics["SNIPS"] = round(float(ips) / float(control_variate), 4)

    return metrics


def main(trial, hyper_params, device="cuda:0", return_model=False):
    STOP_THRESHOLD = 10
    # # If custom hyper_params are not passed, load from hyper_params.py
    # if hyper_params is None: from hyper_params import hyper_params
    hyper_params = deepcopy(hyper_params)
    hyper_params["tensorboard_path"] = add_subpath(
        hyper_params["tensorboard_path"], trial._trial_id
    )
    hyper_params["output_path"] = add_subpath(
        hyper_params["output_path"], trial._trial_id
    )
    hyper_params["log_file"] = add_subpath(hyper_params["log_file"], trial._trial_id)
    hyper_params["summary_file"] = add_subpath(
        hyper_params["summary_file"], trial._trial_id
    )
    print(dirname(hyper_params["summary_file"]), hyper_params["summary_file"])
    os.makedirs(dirname(hyper_params["tensorboard_path"]), exist_ok=True)
    os.makedirs(dirname(hyper_params["log_file"]), exist_ok=True)
    os.makedirs(dirname(hyper_params["summary_file"]), exist_ok=True)
    if hyper_params["save_model"]:
        os.makedirs(dirname(hyper_params["output_path"]), exist_ok=True)
    if hyper_params.experiment.regularizers:
        if "KL" in hyper_params.experiment.regularizers:
            print(
                f"--> Regularizer KL added: {hyper_params.experiment.regularizers.KL}"
            )
            hyper_params.experiment.regularizers.KL = trial.suggest_float(
                "KL_coef",
                hyper_params.experiment.regularizers.KL[0],
                hyper_params.experiment.regularizers.KL[1],
                log=True,
            )
        if "KL2" in hyper_params.experiment.regularizers:
            print(
                f"--> Regularizer Reverse KL added: {hyper_params.experiment.regularizers.KL2}"
            )
            hyper_params.experiment.regularizers.KL2 = trial.suggest_float(
                "KL2_coef",
                hyper_params.experiment.regularizers.KL2[0],
                hyper_params.experiment.regularizers.KL2[1],
                log=True,
            )
        if "SupKL" in hyper_params.experiment.regularizers:
            print(
                f"--> Regularizer Supervised KL added: {hyper_params.experiment.regularizers.SupKL}"
            )
            hyper_params.experiment.regularizers.SupKL = trial.suggest_float(
                "SupKL_coef",
                hyper_params.experiment.regularizers.SupKL[0],
                hyper_params.experiment.regularizers.SupKL[1],
                log=True,
            )
    hyper_params["weight_decay"] = trial.suggest_float(
        "weight_decay",
        hyper_params["weight_decay"][0],
        hyper_params["weight_decay"][1],
        log=True,
    )
    print(hyper_params)

    # Initialize a tensorboard writer
    global writer
    path = hyper_params["tensorboard_path"]
    writer = SummaryWriter(path)

    # Train It..
    train_reader, test_reader, val_reader = load_data_fast(
        hyper_params, labeled=False, device=device
    )

    file_write(
        hyper_params["log_file"],
        "\n\nSimulation run on: " + str(dt.datetime.now()) + "\n\n",
    )
    file_write(hyper_params["log_file"], "Data reading complete!")
    file_write(
        hyper_params["log_file"],
        "Number of train batches: {:4d}".format(len(train_reader)),
    )
    file_write(
        hyper_params["log_file"],
        "Number of test batches: {:4d}".format(len(test_reader)),
    )

    if hyper_params.experiment.feedback == "supervised":
        print("Supervised Training.")
        criterion = nn.CrossEntropyLoss()
    elif hyper_params.experiment.feedback == "bandit":
        if hyper_params.as_reward:
            print("Reward Training")
            criterion = RewardLoss(hyper_params)
        else:
            print("Bandit Training")
            criterion = CustomLoss(hyper_params)
    elif hyper_params.experiment.feedback is None:
        criterion = None
    else:
        raise ValueError(
            f"Feedback type {hyper_params.experiment.feedback} is not valid."
        )
    try:
        best_metrics_total = []
        for exp in range(hyper_params.experiment.n_exp):
            if hyper_params["linear"]:
                model = nn.Linear(
                    hyper_params["feature_size"], hyper_params["dataset"]["num_classes"]
                ).to(device)
            else:
                model = ModelCifar(hyper_params).to(device)
            optimizer = torch.optim.SGD(
                model.parameters(),
                lr=hyper_params["lr"],
                momentum=0.9,
                weight_decay=hyper_params["weight_decay"],
            )
            if "lr_sch" in hyper_params:
                if hyper_params["lr_sch"] == "OneCycle":
                    scheduler = torch.optim.lr_scheduler.OneCycleLR(
                        optimizer,
                        max_lr=hyper_params["lr"],
                        epochs=hyper_params["epochs"],
                        steps_per_epoch=len(train_reader),
                    )
                elif hyper_params["lr_sch"] == "CosineAnnealingLR":
                    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                        optimizer, T_max=hyper_params["epochs"], verbose=True
                    )
            else:
                scheduler = torch.optim.lr_scheduler.StepLR(
                    optimizer, step_size=5, gamma=0.9, verbose=True
                )
            file_write(
                hyper_params["log_file"], "\nModel Built!\nStarting Training...\n"
            )
            file_write(
                hyper_params["log_file"],
                f"################################ MODEL ITERATION {exp + 1}:\n--------------------------------",
            )
            best_acc = 0
            best_metrics = None
            not_improved = 0
            for epoch in range(1, hyper_params["epochs"] + 1):
                epoch_start_time = time.time()

                # Training for one epoch
                metrics = train(
                    trial,
                    model,
                    criterion,
                    optimizer,
                    scheduler,
                    train_reader,
                    hyper_params,
                    device,
                )

                string = ""
                for m in metrics:
                    string += " | " + m + " = " + str(metrics[m])
                string += " (TRAIN)"

                for metric in metrics:
                    writer.add_scalar(
                        f"Train_metrics/exp_{exp}/" + metric, metrics[metric], epoch - 1
                    )

                # Calulating the metrics on the validation set
                metrics = evaluate(
                    model,
                    criterion,
                    val_reader,
                    hyper_params,
                    device,
                    labeled=False,
                )
                string2 = ""
                for m in metrics:
                    string2 += " | " + m + " = " + str(metrics[m])
                string2 += " (VAL)"

                for metric in metrics:
                    writer.add_scalar(
                        f"Validation_metrics/exp_{exp}/" + metric,
                        metrics[metric],
                        epoch - 1,
                    )

                ss = "-" * 89
                ss += "\n| end of epoch {:3d} | time: {:5.2f}s".format(
                    epoch, (time.time() - epoch_start_time)
                )
                ss += string
                ss += "\n"
                ss += "-" * 89
                ss += "\n| end of epoch {:3d} | time: {:5.2f}s".format(
                    epoch, (time.time() - epoch_start_time)
                )
                ss += string2
                ss += "\n"
                ss += "-" * 89
                val_metrics = metrics
                if metrics["Acc"] > best_acc:
                    not_improved = 0
                    best_acc = metrics["Acc"]

                    metrics = evaluate(
                        model,
                        criterion,
                        test_reader,
                        hyper_params,
                        device,
                        labeled=True,
                    )
                    if hyper_params["save_model"]:
                        torch.save(model.state_dict(), hyper_params["output_path"])
                    string3 = ""
                    for m in metrics:
                        string3 += " | " + m + " = " + str(metrics[m])
                    string3 += " (TEST)"

                    ss += "\n| end of epoch {:3d} | time: {:5.2f}s".format(
                        epoch, (time.time() - epoch_start_time)
                    )
                    ss += string3
                    ss += "\n"
                    ss += "-" * 89

                    for metric in metrics:
                        writer.add_scalar(
                            f"Test_metrics/exp_{exp}/" + metric,
                            metrics[metric],
                            epoch - 1,
                        )
                    best_metrics = metrics
                else:
                    not_improved += 1
                file_write(hyper_params["log_file"], ss)

                # trial.report(val_metrics["Acc"], epoch)

                # Handle pruning based on the intermediate value.
                # if trial.should_prune():
                #     best_metrics_total.append(best_metrics)
                #     raise optuna.exceptions.TrialPruned()

                if not_improved >= STOP_THRESHOLD:
                    print("STOP THRESHOLD PASSED.")
                    break
            best_metrics_total.append(best_metrics)

    except KeyboardInterrupt:
        print("Exiting from training early")

    writer.close()

    model_summary = {k: [] for k in best_metrics_total[0].keys()}
    for metric in best_metrics_total:
        for k, v in metric.items():
            model_summary[k].append(v)
    for k, v in model_summary.items():
        model_summary[k] = {"mean": float(np.mean(v)), "std": float(np.std(v))}

    file_write(hyper_params["summary_file"], yaml.dump(model_summary))
    # trial.report(model_summary["Acc"]["mean"], )

    return model_summary["Acc"]["mean"]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--config", required=True, help="Path to experiment config file."
    )
    parser.add_argument("-d", "--device", required=True, help="Device", type=str)
    parser.add_argument(
        "--linear",
        required=False,
        action="store_true",
        help="If used, the learned policy is a linear model",
    )
    parser.add_argument(
        "-s",
        "--save_model",
        required=False,
        action="store_true",
        help="If used, the trained model is saved.",
    )
    # parser.add_argument("-r", "--as_reward", required=False, action="store_true", help="If used, the trained model is saved.")
    parser.add_argument(
        "-l",
        "--ignore_unlabeled",
        required=False,
        action="store_true",
        help="If used, missing-reward instances are completely ignored.",
    )
    parser.add_argument(
        "--tau",
        required=False,
        type=str,
        help="Softmax temperature for the logging policy.",
    )
    parser.add_argument(
        "--add_unlabeled",
        required=False,
        type=int,
        default=0,
        help="The number of missing-reward instances to be added to the known-reward instances. Only use when ignore_unlabeled is true.",
    )
    parser.add_argument(
        "--ul",
        required=False,
        type=str,
        help="The ratio of missing-reward to known-reward samples.",
    )
    parser.add_argument(
        "--raw_image",
        action="store_true",
        help="If used, raw flatten image is given to the model instead of pretrained features.",
    )
    parser.add_argument(
        "--feature_size",
        type=int,
        help="If used, given feature size is supposed for the context.",
    )
    parser.add_argument(
        "--deeplog",
        action="store_true",
        help="If used, dataset generated by deep logging policy is used.",
    )
    parser.add_argument(
        "--wd",
        type=float,
        help="If used, weight decay is manually overwritten.",
    )
    parser.add_argument(
        "--kl",
        type=float,
        help="If used, KL coefficient is manually overwritten.",
    )
    parser.add_argument(
        "--kl2",
        type=float,
        help="If used, KL2 coefficient is manually overwritten.",
    )
    args = parser.parse_args()
    hyper_params = load_hyper_params(args.config)
    hyper_params["raw_image"] = args.raw_image
    hyper_params["linear"] = args.linear
    hyper_params["add_unlabeled"] = args.add_unlabeled
    hyper_params["save_model"] = args.save_model
    hyper_params["as_reward"] = False  # args.as_reward
    hyper_params["ignore_unlabeled"] = args.ignore_unlabeled
    full_dataset = hyper_params["dataset"]
    dataset = full_dataset.split("/")[0].split("_")[0]
    print("Dataset =", dataset, full_dataset)
    hyper_params["dataset"] = dataset_mapper[dataset]
    hyper_params["dataset"]["name"] = full_dataset
    if args.feature_size is not None:
        hyper_params["feature_size"] = args.feature_size
    else:
        hyper_params["feature_size"] = np.prod(hyper_params["dataset"]["data_shape"])
    if "${UL}" in hyper_params.dataset["name"]:
        ul_string = None
        tau_string = None
        if hyper_params["linear"] and not args.deeplog:
            ul_string = args.ul
            tau_string = args.tau
        else:
            # if dataset == "cifar":
            #     ul_string = f"_ul{args.ul}" if args.ul != "0" else ""
            #     tau_string = f"_tau{args.tau}" if args.tau != "1.0" else ""
            # else:
            ul_string = args.ul
            tau_string = args.tau
        hyper_params.dataset["name"] = hyper_params.dataset["name"].replace(
            "${UL}", ul_string
        )
        hyper_params.dataset["name"] = hyper_params.dataset["name"].replace(
            "${TAU}", tau_string
        )
        hyper_params["tensorboard_path"] = hyper_params["tensorboard_path"].replace(
            "${UL}", ul_string
        )
        hyper_params["tensorboard_path"] = hyper_params["tensorboard_path"].replace(
            "${TAU}", tau_string
        )
        hyper_params["output_path"] = hyper_params["output_path"].replace(
            "${UL}", ul_string
        )
        hyper_params["output_path"] = hyper_params["output_path"].replace(
            "${TAU}", tau_string
        )
        hyper_params["log_file"] = hyper_params["log_file"].replace("${UL}", ul_string)
        hyper_params["log_file"] = hyper_params["log_file"].replace(
            "${TAU}", tau_string
        )
        hyper_params["summary_file"] = hyper_params["summary_file"].replace(
            "${UL}", ul_string
        )
        hyper_params["summary_file"] = hyper_params["summary_file"].replace(
            "${TAU}", tau_string
        )
    # if hyper_params["raw_image"] and hyper_params["linear"] and not args.deeplog:
    #     hyper_params["dataset"]["name"] = hyper_params["dataset"]["name"].replace(
    #         dataset, dataset + "_raw"
    #     )
    #     hyper_params["summary_file"] = hyper_params["summary_file"].replace(
    #         dataset, dataset + "_raw"
    #     )
    #     hyper_params["log_file"] = hyper_params["log_file"].replace(
    #         dataset, dataset + "_raw"
    #     )
    #     hyper_params["tensorboard_path"] = hyper_params["tensorboard_path"].replace(
    #         dataset, dataset + "_raw"
    #     )
    if hyper_params["ignore_unlabeled"] and hyper_params["add_unlabeled"] > 0:
        hyper_params["tensorboard_path"] = hyper_params["tensorboard_path"].replace(
            "_KL", "_u" + str(hyper_params["add_unlabeled"]) + "_KL"
        )
        hyper_params["output_path"] = hyper_params["output_path"].replace(
            "_KL", "_u" + str(hyper_params["add_unlabeled"]) + "_KL"
        )
        hyper_params["log_file"] = hyper_params["log_file"].replace(
            "_KL", "_u" + str(hyper_params["add_unlabeled"]) + "_KL"
        )
        hyper_params["summary_file"] = hyper_params["summary_file"].replace(
            "_KL", "_u" + str(hyper_params["add_unlabeled"]) + "_KL"
        )
    if hyper_params["ignore_unlabeled"] and hyper_params["add_unlabeled"] > 0:
        hyper_params["dataset"]["name"] = hyper_params["dataset"]["name"].replace(
            dataset, dataset + "_u" + str(hyper_params["add_unlabeled"])
        )

    if args.wd is not None:
        hyper_params["weight_decay"] = [args.wd, args.wd]
    if args.kl is not None:
        if "KL" not in hyper_params.experiment.regularizers:
            raise ValueError("Config does not allow KL regularizer.")
        hyper_params.experiment.regularizers.KL = [args.kl, args.kl]
    if args.kl2 is not None:
        if "KL2" not in hyper_params.experiment.regularizers:
            raise ValueError("Config does not allow KL2 regularizer.")
        hyper_params.experiment.regularizers.KL2 = [args.kl2, args.kl2]

    print("Ignoring unlabeled data?", hyper_params["ignore_unlabeled"])
    print("save model:", hyper_params["save_model"])
    result_path = hyper_params["summary_file"] + "/final.txt"
    print(result_path)
    print(hyper_params["summary_file"])
    print("Dataset =", hyper_params["dataset"])
    study = optuna.create_study(direction="maximize")
    study.optimize(
        partial(
            main, hyper_params=hyper_params, device=args.device, return_model=False
        ),
        n_trials=hyper_params.experiment.n_trials,
    )
    # best_metrics = main(args.config, device=args.device)
    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
    s = ""
    print("Study statistics: ")
    s += "Study statistics: \n"
    print("  Number of finished trials: ", len(study.trials))
    s += "  Number of finished trials: " + str(len(study.trials)) + "\n"
    print("  Number of pruned trials: ", len(pruned_trials))
    s += "  Number of pruned trials: " + str(len(pruned_trials)) + "\n"
    print("  Number of complete trials: ", len(complete_trials))
    s += "  Number of complete trials: " + str(len(complete_trials)) + "\n"

    print("Best trial:")
    s += "Best trial:\n"
    trial = study.best_trial

    print("  Value: ", trial.value)
    s += "  Value: " + str(trial.value) + "\n"
    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))
        s += "    {}: {}".format(key, value) + "\n"
    file_write(result_path, s)
