import csv
import logging
import pathlib
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch.nn import functional as F

from path_learning.utils.log import get_logger


class BaseLoss(ABC):
    logger = get_logger("base_loss")

    def __init__(self, logdir: pathlib.Path, **kwargs):
        self.config = kwargs
        self.logdir = logdir

        self.loss_functions: Optional[Dict[str, Dict]] = {}
        self.report_frequency: int = self.config["report_frequency"]
        self.loss_aggregator = pick_loss_aggregator(self.config["loss_aggregator"])
        self.batch_results = []  # must be reset at the beginning of each epoch

        self.accuracy_bools = self.config.get("accuracy_bools", {"train": True, "val": True, "test": True})

        # If "accuracy_validation" is true, we use the validation accuracy as a criterion in the controller
        # which decides whether to continue training and which model to save
        self.accuracy_validation = self.config.get("accuracy_validation", False)
        if self.accuracy_validation:
            assert self.accuracy_bools["val"] is True

        self.epoch_losses = {"train": {}, "val": {}, "test": {}}

        self.purpose: Optional[str] = None
        self.epoch: Optional[int] = None

        self.class_weights: List = self.config.get("loss_function", {}).get("class_weights", None)

    def new_epoch(self, epoch: int, purpose: str):
        self.purpose = purpose
        self.epoch = epoch
        return self

    def __enter__(self):
        self.batch_results = []

    def __exit__(self, exc_type, exc_val, exc_tb):
        epoch_result = self.compute_epoch_result()
        if self.accuracy_validation and self.purpose == "val":
            # Use the negative accuracy (lower is better) to select which model to save for the task
            self.epoch_losses[self.purpose].update({str(self.epoch): -epoch_result['accuracy']})
        else:
            self.epoch_losses[self.purpose].update({str(self.epoch): epoch_result['loss']})

        if self.purpose != "test":
            if self.accuracy_bools[self.purpose]:
                self.logger.info(f"Epoch {self.epoch} - {self.purpose} average loss - {epoch_result['loss']:.4f} "
                                 f"- average accuracy {epoch_result['accuracy']:.4f}")
            else:
                self.logger.info(f"Epoch {self.epoch} - {self.purpose} average loss - {epoch_result['loss']:.4f}")
        else:
            if self.accuracy_bools[self.purpose]:
                self.logger.info(f"Testing Loss: {epoch_result['loss']:.4f} "
                                 f"- average accuracy {epoch_result['accuracy']:.4f}")
            else:
                self.logger.info(f"Testing Loss: {epoch_result['loss']:.4f}")

        self.write_logfile(self.epoch, epoch_result)

    def write_logfile(self, epoch, epoch_result: Dict):
        logfile_path = self.logdir / f"{self.purpose}_losses.csv"
        if logfile_path.exists():
            write_mode = "a"
        else:
            write_mode = "w"

        write_dict = {"epoch": epoch, "loss": float(epoch_result["loss"]),
                      "accuracy": float(epoch_result.get("accuracy", None))}

        if "loss_blame" in epoch_result:
            for loss_idx, loss in enumerate(epoch_result["loss_blame"]):
                fieldname = f"loss_blame_{str(loss_idx).zfill(2)}"
                write_dict[fieldname] = float(loss)

        if "accuracy_per_class" in epoch_result:
            for class_idx, class_accuracy in enumerate(epoch_result["accuracy_per_class"]):
                fieldname = f"class_accuracy_{str(class_idx).zfill(2)}"
                write_dict[fieldname] = float(class_accuracy)

        with open(str(logfile_path), write_mode, newline='') as fp:
            fieldnames = list(write_dict.keys())
            writer = csv.DictWriter(fp, fieldnames=fieldnames)

            if write_mode == "w":
                writer.writeheader()

            writer.writerow(write_dict)

    def get_epoch_loss(self):
        return self.epoch_losses[self.purpose][str(self.epoch)]

    def compute_epoch_result(self) -> Dict:
        epoch_result = {}

        num_samples: int = 0
        total_loss: float = 0.0
        total_loss_blame = None

        total_correct: int = 0
        total_correct_per_class = None
        num_samples_per_class = None
        for batch_idx, batch_result in enumerate(self.batch_results):
            total_loss += batch_result["loss"] * batch_result["batch_size"]
            num_samples += batch_result["batch_size"]

            if "loss_blame" in batch_result:
                if total_loss_blame is None:
                    total_loss_blame = batch_result["loss_blame"] * batch_result["batch_size"]
                else:
                    total_loss_blame += batch_result["loss_blame"] * batch_result["batch_size"]

            if self.accuracy_bools[self.purpose]:
                total_correct += batch_result["correct"]

                if "correct_per_class" in batch_result and "num_samples_per_class" in batch_result:
                    if total_correct_per_class is None or num_samples_per_class is None:
                        total_correct_per_class = batch_result["correct_per_class"]
                        num_samples_per_class = batch_result["num_samples_per_class"]
                    else:
                        total_correct_per_class += batch_result["correct_per_class"]
                        num_samples_per_class += batch_result["num_samples_per_class"]

        epoch_result["loss"] = float(total_loss) / float(num_samples)

        if total_loss_blame is not None:
            epoch_result["loss_blame"] = (total_loss_blame / num_samples).tolist()

        epoch_result["accuracy"] = float(total_correct) / float(num_samples)

        if total_correct_per_class is not None and num_samples_per_class is not None:
            epoch_result["accuracy_per_class"] = (total_correct_per_class.float() / num_samples_per_class.float())
        return epoch_result

    @abstractmethod
    def __call__(self, tensors: Dict) -> torch.Tensor:
        pass

    def compute_accuracy(self, pred_targets: torch.Tensor, targets: torch.Tensor) -> Dict:
        batch_result = {}
        batch_size = pred_targets.size(0)
        num_classes = pred_targets.size(1)

        # Compute accuracy
        if pred_targets.size(1) == 1:
            max_pred = pred_targets.clone()
            max_pred[pred_targets > 0] = 1
            max_pred[pred_targets <= 0] = 0
        else:
            # get the index of the max log-probability
            max_pred = pred_targets.argmax(dim=1, keepdim=True)

        correct_batch = max_pred.eq(targets.view_as(max_pred))
        correct = correct_batch.sum().item()

        batch_result["correct"] = correct
        batch_result["num_samples"] = batch_size
        batch_result["accuracy"] = correct / batch_size

        if num_classes > 100:
            correct_per_class = targets.new_zeros((1,))
            num_samples_per_class = targets.new_zeros((1,))
            accuracy_per_class = pred_targets.new_zeros((1,))
        else:
            correct_per_class = targets.new_zeros((num_classes,))
            num_samples_per_class = targets.new_zeros((num_classes,))
            accuracy_per_class = pred_targets.new_zeros((num_classes,))

            for class_idx in range(num_classes):
                correct_per_class[class_idx] = correct_batch[targets == class_idx].sum().item()
                num_samples_per_class[class_idx] = (targets == class_idx).sum().item()
                accuracy_per_class[class_idx] = float(correct_per_class[class_idx]) / num_samples_per_class[class_idx]

        batch_result["correct_per_class"] = correct_per_class
        batch_result["num_samples_per_class"] = num_samples_per_class
        batch_result["accuracy_per_class"] = accuracy_per_class

        return batch_result

    def process_batch_result(self, batch_result: Dict):
        # we need to detach from graph and copy to cpu
        batch_result = recursively_detach(batch_result)

        # Aggregate results
        self.batch_results.append(batch_result)

        if len(self.batch_results) % self.report_frequency == 0 and self.purpose == "train":
            self.logger.info(f"Epoch {self.epoch} - Batch {len(self.batch_results)} - {self.purpose} loss "
                             f"- {batch_result['loss']:.4f}")

    def add_loss_fct_wrappers(self):
        if "loss_function" in self.config:
            loss_function_config = self.config['loss_function']

            if "scaling" in loss_function_config:
                for purpose in ["train", "val", "test"]:
                    self.loss_functions[purpose]["callable"] = scaling_wrapper(self.loss_functions[purpose]["callable"],
                                                                               loss_function_config)

            if "threshold_low" in loss_function_config and "threshold_high" in loss_function_config:
                for purpose in ["train", "val", "test"]:
                    self.loss_functions[purpose]["callable"] = thresholder_wrapper(
                        self.loss_functions[purpose]["callable"],
                        loss_function_config)

    def get_class_weights(self, sample_tensor: torch.Tensor) -> torch.Tensor:
        class_weights = None
        if self.class_weights is not None:
            class_weights = sample_tensor.new_tensor(self.class_weights)

        return class_weights


def thresholder_wrapper(loss_func, loss_fct_dict):
    def wrapper(*args, **kwargs):
        loss_fct_output = loss_func(*args, **kwargs)
        if isinstance(loss_fct_output, torch.Tensor) or type(loss_fct_output) == float:
            loss_fct_output = torch.clamp(input=loss_fct_output,
                                          min=loss_fct_dict["threshold_low"],
                                          max=loss_fct_dict["threshold_high"])
        elif type(loss_fct_output) == tuple:
            loss_fct_output = tuple([torch.clamp(input=x,
                                                 min=loss_fct_dict["threshold_low"],
                                                 max=loss_fct_dict["threshold_high"]) for x in loss_fct_output])
        else:
            raise ValueError
        return loss_fct_output

    return wrapper


def scaling_wrapper(loss_func, loss_fct_dict):
    def wrapper(*args, **kwargs):
        loss_fct_output = loss_func(*args, **kwargs)
        if isinstance(loss_fct_output, torch.Tensor) or type(loss_fct_output) == float:
            loss_fct_output = loss_fct_dict["scaling"] * loss_fct_output
        elif type(loss_fct_output) == tuple:
            loss_fct_output = tuple([loss_fct_dict["scaling"] * x for x in loss_fct_output])
        else:
            raise ValueError
        return loss_fct_output

    return wrapper


def pick_loss_aggregator(aggregator_dict: dict) -> Callable:
    if aggregator_dict["name"] == "mean":
        return torch.mean


def recursively_detach(data):
    if type(data) == torch.Tensor:
        return data.detach().cpu()
    elif type(data) == dict:
        detached_data = {}
        for key, value in data.items():
            detached_data[key] = recursively_detach(value)
        return detached_data
    elif type(data) == list:
        detached_data = []
        for value in data:
            detached_data.append(recursively_detach(value))
        return detached_data
    else:
        return data
