"""
Let f_{theta} denote a pretrained network. The aim is to train a a quantile network
Q(x, tau) that outputs the tau-quantile of the distribution of f_{theta}(x) for a given
input x. 

NOTES:
- The input is expected to have the value of $\tau$ at the final feature.
- 

"""

import pdb
from typing import Any

import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CyclicLR

import pytorch_lightning as pl
import torchmetrics
from torchmetrics import Metric

import config
from utils_train import get_pretrained_model, get_base_datasets

class QuantileNetwork(pl.LightningModule):
    """
    """
    def __init__(self, name_base_model: str):
        super().__init__()
        self.backbone, num_classes, size_dataset = get_pretrained_model(name_base_model)
        # Change the first conv layer to accept 1 additional input
        self.backbone.conv1 = nn.Conv2d(
            in_channels=self.backbone.conv1.in_channels + 1,
            out_channels=self.backbone.conv1.out_channels,
            kernel_size=self.backbone.conv1.kernel_size,
            stride=self.backbone.conv1.stride,
            padding=self.backbone.conv1.padding,
            bias=False,
        )

        # Metrics
        self.accuracy = torchmetrics.Accuracy(task="multilabel", num_labels=num_classes)
        self.accuracy_multiclass = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )
        self.calibration_error = torchmetrics.CalibrationError(
            task="multiclass", num_bins=10, num_classes=num_classes
        )

        # Saving logits and y for each epoch to compute the metrics
        self.train_step_outputs = []
        self.valid_step_outputs = []
        self.valid_step_outputs_quant = []
        self.test_step_outputs_quant = []

        # Quantile params
        self.num_classes = num_classes
        self.size_dataset = size_dataset
        self.num_quant_rep = config.NUM_QUANTILES
        self.quantiles_list = nn.Parameter(
            torch.linspace(0, 1, self.num_quant_rep + 2)[1:-1]
        )
    
    def forward(self, x):
        """
        """
        return self.backbone(x)
    
    def training_step(self, batch, batch_idx):
        loss, logits, yquant = self._common_step(batch, batch_idx)
        self.train_step_outputs.append({"logits": logits, "yquant": yquant, "loss": loss})
        return loss

    def on_train_epoch_end(self) -> None:
        logits = torch.cat([x["logits"] for x in self.train_step_outputs], dim=0)
        yquant = torch.cat([x["yquant"] for x in self.train_step_outputs], dim=0)
        train_loss_epoch = torch.stack(
            [x["loss"] for x in self.train_step_outputs]
        ).mean()
        self.log_dict(
            {
                "train_acc": self.accuracy(logits, yquant),
                "train_loss_epoch": train_loss_epoch,
            },
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        self.train_step_outputs.clear()

    def validation_step(self, batch, batch_idx):
        loss, logits, yquant = self._common_step(batch, batch_idx)
        prob_quantile, ytrue = self._get_quantile_probs(batch, batch_idx)
        self.valid_step_outputs.append({"logits": logits, "yquant": yquant, "loss": loss})
        self.valid_step_outputs_quant.append(
            {"prob_quantile": prob_quantile, "ytrue": ytrue}
        )
        return loss

    def on_validation_epoch_end(self) -> None:
        # Standard Metrics - Accuracy and BCE-Loss
        logits = torch.cat([x["logits"] for x in self.valid_step_outputs], dim=0)
        yquant = torch.cat([x["yquant"] for x in self.valid_step_outputs], dim=0)
        val_loss_epoch = torch.stack(
            [x["loss"] for x in self.valid_step_outputs]
        ).mean()

        # Quantile Metrics - Calibration Error
        prob_quantile = torch.cat(
            [x["prob_quantile"] for x in self.valid_step_outputs_quant], dim=0
        )
        ytrue = torch.cat([x["ytrue"] for x in self.valid_step_outputs_quant], dim=0)
        val_calbration_error = self.calibration_error(prob_quantile, ytrue)

        self.log_dict(
            {
                "val_acc": self.accuracy(logits, yquant),
                "val_loss_epoch": val_loss_epoch,
                "val_acc_quant": self.accuracy_multiclass(prob_quantile, ytrue),
                "val_calib_error": val_calbration_error,
            },
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        self.valid_step_outputs.clear()
        self.valid_step_outputs_quant.clear()

    def test_step(self, batch, batch_idx):
        """
        - _common_step is meaningless for the test step since we do not have pre-trained outputs
            for the test.
        """
        prob_quantile, ytrue = self._get_quantile_probs(batch, batch_idx)
        self.test_step_outputs_quant.append(
            {"prob_quantile": prob_quantile, "ytrue": ytrue}
        )

    def on_test_epoch_end(self) -> None:
        # Standard Metrics do not make sense since we do not have
        # the quantile labels.

        # Quantile Metrics - Calibration Error
        prob_quantile = torch.cat(
            [x["prob_quantile"] for x in self.test_step_outputs_quant], dim=0
        )
        ytrue = torch.cat([x["ytrue"] for x in self.test_step_outputs_quant], dim=0)
        test_calbration_error = self.calibration_error(prob_quantile, ytrue)

        self.log_dict(
            {
                "test_acc_quant": self.accuracy_multiclass(prob_quantile, ytrue),
                "test_calib_error": test_calbration_error,
            },
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        self.test_step_outputs_quant.clear()

    def predict_step(self, batch, batch_idx):
        probs, ytrue = self._get_quantile_probs(batch, batch_idx)
        return probs, ytrue
            

    def _probability_from_logits(self, logits, tau):
        """
        - Consistent definition of probability from logits depending
            on tau.
        """
        term1 = 1 - tau * torch.exp(torch.relu(logits) * (tau - 1))
        term2 = (1 - tau) * torch.exp(-1 * torch.relu(-1 * logits) * (tau))
        probs = torch.where(logits >= 0, term1, term2).float()
        return probs

    def _common_step(self, batch, batch_idx):
        """
        - Use BQR idea to convert logits to probabilities
        """
        x, yquant, _ = batch  # ytrue is not used for training
        logits = self.forward(x)
        loss = F.binary_cross_entropy_with_logits(logits, yquant)
        return loss, logits, yquant

    def _get_quantile_probs(self, batch, batch_idx):
        """
        Notes:
        - yquant is not used to compute the probabilities

        for densnet base_lr = 0.008, max_lr = 0.35
        for resnet34 base_lr = 0.01, max_lr = 1.0
        """
        x, _, ytrue = batch
        preds = []
        for idx in range(self.num_quant_rep):
            tau_val = self.quantiles_list[idx]
            x[:, -1, :, :] = self.quantiles_list[idx]
            logits = self.forward(x)
            preds.append(((logits > 0) * 1).float())

        preds = torch.stack(preds, dim=0)
        probs = torch.mean(preds, dim=0)
        return probs, ytrue

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=0.05,
            momentum=0.9,
        )
        steps_per_epoch = 2*((self.size_dataset // 2*config.BATCHSIZE) + 1)
        scheduler_dict = {
            "scheduler": CyclicLR(
                optimizer,
                base_lr=0.02,
                max_lr=1.0,
                step_size_up=steps_per_epoch,
                mode = "exp_range",
                gamma = 0.99994,

            ),
            "interval": "step",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}




