import logging
import numpy as np
import models
from torch.utils import data
import torch
from tasks import register_task
from tasks.base_task import BaseTask
from tasks.batch_utils import pt_feature_extract_coords_collator
from util.tensorboard_utils import plot_tensorboard_line, plot_tensorboard_cm, plot_tensorboard_loss_sample, plot_tensorboard_loss_sample_hist
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix, balanced_accuracy_score

log = logging.getLogger(__name__)

@register_task(name="pt_feature_extract_coords")
class PTFeatureExtractCoordsTask(BaseTask):
    def __init__(self, cfg):
        super(PTFeatureExtractCoordsTask, self).__init__(cfg)

    def build_model(self, cfg):
        assert hasattr(self, "dataset")
        input_dim = self.dataset.get_input_dim()
        return models.build_model(cfg)
        
    @classmethod
    def setup_task(cls, cfg):
        return cls(cfg)

    def get_valid_outs(self, model, valid_loader, criterion, device):
        model.eval()
        all_outs = {"loss":0}
        predicts, labels, loss_sample = [], [], []
        with torch.no_grad():
            for batch in valid_loader:
                batch["input"] = batch["input"].to(device)
                _, valid_outs = criterion(model, batch, device, return_predicts=True)

                predicts.append(valid_outs["predicts"])
                labels.append(batch["labels"])
                loss_sample.extend(list(valid_outs["loss_sample"]))
                all_outs["loss"] += valid_outs["loss"]
        labels = np.array([x for y in labels for x in y])
        predicts = [np.array([p]) if len(p.shape)==0 else p for p in predicts]
        predicts = np.concatenate(predicts)
        roc_auc = roc_auc_score(labels, predicts)
        balanced_accuracy = balanced_accuracy_score(labels, np.where(predicts > 0.5, 1, 0))
        f1 = f1_score(labels, np.round(predicts))
        all_outs["loss"] /= len(valid_loader)
        all_outs["roc_auc"] = roc_auc
        all_outs["f1"] = f1
        all_outs["balanced_accuracy"] = balanced_accuracy
        all_outs["predicts"] = predicts.tolist()
        all_outs["labels"] = labels.tolist()
        all_outs["accuracy"] = np.sum( np.where(predicts > 0.5, 1, 0) == labels ) / len(labels)
        if self.cfg.get("log_confusion_matrix", False): 
            all_outs["cm"] = confusion_matrix(labels, np.where(predicts > 0.5, 1, 0)) 
        if self.cfg.get("log_loss_histogram", False):
            all_outs["loss_sample"] = loss_sample
        return all_outs

    def get_batch_iterator(self, dataset, batch_size, **kwargs):
        return data.DataLoader(dataset, batch_size=batch_size, collate_fn=pt_feature_extract_coords_collator, **kwargs)

    def output_logs(self, train_logging_outs, val_logging_outs, writer, global_step, test_logging_outs=None):
        val_auc_roc = val_logging_outs["roc_auc"]
        val_f1 = val_logging_outs["f1"]
        val_acc = val_logging_outs["accuracy"]
        val_loss = val_logging_outs["loss"]
        val_balanced_accuracy = val_logging_outs["balanced_accuracy"]

        log.info(f'valid_roc_auc: {val_auc_roc:0.4f}, valid_f1: {val_f1:0.4f}, valid_accuracy: {val_acc:0.4f}, valid_balanced_accuracy: {val_balanced_accuracy:0.4f}')
             
        if writer is not None:
            writer.add_scalar("valid_roc_auc", val_auc_roc, global_step)
            writer.add_scalar("valid_f1", val_f1, global_step)
            writer.add_scalar("valid_accuracy", val_acc, global_step)
            writer.add_scalar("valid_balanced_accuracy", val_balanced_accuracy, global_step)
            if test_logging_outs is not None:
                writer.add_scalar("test_roc_auc", test_logging_outs["roc_auc"], global_step)
                writer.add_scalar("test_f1", test_logging_outs["f1"], global_step)
                writer.add_scalar("test_accuracy", test_logging_outs["accuracy"], global_step)
                writer.add_scalar("test_balanced_accuracy", test_logging_outs["balanced_accuracy"], global_step)
            if self.cfg.get("log_confusion_matrix", False): 
                cm = val_logging_outs["cm"]
                tb_image = plot_tensorboard_cm(cm, n_classes=2, vmin=0, vmax=np.sum(cm) // 2, title=f"AUC-ROC: {val_auc_roc:0.3f} F1: {val_f1:0.03f}")      
                writer.add_image("valid_confusion_matrix", tb_image, global_step)
            if self.cfg.get("log_loss_histogram", False):
                writer.add_histogram("valid_loss_sample", np.array(val_logging_outs["loss_sample"]), global_step)
                tb_image = plot_tensorboard_loss_sample(val_logging_outs["loss_sample"], title=f"Step: {global_step} AUC-ROC: {val_auc_roc:0.3f} Loss: {val_loss:0.03f}")
                writer.add_image("valid_loss_sample_idx", tb_image, global_step)
                tb_image = plot_tensorboard_loss_sample_hist(val_logging_outs["loss_sample"], title=f"Step: {global_step} AUC-ROC: {val_auc_roc:0.3f} Loss: {val_loss:0.03f}")
                writer.add_image("valid_loss_sample_hist", tb_image, global_step)
