#########################################################################
##   This file is part of the α,β-CROWN (alpha-beta-CROWN) verifier    ##
##                                                                     ##
##   Copyright (C) 2021-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################

import torch
from torch import Tensor
from typing import Optional, Union, Tuple, List, Dict
from tensor_storage import get_tensor_storage
from abc import ABC, abstractmethod
import psutil, os

class InputDomainList(ABC):
    """
    Abstract class that maintains a list of domains for input split.
    ABC enforces stronger structure since it will explicitly
    throw an error if a subclass does not implement any of
    these abstract methods.
    """
    def __init__(self):
        pass

    @abstractmethod
    def __len__(self):
        """Number of domains remaining in the list"""
        ...

    @abstractmethod
    def __getitem__(self, *args, **kwargs):
        """
        get lb, dm_l, dm_u, cs, threshold
        (and possibly more element) for idx
        """
        ...

    @abstractmethod
    def add(self, *args, **kwargs):
        """Add domains to the list"""
        ...

    @abstractmethod
    def pick_out_batch(self, batch, device="cuda"):
        """Pick out a batch of subdomains from the domain list."""
        ...

    @staticmethod
    def get_topk_indices(self, k=1, largest=False):
        """get the topk indices, by default worst k"""
        ...

    @property
    @abstractmethod
    def use_alpha(self):
        """
        True/False based on whether this domain list must
        handle storing alpha. Since an abstract class cannot
        enforce subclasses to have this attribute, we
        enforce it as a property.
        """
        ...


class UnsortedInputDomainList(InputDomainList):
    """Unsorted domain list for input split."""

    def __init__(
        self,
        storage_depth,
        device_for_computing: str,
        use_alpha: bool=False,
        sort_index: Optional[Tensor]=None,
        sort_descending: bool=True,
        use_split_idx: bool=True,
    ):
        """
        The initialization method for the UnsortedInputDomainList class.

        :param storage_depth:       The maximum number of splits we could use in input BaB
        :param device_for_computing: The device performing the necessary computations. it is not for storage.
                                    __getitem__ will send the tensors to the device specified here.
        :param use_alpha:           True if we must also store alpha parameters
        :param sort_index:          The index along which to sort the domains
        :param sort_descending:     If True, domains will get sorted in descending order w.r.t. their lower bounds
                                    whenever the 'sort' method is called.
        :param use_split_idx:       If True, we also store the split indices for each domain
        """
        super(UnsortedInputDomainList, self).__init__()
        self.device = device_for_computing
        self.lb = None
        self.dm_l = None
        self.dm_u = None
        self.alpha = {}
        self._use_alpha = use_alpha
        self.sort_index = sort_index
        self.cs = None
        self.threshold = None
        self.constraint_A = None
        self.constraint_b = None
        self.split_idx = None
        self.storage_depth = storage_depth
        self.sort_descending = sort_descending
        self.volume = self.all_volume = None
        self.use_split_idx = use_split_idx

    def __len__(self):
        if self.dm_l is None:
            return 0
        return self.dm_l.num_used

    def __getitem__(self, idx):
        # convert idx to tensor on cpu for slicing.
        if isinstance(idx, slice):
            idx = torch.arange(len(self), device="cpu")[idx]
        else:
            idx = torch.as_tensor(idx, device="cpu")
        assert idx.numel() > 0, "Empty index"
        return (
            self.lb._storage[idx].to(self.device),
            self.dm_l._storage[idx].to(self.device),
            self.dm_u._storage[idx].to(self.device),
            self.cs._storage[idx].to(self.device),
            self.threshold._storage[idx].to(self.device)
        )

    @staticmethod
    def filter_verified_domains(
            batch: int,
            lb: Tensor,
            dm_l: Tensor,
            dm_u: Tensor,
            alpha: dict,
            cs: Tensor,
            use_alpha: bool,
            threshold: Tensor,
            lA: Optional[Tensor] = None,
            lbias: Optional[Tensor] = None,
            constraints: Optional[tuple] = None,
            split_idx: Optional[Tensor] = None,
            check_thresholds: bool = True,
            check_bounds: bool = True,
            remaining_index: Optional[Tensor] = None
    ) -> Tuple[int, Tensor, Tensor, Tensor, dict,
    Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
        """
        Filters out the domains that are verified and only returns unverified domains
        @param batch:                                   Batch size of domains
        @param lb: (batch, spec_dim)                    Domain lower bound output
        @param dm_l: (batch, dim_in)                    Input domain lower input bound
        @param dm_u: (batch, dim_in)                    Input domain upper input bound
        @param alpha:                                   CROWN alpha parameters for domains
        @param cs: (batch, spec_dim, lA_rows)           specification matrix
        @param threshold: (batch, spec_dim)             Threshold to verify specification with
        @param lA: (batch, lA_rows or spec_dim, dim_in) CROWN lA coefficient matrix
        @param lbias: (batch, spec_dim)                 CROWN lbias coefficient matrix
        @param constraints:                             linear constraints
        @param split_idx: (batch, num of splits)        Specifies along which dimensions to split
        @param check_thresholds:                        If true, filters out domains that have been verified
                                                        by lb > threshold
        @param check_bounds:                            If true, filters out domains that have been verified
                                                        by dm_l < dm_u
        @param remaining_index:                         If not None, user is specifying which domains are unverified

        @return:
        """
        if remaining_index is None:
            remaining_index = UnsortedInputDomainList.get_remaining_index(
                batch, lb, threshold, dm_l, dm_u, check_thresholds, check_bounds
            )
        lb_filt = lb[remaining_index]
        dm_l_filt = dm_l[remaining_index]
        dm_u_filt = dm_u[remaining_index]
        cs_filt = cs[remaining_index]
        batch_filt = len(dm_l_filt)
        alpha_filt = {}
        if use_alpha and batch_filt > 0:
            with torch.no_grad():
                # alpha may have different device from other tensors.
                # In get_lower_bound_naive() we transfer alpha to cpu and float16.
                alpha_device = next(iter(next(iter(alpha.values())).values())).device
                if isinstance(remaining_index, Tensor):
                    remaining_index_for_alpha = remaining_index.to(alpha_device)
                else:
                    remaining_index_for_alpha = remaining_index
                for key0 in alpha.keys():
                    alpha_filt[key0] = {}
                    for key1 in alpha[key0].keys():
                        # alpha[key0][key1] has shape (dim_in, spec_dim, batches, unstable size)
                        alpha_filt[key0][key1] = alpha[key0][key1][:, :, remaining_index_for_alpha]

        threshold_filt = threshold[remaining_index]
        lA_filt = lA[remaining_index] if lA is not None else None
        lbias_filt = lbias[remaining_index] if lbias is not None else None
        split_idx_filt = split_idx[remaining_index] if split_idx is not None else None
        constraints_filt = None
        if constraints is not None:
            c_A, c_b = constraints
            c_A_filt = c_A[remaining_index]
            c_b_filt = c_b[remaining_index]
            constraints_filt = (c_A_filt, c_b_filt)

        return batch_filt, lb_filt, dm_l_filt, dm_u_filt, alpha_filt, cs_filt, threshold_filt, lA_filt, lbias_filt, constraints_filt, split_idx_filt

    @staticmethod
    def get_remaining_index(
            batch: int,
            lb: Tensor,
            threshold: Tensor,
            dm_l: Tensor,
            dm_u: Tensor,
            check_thresholds=True,
            check_bounds=True
    ) -> Union[slice, Tensor]:
        """
        Gets the indices of the batch instances that are not verified. Verification conditions are specified by
        the check_thresholds and check_bounds flags. If both are None, all indicies are returned.

        @param batch:                       Batch size of domains
        @param lb: (batch, spec_dim)        Domain lower bound output
        @param threshold: (batch, spec_dim) Threshold to verify specification with
        @param dm_l: (batch, dim_in)        Input domain lower input bound
        @param dm_u: (batch, dim_in)        Input domain upper input bound
        @param check_thresholds:            If true, filters out domains that have been verified by lb > threshold
        @param check_bounds:                If true, filters out domains that have been verified by dm_l < dm_u
        @return:                            The indices of the batch instances that are left unverified
        """

        remaining_mask = torch.ones(batch, dtype=torch.bool, device=lb.device)
        if check_thresholds:
            remaining_mask = remaining_mask & (lb <= threshold).all(1)
        if check_bounds:
            remaining_mask = remaining_mask & (dm_l.view(batch, -1) <= dm_u.view(batch, -1)).all(1)
        if remaining_mask.all():
            return slice(None)
        return torch.where(remaining_mask)[0]

    def add(
            self,
            lb: Tensor,
            dm_l: Tensor,
            dm_u: Tensor,
            alpha: dict,
            cs: Tensor,
            threshold: Tensor,
            constraints: tuple = None,
            split_idx: Union[Tensor, None] = None,
            remaining_index: Union[Tensor, None] = None,
            check_thresholds: bool=True,
            check_bounds: bool=True
    ) -> None:
        """
        Takes verified and unverified subdomains and only adds the unverified subdomains

        @param lb: Shape (batch, num_spec)                  Lower bound on domain outputs
        @param dm_l: Shape (batch, input_dim)               Lower bound on domain inputs
        @param dm_u: Shape (batch, input_dim)               Upper bound on domain inputs
        @param alpha:                                       alpha parameters
        @param cs: Shape (batch, num_spec, lA rows)         The C transformation matrix
        @param threshold: Shape (batch, num_spec)           The specification thresholds
        @param constraints:                                 constraints parameters
        @param split_idx: Shape (batch, num of splits)      Specifies along which dimensions to split
        @param remaining_index:                             If not None, user is specifying which domains are unverifie
        """
        # check shape correctness
        batch = len(lb)
        if batch == 0:
            return
        if self.use_split_idx:
            assert split_idx is not None, "Cannot accept split_idx"
            assert len(split_idx) == batch
            assert split_idx.shape[1] == self.storage_depth
        else:
            assert split_idx is None, "Expected to receive split_idx"
        assert len(dm_l) == len(dm_u) == len(cs) == len(threshold) == batch
        if self.use_alpha:
            if alpha is None:
                raise ValueError("alpha should not be None in alpha-crown.")
        # initialize attributes using input shapes and types
        if self.lb is None:
            self.lb = get_tensor_storage(lb.shape, dtype=lb.dtype, device="cpu")
        if self.dm_l is None:
            self.dm_l = get_tensor_storage(dm_l.shape, dtype=dm_l.dtype, device="cpu")
        if self.dm_u is None:
            self.dm_u = get_tensor_storage(dm_u.shape, dtype=dm_u.dtype, device="cpu")
        if self.use_alpha and not self.alpha:
            for key0 in alpha.keys():
                self.alpha[key0] = {}
                for key1 in alpha[key0].keys():
                    self.alpha[key0][key1] = get_tensor_storage(
                        alpha[key0][key1].shape, concat_dim=2,
                        dtype=alpha[key0][key1].dtype, device="cpu"
                    )
        if self.cs is None:
            self.cs = get_tensor_storage(cs.shape, dtype=cs.dtype, device="cpu")
        if self.threshold is None:
            self.threshold = get_tensor_storage(threshold.shape, dtype=threshold.dtype, device="cpu")
        if constraints is not None:
            constraint_A, constraint_b = constraints
            if self.constraint_A is None or self.constraint_b is None:
                self.constraint_A = get_tensor_storage(constraint_A.shape, dtype=constraint_A.dtype, device="cpu")
                self.constraint_b = get_tensor_storage(constraint_b.shape, dtype=constraint_b.dtype, device="cpu")
        if self.split_idx is None and self.use_split_idx:
            self.split_idx = get_tensor_storage(split_idx.shape, dtype=split_idx.dtype, device="cpu")
        # compute unverified indices
        if remaining_index is None:
            remaining_index = UnsortedInputDomainList.get_remaining_index(
                batch, lb, threshold, dm_l, dm_u, check_thresholds, check_bounds
            )
        # append the tensors
        self.lb.append(lb[remaining_index].to(self.lb.device))

        dm_l = dm_l[remaining_index]
        dm_u = dm_u[remaining_index]
        self._add_volume(dm_l, dm_u)
        self.dm_l.append(dm_l.to(self.dm_l.device))
        self.dm_u.append(dm_u.to(self.dm_u.device))
        if self.use_alpha:
            for key0 in alpha.keys():
                for key1 in alpha[key0].keys():
                    self.alpha[key0][key1].append(
                        alpha[key0][key1][:, :, remaining_index]
                        .to(self.alpha[key0][key1].device)
                    )
        self.cs.append(cs[remaining_index].to(self.cs.device))
        self.threshold.append(
            threshold[remaining_index]
            .to(self.threshold.device)
        )

        if constraints is not None:
            self.constraint_A.append(
                constraint_A[remaining_index]
                .type(self.constraint_A.dtype)
                .to(self.constraint_A.device)
            )
            self.constraint_b.append(
                constraint_b[remaining_index]
                .type(self.constraint_b.dtype)
                .to(self.constraint_b.device)
            )            
        if self.use_split_idx:
            self.split_idx.append(
                split_idx[remaining_index]
                .to(self.split_idx.device)
            )

    def pick_out_batch(self, batch: int, device="cuda"
                       )->Tuple[dict, Tensor, Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
        """
        Picks out a batch of subdomains from the domain list.

        :param batch:       The maximum number of domains we should pick out
        :param device:      The device all Tensors should be sent to

        :return alphas:     If supported, contains alpha parameters for the batch
        :return lb:         Output lower bounds
        :return dm_l:       Domain input lower bounds
        :return dm_u:       Domain input upper bounds
        :return cs:         Specification matrices
        :return threshold:  Thresholds
        :return split_idx:  If supported, the input dimensions we should split along
        """
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        batch = min(len(self), batch)
        assert batch > 0, "List of InputDomain is empty; pop failed."
        lb = self.lb.pop(batch).to(device=device, non_blocking=True)
        dm_l = self.dm_l.pop(batch).to(device=device, non_blocking=True)
        dm_u = self.dm_u.pop(batch).to(device=device, non_blocking=True)
        alpha = {}
        if self.use_alpha:
            for key0, val0 in self.alpha.items():
                alpha[key0] = {}
                for key1, val1 in val0.items():
                    alpha[key0][key1] = val1.pop(batch).to(device=device, dtype=lb.dtype, non_blocking=True)
        cs = self.cs.pop(batch).to(device=device, non_blocking=True)
        threshold = self.threshold.pop(batch).to(device=device, non_blocking=True)
        constraints = None
        if self.constraint_A is not None or self.constraint_b is not None:
            constraint_A = self.constraint_A.pop(batch).to(device, non_blocking=True)
            constraint_b = self.constraint_b.pop(batch).to(device, non_blocking=True)
            constraints = (constraint_A, constraint_b)
        if self.use_split_idx:
            split_idx = self.split_idx.pop(batch).to(device=device, non_blocking=True)
        else:
            split_idx = None
        self._add_volume(dm_l, dm_u, sign=-1)
        return alpha, lb, dm_l, dm_u, cs, threshold, constraints, split_idx

    def _add_volume(self, dm_l, dm_u, sign=1):
        volume = torch.prod(dm_u - dm_l, dim=-1).sum().item()
        if self.all_volume is None:
            self.all_volume = volume
            self.volume = 0
        self.volume = self.volume + sign * volume

    def get_progess(self):
        if self.all_volume is None or self.all_volume == 0:
            return 0.
        else:
            return 1 - self.volume / self.all_volume

    def _get_sort_margin(self, margin):
        if self.sort_index is not None:
            return margin[..., self.sort_index]
        else:
            return margin.max(dim=1).values

    @property
    def use_alpha(self):
        return self._use_alpha

    def get_topk_indices(self, k=1, largest=False, return_margin=False):
        assert k <= len(self), print("Asked indices more than domain length.")
        lb = self.lb._storage[: self.lb.num_used]
        threshold = self.threshold._storage[: self.threshold.num_used]
        margins, indices = self._get_sort_margin(lb - threshold).topk(k, largest=largest)
        if return_margin:
            return indices, margins
        return indices

    def sort(self):
        lb = self.lb._storage[: self.lb.num_used]
        threshold = self.threshold._storage[: self.threshold.num_used]
        indices = self._get_sort_margin(lb - threshold).argsort(
            descending=self.sort_descending)
        # sort the storage
        self.lb._storage[: self.lb.num_used] = self.lb._storage[indices]
        self.dm_l._storage[: self.dm_l.num_used] = self.dm_l._storage[indices]
        self.dm_u._storage[: self.dm_u.num_used] = self.dm_u._storage[indices]
        if self.use_alpha:
            for val0 in self.alpha.values():
                for val1 in val0.values():
                    val1._storage[
                    :, :, :val1.num_used] = val1._storage[:, :, indices]
        self.cs._storage[: self.cs.num_used] = self.cs._storage[indices]
        self.threshold._storage[: self.threshold.num_used] = self.threshold._storage[indices]
        
        if self.constraint_A is not None:
            self.constraint_A._storage[: self.constraint_A.num_used] = self.constraint_A._storage[indices]
        if self.constraint_b is not None:
            self.constraint_b._storage[: self.constraint_b.num_used] = self.constraint_b._storage[indices]
        
        
        if self.use_split_idx:
            self.split_idx._storage[: self.split_idx.num_used] = self.split_idx._storage[indices]

    def report_memory(self):
        def _report_memory(attr_name, allocated_in_MB, used_in_MB):
            print(f"[{attr_name}] allocated: {allocated_in_MB:.2f} MB, used: {used_in_MB:.2f} MB")

        _report_memory("lb", *self.lb.calculate_memory())
        _report_memory("dm_l", *self.dm_l.calculate_memory())
        _report_memory("dm_u", *self.dm_u.calculate_memory())
        if self.use_alpha:
            alpha_allocated_in_MB, alpha_used_in_MB = 0, 0
            for val0 in self.alpha.values():
                for val1 in val0.values():
                    curr_alpha_allocated, curr_alpha_used = val1.calculate_memory()
                    alpha_allocated_in_MB += curr_alpha_allocated
                    alpha_used_in_MB += curr_alpha_used
            _report_memory("alpha", alpha_allocated_in_MB, alpha_used_in_MB)
        _report_memory("cs", *self.cs.calculate_memory())
        _report_memory("threshold", *self.threshold.calculate_memory())
        if self.use_split_idx:
            _report_memory("split_idx", *self.split_idx.calculate_memory())

        process = psutil.Process(os.getpid())
        memory_info = process.memory_info()
        memory_used_MB = memory_info.rss / 1024 / 1024
        print(f"Total memory used: {memory_used_MB:.2f} MB")

        return