import argparse
import os
from datetime import date

import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn import metrics
from sklearn.metrics import accuracy_score, ConfusionMatrixDisplay
import seaborn as sns
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from models.lcnn import LcnnASV
from models.resnet import ResNet, Res2Net
from models.x_vector import X_vector
from speech_dataset import SpeechDataset, SpeechFeatureDataset

########## Argument parser
from utils.utils import speech_collate

EXP_TAG = "voc-exp"

LABEL_IDS = {
    "voc-exp": {
        "real": 0,
        "parallel_wavegan": 1,
        "hifigan": 2,
        "mb_melgan": 3,
        "style_melgan": 4
    }
}

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--training_filepath', type=str,
                    default=f"meta/{EXP_TAG}/train_feats.txt")
parser.add_argument('--testing_filepath', type=str,
                    default=f"meta/{EXP_TAG}/eval_feats.txt")
parser.add_argument('--validation_filepath', type=str,
                    default=f"meta/{EXP_TAG}/dev_feats.txt")
parser.add_argument('--input_dim', action="store_true", default=40)
parser.add_argument('--model', default="res2net")
parser.add_argument('--mode', type=str, default="train")
parser.add_argument('--ckpt', type=str, default="model_spk/exp2/res2net-20220726/best_check_point_26_0.9")

parser.add_argument('--num_classes', action="store_true", default=5)
parser.add_argument('--lamda_val', action="store_true", default=0.1)
parser.add_argument('--batch_size', action="store_true", default=32)
parser.add_argument('--use_gpu', action="store_true", default=True)
parser.add_argument('--num_epochs', action="store_true", default=100)
parser.add_argument('--start_epoch', action="store_true", default=0)
args = parser.parse_args()

dataset_train = SpeechFeatureDataset(manifest=args.training_filepath, mode='train')
dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, collate_fn=speech_collate)

dataset_val = SpeechFeatureDataset(manifest=args.validation_filepath, mode='val')
dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=True, collate_fn=speech_collate)

dataset_test = SpeechFeatureDataset(manifest=args.testing_filepath, mode='test')
dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, shuffle=True, collate_fn=speech_collate)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(args.training_filepath)
if args.model == 'xvector':
    model = X_vector(args.input_dim, args.num_classes).to(device)
elif args.model == 'lcnn':
    model = LcnnASV(args.input_dim, args.num_classes).to(device)
elif args.model == 'resnet':
    model = ResNet(args.num_classes).to(device)
elif args.model == 'res2net':
    model = Res2Net('SEBottle2neck', scale=8, num_classes=args.num_classes).to(device)
elif args.model == 'res2net_48':
    model = Res2Net('SEBottle2neck', baseWidth=48, scale=8, num_classes=args.num_classes).to(device)

print(model)
print(args.model)

model = nn.DataParallel(model)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0, betas=(0.9, 0.98), eps=1e-9)
loss_fun = nn.CrossEntropyLoss()

auj = date.isoformat(date.today()).replace("-", "")


def train(dataloader, epoch):
    train_loss_list = []
    full_preds = []
    full_gts = []
    model.train()
    for i_batch, sample_batched in tqdm(enumerate(dataloader), total=len(dataloader)):
        features = torch.from_numpy(np.asarray([sample.numpy().T for sample in sample_batched[0]])).float()
        labels = torch.from_numpy(np.asarray([torch_tensor[0].numpy() for torch_tensor in sample_batched[1]]))
        features, labels = features.to(device), labels.to(device)
        features.requires_grad = True
        optimizer.zero_grad()
        pred_logits, _ = model(features)
        #### CE loss
        loss = loss_fun(pred_logits, labels)
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
        # train_acc_list.append(accuracy)

        predictions = np.argmax(pred_logits.detach().cpu().numpy(), axis=1)

        for pred in predictions:
            full_preds.append(pred)
        for lab in labels.detach().cpu().numpy():
            full_gts.append(lab)

    mean_acc = accuracy_score(full_gts, full_preds)
    mean_loss = np.mean(np.asarray(train_loss_list))
    print(f'Total training loss {mean_loss} and training accuracy {mean_acc} after {epoch} epochs')
    return mean_loss, mean_acc


def validation(dataloader, epoch):
    model.eval()
    with torch.no_grad():
        val_loss_list = []
        full_preds = []
        full_gts = []
        for i_batch, sample_batched in tqdm(enumerate(dataloader), total=len(dataloader)):
            features = torch.from_numpy(np.asarray([sample.numpy().T for sample in sample_batched[0]])).float()
            labels = torch.from_numpy(np.asarray([torch_tensor[0].numpy() for torch_tensor in sample_batched[1]]))
            features, labels = features.to(device), labels.to(device)
            pred_logits, _ = model(features)
            #### CE loss
            loss = loss_fun(pred_logits, labels)
            val_loss_list.append(loss.item())
            # train_acc_list.append(accuracy)
            predictions = np.argmax(pred_logits.detach().cpu().numpy(), axis=1)
            for pred in predictions:
                full_preds.append(pred)
            for lab in labels.detach().cpu().numpy():
                full_gts.append(lab)

        mean_acc = accuracy_score(full_gts, full_preds)
        mean_loss = np.mean(np.asarray(val_loss_list))
        print(f'Total validation loss {mean_loss} and validation accuracy {mean_acc} after {epoch} epochs')
        create_folders = os.path.join('model_spk', EXP_TAG, f'{args.model}-{auj}')
        if not os.path.exists(create_folders):
            os.makedirs(create_folders)
        model_save_path = os.path.join('model_spk', EXP_TAG, f'{args.model}-{auj}',
                                       f'best_check_point_{str(epoch)}_{str(mean_acc)}')
        state_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
        torch.save(state_dict, model_save_path)

        return mean_loss, mean_acc


def plot_confusion_matrix(labels_all, predict_all, class_ids, filepath):
    ConfusionMatrixDisplay.from_predictions(
        y_true=labels_all, y_pred=predict_all,
        display_labels=class_ids, xticks_rotation="vertical", cmap="Blues",colorbar=False)
    #plt.matshow(cf_matrix, interpolation='nearest', cmap="Blues")
    #plt.xticks(np.arange(0, len(class_ids)), class_ids, fontsize="x-small")
    #plt.yticks(np.arange(0, len(class_ids)), class_ids, fontsize="x-small")
    #for i in range(len(class_ids)):
    #    for j in range(len(class_ids)):
    #        c = cf_matrix[j, i]
    #        plt.text(i, j, str(c), va='center', ha='center')
    plt.savefig(filepath, dpi=200, bbox_inches='tight')
    plt.show()


def test(dataloader, model_path, class_ids):
    class_list = []
    for name, label in class_ids.items():
        class_list.append(name)
    model.load_state_dict(torch.load(model_path)['model'])
    nn.DataParallel(model)
    model.eval()
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    pred_score_all = None
    with torch.no_grad():
        loss_total = 0
        for i_batch, sample_batched in tqdm(enumerate(dataloader), total=len(dataloader)):
            features = torch.from_numpy(np.asarray([sample.numpy().T for sample in sample_batched[0]])).float()
            labels = torch.from_numpy(np.asarray([torch_tensor[0].numpy() for torch_tensor in sample_batched[1]]))
            features, labels = features.to(device), labels.to(device)
            pred_logits, _ = model(features)

            #### CE loss
            loss = loss_fun(pred_logits, labels)
            loss_total += loss
            # train_acc_list.append(accuracy)
            pred_score = pred_logits.detach().cpu().numpy()
            predictions = np.argmax(pred_score, axis=1)
            for score in pred_score:
                if pred_score_all is None:
                    pred_score_all = np.copy([score])
                else:
                    pred_score_all = np.append(pred_score_all, [score], axis=0)
            for pred in predictions:
                predict_all = np.append(predict_all, pred)
            for lab in labels.detach().cpu().numpy():
                labels_all = np.append(labels_all, lab)
        acc = metrics.accuracy_score(labels_all, predict_all)
        for lbl in class_ids.values():
            y_score = pred_score_all[:, lbl]
            fpr, tpr, _ = metrics.roc_curve(labels_all, y_score, pos_label=lbl)
            plt.plot(fpr, tpr, label=f"{class_list[lbl]}", lw=1)

        report = metrics.classification_report(labels_all, predict_all, target_names=class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.0])
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"exp/{EXP_TAG}/{args.model}-roc.pdf", dpi=200)
        plt.show()
        plt.close()

        plot_confusion_matrix(labels_all, predict_all, class_list, f"exp/{EXP_TAG}/{args.model}-confusion.pdf")
        return acc, loss_total / len(dataloader), report, confusion


def main():
    assert (args.mode is not None), "You need to specify a mode"

    class_ids = LABEL_IDS[EXP_TAG]

    class_list = []
    for name, label in class_ids.items():
        class_list.append(name)

    if args.mode == "train":
        for epoch in range(args.start_epoch, args.num_epochs):
            t_loss, t_acc = train(dataloader_train, epoch)
            v_loss, v_acc = validation(dataloader_val, epoch)
            flog = open(os.path.join('logs', EXP_TAG, f'{args.model}-{auj}.log'), mode="a+")
            flog.write(f"{epoch}\t{t_loss}\t{t_acc}\t{v_loss}\t{v_acc}\n")
            flog.close()
    else:
        model_path = args.ckpt
        test_acc, test_loss, test_report, test_confusion = test(dataloader_test, model_path, class_ids)
        msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
        print(msg.format(test_loss, test_acc))
        print("Precision, Recall and F1-Score...")
        print(test_report)
        frep = open(f"exp/{EXP_TAG}/{args.model}-report.txt", mode="w")
        frep.write(test_report)
        print("Confusion Matrix...")
        print(test_confusion)
        frep.write(str(test_confusion))
        frep.close()


if __name__ == '__main__':
    main()
