#########################################################################
##   This file is part of the α,β-CROWN (alpha-beta-CROWN) verifier    ##
##                                                                     ##
##   Copyright (C) 2021-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################
import torch
import matplotlib.pyplot as plt
from os import makedirs
from math import ceil
from numpy import ndarray
from typing import Optional
from warnings import warn

def sample_check(model, x_L, x_U, C, constraints, num_samples=10000):
    batchsize = x_L.size(0)
    x_dim = x_L.size(-1)
    x_L = x_L.reshape((batchsize, 1, x_dim))
    x_U = x_U.reshape((batchsize, 1, x_dim))
    constraints_A, constraints_bias = constraints
    constraints_A = constraints_A.reshape((batchsize, -1, x_dim)).repeat((num_samples, 1, 1))
    constraints_bias = constraints_bias.reshape((batchsize, -1)).repeat((num_samples, 1))
    samples = torch.rand((batchsize, num_samples, x_dim)).to(x_L.device) * (x_U - x_L) + x_L
    samples = samples.reshape((batchsize*num_samples, 1, 1, x_dim))
    model.eval()

    with torch.no_grad():
        # outputs: (bs * n_sam, 4, 1)
        outputs = model(samples).unsqueeze(-1)
        # new_C: (bs * n_sam, o_dim, model_o_dim)
        new_C = C.repeat(num_samples, 1, 1)
        outputs = new_C.bmm(outputs).reshape((batchsize, num_samples, -1))
        output_dim = outputs.size(2)
        # constraints_lower_bound: (bs * n_sam, n_constraints)
        constraints_lower_bound = constraints_A.bmm( samples.reshape( (batchsize*num_samples, x_dim, 1) ) ).squeeze(-1) + constraints_bias
        condition = (constraints_lower_bound.reshape((batchsize, num_samples, -1)) < 0).all(dim=-1)
        condition_expanded = condition.unsqueeze(-1).expand(-1, -1, output_dim)
        outputs_masked = outputs.masked_fill(~condition_expanded, float('inf'))
        min_value = outputs_masked.min(dim=1).values
    return min_value


def save_sanity_check_graphs(
        global_lbs: ndarray,
        benchmark_name: Optional[str],
        vnnlib_id: int,
        property_idx: int,
        dir_timestamp: str
) -> None:
    """
    When called, creates a log scale convergence plot of (rhs - global_lb) where global_lb < rhs.
    Should global_lb >= rhs, an error will occur when plotting.
    @param global_lbs:      The global lower bound values after i iterations
    @param benchmark_name:  Name of the currently running benchmark
    @param vnnlib_id:       Current vnnlib_id in the benchmark
    @param property_idx:    Current property being verified for the current vnnlib file
    @param dir_timestamp:   The timestamp is used as the directory for which all graphs will be saved to
    @return:
    """

    if benchmark_name is None:
        warn("'save_sanity_check_graphs' was called but benchmark_name not given. Will skip creating graphs.")
        return

    iterations, features = global_lbs.shape
    grid_division = 2 if features > 1 else 1
    plt.figure(figsize=(16, 12))
    print(f"Saving graphs...")
    rows = ceil(features / grid_division)
    for i in range(features):
        plt.subplot(rows, grid_division, i + 1)
        best_lb = global_lbs[:, i].min()
        plt.plot(global_lbs[:, i])
        plt.grid()
        plt.title(f"Convergence for feature {i}\nBest (rhs - lb): {best_lb}")
        plt.xlabel("Iterations")
        plt.ylabel(f"rhs[{i}] - global_lb[{i}]")
        plt.yscale('log')
    save_dir = f"../sanity_check_outputs/{dir_timestamp}"
    makedirs(save_dir, exist_ok=True)
    plt.savefig(
        save_dir + f"/benchmark_{benchmark_name}_vnnlib_id_{vnnlib_id}_property_{property_idx}_sanity_check_graphs.png")