"""
QuantileDataModule feeds the data into the QuantileNetwork.

- Appends an arbitrary value of $\tau$ as the final feature of the input.
- Uses the pretrained model to extract the quantile labels and 
    caches them in DUMP_DIR (see config.py).
"""

import pdb
import os
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset

import torchvision
from torchvision import transforms, datasets

import pytorch_lightning as pl

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

import config
from utils_train import get_pretrained_model, get_base_datasets



def _weighted_quantile(arr, quantiles, weights):
    """ """
    indsort = np.argsort(arr)
    weights_cum = np.cumsum(weights[indsort])
    weights_cum = weights_cum / np.sum(weights)
    return np.interp(quantiles, weights_cum, arr[indsort])

def _get_class_weights(ylabel, class_no):
    """ """
    weights = np.ones(len(ylabel))
    weights[ylabel == class_no] = np.sum(ylabel != class_no)
    weights[ylabel != class_no] = np.sum(ylabel == class_no)
    return weights

def weighted_quantiles(logits, quantiles, labels):
    """
    - This is equivalent to 'lower' interpolation in torch.quantile.
    """

    quantiles_data = []
    for class_no in range(logits.shape[1]):
        weights = _get_class_weights(labels, class_no)
        quantiles_data.append(
            _weighted_quantile(logits[:, class_no], quantiles, weights)
        )
    quantiles_data = np.stack(quantiles_data, axis=1)

    return quantiles_data

def _compute_and_cache_quantile_labels(name_base_model, filename):
    """
    - Compute the quantile labels for the pretrained model.
    - Note that we use weighted quantiles for multi-class classification.
    """
    model, _, _ = get_pretrained_model(name_base_model)
    train_dataset, _, _, _ = get_base_datasets(name_base_model)

    """
    STEP 1: Extract the features from the pretrained model.
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.fc = torch.nn.Identity() # Remove the last layer
    model.to(device)
    model.eval()
    features, labels = [], []
    dataloader = DataLoader(train_dataset, batch_size=config.BATCHSIZE, shuffle=False, num_workers=config.NUM_WORKERS)
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            features.append(output.cpu().numpy())
            labels.append(target.cpu().numpy())
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)

    """
    STEP 2: Train a One-Vs-Rest classifier for each class.
    """
    clf = LogisticRegression(max_iter=1000, multi_class="ovr")
    clf.fit(features, labels)
    print("(check) accuracy of logstic regression:", clf.score(features, labels))
    logits = clf.decision_function(features)

    """
    STEP 3: Compute the quantile labels by computing weigted quantiles on the logits.
    """
    quantiles_list = np.linspace(0, 1, config.NUM_QUANTILES + 2)[1:-1]
    quantiles = weighted_quantiles(logits, quantiles_list, labels)
    quant_labels = (logits[:, np.newaxis, :] > quantiles[np.newaxis, :, :]) * 1

    # sanity check for the quantile labels
    pred_quant_labels = np.argmax(np.mean(quant_labels, axis=1), axis=1)
    print("(check) accuracy of quantile labels :", np.mean(pred_quant_labels == labels))

    """
    STEP 4: Cache the quantile labels for future use.
    """
    torch.save(quant_labels, filename)
    print(f"Quantile labels for {name_base_model} are cached at {filename}.")

def cache_quantile_labels(name_base_model):
    """
    """
    filename = os.path.join(config.DUMP_DIR, f"{name_base_model}_quantile_labels.npy")
    if os.path.exists(filename):
        print(f"Quantile labels for {name_base_model} already exist.")
    else:
        _compute_and_cache_quantile_labels(name_base_model, filename)

class QuantileDataset(Dataset):
    """
    - Appends the input with a random quantile value.
    - Along with the target, it also returns the quantile label.
    - Note that quantile labels are multi-label vectors.
    """
    def __init__(self, name_base_model:str, train=True):
        super().__init__()
        self.name_base_model = name_base_model
        train_ds, test_ds, num_classes, _ = get_base_datasets(name_base_model)
        self.dataset = train_ds if train else test_ds
        self.train = train
        self.num_classes = num_classes

        self.quantile_labels = torch.from_numpy(
            torch.load(os.path.join(config.DUMP_DIR, f"{name_base_model}_quantile_labels.npy"))
        ).float()
        self.quantiles_list = torch.linspace(0, 1, config.NUM_QUANTILES + 2)[1:-1]

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        idx_sample, idx_quantile = idx, np.random.randint(0, len(self.quantiles_list))
        x, label = self.dataset.__getitem__(idx_sample)

        assert x.dim() == 3, "Only Implemented for RGB-Images"
        quant_val = self.quantiles_list[idx_quantile]* torch.ones((1, x.shape[1], x.shape[2])).float()
        x = torch.cat([x, quant_val], dim=0)

        if self.train:
            y = (self.quantile_labels[idx_sample, idx_quantile, :]).float()
            return x, y, label
        else:
            out = F.one_hot(torch.tensor([label]).long(), self.num_classes).float()
            return x, out, label

class QuantileDataModule(pl.LightningDataModule):
    """
    """
    def __init__(self, name_base_model:str) -> None:
        super().__init__()

        self.name_base_model = name_base_model
        self.num_quantiles = config.NUM_QUANTILES
        self.quantile_list = np.linspace(0, 1, self.num_quantiles+2)[1:-1]

    def prepare_data(self):
        cache_quantile_labels(self.name_base_model)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.entire_dataset = QuantileDataset(self.name_base_model, train=True)
            size_train = int(len(self.entire_dataset)*0.8)
            size_val = len(self.entire_dataset) - size_train
            self.train_ds, self.val_ds = random_split(self.entire_dataset, [size_train, size_val])

        if stage == 'test' or stage is None:
            self.test_ds = QuantileDataset(self.name_base_model, train=False)
    
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=config.BATCHSIZE, shuffle=True, num_workers=config.NUM_WORKERS)
    
    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=config.BATCHSIZE, shuffle=False, num_workers=config.NUM_WORKERS)
    
    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=config.BATCHSIZE, shuffle=False, num_workers=config.NUM_WORKERS)
    
if __name__ == "__main__":
    pass    