import logging
import os
from datetime import datetime

import torch
from torch import optim
from tqdm import tqdm

from src.losses.loss import *
from src.metrics.bootstrap_ece import *
from src.models.utils import load_ckp, save_ckp, save_best, load_best

eps = 1e-7

loss_fcn_dict = {
    'ce': cross_entropy,
    'mse': mean_squared_error,
    'focal': focal_loss,
    'mmce': mmce,
    'kde_mse': kde_mse,
    'kde_ce': kde_ce
}


def train_and_inference_loop(args, experiment):
    device, model = get_device_and_model(experiment, args)
    optimizer = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'], weight_decay=args['wd'])

    best_loss = torch.finfo(torch.float).max
    start_epoch = 0
    if experiment_dir := args['experiment_dir']:
        model, optimizer, start_epoch, best_loss = load_ckp(experiment_dir, model, optimizer, device)
        best_model = load_best(experiment_dir, device, args)
        adjust_lr_from_ckp(args, start_epoch)

    train_loader, val_loader, test_loader = experiment.get_data_loaders(args['batch_size'])

    for epoch in range(start_epoch, args['epochs']):
        adjust_lr(args, epoch, optimizer, experiment.model.name == 'resnet110')
        start_time = datetime.now()

        train_loss, train_acc = training(train_loader, model, optimizer, device, args)
        logging.info(f"Finished epoch {epoch} in {datetime.now() - start_time}")
        logging.info(f'Epoch {epoch}: train_loss {train_loss}, train_acc {train_acc}')

        val_loss, val_acc, ece = evaluation(val_loader, model, args)
        logging.info(f'Epoch {epoch}: val_loss {val_loss}, val_acc {val_acc}, ece {ece}')

        if epoch % args['how_often_ckp'] == 0 and epoch != 0:
            save_ckp(model, epoch, val_loss, optimizer, experiment.path)

        if val_loss <= best_loss:
            best_model = save_best(model, epoch, val_loss, best_loss, experiment.path)
            best_loss = val_loss

    evaluate_test(test_loader, best_model, args, experiment.path)

    return model


def training(train_loader, model, optimizer, device, config):
    train_loss, train_acc, total = 0., 0., 0.
    bandwidth = -1
    model.train()
    with tqdm(total=len(train_loader)) as p_bar:
        for data, target_orig in train_loader:
            data, target_orig = data.to(device), target_orig.to(device)
            optimizer.zero_grad()

            target = format_target(config, target_orig)
            pred_logits = model(data)
            output, f = format_output(config, pred_logits)
            config['f'] = torch.clamp(f, min=eps, max=1 - eps)
            config['target_orig'] = target_orig

            loss = loss_fcn_dict[config['loss']](output, target, **config)

            loss.backward()
            # if 'focal' in config['loss']:
            torch.nn.utils.clip_grad_norm(model.parameters(), 2)
            optimizer.step()

            train_loss += loss.item()
            train_acc += get_acc(f, target_orig)
            total += target.size(0)

            # Update progress bar
            p_bar.update(1)
            p_bar.set_postfix({"Batch loss": loss.item(), "b": bandwidth})

    train_loss /= len(train_loader)
    train_acc /= total

    return train_loss, train_acc


def evaluation(data_loader, model, config, save_path=None, model_name=""):
    # set the model to eval mode
    device = config['device']
    model.eval()
    test_acc, test_loss = 0., 0.
    # turn off gradients for validation
    with torch.no_grad():
        all_logits = torch.tensor([]).to(device)
        all_scores = torch.tensor([]).to(device)
        all_targets = torch.tensor([]).to(device)
        with tqdm(total=len(data_loader)) as p_bar:
            for data, target_orig in data_loader:
                data, target_orig = data.to(device), target_orig.to(device)

                target = format_target(config, target_orig)
                pred_logits = model(data)
                output, f = format_output(config, pred_logits)
                f = torch.clamp(f, min=eps, max=1 - eps)
                config['f'] = torch.clamp(f, min=eps, max=1 - eps)
                config['target_orig'] = target_orig

                loss = loss_fcn_dict[config['loss']](output, target, **config)

                test_loss += loss.item()
                test_acc += get_acc(f, target_orig)

                all_scores = torch.cat((all_scores, f), 0)
                all_logits = torch.cat((all_logits, pred_logits), 0)
                all_targets = torch.cat((all_targets, target_orig), 0)

                # Update progress bar
                p_bar.update(1)
                p_bar.set_postfix({"Batch loss": loss.item()})

    if save_path:
        np.save(os.path.join(save_path, f'pred_logits_{model_name}.npy'), all_logits.cpu().numpy())
        np.save(os.path.join(save_path, 'targets.npy'), all_targets.cpu().numpy())

    test_loss /= len(data_loader)
    test_acc /= len(all_targets)

    ECE = get_ECE(config, all_targets, all_scores)

    return test_loss, test_acc, ECE


def get_device_and_model(experiment, args):
    model = experiment.model
    device = "cpu"
    if args['use_cuda'] and torch.cuda.is_available():
        device = "cuda"
        torch.cuda.manual_seed_all(args['seed'])
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(experiment.model)
    model.to(device)
    args['device'] = device

    logging.info(f"--- Using device {device} ---")
    logging.info(f"Config: {args}")
    logging.info(f"Model : {experiment.model.name}")
    logging.info(f"Training with {args['loss']} loss")

    return device, model


def adjust_lr_from_ckp(args, epoch):
    for lr_epoch in args['decrease_lr_epochs']:
        if epoch > lr_epoch:
            args['lr'] *= args['decrease_lr_factor']


def adjust_lr(args, epoch, optimizer, is_resnet110):
    adjusted_lr = get_lr(args, epoch, is_resnet110)
    logging.info(f"LR: {adjusted_lr}")
    for param_group in optimizer.param_groups:
        param_group['lr'] = adjusted_lr


def get_lr(args, epoch, is_resnet110=False):
    # In the original paper they use lr=0.01 for the first ~400 minibatches =~ 1 epoch for Resnet110
    if epoch == 0 and is_resnet110:
        return 0.01

    if epoch in args['decrease_lr_epochs']:
        args['lr'] *= args['decrease_lr_factor']

    return args['lr']


def format_target(config, target_orig):
    if config['num_classes'] == 1:
        return target_orig.unsqueeze(dim=-1).float()
    elif config['loss'] == 'mse':
        return nn.functional.one_hot(target_orig, num_classes=config['num_classes']).to(torch.float32)
    else:
        return target_orig


def format_output(config, pred_logits):
    # f is always the sigmoid/softmax scores
    f = torch.sigmoid(pred_logits) if config['num_classes'] == 1 else torch.softmax(pred_logits, dim=1)
    # For MSE loss, the output should be sigmoid/softmax scores, otherwise the output is the logits
    output = f if 'mse' in config['loss'] else pred_logits

    return output, f


def evaluate_test(loader, model, args, path):
    test_loss, test_acc, ece = evaluation(loader, model, args, path, "best")
    print(f'Test loss using best model: {test_loss}.. Test Accuracy: {test_acc}.. ECE: {ece}')
    logging.info(f'Test loss using best model: {test_loss}, test accuracy: {test_acc}, ece: {ece} \n')


def get_acc(f, target):
    if f.shape[1] == 1:
        predicted = torch.round(f).squeeze()
        return (predicted == target).float().sum()
    else:
        _, predicted = torch.max(f, 1)
        return (predicted == target).float().sum()


def get_ECE(config, targets, scores):
    if config['num_classes'] == 1:
        return fast_ece(
            y_true=targets.squeeze().cpu(),
            y_pred=scores.squeeze().cpu(),
            n_bins=config['n_bins'],
            p=config['p'],
        )

    eces = []
    for i in range(config['num_classes']):
        binary_y = targets == i
        eces.append(
            fast_ece(
                y_true=binary_y.squeeze().cpu(),
                y_pred=scores[:, i].squeeze().cpu(),
                n_bins=config['n_bins'],
                p=config['p'],
            )
        )

    return torch.mean(torch.FloatTensor(eces))
