import argparse
import copy
import json
import os
import pickle
import time
from datetime import datetime

import numpy as np
import torch.cuda
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import roc_auc_score, f1_score, balanced_accuracy_score, recall_score
from torch.autograd import backward
from torch.utils.data.sampler import SubsetRandomSampler

from datasets import *
from datasets.ToyDataset import ToyDataset
from losses import *
from models import *


def load_net(checkpoint_path, net):
    checkpoint = torch.load(checkpoint_path)
    state_dict = checkpoint['net']
    net.load_state_dict(state_dict)
    return net


def save_model(net, epochs, use_exp, save_dir):
    print('==> Saving model..')

    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)
    if args.no_exp:
        if not args.focal_loss:
            use_grads_str = 'no_exp'
        else:
            use_grads_str = 'focal'
    elif args.feat_loss:
        if not args.sparse_exp:
            use_grads_str = 'feat'
        else:
            use_grads_str = 'sparse'
    else:
        use_grads_str = 'grad_l{}'.format(str(args.l).replace(".", ""))

    state = {
        'net': net.state_dict(),
        'epochs': epochs,
        'use_exp': use_grads_str
    }

    torch.save(state, './{0}/{1}_epochs{2}_{3}.t7'.format(
        save_dir, use_grads_str, epochs,
        datetime.now().strftime('%m%d-%H%M')))


def save_results(data, save_dir):
    print('==> Saving data..')

    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)
    if args.no_exp:
        if not args.focal_loss:
            use_grads_str = 'no_exp'
        else:
            use_grads_str = 'focal'
    elif args.feat_loss:
        if not args.sparse_exp:
            use_grads_str = 'feat'
        else:
            use_grads_str = 'sparse'
    else:
        use_grads_str = 'grad_l{}'.format(str(args.l).replace(".", ""))

    with open(os.path.join(save_dir, "{}_{}.json".format(use_grads_str, datetime.now().strftime('%m%d-%H%M'))),
              'w') as f:
        json.dump(data, f, indent=2)


def learn_epoch(epoch):
    net.train()

    total, correct = 0, 0
    duration = 0
    train_loss, iou_counter, gradient_diff = 0, 0, 0

    for batch_idx, (inputs, labels, masks) in enumerate(train_loader):
        inputs = inputs.to(device)
        inputs.requires_grad_(True)

        experts_label = labels
        experts_exp = masks
        true_preds = torch.LongTensor(experts_label).to(device)
        true_exps = torch.BoolTensor(experts_exp).to(device)

        outputs = net(inputs)

        if args.skew_ratio != 0.5 and args.sparse_exp:
            keep_exp = true_preds == 1
            true_exps[~keep_exp] = torch.zeros(((~keep_exp).sum(), true_exps.shape[1]), dtype=torch.bool).to(device)

        predicted = outputs.reshape(-1).detach().round()
        total += true_preds.size(0)
        correct += predicted.eq(true_preds).sum().item()

        t1 = time.time()
        if not args.feat_loss:
            if args.skew_ratio != 0.5 and args.no_exp and args.focal_loss:
                print("Use focal loss")
                loss = torchvision.ops.sigmoid_focal_loss(outputs.squeeze(), true_preds.float(), reduction='mean')
                g_loss = torch.FloatTensor([0])
            elif args.no_exp:
                loss = bce_loss(true_preds, outputs)
                g_loss = torch.FloatTensor([0])
            else:
                [loss, g_loss] = mask_loss_binary(true_preds, outputs, inputs, true_exps,
                                                       args.l, use_grad=not args.no_exp)

            gradient_diff += g_loss.detach().to("cpu").item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.detach().to("cpu").item()
            del loss
            del g_loss
        else:
            print("Use feature map difference")
            dense_loss = bce_loss(true_preds, outputs)
            dense_optimizer.zero_grad()
            backward(dense_loss)
            dense_optimizer.step()

            feat_opt_steps = 5 if args.dataset.endswith("spurious") else 1
            for i in range(feat_opt_steps):
                feature_loss = feature_difference_loss_kl(net, inputs, true_exps)
                if epoch > 2:
                    feat_optimizer.zero_grad()
                    backward(feature_loss)
                    feat_optimizer.step()

            gradient_diff += feature_loss.detach().to("cpu").item()
            train_loss += dense_loss.detach().to("cpu").item()

            del feature_loss
            del dense_loss
        t2 = time.time()
        del inputs

        duration += (t2 - t1)

    train_accuracy = correct / total
    print("Training accuracy for epoch {} is: {}".format(epoch, train_accuracy))
    print("Label loss is :{}".format(train_loss / (batch_idx + 1)))

    return train_loss / (batch_idx + 1), gradient_diff / (batch_idx + 1), duration


def validation(net, valid_loader):
    net.eval()
    with torch.no_grad():
        loss = 0
        for batch_idx, (inputs, labels, masks) in enumerate(valid_loader):
            inputs = inputs.to(device)
            labels = torch.LongTensor(labels).to(device)
            outputs = net(inputs)

            valid_loss = bce_loss(labels, outputs)

            loss += valid_loss.detach().to("cpu")
    return loss


def test(net, test_loader):
    net.eval()
    correct = 0
    total = 0
    iou_counter = 0
    output = []
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch_idx, (inputs, labels, masks) in enumerate(test_loader):
            inputs = inputs.to(device)
            labels = torch.LongTensor(labels).to(device)
            outputs = net(inputs)

            predicted = outputs.reshape(-1).detach().round()
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            output.append(outputs.cpu().numpy())
            y_true.append(labels.cpu().numpy())
            y_pred.append(predicted.cpu().numpy())

    output = np.concatenate([np.array(i) for i in output])
    y_true = np.concatenate([np.array(i) for i in y_true])
    y_pred = np.concatenate([np.array(i) for i in y_pred])
    auc = roc_auc_score(y_true, output)
    f1 = f1_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    balanced_accuracy = balanced_accuracy_score(y_true, y_pred)

    print("Test accuracy is: {}".format(balanced_accuracy))
    return balanced_accuracy, auc, recall


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Epoch Learning Parser")
    parser.add_argument('--dataset', default='fox_cat', type=str)
    parser.add_argument('--batch_size', default=128, type=int, help='batch size')
    parser.add_argument('--dataset_size', default=1000, type=int)
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--results_save_dir', type=str)
    parser.add_argument('--model_save_dir', type=str)
    parser.add_argument('--no_exp', action='store_true')
    parser.add_argument('--l', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--skew_ratio', type=float, default=0.5)
    parser.add_argument('--feat_loss', action='store_true')
    parser.add_argument('--feat_lr', type=float, default=0.001)
    parser.add_argument('--random_seed', type=int, default=88)
    parser.add_argument('--fix_random_seed', action='store_true')
    parser.add_argument('--trials', type=int, default=1)
    parser.add_argument('--model_save_freq', type=int, default=50)
    parser.add_argument('--feat_opt_decay_step', type=int, default=50)
    parser.add_argument('--opt_decay_step', type=int, default=50)
    parser.add_argument('--sparse_exp', action='store_true')
    parser.add_argument('--focal_loss', action='store_true')
    parser.add_argument('--early_stopping', action='store_true')
    parser.add_argument('--secondary_testset', default=None, type=str)
    parser.add_argument('--load_model_path', default=None, type=str)
    args = parser.parse_args()

    assert not np.logical_and(args.no_exp, args.feat_loss), "Wrong parameter values!"

    if not args.fix_random_seed:
        np.random.seed(args.random_seed)
        random_seeds_list = np.random.randint(10000, size=args.trials)
    else:
        random_seeds_list = np.ones(args.trials, dtype=int) * args.random_seed

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if args.dataset == "fox_cat" or args.dataset == "fox_cat_spurious":
        test_data_path = "data/fox_cat_mono_test.pickle"
    elif args.dataset == "triangle" or args.dataset == "triangle_spurious":
        test_data_path = "data/triangle_mono_test.pickle"
    elif args.dataset == "bird" or args.dataset == "bird_spurious":
        test_data_path = "data/indigo_bunting_blue_grosbeak_test.pickle"
    else:
        raise NotImplementedError
    with open(test_data_path, 'rb') as file:
        d = pickle.load(file)

    test_loader = torch.utils.data.DataLoader(ToyDataset(d['data'], d['labels'], d['masks'], transform),
                                              batch_size=args.batch_size,
                                              shuffle=False)
    if args.secondary_testset is not None:
        if args.secondary_testset == 'triangle_circle':
            secondary_testset_path = "data/triangle_circle_mono_train.pickle"
        elif args.secondary_testset == 'triangle_spurious':
            secondary_testset_path = "data/triangle_spurious_test.pickle"
        elif args.secondary_testset == 'fox_cat_spurious':
            secondary_testset_path = "data/fox_cat_mono_spurious_test.pickle"
        elif args.secondary_testset == 'pentagon':
            secondary_testset_path = "data/pentagon_mono_test.pickle"
        elif args.secondary_testset == 'bird_spurious':
            secondary_testset_path = "data/indigo_bunting_blue_grosbeak_spurious_test.pickle"
        else:
            raise NotImplementedError
        with open(secondary_testset_path, 'rb') as file:
            d = pickle.load(file)
        test_loader2 = torch.utils.data.DataLoader(ToyDataset(d['data'], d['labels'], d['masks'], transform),
                                                  batch_size=128,
                                                  shuffle=False)

    if args.early_stopping:
        if args.dataset == "fox_cat":
            valid_data_path = "data/fox_cat_mono_valid.pickle"
        elif args.dataset == "triangle":
            valid_data_path = "data/triangle_mono_valid.pickle"
        else:
            raise NotImplementedError

    if args.early_stopping:
        with open(valid_data_path, 'rb') as file:
            d_valid = pickle.load(file)
        valid_loader = torch.utils.data.DataLoader(ToyDataset(d_valid['data'], d_valid['labels'], d_valid['masks'], transform), batch_size=128,
                                                shuffle=False)

    test_accuracies = []
    test_aucs = []
    test_recalls = []
    test_accuracies2 = []
    test_aucs2 = []
    test_recalls2 = []
    train_label_loss = []
    train_g_loss = []
    train_ious = []
    train_times = []
    trial_times = []
    best_test_perf = 0
    best_test_perf2 = 0

    for trial in range(args.trials):
        print("Trial {}/{}".format(trial+1, args.trials))
        print("Current random seed is {}".format(random_seeds_list[trial]))

        if args.dataset.startswith("bird"):
            net = BirdSimpleConv(endo=args.feat_loss).to(device)
        else:
            net = ToyConvBinary(input_channel=1).to(device)

        if args.load_model_path is not None:
            net = load_net(args.load_model_path, net)

        if args.early_stopping:
            best_valid_loss = 99999
            early_stopping_counter = 0
            best_test_acc = 0
            best_train_loss = 9999


        print("Generating data")
        if args.dataset == "fox_cat":
            data, labels, masks = create_fox_cat_dataset(args.dataset_size, args.skew_ratio,
                                                         random_seed=random_seeds_list[trial])
        elif args.dataset == "triangle":
            data, labels, masks = create_triangle_dataset(args.dataset_size, args.skew_ratio,
                                                          random_seed=random_seeds_list[trial])
        elif args.dataset == "triangle_spurious" or args.dataset == "fox_cat_spurious":
            if args.dataset == "triangle_spurious":
                train_path = "./data/triangle_spurious_train.pickle"
            else:
                train_path = "./data/fox_cat_mono_spurious_train.pickle"
            with open(train_path, "rb") as f:
                d = pickle.load(f)
            data = d['data']
            labels = d['labels']
            masks = d['masks']

            np.random.seed(random_seeds_list[trial])
            idx = np.random.choice(np.arange(data.shape[0]), size=args.dataset_size, replace=False)
            data = np.array(data)[idx]
            labels = np.array(labels)[idx]
            masks = np.array(masks)[idx]
        elif args.dataset.startswith("bird"):
            if args.dataset == "bird":
                train_path = "./data/indigo_bunting_blue_grosbeak_train.pickle"
            else:
                train_path = "./data/indigo_bunting_blue_grosbeak_spurious_train.pickle"
            with open(train_path, "rb") as f:
                d = pickle.load(f)
            data = d['data']
            labels = d['labels']
            masks = d['masks']

            if args.dataset_size < 59:
                np.random.seed(random_seeds_list[trial])
                if args.skew_ratio == 0.5:
                    # Balanced sampling
                    idx = np.random.choice(np.arange(len(data)), size=args.dataset_size, replace=False)
                else:
                    class1_idx = np.where(np.array(labels) == 1)[0]
                    class0_idx = np.where(np.array(labels) == 0)[0]
                    idx1 = np.random.choice(class1_idx, size=int(args.skew_ratio * args.dataset_size), replace=False)
                    idx0 = np.random.choice(class0_idx, size=args.dataset_size - int(args.skew_ratio * args.dataset_size), replace=False)
                    idx = np.concatenate([idx0, idx1])

                data = np.array(data)[idx]
                labels = np.array(labels)[idx]
                masks = np.array(masks)[idx]
            if args.dataset_size > 60 or (args.dataset_size == 60 and args.skew_ratio != 0.5):
                raise AttributeError("We don't have so many points!!")
        else:
            raise NotImplementedError
        dataset = ToyDataset(data, labels, masks, transform)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

        if not args.feat_loss:
            if args.dataset.startswith('fox_cat') or args.dataset.startswith('triangle'):
                optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
            else:
                optimizer = torch.optim.SGD(net.parameters(), lr=args.lr)
        else:
            if args.dataset.startswith('fox_cat') or args.dataset.startswith('triangle'):
                feat_optimizer = torch.optim.Adam(net.endo_map.parameters(), lr=args.feat_lr)
                dense_optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
            else:
                feat_optimizer = torch.optim.SGD(net.endo_map.parameters(), lr=args.feat_lr)
                dense_optimizer = torch.optim.SGD(net.parameters(), lr=args.lr)
            feat_optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(feat_optimizer, T_max=args.epochs)

            dense_optimizer_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(dense_optimizer, T_max=args.epochs)

        test_accuracy_epoch = []
        test_auc_epoch = []
        test_recall_epoch = []
        test_accuracy_epoch2 = []
        test_auc_epoch2 = []
        test_recall_epoch2 = []
        train_label_loss_epoch = []
        train_g_loss_epoch = []
        train_iou_epoch = []
        epoch_time = []
        trial_time = 0
        for epoch in range(args.epochs):
            label_loss, g_loss, train_time = learn_epoch(epoch)

            epoch_time.append(train_time)
            trial_time += train_time

            train_label_loss_epoch.append(label_loss)

            if args.dataset != "nih" or (args.dataset == "nih" and ((epoch + 1) % 10 == 0) or epoch == 0):
                test_acc, auc, recall = test(net, test_loader)

                print("="*40)
                test_accuracy_epoch.append(test_acc)
                test_auc_epoch.append(auc)
                test_recall_epoch.append(recall)

            if args.secondary_testset is not None:
                test_acc2, auc2, recall2 = test(net, test_loader2)
                print("=" * 40)
                test_accuracy_epoch2.append(test_acc2)
                test_auc_epoch2.append(auc2)
                test_recall_epoch2.append(recall2)

            if epoch > 10 and args.early_stopping:
                valid_loss = validation(net, valid_loader)
                if valid_loss < best_valid_loss:
                    best_valid_loss = valid_loss
                    early_stopping_counter = 0
                    best_test_acc = test_acc
                    best_test_auc = auc
                    best_test_recall = recall
                    best_train_loss = label_loss
                    net_ckpt = copy.deepcopy(net)
                    if args.secondary_testset is not None:
                        best_test_acc2 = test_acc2
                        best_test_auc2 = auc2
                        best_test_recall2 = recall2
                else:
                    early_stopping_counter += 1

            if test_acc > best_test_perf:
                best_test_perf = test_acc

            if args.secondary_testset:
                if test_acc2 > best_test_perf2:
                    best_test_perf2 = test_acc2

            if (epoch + 1) % args.model_save_freq == 0 and not args.early_stopping:
                save_model(net, epoch + 1, not args.no_exp, args.model_save_dir)
            if args.early_stopping and early_stopping_counter >= 20:
                print("Early stopped at epoch {}".format(epoch))
                print("Rolling back!")
                train_label_loss_epoch[-19:] = [best_train_loss] * 19
                test_accuracy_epoch[-19:] = [best_test_acc] * 19
                test_auc_epoch[-19:] = [best_test_auc] * 19
                test_recall_epoch[-19:] = [best_test_recall] * 19
                for remaining_epoch in range(epoch+1, args.epochs):
                    train_label_loss_epoch.append(best_train_loss)
                    test_accuracy_epoch.append(best_test_acc)
                    test_auc_epoch.append(best_test_auc)
                    test_recall_epoch.append(best_test_recall)
                if args.secondary_testset:
                    test_accuracy_epoch2[-19:] = [best_test_acc2] * 19
                    test_auc_epoch2[-19:] = [best_test_auc2] * 19
                    test_recall_epoch2[-19:] = [best_test_recall2] * 19
                    for remaining_epoch in range(epoch + 1, args.epochs):
                        test_accuracy_epoch2.append(best_test_acc2)
                        test_auc_epoch2.append(best_test_auc2)
                        test_recall_epoch2.append(best_test_recall2)
                save_model(net_ckpt, epoch + 1, not args.no_exp, os.path.join(args.model_save_dir, "early_stopping"))
                break
            if args.early_stopping and epoch == args.epochs - 1:
                save_model(net, epoch + 1, not args.no_exp, os.path.join(args.model_save_dir, "early_stopping"))

        test_accuracies.append(test_accuracy_epoch)
        test_aucs.append(test_auc_epoch)
        test_recalls.append(test_recall_epoch)
        test_accuracies2.append(test_accuracy_epoch2)
        test_aucs2.append(test_auc_epoch2)
        test_recalls2.append(test_recall_epoch2)
        train_ious.append(train_iou_epoch)
        train_label_loss.append(train_label_loss_epoch)
        train_g_loss.append(train_g_loss_epoch)
        train_times.append(epoch_time)
        trial_times.append(trial_time)

    data = {
        'dataset_size': args.dataset_size,
        'batch_size': args.batch_size,
        'trials': args.trials,
        'best_test_performance': best_test_perf,
        'best_test_performance2': best_test_perf2,
        'learning_rate': args.lr,
        'lr_decay': args.opt_decay_step,
        'feature_extractor_lr': args.feat_lr,
        'feature_extractor_lr_decay': args.feat_opt_decay_step,
        'use_grad': not args.no_exp,
        'use_feat_loss': args.feat_loss,
        'use_focal_loss': args.focal_loss,
        'use_sparse_explanation': args.sparse_exp,
        'lambda': args.l,
        'epochs': args.epochs,
        'loss': train_label_loss,
        'g_loss': train_g_loss,
        'test_accuracy': test_accuracies,
        'AUC': test_aucs,
        'Recall': test_recalls,
        'test_accuracy2': test_accuracies2,
        'AUC2': test_aucs2,
        'Recall2': test_recalls2,
        'per_epoch_time': train_times,
        'trial_times': trial_times
    }
    print('==> Saving results..')
    if not args.early_stopping:
        save_results(data, args.results_save_dir)
    else:
        save_results(data, os.path.join(args.results_save_dir, "early_stopping"))
