"""
Use the Optuna library to find the Optimal Hyper-Parameters
Train with GunPoint dataset
First load the data
load the model
"""

### -------------------
### --- Third-Party ---
### -------------------
from functools import partial
import argparse
import os
import sys
sys.path.append("..")
fileDir = os.path.dirname(os.path.abspath(__file__))
parentDir = os.path.dirname(fileDir)
sys.path.append(parentDir)
import numpy as np
import torch as t
import torch.nn as nn
from torchsummary import summary
import optuna
from optuna.trial import TrialState
import logging

## -----------
## --- Own ---
## -----------
from utils import read_dataset_ts, create_directory
from trainhelper.stopping import EarlyStoppingCallback
from trainhelper.dataset import Dataset, TrainValSplit
from trainhelper.trainer import Trainer
from models.models import TCN, FCN, TCN_laststep, FCN_laststep
from models.lstm import LSTM

DEVICE = t.device('cuda' if t.cuda.is_available() else 'cpu')

def get_model(trial, num_cls, arguments):
    ## -------------------
    ## --- hyperparams ---
    ## -------------------
    selected_model = arguments.DLModel
    use_fc = arguments.UseFC
    ## ["BaselineFCN", "FCN", "TCN", "ResNet", "FCN_withoutFC", "TCN_withoutFC",
    ## "TCN_laststep", "FCN_laststep",
    ## "LSTM", "LSTMInputCell"]
    input_channel = arguments.Input_Channel
    ch_out = []
    kernel_size = []
    chout_num = arguments.CHOUT_NUM
    # kernel_num = len(OmegaConf.create(cfg.model.kernel_size))
    dropout = trial.suggest_float("dropout", 0.2, 0.5)
    for i in range(chout_num):
        ch_out.append(trial.suggest_int("ch_out_l{}".format(i), 4, 80))
        kernel_size.append(trial.suggest_int("kernel_size_l{}".format(i), 2, 7))

    if selected_model in ["FCN_withoutFC", "FCN"]:
        model = FCN(ch_in=input_channel, ch_out=ch_out,
                    dropout_rate=dropout,
                    num_classes=num_cls,
                    kernel_size=kernel_size,
                    use_fc=use_fc)
    elif selected_model in ["TCN_laststep", "FCN_laststep"]:
        if selected_model in ["TCN_laststep"]:
            model = TCN_laststep(ch_in=input_channel, ch_out=ch_out,
                                 kernel_size=kernel_size,
                                 dropout_rate=dropout,
                                 num_classes=num_cls)
        else:
            model = FCN_laststep(ch_in=input_channel, ch_out=ch_out,
                                 dropout_rate=dropout,
                                 num_classes=num_cls,
                                 kernel_size=kernel_size)

    elif selected_model in ["TCN_withoutFC", "TCN"]:
        model = TCN(ch_in=input_channel, ch_out=ch_out,
                    kernel_size=kernel_size,
                    dropout_rate=dropout,
                    use_fc=use_fc,
                    num_classes=num_cls)

    elif selected_model in ["LSTM"]:
        hidden_size = trial.suggest_int("hidden_size", 1, 40)
        num_layers = trial.suggest_int("num_layers", 1, 2)
        bidirectional = arguments.Bidirectional
        model = LSTM(ch_in=input_channel,
                     hidden_size=hidden_size,
                     num_layers=num_layers,
                     dropout=dropout,
                     bidirectional=bidirectional,
                     num_classes=num_cls)
    return model


def get_data(arguments, hyperparameters):
    ## parameters setting
    batch_size = arguments.BatchSize

    ## Load data
    root_dir = arguments.Root_Dir
    dataset_name = "GunPointAgeSpan"
    dataset = read_dataset_ts(root_dir, dataset_name)
    train_x, test_x, train_y, test_y, label_dict = dataset[dataset_name]

    label_summary = np.unique(list(test_y) + list(train_y))
    num_cls = len(label_summary)

    ## transfer train and test set into Torch Dataset
    trainset = Dataset(train_x, train_y)
    testset = Dataset(test_x, test_y)

    ## create data indices for train, Validation Test set splits:
    val_train_split = 0.2
    trainvalsplit = TrainValSplit(trainset, val_train_split=val_train_split)

    ## number of train and test set before balance
    trainvalues, traincounts = np.unique(trainvalsplit.trainset.labels, return_counts=True)
    valvalues, valcounts = np.unique(trainvalsplit.valset.labels, return_counts=True)
    testvalues, test_classcounts = np.unique(testset.labels, return_counts=True)
    number_of_trainset = [count_tuple for count_tuple in zip(trainvalues, traincounts)]
    number_of_valset = [count_tuple for count_tuple in zip(valvalues, valcounts)]
    number_of_testset = [count_tuple for count_tuple in zip(testvalues, test_classcounts)]
    print(f"the Test set size: {len(testset)}")
    print(f"the number of each class in Test set: {test_classcounts}")
    ## set the hyperparameters
    hyperparameters["num_train"] = number_of_trainset
    hyperparameters["num_validation"] = number_of_valset
    hyperparameters["num_test"] = number_of_testset

    trainloader, val_loader = trainvalsplit.get_split(batch_size=batch_size,
                                                      num_workers=1)
    testloader = t.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False
    )

    return trainloader, val_loader, testloader, num_cls

def train(x, y, model, optimizer, loss_func):
    optimizer.zero_grad()
    predicted = model(x)
    loss = loss_func(predicted,
                     t.tensor(y, dtype=t.long))
    # if self._class_weights is not None:
    #     weight_ = self._class_weights[y.view(-1).long()]
    #     loss = loss * weight_.to(DEVICE)
    #     loss = loss.sum()
    loss.backward()
    optimizer.step()

    return loss.item(), predicted

def test(x, y, model, loss_func):
    predicted = model(x)
    loss = loss_func(predicted,
                     t.tensor(y, dtype=t.long))
    # if self._class_weights is not None:
    #     weight_ = self._class_weights[y.view(-1).long()]
    #     loss = loss * weight_.to(DEVICE)
    #     loss = loss.sum()
    return loss.item(), predicted


# @hydra.main(config_path="conf", config_name="config")
def objective(arguments, trial):
    epochs = arguments.Epochs
    one_matrix = arguments.One_matrix

    hyperparameters = {}

    ### Get the dataset
    trainloader, val_loader, testloader, num_classes = get_data(arguments, hyperparameters)

    ### Generate the model
    model = get_model(trial, num_cls=num_classes, arguments=arguments).to(DEVICE)


    ### Generate the optimizers
    # optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    lr = trial.suggest_float("lr", 1e-3, 1e-1, log=True)
    # optimizer = getattr(t.optim, optimizer_name)(model.parameters(), lr=lr)
    optimizer = t.optim.Adam(model.parameters(), lr=lr)

    loss_func = nn.CrossEntropyLoss()

    ### training the model
    for epoch in range(epochs):
        ## train
        model.train()
        train_losses = []
        correct = 0
        sum_labels = 0
        for i, (data, label) in enumerate(trainloader, 0):  ## data.size = [batch, dims, len_sample]
            if one_matrix:
                xt = data.float().to(DEVICE)
                if np.size(label[0].cpu().numpy()) == 1:
                    label = label.to(DEVICE).reshape((-1, 1))  ## label should be (len, 1)
                else:
                    label = label.to(DEVICE)
                ## Forward pass
                loss, output = train(xt, label, model, optimizer, loss_func)
            else:
                NotImplementedError("multiple input not implemented")

            ## calculate the accuracy
            output = t.argmax(output, dim=1)  ## compute the highest probability
            correct += (output == label).sum().item()
            if np.size(label[0].cpu().numpy()) == 1:
                sum_labels += len(label)
            else:
                sum_labels += len(label) * len(label[0])
            train_losses.append(loss)
        avg_loss = np.mean(train_losses)
        avg_acc = (correct / sum_labels) * 100

        ## evaluation
        model.eval()
        test_losses = []
        correct = 0
        sum_labels = 0
        with t.no_grad():
            for i, (data, label) in enumerate(val_loader, 0):  ## data.size = [batch, dims, len_sample]
                if one_matrix:
                    xt = data.float().to(DEVICE)
                    if np.size(label[0].cpu().numpy()) == 1:
                        label = label.to(DEVICE).reshape((-1, 1))  ## label should be (len, 1)
                    else:
                        label = label.to(DEVICE)
                    ## Forward pass
                    loss, output = test(xt, label, model, loss_func)

                ## calculate the accuracy
                output = t.argmax(output, dim=1)  ## compute the highest probability
                correct += (output == label).sum().item()
                if np.size(label[0].cpu().numpy()) == 1:
                    sum_labels += len(label)
                else:
                    sum_labels += len(label) * len(label[0])
                test_losses.append(loss)

        avg_val_loss = np.mean(test_losses)
        avg_val_acc = (correct / sum_labels) * 100

        trial.report(avg_val_acc, epoch)

    return avg_val_acc

def main(args):
    return objective(arguments=args)

def parse_arguments(argv):
    parser = argparse.ArgumentParser()

    parser.add_argument("--Root_Dir", type=str, default="../")
    parser.add_argument("--DLModel", type=str, default="FCN_withoutFC")
    parser.add_argument("--Input_Channel", type=int, default=1)
    parser.add_argument("--CHOUT_NUM", type=int, default=4)
    parser.add_argument("--Epochs", type=int, default=500)
    parser.add_argument("--One_matrix", type=bool, default=True)
    parser.add_argument("--BatchSize", type=int, default=16)
    parser.add_argument("--UseFC", type=bool, default=False)
    parser.add_argument("--Bidirectional", type=bool, default=False)


    return parser.parse_args()


if __name__ == "__main__":
    # objective_func = objective(arguments=parse_arguments(sys.argv[1:]))
    args = parse_arguments(sys.argv[1:])
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

    study_name = "GunPointAgeSpan_" + args.DLModel
    storage_name = "sqlite:///parameters_optimization/{}.db".format(study_name)
    study = optuna.create_study(direction="maximize", study_name=study_name, storage=storage_name)
    study.optimize(partial(objective, args), n_trials=30, timeout=3600)

    # pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    # print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

    fig = optuna.visualization.plot_optimization_history(study)
    name_of_history = "parameters_optimization/" + storage_name + "_" + args.DLModel + "search_history.html"
    fig.write_html(name_of_history)