# Coverage analysis and improval
# - conformal prediction [✓]
# - coverage over different CI's [✓]
# - average coverage error [✓]
# - plot calibration curve (nominal vs. empirical) [✓]
# - local calibration [✓]

import torch
from pathlib import Path
# cmap = plt.get_cmap("tab10")
# colors = [cmap(i) for i in range(cmap.N)]


class Calibrator:
    def __init__(self, d: int, intervals: list[int] = range(1,100)):
        self.d = d # dimensionality of target
        self.intervals = intervals
        self.corrections = {str(i): torch.zeros(d) for i in self.intervals}
        
    def __repr__(self) -> str:
        return str(self.corrections)
    
    def update(self, Q: torch.Tensor, i: int) -> None:
        assert Q.dim() == 1 and len(Q) == self.d
        self.corrections[str(i)] = Q
        
    def get(self, i: int) -> torch.Tensor:
        return self.corrections[str(i)]
    
    def calibrate(self,
                  model,
                  proposed: dict[str, torch.Tensor],
                  targets: torch.Tensor,
                  local: bool = False) -> None:
        # wrapper for calibrateSingle()
        for i in self.intervals:
            alpha = (100-i)/100
            quantiles = model.quantiles(proposed, [alpha/2, 1-alpha/2], local=local)
            self.calibrateSingle(quantiles, targets, i)
    
    def calibrateSingle(self,
                        quantiles: torch.Tensor,
                        targets: torch.Tensor,
                        i: int) -> None:
        b = len(targets)
        mask = (targets != 0.)
        alpha = torch.tensor( (100-i)/100 )        
        scores = torch.zeros_like(quantiles)
        
        # calculate score bounds (distance of targets to CI i)
        scores[..., 0] = quantiles[..., 0] - targets
        scores[..., 1] = targets - quantiles[..., 1]
        
        # get the distance to the closer boundary
        midx = scores.abs().min(-1)[1].unsqueeze(-1)
        scores = torch.gather(scores, -1, midx).squeeze(-1)
        scores[~mask] = float('inf')
        B = mask.sum(0)
        
        # sort and get desired score quantile
        scores, _ = scores.sort(0, descending=False)
        factor = mask.float().mean(0) * b * 1.01
        idx = (factor * (1-alpha)).ceil().clamp(max=B-1).to(torch.int64)
        corrections = torch.gather(scores, dim=0, index=idx.unsqueeze(0)).squeeze(0)
            
        # store correction
        if quantiles.dim() == 4:
            corrections = corrections.mean(0)
        self.corrections[str(i)] = corrections
    
    def apply(self, quantiles: torch.Tensor, i: int) -> torch.Tensor:
        Q = self.corrections.get(str(i), None)
        if Q is None:
            print(f'{i}-CI not learned by calibrator')
            return quantiles
        quantiles_c = quantiles.clone()
        quantiles_c[..., 0] -= Q
        quantiles_c[..., 1] += Q
        return quantiles_c
    
    def insert(self, corrections: dict[str, torch.Tensor]):
        self.corrections = corrections
        self.intervals = [int(k) for k in corrections]
        
    def save(self, model_id: str, i: int, local: bool = False) -> None:
        suffix = '-local' if local else ''
        fn =  Path('outputs', 'checkpoints', model_id, f'calibrator_i={i}{suffix}.pt')
        torch.save(self.corrections, fn)
        print(f'Saved calibration values to {fn}.')
        
    def load(self, model_id: str, i: int) -> None:
        fn =  Path('outputs', 'checkpoints', model_id, f'calibrator_i={i}.pt')
        corrections = torch.load(fn, weights_only=False)
        self.insert(corrections)
        print(f'Loaded calibration values from {fn}.')


# class Calibrator:
#     def __init__(self, d: int, intervals: list[int] = range(1,100)):
#         self.d = d # dimensionality of target
#         self.bins = 1 # must be divisible by 2
#         self.intervals = intervals
#         self.corrections = {str(i): torch.zeros(self.bins, d) for i in self.intervals}
#         self.borders = {str(i): torch.zeros(self.bins, d) for i in self.intervals}
        
#     def __repr__(self) -> str:
#         return str(self.corrections)
    
#     def update(self, Q: torch.Tensor, i: int) -> None:
#         assert Q.dim() == 1 and len(Q) == self.d
#         self.corrections[str(i)] = Q
        
#     def get(self, i: int) -> torch.Tensor:
#         return self.corrections[str(i)]
    
#     def calibrate(self,
#                   model,
#                   proposed: dict[str, torch.Tensor],
#                   targets: torch.Tensor,
#                   local: bool = False) -> None:
#         # wrapper for calibrateSingle()
#         for i in self.intervals:
#             alpha = (100-i)/100
#             quantiles = model.quantiles(proposed, [alpha/2, 1-alpha/2], local=local)
#             self.calibrateSingle(quantiles, targets, i)
    
#     def partition(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
#         b, d = x.shape
#         n = self.bins
#         x_sorted, idx = x.sort(dim=0)
#         y_sorted = y.gather(0, idx.unsqueeze(-1).expand(*y.shape))
#         x_partitioned = x_sorted.view(n, b//n, d).transpose(1,0)
#         y_partitioned = y_sorted.view(n, b//n, d, 2).transpose(1,0)
#         return x_partitioned, y_partitioned
        
    
#     def calibrateSingle(self,
#                         quantiles: torch.Tensor,
#                         targets: torch.Tensor,
#                         i: int) -> None:
        
#         alpha = torch.tensor( (100-i)/100 )     

#         # partition targets and quantiles
#         targets, quantiles = self.partition(targets, quantiles)
#         b = len(targets)
#         self.borders[str(i)] = targets.max(0)[0]
#         mask = (targets != 0.)
        
#         # calculate score bounds (distance of targets to CI i)
#         scores = torch.zeros_like(quantiles)
#         scores[..., 0] = quantiles[..., 0] - targets
#         scores[..., 1] = targets - quantiles[..., 1]
        
#         # get the distance to the closer boundary
#         midx = scores.abs().min(-1)[1].unsqueeze(-1)
#         scores = torch.gather(scores, -1, midx).squeeze(-1)
#         scores[~mask] = float('inf')
#         B = mask.sum(0)
        
#         # sort and get desired score quantile
#         scores, _ = scores.sort(0, descending=False)
#         factor = mask.float().mean(0) * b * 1.01
#         idx = (factor * (1-alpha)).ceil().clamp(max=B-1).to(torch.int64)
#         corrections = torch.gather(scores, dim=0, index=idx.unsqueeze(0)).squeeze(0)
            
#         # store correction
#         # if quantiles.dim() == 4:
#         #     corrections = corrections.mean(0)
        
#         self.corrections[str(i)] = corrections
    
#     def apply(self, quantiles: torch.Tensor, i: int) -> torch.Tensor:
#         Q = self.corrections.get(str(i), None)
#         borders = self.borders.get(str(i), None)
#         if Q is None:
#             print(f'{i}-CI not learned by calibrator')
#             return quantiles
        
#         # expand corrections to match the bin for each quantile
#         indices = (quantiles.unsqueeze(1)[..., 1] >= borders.unsqueeze(0))
#         indices = indices.sum(1).clamp(max=len(borders)-1)
#         Q = Q.gather(0, indices)
        
#         # apply corrections
#         quantiles_c = quantiles.clone()
#         quantiles_c[..., 0] -= Q
#         quantiles_c[..., 1] += Q
#         return quantiles_c
    
#     def insert(self, corrections: dict[str, torch.Tensor]):
#         self.corrections = corrections
#         self.intervals = [int(k) for k in corrections]
        
#     def save(self, model_id: str, i: int, local: bool = False) -> None:
#         suffix = '-local' if local else ''
#         fn =  Path('outputs', 'checkpoints', model_id, f'calibrator_i={i}{suffix}.pt')
#         torch.save(self.corrections, fn)
#         print(f'Saved calibration values to {fn}.')
        
#     def load(self, model_id: str, i: int) -> None:
#         fn =  Path('outputs', 'checkpoints', model_id, f'calibrator_i={i}.pt')
#         corrections = torch.load(fn, weights_only=False)
#         self.insert(corrections)
#         print(f'Loaded calibration values from {fn}.')
        
        
        
    
def empiricalCoverage(quantiles: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    #  how often are the targets actually inside a given credibility interval?
    mask = (targets != 0.)
    # mask[:, 0] = mask[:, 1:-1].any(1)
    above = targets >= quantiles[..., 0] - 1e-6 # above lower quantile
    below = targets <= quantiles[..., -1] + 1e-6 # belowe upper quantile
    inside = above * below * mask
    coverage = inside.float().sum(0)/(mask.sum(0) + 1e-12)
    return coverage # (d,)


def getCoverage(model,
                proposed: dict[str, torch.Tensor],
                targets: torch.Tensor,
                intervals: list[int],
                calibrate: bool = False,
                local: bool = False,
                ) -> dict[str, torch.Tensor]:
    out = {}
    for i in intervals:
        alpha = (100-i)/100
        quantiles = model.quantiles(proposed,
                                    [alpha/2, 1-alpha/2],
                                    calibrate=calibrate, local=local)
        ce = empiricalCoverage(quantiles, targets)
        if local:
            ce = ce.mean(0)
        out[str(i)] = ce
    return out


def coverageError(coverage: dict[str, torch.Tensor]) -> torch.Tensor:
    concatenated = torch.cat([v.unsqueeze(0) for _, v in coverage.items()])
    mask = (concatenated != 0.)
    nominal = torch.tensor([int(k) for k in coverage.keys()]).unsqueeze(1) / 100
    errors = (concatenated - nominal) * mask
    mean_error = errors.sum(0) / (mask.sum(0) + 1e-12)
    return mean_error


def plotCalibration(ax, 
                    coverage: dict[str, torch.Tensor], names,
                    linestyle: str = '-', lw=2, upper: bool = False) -> None:
    nominal = [int(k) for k in coverage.keys()]
    matrix = torch.cat([t.unsqueeze(-1) for _,t in coverage.items()], dim=-1)
    for i, name in enumerate(names):
        # color = colors[i]
        coverage_i = matrix[i] * 100.
        if coverage_i.sum() == 0: continue
        ax.plot(nominal, coverage_i, label=name, linestyle=linestyle, lw=lw) #color=color
    ax.plot([50, 95], [50, 95], ':', zorder=1,
            color='grey', label='identity')
    ax.set_xticks(nominal)
    ax.set_yticks(nominal)
    
    ax.set_ylabel('empirical CI', fontsize=26, labelpad=10)
    ax.grid(True)
    # ax.set_title(f'Coverage {source}', fontsize=26, pad=15)
    if upper:
        ax.set_title('Coverage', fontsize=30, pad=15)
        ax.legend(fontsize=18, loc='lower right')
        ax.set_xlabel('')
        ax.tick_params(axis='x', labelcolor='w', size=1)
    else:
        ax.set_xlabel('nominal CI', fontsize=26, labelpad=10)
    
    
    
    

# outdated and not as exact
# def postCalibrate(samples: torch.Tensor, weights: list[float]) -> torch.Tensor:
#     
#     weights = torch.tensor(weights).view(1,-1,1)
#     means = samples.mean(-1, keepdim=True)
#     deltas = samples - means
#     deltas *= weights
#     samples_ = deltas + means
#     return samples_


# def calibrate(model, dl, maxiter: int = 50):
#     def _calibrate(scales):
#         model.posterior.calibrate(scales)
#         results = sprint(model, dl, importance=cfg.importance)
#         coverage = getCoverage(model, results['proposed']['global'], results['targets'])
#         error = coverageError(coverage)[1:]
#         print(f"Mean coverage errors: {error.numpy()}")
#         return error
    
#     print("Initiating claibration...")
#     step_size = 0.1
#     scales = torch.ones(cfg.d)
#     best_error = _calibrate(scales)
#     for _ in range(maxiter):
#         candidate = scales.clone() - step_size * torch.rand(cfg.d)
#         error = _calibrate(candidate)
#         if error.abs().sum() < 0.05:
#             break
#         for i in range(cfg.d):
#             if error[i] < best_error[i] and all(error >= 0):
#                 scales[i] = candidate[i]
#     print("Finished claibration.\n")
#     return scales