# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Any, Dict, List, Sequence

import omegaconf
import torch
import torch.nn as nn
from solo.losses.iter_dist import iter_dist_loss_func, batch_sparsity_metric, embedding_sparsity_metric, active_feature_fraction, count_avg_nonzero_elements_per_dimension, count_avg_nonzero_elements_per_sample, choose_sigma_for_unit_var, determine_sigma_for_lp_dist
from solo.methods.base import BaseMethod
from solo.utils.misc import omegaconf_select

import math
from itertools import combinations
import random

# for low descrepency sampling
from torch.quasirandom import SobolEngine

def generate_random_projections(num_projections, D, device=None, dtype=None):
    """
    Generates a set of random, normalized projection vectors.
    """
    P_directions = torch.randn(num_projections, D, device=device, dtype=dtype)
    P_directions = P_directions / torch.norm(P_directions, dim=1, keepdim=True)
    return P_directions

def sample_sphere_pos(num_projections, D, device=None, dtype=None):
    '''
    Generate points from the unit sphere in D-dimensional space restricted to the positive orthant.
    '''
    P_directions = torch.randn(num_projections, D, device=device, dtype=dtype).abs()
    P_directions = P_directions / torch.norm(P_directions, dim=1, keepdim=True)
    return P_directions


def generate_svd_projections(z1, z2):
    """
    Computes the right-singular vectors (V^T) of the centered feature matrices z1 and z2.
    Uses a more robust eigensolver (LOBPCG) to avoid convergence issues with torch.linalg.svd.
    """
    # Use autocast(enabled=False) for the entire computation to prevent implicit downcasting
    # and force float32 for these computations as lobpcg/eigh don't support float16 on GPU.
    with torch.amp.autocast('cuda', enabled=False):
        z1 = z1.detach().float()
        z2 = z2.detach().float()

        # Center each batch (remove mean)
        z1_centered = z1 - z1.mean(dim=0)
        z2_centered = z2 - z2.mean(dim=0)

        B, D = z1_centered.shape
        k = min(B, D)

        # Covariance matrices z^T z
        A1 = torch.matmul(z1_centered.T, z1_centered)
        A2 = torch.matmul(z2_centered.T, z2_centered)

        # Initial guess for eigenvectors
        X1 = torch.randn(D, k, device=z1.device, dtype=torch.float32)
        X2 = torch.randn(D, k, device=z1.device, dtype=torch.float32)

        try:
            # LOBPCG finds the k largest eigenvalues/eigenvectors
            # largest=True is for top eigenvectors (singular vectors)
            _, V1 = torch.lobpcg(A1, X=X1, largest=True)
            _, V2 = torch.lobpcg(A2, X=X2, largest=True)
            Vt_z1, Vt_z2 = V1.T, V2.T
        except Exception:
            # Fallback to standard eigh if lobpcg fails
            _, V1 = torch.linalg.eigh(A1)
            _, V2 = torch.linalg.eigh(A2)
            # eigh returns in ascending order, so flip and take top k
            Vt_z1 = V1.T.flip(0)[:k]
            Vt_z2 = V2.T.flip(0)[:k]

    return Vt_z1, Vt_z2


# ## for low descrepency sampling

# def generate_low_discrepancy_projections(num_projections, D, device=None, dtype=None):
#     """
#     Generates quasi-random projection directions using a Sobol low-discrepancy sequence.
#     """

#     # Sobol sequence fills [0, 1]^D uniformly.
#     sobol = SobolEngine(dimension=D)

#     # Draw quasi-random points in [0, 1]^D.
#     samples = sobol.draw(num_projections).to(device=device, dtype=dtype)

#     # Map [0, 1] → [-1, 1] to make them symmetric around zero.
#     samples = 2 * samples - 1

#     # Normalize each to unit length so they lie on the sphere S^{D-1}.
#     samples = samples / torch.norm(samples, dim=1, keepdim=True)

#     return samples



class IterDist(BaseMethod):
    def __init__(self, cfg: omegaconf.DictConfig):
        """Implements IterDist
        
        NOTE WE NEVER USE VARIANCE AND COVARIANCE LOSS FOR ANY ITER-DIST METHODS; 
        WE ONLY LOG THESE TERMS BUT NEVER OPTIMIZE FOR THEM. ITER-DIST METHODS USE SIM_LOSS AND ONE_D_DIST_LOSS ONLY.
        
        Extra cfg settings:
            method_kwargs:
                proj_output_dim (int): number of dimensions of the projected features.
                proj_hidden_dim (int): number of neurons in the hidden layers of the projector.
                projector_type (str): the type of projector.
                sim_loss_weight (float): weight of the invariance term.
                var_loss_weight (float): weight of the variance term.
                cov_loss_weight (float): weight of the covariance term.
                target_distribution (str): the distribution to optimize for.
                one_d_dist_loss_weight (float): weight of the one-dimensional loss.
                one_d_dist_loss_choice (str): the one-dimensional loss choice.
                swd_num_projections (int): number of projections for sliced wasserstein distance.
                projection_sampling_mode (str): the projection sampling mode.
                active_feature_threshold (float): the threshold for active features.
                mean_shift_scalar_for_rectified_gauss (float): the mean shift scalar for the rectified distributions (gauss, laplace, etc.).
                p_norm_for_rectified_lp_distribution (float): the p norm for the rectified lp distribution.
        """

        # proj_output_dim is needed by projector_classifier in BaseMethod if enabled.
        # Ensure it's available on cfg.method_kwargs if projector_classifier is true.
        if omegaconf_select(cfg, "method_kwargs.add_projector_classifier", False):
            assert not omegaconf.OmegaConf.is_missing(
                cfg, "method_kwargs.proj_output_dim"
            ), "IterDist: method_kwargs.proj_output_dim must be set if add_projector_classifier is True."

        super().__init__(cfg)

        self.sim_loss_weight: float = cfg.method_kwargs.sim_loss_weight
        self.var_loss_weight: float = cfg.method_kwargs.var_loss_weight
        self.cov_loss_weight: float = cfg.method_kwargs.cov_loss_weight

        # distribution mode
        self.target_distribution: str = cfg.method_kwargs.target_distribution

        # hyperparameters for the one_d_dist_loss
        self.one_d_dist_loss_weight: float = cfg.method_kwargs.one_d_dist_loss_weight
        self.one_d_dist_loss_choice: str = cfg.method_kwargs.one_d_dist_loss_choice
        self.swd_num_projections = cfg.method_kwargs.swd_num_projections
        self.projection_sampling_mode: str = cfg.method_kwargs.projection_sampling_mode
        self.active_feature_threshold: float = cfg.method_kwargs.active_feature_threshold

        # hyperparameters for the rectified gauss
        self.mean_shift_scalar_for_rectified_gauss: float = cfg.method_kwargs.mean_shift_scalar_for_rectified_gauss
        self.p_norm_for_rectified_lp_distribution: float = cfg.method_kwargs.p_norm_for_rectified_lp_distribution
        
        # determine chosed_sigma (preserving standard scales for existing distributions)
        if self.target_distribution == "rectified_gauss" or self.target_distribution == "gauss":
            self.chosed_sigma = determine_sigma_for_lp_dist(2.0)  # equals 1.0
        elif self.target_distribution == "rectified_product_laplace" or self.target_distribution == "product_laplace":
            self.chosed_sigma = determine_sigma_for_lp_dist(1.0)  # equals 1/sqrt(2)
        elif self.target_distribution == "rectified_lp_distribution":
            self.chosed_sigma = determine_sigma_for_lp_dist(self.p_norm_for_rectified_lp_distribution)
        elif self.target_distribution == "lp_distribution":
            self.chosed_sigma = determine_sigma_for_lp_dist(self.p_norm_for_rectified_lp_distribution)
        else:
            self.chosed_sigma = None
            
        if self.target_distribution == "gauss":
            assert self.one_d_dist_loss_choice in ['sliced_wasserstein_distance', 'jarque_bera_loss', 'sigreg_loss']
        elif self.target_distribution == "laplace":
            assert self.one_d_dist_loss_choice in ['sliced_wasserstein_distance']
        elif self.target_distribution == "product_laplace":
            assert self.one_d_dist_loss_choice in ['sliced_wasserstein_distance']
        elif self.target_distribution == "rectified_gauss":
            assert self.one_d_dist_loss_choice in ['sliced_wasserstein_distance']
        elif self.target_distribution == "rectified_product_laplace": # rectfied product laplace is the headline sparsity distribution
            assert self.one_d_dist_loss_choice in ['sliced_wasserstein_distance']
        elif self.target_distribution == "rectified_lp_distribution":
            assert self.one_d_dist_loss_choice in ['sliced_wasserstein_distance']
        elif self.target_distribution == "lp_distribution":
            assert self.one_d_dist_loss_choice in ['sliced_wasserstein_distance']
        else:
            raise ValueError("not supported.")

        # projector configuration
        proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim
        proj_output_dim: int = cfg.method_kwargs.proj_output_dim
        self.proj_output_dim = proj_output_dim
        self.projector_type: str = cfg.method_kwargs.projector_type

        # projector
        if self.projector_type == "mlp3":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_output_dim),
            )
        elif self.projector_type == "mlp3_with_one_more_relu":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_output_dim),
                nn.ReLU(),
            )
        elif self.projector_type == "identity":
            self.projector = nn.Identity()
            assert self.features_dim == self.proj_output_dim, "features dim and projector output dim must be the same for identity projector."
        else:
            raise ValueError(f"Invalid projector type: {self.projector_type}")
        
        # mode of sigma
        if self.target_distribution == "rectified_lp_distribution":
            actual_lp = self.p_norm_for_rectified_lp_distribution
        elif self.target_distribution == "rectified_gauss":
            actual_lp = 2.0
        elif self.target_distribution == "rectified_product_laplace":
            actual_lp = 1.0
        elif self.target_distribution == "gauss":
            actual_lp = 2.0
        elif self.target_distribution == "product_laplace":
            actual_lp = 1.0
        elif self.target_distribution == "laplace":
            raise ValueError(f"Should not be using elliptical laplace anymore.")
        else:
            raise ValueError(f"Invalid target distribution: {self.target_distribution}")

        self.mode_of_sigma: str = cfg.method_kwargs.mode_of_sigma
        if self.mode_of_sigma == "gen_gauss_var_1":
            self.chosed_sigma = determine_sigma_for_lp_dist(actual_lp)
        elif self.mode_of_sigma == "rec_gen_gauss_var_1":
            if "rectified" in self.target_distribution:
                self.chosed_sigma = choose_sigma_for_unit_var(actual_lp, self.mean_shift_scalar_for_rectified_gauss)
            else:
                raise ValueError(f"Invalid target distribution: {self.target_distribution} for self.mode_of_sigma=rec_gen_gauss_var_1")
        else:
            raise ValueError(f"Invalid mode of sigma: {self.mode_of_sigma}")
        
        # print out the chosen sigma
        print(f"Chosen sigma for {self.target_distribution} with mean shift {self.mean_shift_scalar_for_rectified_gauss} and p_norm {self.p_norm_for_rectified_lp_distribution} and self.mode_of_sigma={self.mode_of_sigma} is {self.chosed_sigma}")

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """

        cfg = super(IterDist, IterDist).add_and_assert_specific_cfg(cfg)

        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim")
        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim")

        cfg.method_kwargs.sim_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.sim_loss_weight",
            25.0,
        )
        cfg.method_kwargs.var_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.var_loss_weight",
            25.0,
        )
        cfg.method_kwargs.cov_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.cov_loss_weight",
            1.0,
        )
        
        cfg.method_kwargs.active_feature_threshold = omegaconf_select(
            cfg,
            "method_kwargs.active_feature_threshold",
            1e-3,
        )
        cfg.method_kwargs.projector_type = omegaconf_select(
            cfg,
            "method_kwargs.projector_type",
            "mlp3",
        )
        cfg.method_kwargs.mean_shift_scalar_for_rectified_gauss = omegaconf_select(
            cfg,
            "method_kwargs.mean_shift_scalar_for_rectified_gauss",
            0.0,
        ) # this is for the rectified gauss distribution to control sparsity.
        cfg.method_kwargs.p_norm_for_rectified_lp_distribution = omegaconf_select(
            cfg,
            "method_kwargs.p_norm_for_rectified_lp_distribution",
            1.0, # default to 1. This parameter only matters when target_distribution is "rectified_lp_distribution". Otherwise, it's not interfering with anything.
        )
        cfg.method_kwargs.mode_of_sigma = omegaconf_select(
            cfg,
            "method_kwargs.mode_of_sigma",
            "gen_gauss_var_1",
        )

        return cfg

    @property
    def learnable_params(self) -> List[dict]:
        """Adds projector parameters to the parent's learnable parameters.

        Returns:
            List[dict]: list of learnable parameters.
        """

        extra_learnable_params = [{"name": "projector", "params": self.projector.parameters()}]
        return super().learnable_params + extra_learnable_params

    def forward(self, X: torch.Tensor) -> Dict[str, Any]:
        """Performs the forward pass of the backbone and the projector.

        Args:
            X (torch.Tensor): a batch of images in the tensor format.

        Returns:
            Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
        """

        out = super().forward(X)
        z = self.projector(out["feats"])
        out.update({"z": z})

        # If the projector classifier exists (initialized in BaseMethod), call it.  # self.projector_classifier is from BaseMethod
        if self.projector_classifier is not None:
            # Pass z.detach() to ensure gradients for this classifier only update itself,
            # not the projector or backbone through this path.

            projector_logits = self.projector_classifier(z.detach())
        
            out.update({"projector_logits": projector_logits})

        return out

    def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
        """Training step for iter_dist reusing BaseMethod training step.

        Args:
            batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size num_crops containing batches of images.
            batch_idx (int): index of the batch.
    
        Returns:
            torch.Tensor: total loss composed of iter_dist loss and classification loss.
        """

        out = super().training_step(batch, batch_idx)
        class_loss = out["loss"]
        z1, z2 = out["z"]
        feats1, feats2 = out["feats"]

        do_log = self.global_step % self.logging_interval == 0
        if do_log:
            self._log_sparsity_metrics(z1, z2, feats1, feats2)

        if self.projection_sampling_mode == 'torch_svd_and_random':
            # compute eigenvectors of the covariance matrices.
            Vt_z1, Vt_z2 = generate_svd_projections(z1, z2)

            # assert that we have enough number of projections
            assert self.swd_num_projections - Vt_z1.size(0)*2 > 0, "we should have extra random projections beyond eigenvectors!"
            assert Vt_z1.size(0) == Vt_z2.size(0), "equal number of eigenvectors."
            orthogonal_transform = generate_random_projections(self.swd_num_projections-Vt_z1.size(0), self.proj_output_dim, device=z1.device, dtype=z1.dtype)

            # incorporate the eigenvectors as part of the orthogonal transform
            # there is no need to shuffle because we averaged over all projections anyway. 
            orthogonal_transform_z1 = torch.vstack([Vt_z1, orthogonal_transform])
            orthogonal_transform_z2 = torch.vstack([Vt_z2, orthogonal_transform])
            orthogonal_transform = [orthogonal_transform_z1, orthogonal_transform_z2]
        elif self.projection_sampling_mode == 'torch_svd_bottom_half_eigen_and_random':
            # compute eigenvectors of the covariance matrices.
            Vt_z1, Vt_z2 = generate_svd_projections(z1, z2)

            # assert that we have enough number of projections
            assert self.swd_num_projections - Vt_z1.size(0)*2 > 0, "we should have extra random projections beyond eigenvectors!"
            assert Vt_z1.size(0) == Vt_z2.size(0), "equal number of eigenvectors."

            # slice over the bottom half of the eigenvectors
            Vt_z1_bottom_half = Vt_z1[Vt_z1.size(0)//2:]
            Vt_z2_bottom_half = Vt_z2[Vt_z2.size(0)//2:]

            assert Vt_z1_bottom_half.size(0) == Vt_z2_bottom_half.size(0), "equal number of bottom half eigenvectors."

            # generate random projections for the remaining number of projections
            orthogonal_transform = generate_random_projections(
                self.swd_num_projections - Vt_z1_bottom_half.size(0),
                self.proj_output_dim, device=z1.device, dtype=z1.dtype)
                                    
            # incorporate the eigenvectors as part of the orthogonal transform
            # there is no need to shuffle because we averaged over all projections anyway. 
            orthogonal_transform_z1 = torch.vstack([Vt_z1_bottom_half, orthogonal_transform])
            orthogonal_transform_z2 = torch.vstack([Vt_z2_bottom_half, orthogonal_transform])
            orthogonal_transform = [orthogonal_transform_z1, orthogonal_transform_z2]
        elif self.projection_sampling_mode == 'fixed_torch_svd':
            assert self.swd_num_projections == min(z1.shape), "the total number of projections should be the same as the number of singular vectors under truncated SVD."
            # This is the setting where we only use the eigenvectors of the covariance matrices.

            # compute eigenvectors of the covariance matrices.
            # Vt_z1.shape = (batch_size, dimension)
            # Vt_z2.shape = (batch_size, dimension)
            Vt_z1, Vt_z2 = generate_svd_projections(z1, z2)

            orthogonal_transform = [Vt_z1, Vt_z2]
        elif self.projection_sampling_mode == 'fixed_random':
            assert self.swd_num_projections == min(z1.shape), "the total number of random projections should be the same as the number of singular vectors under truncated SVD. This ensures fair comparison to fixed_torch_svd."

            # generate different random projections for each view.
            orthogonal_transform_1 = generate_random_projections(self.swd_num_projections, self.proj_output_dim, device=z1.device, dtype=z1.dtype)
            orthogonal_transform_2 = generate_random_projections(self.swd_num_projections, self.proj_output_dim, device=z1.device, dtype=z1.dtype)

            # use different random projections for each view.
            orthogonal_transform = [orthogonal_transform_1, orthogonal_transform_2]
        elif self.projection_sampling_mode == 'random':
            orthogonal_transform = generate_random_projections(self.swd_num_projections, self.proj_output_dim, device=z1.device, dtype=z1.dtype)
        elif self.projection_sampling_mode == 'sphere_pos':
            orthogonal_transform = sample_sphere_pos(self.swd_num_projections, self.proj_output_dim, device=z1.device, dtype=z1.dtype)
        # elif self.projection_sampling_mode == 'low_discrepancy': # let's try this :)
        #     orthogonal_transform = generate_low_discrepancy_projections(self.swd_num_projections, self.proj_output_dim, device=z1.device, dtype=z1.dtype)
        else:
            raise ValueError

        # ------- iter_dist loss -------
        iter_dist_loss, sim_loss, var_loss, cov_loss, marginal_dist_loss = iter_dist_loss_func(
            z1,
            z2,
            orthogonal_transform,
            target_distribution=self.target_distribution,
            sim_loss_weight=self.sim_loss_weight,
            var_loss_weight=self.var_loss_weight,
            cov_loss_weight=self.cov_loss_weight,
            one_d_dist_loss_weight=self.one_d_dist_loss_weight,
            one_d_dist_loss_choice=self.one_d_dist_loss_choice,
            mean_shift_scalar_for_rectified_gauss=self.mean_shift_scalar_for_rectified_gauss,
            p_norm_for_rectified_lp_distribution=self.p_norm_for_rectified_lp_distribution,
            chosed_sigma=self.chosed_sigma,
        )

        self.log("train_iter_dist_loss", iter_dist_loss, on_epoch=True, sync_dist=True)
        self.log("train_sim_loss", sim_loss, on_epoch=True, sync_dist=True)
        self.log("train_var_loss", var_loss, on_epoch=True, sync_dist=True)
        self.log("train_cov_loss", cov_loss, on_epoch=True, sync_dist=True)
        self.log("train_marginal_dist_loss", marginal_dist_loss, on_epoch=True, sync_dist=True)

        # The projector_class_loss contributes to the optimization objective, so it must be 
        # computed every step. However, we only log its results (accuracies and loss) at the 
        # specified logging_interval to minimize the time-costly synchronization overhead.
        projector_class_loss = torch.tensor(0.0, device=self.device)
        if self.projector_classifier is not None:
            _, _, targets = batch
            proj_metrics1 = self._projector_classifier_step(z1, targets)
            proj_metrics2 = self._projector_classifier_step(z2, targets)
            
            if proj_metrics1 and proj_metrics2:
                projector_class_loss = (proj_metrics1["proj_loss"] + proj_metrics2["proj_loss"]) / 2
                if do_log:
                    self.log("train_proj_loss", projector_class_loss, on_epoch=True, sync_dist=True)
                    proj_acc1 = (proj_metrics1["proj_acc1"] + proj_metrics2["proj_acc1"]) / 2
                    proj_acc5 = (proj_metrics1["proj_acc5"] + proj_metrics2["proj_acc5"]) / 2
                    self.log("train_proj_acc1", proj_acc1, on_epoch=True, sync_dist=True)
                    self.log("train_proj_acc5", proj_acc5, on_epoch=True, sync_dist=True)

        return iter_dist_loss + class_loss + projector_class_loss
