"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
import os

import pytorch_lightning as pl
import torch

import torch
from torch import optim as optim


from pytorch_lightning.callbacks import BasePredictionWriter
from timm.utils import accuracy

from torchmetrics.utilities.data import to_onehot
import pickle


import os

from arch.ResNet import ResNet18
import torch
import torchvision.models as tvm
import torchvision.transforms as transforms

class CustomDataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, val_dataset, test_dataset=None, batch_size=32):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.batch_size = batch_size
        self.resize_transform = transforms.Resize((224, 224))

    def resize_collate_fn(self, batch):
        images, labels, _ = zip(*batch)
        base_sample = images
        # 判断图片类型
        if isinstance(images[0], torch.Tensor):
            # 如果是Tensor，确保是float类型
            images = torch.stack(images)
            if images.dtype != torch.float32:
                images = images.float()
            images = torch.nn.functional.interpolate(images, size=(224, 224), mode='bilinear')
        else:
            # 如果是PIL Image
            images = [transforms.ToTensor()(self.resize_transform(img)) for img in images]
            images = torch.stack(images)
        labels = torch.as_tensor(labels)
        return images, labels

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            collate_fn=self.resize_collate_fn
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            collate_fn=self.resize_collate_fn
        )

    def test_dataloader(self):
        if self.test_dataset is not None:
            return torch.utils.data.DataLoader(
                self.test_dataset, 
                batch_size=self.batch_size, 
                collate_fn=self.resize_collate_fn
            )
        return None


# from transformers import AutoModelForImageClassification, AutoFeatureExtractor, ResNetForImageClassification, \
#     ResNetConfig, ViTConfig, ViTForImageClassification

def build_optimizer(
    model,
    opt_type="sgd",
    lr=5e-4,
    momentum=0.9,
    eps=1e-8,
    betas=(0.9, 0.999),
    weight_decay=0.05,
):
    """
    Build optimizer, set weight decay of normalization to 0 by default.
    """
    skip = {}
    skip_keywords = {}
    if hasattr(model, "no_weight_decay"):
        skip = model.no_weight_decay()
    if hasattr(model, "no_weight_decay_keywords"):
        skip_keywords = model.no_weight_decay_keywords()
    parameters = set_weight_decay(model, skip, skip_keywords)

    opt_lower = opt_type
    optimizer = None
    if opt_lower.lower()  == "sgd":
        optimizer = optim.SGD(
            parameters,
            momentum=momentum,
            nesterov=True,
            lr=lr,
            weight_decay=weight_decay,
        )
    elif opt_lower.lower() == "adamw":
        optimizer = optim.AdamW(
            parameters, eps=eps, betas=betas, lr=lr, weight_decay=weight_decay
        )
    elif opt_lower.lower()  == "adam":
        optimizer = optim.Adam(
            parameters, eps=eps, betas=betas, lr=lr, weight_decay=weight_decay
        )

    return optimizer


def set_weight_decay(model, skip_list=(), skip_keywords=()):
    has_decay = []
    no_decay = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if (
            len(param.shape) == 1
            or name.endswith(".bias")
            or (name in skip_list)
            or check_keywords_in_name(name, skip_keywords)
        ):
            no_decay.append(param)
            # print(f"{name} has no weight decay")
        else:
            has_decay.append(param)
    return [{"params": has_decay}, {"params": no_decay, "weight_decay": 0.0}]


def check_keywords_in_name(name, keywords=()):
    isin = False
    for keyword in keywords:
        if keyword in name:
            isin = True
    return isin







def get_torchvision_model(
    model_name="convnext-tiny", num_classes=10, sample_input=None, hidden_dims=[]
):
    model = None
    if "convnext-tiny" in model_name:
        model_fn = tvm.convnext_tiny
        model_weights = tvm.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
        model = model_fn(weights=model_weights)
        if len(hidden_dims):
            prev_size = 768
            mlp_list = []
            for hd in hidden_dims:
                mlp_list.append(torch.nn.Linear(prev_size, hd))
                mlp_list.append(torch.nn.ReLU())
                prev_size = hd
            mlp_list.append(torch.nn.Linear(prev_size, num_classes))
            model.classifier = torch.nn.Sequential(*mlp_list)
        else:
            model.classifier[-1] = torch.nn.Linear(
                in_features=768, out_features=num_classes, bias=True
            )
    else:
        raise NotImplementedError
    return model




def get_cifar_resnet_model(
    model="cifar-resnet-18", num_classes=10, hidden_dims=[], extra_inputs=None
):  

    if extra_inputs is None:
        if model == "cifar-resnet-18":
            return ResNet18(num_classes=num_classes)
    raise NotImplementedError

# TODO: change it to parse_model

def get_model(
    architecture,
    num_classes,
    image_size=None,
    freeze_embedding=False,
    hidden_dims=[],
    extra_inputs=None,
):  # TODO: fix this bug
    print("Using CIFAR ResNet model:", architecture)

    if architecture.startswith("cifar"):    
        model = get_cifar_resnet_model(
            architecture,
            num_classes=num_classes,
            hidden_dims=hidden_dims,
            extra_inputs=extra_inputs,
        )
    else:
        model = get_torchvision_model(
            model_name=architecture, num_classes=num_classes, hidden_dims=hidden_dims
        )
    return model


## training parameters

def build_scheduler(
    scheduler,
    epochs,
    optimizer,
    step_fraction=0.33,
    mode="max",
    l_steps=None,
    step_gamma=0.1,
    lr=None,
):
    if scheduler is None or scheduler == "":
        lr_scheduler = None
    elif scheduler == "cosine":
        # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=optimizer.param_groups[0]['lr']*min_factor, last_epoch=- 1, verbose=False)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs
        )
    elif scheduler == "linear":
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.1, total_iters=epochs
        )
    elif scheduler == "step":
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=int(epochs * step_fraction), gamma=step_gamma
        )
    elif scheduler == "plateau":
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode=mode, patience=5
        )
    elif scheduler == "onecycle":
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, lr, epochs=epochs, steps_per_epoch=1
        )
    elif scheduler == "warmupstep":
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[
                torch.optim.lr_scheduler.LinearLR(
                    optimizer, start_factor=0.1, end_factor=1.0, total_iters=5
                ),
                torch.optim.lr_scheduler.StepLR(
                    optimizer, step_size=int(epochs * step_fraction), gamma=step_gamma
                ),
            ],
            milestones=[5],
        )

    else:
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs
        )
    return lr_scheduler





def get_optimizer_params(optimizer_params):
    "convenience function to add default options to optimizer params if not provided"
    # optimizer
    optimizer_params.setdefault("opt_type", "adamw")
    optimizer_params.setdefault("weight_decay", 0.0)
    optimizer_params.setdefault("lr", 1e-3)

    # scheduler
    optimizer_params.setdefault("scheduler", None)
    # optimizer_params.setdefault('min_factor', 1.)
    optimizer_params.setdefault("epochs", 100)  # needed for CosineAnnealingLR
    optimizer_params.setdefault("step_gamma", 0.1)  # decay fraction in step scheduler
    optimizer_params.setdefault(
        "step_fraction", 0.33
    )  # fraction of total epochs before step decay

    return optimizer_params


def get_batch(batch):
    if len(batch) == 2:
        samples, targets = batch
        base_samples = samples
    else:
        samples, targets, base_samples = batch
    return samples, targets, base_samples


class CustomWriter(BasePredictionWriter):
    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        torch.save(
            predictions,
            os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"),
        )


    def __init__(
        self,
        architecture,
        num_classes,
        image_size=-1,
        optimizer_params=None,
        loss_fn="Crossentropy",
        label_smoothing=0.0,
        model=None
    ):
        super().__init__()
        if optimizer_params is None:
            optimizer_params = {}
        if loss_fn == "Crossentropy":
            self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        else:
            raise NotImplementedError
        self.optimizer_params = get_optimizer_params(optimizer_params)

        self.save_hyperparameters(
            "architecture", "num_classes", "image_size", "optimizer_params", "loss_fn"
        )
        if model is None:
            raise ValueError(
                "Model must be provided, use model_setup to create a model with the correct architecture"
            )
        self.model = model

        self.validation_step_outputs = []

    def forward(self, samples: torch.Tensor) -> torch.Tensor:
        logits = self.model(samples)
        return logits

    def training_step(self, batch, batch_idx: int):
        samples, targets, base_samples = get_batch(batch)
        logits = self.forward(samples)
        loss = self.loss_fn(logits, targets).mean()
        acc1, acc5 = accuracy(logits, targets, topk=(1, 5))

        self.log("ptl/loss", loss, on_epoch=True, prog_bar=True, on_step=False)
        self.log("ptl/acc1", acc1, on_epoch=True, prog_bar=True, on_step=False)
        self.log("ptl/acc5", acc5, on_epoch=True, prog_bar=True, on_step=False)

        return {
            "loss": loss,
            "acc1": acc1,
            "acc5": acc5,
        }

    def validation_step(self, batch, batch_idx: int):
        samples, targets, base_samples = get_batch(batch)

        logits = self.forward(samples)
        loss = self.loss_fn(logits, targets).mean()
        acc1, acc5 = accuracy(logits, targets, topk=(1, 5))

        rets = {
            "val_loss": loss,
            "val_acc1": acc1,
            "val_acc5": acc5,
        }
        self.validation_step_outputs.append(rets)
        return rets

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(
            [x["val_loss"] for x in self.validation_step_outputs]
        ).mean()
        avg_acc1 = torch.stack(
            [x["val_acc1"] for x in self.validation_step_outputs]
        ).mean()
        avg_acc5 = torch.stack(
            [x["val_acc5"] for x in self.validation_step_outputs]
        ).mean()
        self.log("ptl/val_loss", avg_loss, prog_bar=True)
        self.log("ptl/val_acc1", avg_acc1, prog_bar=True)
        self.log("ptl/val_acc5", avg_acc5, prog_bar=True)
        self.validation_step_outputs.clear()

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        samples, targets, base_samples = get_batch(batch)

        logits = self.forward(samples)
        loss = self.loss_fn(logits, targets)
        # get hinge score
        oh_label = to_onehot(targets, logits.shape[-1]).bool()
        score = logits[oh_label]
        score -= torch.max(logits[~oh_label].view(logits.shape[0], -1), dim=1)[0]
        return logits, targets, loss, score
        # return score



# Lightning wrapper for MIA/QR model
class LightningQMIA(pl.LightningModule):
    def __init__(
        self,
        architecture,
        base_architecture,
        num_base_classes,
        image_size,
        hidden_dims,
        freeze_embedding,
        # base_model_name_prefix,
        low_quantile,
        high_quantile,
        n_quantile,
        # cumulative_qr,
        optimizer_params,
        base_model_path=None,
        rearrange_on_predict=True,
        use_target_label=False,
        use_hinge_score=False,
        use_logscale=False,
        use_gaussian=False,
        return_mean_logstd=False,
        use_target_dependent_scoring=False,
        use_target_inputs=False,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.use_target_dependent_scoring = use_target_dependent_scoring
        assert not (
            use_target_dependent_scoring and use_target_inputs
        ), "target_dependent scoring should not be used with use_target_inputs"

        self.use_target_inputs = use_target_inputs
        self.num_base_classes = num_base_classes
        self.base_n_outputs = 2 if use_gaussian else n_quantile
        if self.use_target_dependent_scoring:
            n_outputs = self.base_n_outputs * self.num_base_classes
        else:
            n_outputs = self.base_n_outputs

        # TODO: fix this problem
        model, base_model = model_setup(
            architecture=architecture,
            base_architecture=base_architecture,
            image_size=image_size,
            num_quantiles=n_outputs,
            num_base_classes=num_base_classes,
            # base_model_name_prefix=base_model_name_prefix,
            hidden_dims=hidden_dims,
            freeze_embedding=freeze_embedding,
            base_model_path=base_model_path,
            extra_inputs=num_base_classes if self.use_target_inputs else None,
        )

        self.model = model
        self.base_model = base_model
        self.base_model_path = base_model_path
        self.use_gaussian = use_gaussian
        self.return_mean_logstd = return_mean_logstd

        for parameter in self.base_model.parameters():
            parameter.requires_grad = False

        if use_logscale:
            self.QUANTILE = torch.sort(
                1
                - torch.logspace(
                    low_quantile, high_quantile, n_quantile, requires_grad=False
                )
            )[0].reshape([1, -1])
        else:
            self.QUANTILE = torch.sort(
                torch.linspace(
                    low_quantile, high_quantile, n_quantile, requires_grad=False
                )
            )[0].reshape([1, -1])

        if self.use_gaussian:
            self.loss_fn = gaussian_loss_fn
            self.target_scoring_fn = label_logit_and_hinge_scoring_fn
            self.rearrange_on_predict = False
        else:
            self.loss_fn = pinball_loss_fn
            self.target_scoring_fn = label_logit_and_hinge_scoring_fn
            self.rearrange_on_predict = rearrange_on_predict and not use_logscale
            if not use_target_label or not use_hinge_score:
                raise NotImplementedError

        optimizer_params.update(**kwargs)
        print(optimizer_params)
        self.optimizer_params = get_optimizer_params(optimizer_params)

        self.validation_step_outputs = []

    def forward(
        self, samples: torch.Tensor, targets: torch.LongTensor = None
    ) -> torch.Tensor:
        if self.use_target_inputs:
            oh_targets = to_onehot(targets, self.num_base_classes)
            scores = self.model(samples, oh_targets)
            return scores
        scores = self.model(samples)
        if self.use_target_dependent_scoring:
            oh_targets = to_onehot(targets, self.num_base_classes).unsqueeze(1)
            scores = (
                scores.reshape(
                    [
                        -1,
                        self.base_n_outputs,
                        self.num_base_classes,
                    ]
                )
                * oh_targets
            ).sum(-1)
        return scores

    def training_step(self, batch, batch_idx: int) -> torch.Tensor:
        samples, targets, base_samples = get_batch(batch)
        scores = self.forward(samples, targets)
        target_score, target_logits = self.target_scoring_fn(
            base_samples, targets, self.base_model
        )
        loss = self.loss_fn(
            scores, target_score, self.QUANTILE.to(scores.device)
        ).mean()
        self.log("ptl/train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx: int):
        samples, targets, base_samples = get_batch(batch)
        # print('VALIDATION STEP', self.model.training), print(self.base_model.training)
        scores = self.forward(samples, targets)
        if self.rearrange_on_predict and not self.use_gaussian:
            scores = rearrange_quantile_fn(
                scores, self.QUANTILE.to(scores.device).flatten()
            )
        target_score, target_logits = self.target_scoring_fn(
            base_samples, targets, self.base_model
        )
        loss = self.loss_fn(
            scores, target_score, self.QUANTILE.to(scores.device)
        ).mean()

        rets = {
            "val_loss": loss,
            "scores": scores,
            "targets": target_score,
        }
        self.validation_step_outputs.append(rets)
        return rets

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(
            [x["val_loss"] for x in self.validation_step_outputs]
        ).mean()
        targets = torch.concatenate(
            [x["targets"] for x in self.validation_step_outputs], dim=0
        )
        scores = torch.concatenate(
            [x["scores"] for x in self.validation_step_outputs], dim=0
        )

        self.validation_step_outputs.clear()  # free memory
        self.log("ptl/val_loss", avg_loss, sync_dist=True, prog_bar=True)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        samples, targets, base_samples = get_batch(batch)

        scores = self.forward(samples, targets)
        if self.rearrange_on_predict and not self.use_gaussian:
            scores = rearrange_quantile_fn(
                scores, self.QUANTILE.to(scores.device).flatten()
            )
        target_score, target_logits = self.target_scoring_fn(
            base_samples, targets, self.base_model
        )
        loss = self.loss_fn(scores, target_score, self.QUANTILE.to(scores.device))
        base_acc1, base_acc5 = accuracy(target_logits, targets, topk=(1, 5))

        if self.use_gaussian and not self.return_mean_logstd:
            # use torch distribution to output quantiles
            mu = scores[:, 0]
            log_std = scores[:, 1]
            scores = mu.reshape([-1, 1]) + torch.exp(log_std).reshape(
                [-1, 1]
            ) * torch.erfinv(2 * self.QUANTILE.to(scores.device) - 1).reshape(
                [1, -1]
            ) * math.sqrt(
                2
            )
            assert (
                scores.ndim == 2
                and scores.shape[0] == targets.shape[0]
                and scores.shape[1] == self.QUANTILE.shape[1]
            ), "inverse cdf quantiles have the wrong shape, got {} {} {}".format(
                scores.shape, targets.shape, self.QUANTILE.size()
            )

        return scores, target_score, loss, base_acc1, base_acc5

    def configure_optimizers(self):
        optimizer = build_optimizer(
            self.model,
            opt_type=self.optimizer_params["opt_type"],
            lr=self.optimizer_params["lr"],
            weight_decay=self.optimizer_params["weight_decay"],
        )
        print("Optimizer params:", self.optimizer_params["opt_type"])
        interval = "epoch"

        lr_scheduler = build_scheduler(
            scheduler=self.optimizer_params["scheduler"],
            epochs=self.optimizer_params["epochs"],
            step_fraction=self.optimizer_params["step_fraction"],
            step_gamma=self.optimizer_params["step_gamma"],
            optimizer=optimizer,
            mode="min",
            lr=self.optimizer_params["lr"],
        )
        opt_and_scheduler_config = {
            "optimizer": optimizer,
        }
        if lr_scheduler is not None:
            opt_and_scheduler_config["lr_scheduler"] = {
                # REQUIRED: The scheduler instance
                "scheduler": lr_scheduler,
                "interval": interval,
                "frequency": 1,
                "monitor": "ptl/val_loss",
                "strict": True,
                "name": None,
            }

        return opt_and_scheduler_config


# Convenience function to create models and potentially load weights for base classifier
def model_setup(
    architecture,
    base_architecture,
    image_size,
    num_quantiles,
    num_base_classes,
    hidden_dims,
    freeze_embedding,
    base_model_path=None,
    extra_inputs=None,
):
    # Get forward function of regression model
    model = get_model(
        architecture,
        num_quantiles,
        image_size,
        freeze_embedding,
        hidden_dims=hidden_dims,
        extra_inputs=extra_inputs,
    )

    ## Create base model, load params from pickle, then define the scoring function and the logit embedding function
    base_model = get_model(
        base_architecture, num_base_classes, image_size, freeze_embedding=False
    )

    # if base_model_path is not None:
    
    if os.path.exists(base_model_path):
        base_state_dict = load_pickle(
            name="model.pickle",
            map_location=next(base_model.parameters()).device,
            base_model_dir=os.path.dirname(base_model_path),
        )
        base_model.load_state_dict(base_state_dict)

    return model, base_model


def load_pickle(name="quantile_model.pickle", map_location=None, base_model_dir=None):
    # pickle_path = os.path.join(args.log_root, args.dataset, name.replace('/', '_'))
    pickle_path = os.path.join(base_model_dir, name.replace("/", "_"))
    if map_location:
        state_dict = torch.load(pickle_path, map_location=map_location)
    else:
        state_dict = torch.load(pickle_path)
    return state_dict




##########
# distribution learning losses
##########
def pinball_loss_fn(score, target, quantile):
    target = target.reshape([-1, 1])
    assert (
        score.ndim == 2
    ), "score has the wrong shape, expected 2d input but got {}".format(score.shape)
    delta_score = target - score
    loss = torch.nn.functional.relu(delta_score) * quantile + torch.nn.functional.relu(
        -delta_score
    ) * (1.0 - quantile)
    return loss


def gaussian_loss_fn(score, target, quantile):
    # little different from the rest, score is Nx2, quantile is ignored, this is just a negative log likelihood of a Gaussian distribution
    assert (
        score.ndim == 2 and score.shape[-1] == 2
    ), "score has the wrong shape, expected Nx2 input but got {}".format(score.shape)
    assert (
        target.ndim == 1
    ), "target has the wrong shape, expected 1-d vector, got {}".format(target.shape)
    mu = score[:, 0]
    log_std = score[:, 1]
    assert (
        mu.shape == log_std.shape and mu.shape == target.shape
    ), "mean, std and target have non-compatible shapes, got {} {} {}".format(
        mu.shape, log_std.shape, target.shape
    )
    loss = log_std + 0.5 * torch.exp(-2 * log_std) * (target - mu) ** 2
    assert target.shape == loss.shape, "loss should be a 1-d vector got {}".format(
        loss.shape
    )
    return loss


##########
# Score functions for base network
##########


def label_logit_and_hinge_scoring_fn(samples, label, base_model):
    # z_y(x)-max_{y'\neq y} z_{y'}(x)
    base_model.eval()
    with torch.no_grad():
        logits = base_model(samples)

        oh_label = to_onehot(label, logits.shape[-1]).bool()
        score = logits[oh_label]
        score -= torch.max(logits[~oh_label].view(logits.shape[0], -1), dim=1)[0]
        assert (
            score.ndim == 1
        ), "hinge loss score should be 1-dimensional, got {}".format(score.shape)
    return score, logits


##########
# logit to quantile prediction nonlinearities for QR
##########


# Based on "Quantile and probability curves without crossing.", ensures that, at evaluation time, predicted quantiles are monotonically increasing (non differentiable)
# only usable for linearly spaced quantiles
def rearrange_quantile_fn(test_preds, all_quantiles, target_quantiles=None):
    """Produce monotonic quantiles
    Parameters
    ----------
    test_preds : array of predicted quantile (nXq)
    all_quantiles : array (q), grid of quantile levels in the range (0,1)
    target_quantiles: array (q'), grid of target quantile levels in the range (0,1)

    Returns
    -------
    q_fixed : array (nXq'), containing the rearranged estimates of the
              desired low and high quantile
    References
    ----------
    .. [1]  Chernozhukov, Victor, Iván Fernández‐Val, and Alfred Galichon.
            "Quantile and probability curves without crossing."
            Econometrica 78.3 (2010): 1093-1125.
    """
    if not target_quantiles:
        target_quantiles = all_quantiles

    scaling = all_quantiles[-1] - all_quantiles[0]
    rescaled_target_qs = (target_quantiles - all_quantiles[0]) / scaling
    q_fixed = torch.quantile(
        test_preds, rescaled_target_qs, interpolation="linear", dim=-1
    ).T
    assert (
        q_fixed.shape[0] == test_preds.shape[0] and q_fixed.ndim == test_preds.ndim
    ), "fixed quantiles have the wrong shape, {}".format(q_fixed.shape)
    return q_fixed


import matplotlib.pyplot as plt
import numpy as np

# Get base quantile performances


def get_rates(
    private_target_scores, public_target_scores, private_thresholds, public_thresholds
):
    # Get TPR, TNR and precision for all thresholds
    # scores are real valued vectors of size n
    # thresholds are either [n,n_thresholds] or [1,n_thresholds] depending on if the threshold is sample dependent or not
    assert (
        len(private_target_scores.shape) == 1
    ), "private scores need to be real-valued vectors"
    assert (
        len(public_target_scores.shape) == 1
    ), "public scores need to be real-valued vectors"
    assert (
        len(private_thresholds.shape) == 2
    ), "private thresholds need to be 2-d vectors"
    assert len(public_thresholds.shape) == 2, "public thresholds need to be 2-d vectors"
    prior = 0.0

    true_positives = (private_target_scores.reshape([-1, 1]) >= private_thresholds).sum(
        0
    ) + prior
    false_negatives = (private_target_scores.reshape([-1, 1]) < private_thresholds).sum(
        0
    ) + prior
    true_negatives = (public_target_scores.reshape([-1, 1]) < public_thresholds).sum(
        0
    ) + prior
    false_positives = (public_target_scores.reshape([-1, 1]) >= public_thresholds).sum(
        0
    ) + prior

    true_positive_rate = np.nan_to_num(
        true_positives / (true_positives + false_negatives)
    )
    true_negative_rate = np.nan_to_num(
        true_negatives / (true_negatives + false_positives)
    )
    precision = np.nan_to_num(
        true_positive_rate / (true_positive_rate + 1 - true_negative_rate)
    )

    return precision, true_positive_rate, true_negative_rate


def pinball_loss_np(target, score, quantile):
    target = target.reshape([-1, 1])
    assert (
        score.ndim == 2
    ), "score has the wrong shape, expected 2d input but got {}".format(score.shape)
    delta_score = target - score
    loss = np.maximum(delta_score * quantile, -delta_score * (1.0 - quantile)).mean(0)
    return loss


def plot_performance_curves(
    private_target_scores,
    public_target_scores,
    private_predicted_score_thresholds=None,
    public_predicted_score_thresholds=None,
    model_target_quantiles=None,
    model_name="Quantile Model",
    use_quantile_thresholds=True,
    use_thresholds=True,
    use_logscale=True,
    fontsize=12,
    savefig_path="results.png",
    plot_results=True,
):
    plt.ioff()
    n_baseline_points = 500
    if use_quantile_thresholds:
        if use_logscale:
            baseline_quantiles = np.sort(
                1.0 - np.logspace(-6, 0, n_baseline_points)[:-1]
            )
        else:
            baseline_quantiles = np.linspace(0, 1, n_baseline_points)[:-1]
        baseline_thresholds = np.quantile(public_target_scores, baseline_quantiles)
        baseline_public_loss = pinball_loss_np(
            public_target_scores,
            baseline_thresholds.reshape([1, -1]),
            baseline_quantiles,
        )
        baseline_private_loss = pinball_loss_np(
            private_target_scores,
            baseline_thresholds.reshape([1, -1]),
            baseline_quantiles,
        )

    else:
        raise NotImplementedError

    baseline_precision, baseline_tpr, baseline_tnr = get_rates(
        private_target_scores,
        public_target_scores,
        baseline_thresholds.reshape([1, -1]),
        baseline_thresholds.reshape([1, -1]),
    )

    (
        model_precision,
        model_tpr,
        model_tnr,
        model_auc,
        model_public_loss,
        model_private_loss,
    ) = (None, None, None, None, None, None)

    if (
        private_predicted_score_thresholds is not None and use_thresholds
    ):  # scores and thresholds are provided directly (quantile model)
        model_target_quantiles = np.sort(model_target_quantiles)
        private_predicted_score_thresholds = np.sort(
            private_predicted_score_thresholds, axis=-1
        )
        public_predicted_score_thresholds = np.sort(
            public_predicted_score_thresholds, axis=-1
        )

        model_precision, model_tpr, model_tnr = get_rates(
            private_target_scores,
            public_target_scores,
            private_predicted_score_thresholds,
            public_predicted_score_thresholds,
        )
        model_public_loss = pinball_loss_np(
            public_target_scores,
            public_predicted_score_thresholds,
            model_target_quantiles,
        )
        model_private_loss = pinball_loss_np(
            private_target_scores,
            private_predicted_score_thresholds,
            model_target_quantiles,
        )

        model_adjusted_public_loss = pinball_loss_np(
            public_target_scores, public_predicted_score_thresholds, model_tnr
        )

    # Plot ROC
    fig, ax = plt.subplots(figsize=(6, 6), ncols=1, nrows=1)

    ax.set_title("ROC", fontsize=fontsize)
    ax.set_ylabel("True positive rate")
    ax.set_xlabel("False positive rate")
    ax.set_ylim([1e-3, 1])
    ax.set_xlim([1e-3, 1])
    baseline_auc = np.abs(np.trapz(baseline_tpr, x=1 - baseline_tnr))
    # baseline_acc = (baseline_tpr + baseline_tnr).max() / 2.0
    ax.plot(
        1 - baseline_tnr,
        baseline_tpr,
        "-",
        # label="Marginal Quantiles Acc {:.1f}%".format(100 * baseline_max_acc),
        label="Marginal Quantiles",
    )
    if model_tpr is not None:
        model_auc = np.abs(np.trapz(model_tpr, x=1 - model_tnr))
        # model_acc = (model_tpr + model_tnr).max() / 2.0
        ax.plot(
            1 - model_tnr,
            model_tpr,
            "-",
            markersize=12,
            # label="{} Acc {:.1f}%".format(model_name, 100 * model_acc),
            label="{}".format(model_name),
        )
    ax.legend()
    if use_logscale:
        plt.semilogx()
        plt.semilogy()
    # Finishing
    plt.tight_layout()
    if savefig_path is not None:
        os.makedirs(os.path.dirname(savefig_path), exist_ok=True)
        roc_path = os.path.join(os.path.dirname(savefig_path), "roc.png")
        plt.savefig(roc_path, dpi=300)
        print("saving plot to", roc_path)
    if plot_results:
        plt.show()

    # Plot Pinball losses on public data
    fig, ax = plt.subplots(figsize=(6, 6), ncols=1, nrows=1)

    ax.set_title("Pinball loss", fontsize=fontsize)
    ax.set_xlabel("Significance level")
    ax.set_ylabel("Pinball loss")
    color = next(ax._get_lines.prop_cycler)["color"]
    ax.plot(
        1 - baseline_quantiles,
        baseline_public_loss,
        "x-",
        label="Marginal Quantiles" + " (Public)",
        color=color,
    )
    if model_public_loss is not None:
        color = next(ax._get_lines.prop_cycler)["color"]
        ax.plot(
            1 - model_target_quantiles,
            model_public_loss,
            "x-",
            label=model_name + "  (Public)",
            color=color,
        )
    plt.semilogx()
    ax.legend()
    # Finishing
    plt.tight_layout()
    if savefig_path is not None:
        os.makedirs(os.path.dirname(savefig_path), exist_ok=True)
        pinball_path = os.path.join(os.path.dirname(savefig_path), "pinball.png")
        plt.savefig(pinball_path, dpi=300)
        print("saving plot to", pinball_path)
    if plot_results:
        plt.show()

    # pickle results and also print results at 1% and 0.1% FPR
    pickle_path = os.path.join(
        os.path.dirname(savefig_path),
        os.path.basename(savefig_path).split(".")[0] + ".pkl",
    )

    def convenience_dict(
        model_precision,
        model_tpr,
        model_tnr,
        model_auc,
        model_public_loss,
        model_private_loss,
        adjusted_public_loss=None,
    ):
        idx_1pc = np.argmin(np.abs(model_tnr - 0.99))
        idx_01pc = np.argmin(np.abs(model_tnr - 0.999))
        print(
            "Precision @1%  FPR {:.2f}% \t  TPR @ 1% FPR {:.2f}% ".format(
                model_precision[idx_1pc] * 100, model_tpr[idx_1pc] * 100
            )
        )
        print(
            "Precision @0.1% FPR {:.2f}% \t  TPR @ 0.1% FPR {:.2f}% ".format(
                model_precision[idx_01pc] * 100, model_tpr[idx_01pc] * 100
            )
        )
        cdict = {
            "precision": model_precision,
            "tpr": model_tpr,
            "tnr": model_tnr,
            "auc": model_auc,
            "public_loss": model_public_loss,
            "private_loss": model_private_loss,
        }
        cdict["adjusted_public_loss"] = (
            adjusted_public_loss
            if adjusted_public_loss is not None
            else model_public_loss
        )
        return cdict

    with open(pickle_path, "wb") as f:
        save_dict = {}
        if baseline_tnr is not None:
            print("baseline")
            save_dict["baseline"] = convenience_dict(
                baseline_precision,
                baseline_tpr,
                baseline_tnr,
                baseline_auc,
                baseline_public_loss,
                baseline_private_loss,
            )

        if model_tpr is not None:
            print("model")
            save_dict["model"] = convenience_dict(
                model_precision,
                model_tpr,
                model_tnr,
                model_auc,
                model_public_loss,
                model_private_loss,
                model_adjusted_public_loss,
            )
        pickle.dump(save_dict, f)

    return baseline_auc, model_auc