# Copyright (c)
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import copy

import wandb
import hydra
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
import timm

from lightly.data import LightlyDataset
from lightly.models.utils import deactivate_requires_grad
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule

from pretrain.metrics import compute_contrastive_acc, log_example_inputs

from pretrain.trainer_common import LightlyModel, main_pretrain

import torch, torch.nn as nn, torch.nn.functional as F

def log_Cd(kappa, D, eps=1e-8):
    nu = D/2 - 1.0
    k = kappa.clamp_min(eps)
    return (nu)*torch.log(k) - (D/2)*torch.log(torch.tensor(2*torch.pi, device=k.device)) \
           - torch.log(torch.special.iv(nu, k).clamp_min(eps))

class HypersphereDensityLoss(nn.Module):
    def __init__(self, kappa=1.0, alpha=1.0, beta=1.0):
        super().__init__()
        self.kappa = kappa
        self.alpha = alpha
        self.beta = beta

    def forward(self, feats):                     # feats: [B, V, D]
        B, V, D = feats.shape
        N = B * V
        x = F.normalize(feats, dim=-1).reshape(N, D)

        sim = x @ x.T                              # cosine on S^{D-1}
        eye = torch.eye(N, dtype=torch.bool, device=x.device)
        not_self = ~eye

        logK = self.kappa * sim # + log_Cd(torch.tensor(self.kappa, device=x.device), D)
        logK = logK.masked_fill(~not_self, float('-inf'))

        # global KDE: log p_i = logsumexp_j logK_ij - log(N-1)
        log_p_global = torch.logsumexp(logK, dim=1) - torch.log(torch.tensor(max(N-1,1.), device=x.device))

        # positives mask by batch group
        gid = torch.arange(B, device=x.device).repeat_interleave(V)
        pos_mask = (gid[:, None] == gid[None, :]) & not_self
        npos = pos_mask.sum(1)
        has_pos = npos > 0

        logK_pos = logK.masked_fill(~pos_mask, float('-inf'))
        log_p_pos = torch.empty_like(log_p_global).fill_(float('-inf'))
        if has_pos.any():
            log_p_pos[has_pos] = torch.logsumexp(logK_pos[has_pos], dim=1) - torch.log(npos[has_pos].float())

        # Loss: maximize local density (−log p_pos), minimize global density (+log p_global)
        L_local = -(log_p_pos[has_pos]).mean() if has_pos.any() else torch.tensor(0., device=x.device)
        L_global = (log_p_global).mean()
        loss = self.beta * L_local +  L_global * self.alpha
        return loss, L_local.detach(), L_global.detach()



class SIMDEX(LightlyModel):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)
        self.projection_head = nn.Linear(self.backbone.num_features, cfg.method_kwargs.projector_dim, bias=False)
        self.cfg = cfg
        self.criterion = HypersphereDensityLoss(
            kappa=cfg.method_kwargs.get("kappa", 1.0),
            beta=cfg.method_kwargs.get("beta", 1.0)
        )

    def setup_transform(self):
        if self.cfg.backbone.name.startswith("vit"):
            self.transform = DINOTransform(
                global_crop_size=224,
                global_crop_scale=(0.4, 1.0),
                local_crop_size=224,
                local_crop_scale=(0.05, 0.4),
                n_local_views=self.cfg.method_kwargs.local_crops,
            )
        elif self.input_size == 32:
            self.transform = DINOTransform(
                global_crop_size=32,
                global_crop_scale=(0.4, 1.0),
                local_crop_size=16,
                local_crop_scale=(0.05, 0.4),
                cj_strength=0.5,
                gaussian_blur=(0, 0, 0),
                n_local_views=self.cfg.method_kwargs.local_crops,
            )
        elif self.input_size == 64:
            self.transform = DINOTransform(
                global_crop_size=64,
                global_crop_scale=(0.4, 1.0),
                local_crop_size=32,
                local_crop_scale=(0.05, 0.4),
                cj_strength=0.5,
                gaussian_blur=(0, 0, 0),
                n_local_views=self.cfg.method_kwargs.local_crops,
            )
        elif self.input_size == 96:
            self.transform = DINOTransform(
                global_crop_size=96,
                global_crop_scale=(0.4, 1.0),
                local_crop_size=48,
                local_crop_scale=(0.05, 0.4),
                n_local_views=self.cfg.method_kwargs.local_crops,
            )
        else:
            self.transform = DINOTransform(
                n_local_views=self.cfg.method_kwargs.local_crops,
            )

    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        y = self.projection_head(y)
        return y

    def train_val_step(self, batch, batch_idx, metric_label="train_metrics"):
        views = batch[0]
        features = torch.stack([self.forward(view) for view in views], dim=1)
        loss, comp_loss, expa_loss = self.criterion(features)
        self.log(f"{metric_label}/shannon_compression_loss", comp_loss, on_step=True, on_epoch=True)
        self.log(f"{metric_label}/shannon_expansion_loss", expa_loss, on_step=True, on_epoch=True)
        self.log(f"{metric_label}/shannon_loss", loss, on_step=True, on_epoch=True)
        return loss

@hydra.main(version_base="1.2", config_path="configs/", config_name="stl10_renyi.yaml")
def pretrain_simdex(cfg: DictConfig):
    main_pretrain(cfg, SIMDEX)

if __name__ == "__main__":
    pretrain_simdex()