from __future__ import print_function
from math import e
import numpy as np
import argparse, os, time, random
from tqdm import tqdm
import logging
import torch, torchvision
import torch.backends.cudnn as cudnn
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader,Subset
from torchvision.datasets import * 



def math_proof_analysis(target_class,pred_class,y_star,y_prime,p_star,p_prime,epsilon,
                        KL_y_prime_p_star,KL_y_star_p_star,KL_y_prime_p_prime,KL_y_star_p_prime,
                        L0_y_prime_p_star,L0_y_star_p_star,L0_y_prime_p_prime,L0_y_star_p_prime,
                        args
                        ):
    # Ensure all inputs are 1D tensors
    y_star = y_star.flatten()
    y_prime = y_prime.flatten()
    p_star = p_star.flatten()
    p_prime = p_prime.flatten()
    
    # Ensure epsilon is a scalar
    if isinstance(epsilon, torch.Tensor):
        epsilon = epsilon.item()
    #print(f"epsilon is {epsilon}")
    v=(y_prime-y_star)/p_star #v
    
    v_norm=torch.norm(v,p=2) #||v||_2 (scalar)
    #print(f"v_norm's values are {v_norm}")
    delta=p_prime-p_star #δ
    delta_L1_norm=torch.norm(delta,p=1) #||δ||_1 (scalar)
    delta_L2_norm=torch.norm(delta,p=2) #||δ||_2 (scalar)
    delta_Linf_norm=torch.norm(delta,p=float('inf')) #||δ||_inf (scalar)
    
    epsilon=delta_L2_norm
    
    #print(f"delta_L2_norm is {delta_L2_norm}")
    delta_max=epsilon*v/v_norm #δ_max
    tau=0.75 #tau
    d=len(y_star) #d
    mu_y_prime_p_star=torch.sum(torch.abs(y_prime-p_star))/d #μ(y_prime,p_star)
    mu_y_star_p_star=torch.sum(torch.abs(y_star-p_star))/d #μ(y_star,p_star)
    
    min_tensor=y_star.clone() #min
    for i in range(d):
        term1=torch.abs(torch.abs(y_prime[i]-p_star[i])-tau*mu_y_prime_p_star) #term1
        term2=torch.abs(torch.abs(y_star[i]-p_star[i])-tau*mu_y_star_p_star) #term2
        min_tensor[i]=torch.min(term1,term2)-tau *delta_L1_norm/d
    
    Delta_L0_p_star=L0_y_prime_p_star-L0_y_star_p_star #Delta_L0(p_star)
    Delta_L0_p_prime=L0_y_prime_p_prime-L0_y_star_p_prime #Delta_L0(p_prime)

    Delta_KL_p_star=KL_y_prime_p_star-KL_y_star_p_star #Delta_KL(p_star)
    Delta_KL_p_prime=KL_y_prime_p_prime-KL_y_star_p_prime #Delta_KL(p_prime)
    #print("previous KL_y_prime_p_star is {}, KL_y_star_p_star is {}".format(KL_y_prime_p_star,KL_y_star_p_star))
    #print(f"Delta_KL_p_star is {Delta_KL_p_star}, Delta_KL_p_prime is {Delta_KL_p_prime}")
    # Calculate Equation (46) from the proof
    # max { Term1, Term2 }
    
    # Common sub-expressions
    mu_ratio = mu_y_prime_p_star / mu_y_star_p_star  # μ(y', p*) / μ(y*, p*)
    #print(f"mu_ratio: {mu_ratio}")
    delta_term = delta_L1_norm / (d * mu_y_star_p_star)  # ||δ||₁ / (d ⋅ μ(y*, p*))
    #print(f"delta_L1 norm: {delta_L1_norm}")
    term_A = mu_ratio + delta_term - 1  # (μ(y', p*) / μ(y*, p*)) + (||δ||₁ / (d ⋅ μ(y*, p*))) - 1
    
    #proof of proposition1:
    
    if Delta_KL_p_star >0 and Delta_KL_p_prime <0 and Delta_L0_p_star >0:
        
        # Equation: sum_{i=1}^{d} (y'_i - y^*_i) * (delta_i / p^*_i)
        y_prime_softmax=y_prime #torch.softmax(y_prime,dim=0)
        #print(f"y_prime_softmax: {y_prime_softmax.shape}")
        y_star_softmax=y_star #torch.softmax(y_star,dim=0)
    
        p_star_softmax=p_star #torch.softmax(p_star,dim=0)
        p_prime_softmax=p_prime #torch.softmax(p_prime,dim=0)

        kl_y_prime_p_star=torch.sum(y_prime_softmax*torch.log(y_prime_softmax/p_star_softmax))
        kl_y_star_p_star=torch.sum(y_star_softmax*torch.log(y_star_softmax/p_star_softmax))
        #print("recomputed KL_y_prime_p_star is {}, KL_y_star_p_star is {}".format(kl_y_prime_p_star,kl_y_star_p_star))
        delta_softmax=p_prime_softmax-p_star_softmax
        R_3=torch.sum(torch.abs(delta_softmax/p_star_softmax)**3)/3
        equation_sum = torch.sum((y_prime_softmax - y_star_softmax) * ((delta_softmax / p_star_softmax) ))
        #equation_sum = torch.sum((y_prime_softmax - y_star_softmax) * (torch.log(p_prime_softmax) - torch.log(p_star_softmax)))
        if equation_sum >Delta_KL_p_star:
            print(f"proposition1 is fulfulled!")
        else:
            print(f"proposition1 is not fulfulled! with LHS {equation_sum} smaller than RHS {Delta_KL_p_star}, recomputed is {kl_y_prime_p_star-kl_y_star_p_star}")
    
        #if Delta_L0_p_star >0 and Delta_L0_p_prime <0:
        # Equation (12): |δ_i| ≥ min{||y'_i - p*_i|| - τμ(y', p*) - (τ||δ||_1)/d, ||y*_i - p*_i|| - τμ(y*, p*) - (τ||δ||_1)/d}
        k = Delta_L0_p_star  # k = ΔL_0(p*)
        #print(f"Proposition 2: Looking for set S with |S| ≥ k = {k}")
        
        # Calculate the right-hand side of the inequality for each dimension
        rhs_values = []
        for i in range(d):
            # Term 1: ||y'_i - p*_i|| - τμ(y', p*) - (τ||δ||_1)/d
            term1 = torch.abs(torch.abs(y_prime[i] - p_star[i]) - tau * mu_y_prime_p_star)- (tau * delta_L1_norm) / d
            
            # Term 2: ||y*_i - p*_i|| - τμ(y*, p*) - (τ||δ||_1)/d
            term2 = torch.abs(torch.abs(y_star[i] - p_star[i]) - tau * mu_y_star_p_star) - (tau * delta_L1_norm) / d
            
            # Take the minimum of the two terms
            min_term = torch.min(term1, term2)
            rhs_values.append(min_term)
        #print(f"min rhs_values: {min(rhs_values)}")
        # Check which dimensions satisfy the inequality |δ_i| ≥ min_term
        satisfied_indices = []
        for i in range(d):
            if torch.abs(delta[i]) >= rhs_values[i]:
                satisfied_indices.append(i)
        
        print(f"Dimensions satisfying Equation (12): {len(satisfied_indices)} out of {d}")
        print(f"Required: |S| ≥ k = {k}")
        
        if len(satisfied_indices) >= k:
            print(f"Proposition 2 is fulfilled! Found {len(satisfied_indices)} dimensions ≥ required {k}")
        else:
            print(f"Proposition 2 is not fulfilled! Found {len(satisfied_indices)} dimensions < required {k}")


        equation_46_results = []
    
        for j in range(d):
            y_star_j = y_star[j]
            y_prime_j = y_prime[j]
            p_star_j = p_star[j]
            v_j = v[j]
            
            # |y*j - p*j|
            abs_y_star_p_star = torch.abs(y_star_j - p_star_j)
            
            # ||v||₂² (squared L2 norm of v)
            v_norm_sq = v_norm**2
            
            # |p*j|
            abs_p_star_j = torch.abs(p_star_j)
            
            # sgn(y'j - y*j)
            sgn_diff = 1 if y_prime_j > y_star_j else (-1 if y_prime_j < y_star_j else 0)
            
            # Term 1: (term_A) * ( |y*j - p*j| ⋅ ||v||₂² ⋅ |p*j| / (||v||₂² ⋅ |p*j| + ε) )
            term1_numerator = abs_y_star_p_star * v_norm * abs_p_star_j
            term1_denominator = v_norm * abs_p_star_j - epsilon
            term1 = term_A * (term1_numerator / term1_denominator)
            
            
            # Term 2: ( p*j ⋅ ||v||₂ ⋅ √(-ΔKL(p*)² + ε²||v||₂²) / (||v||₂² ⋅ p*j - ΔKL(p*)sgn(y'j - y*j)) ) +
            #         (term_A) * ( |y*j - p*j| ⋅ ||v||₂² ⋅ |p*j| / (||v||₂² ⋅ p*j - ΔKL(p*)sgn(y'j - y*j)) )
            
            # First part of Term 2
            sqrt_val = abs(-Delta_KL_p_star**2 + (epsilon**2) * v_norm_sq)
            if sqrt_val < 0:
                print(f"Warning: sqrt_val < 0 for dimension {j}: {sqrt_val}")
                print(f"Skipping this example due to invalid sqrt_val")
                return False  # Skip this example
            
            sqrt_term = sqrt_val**0.5
            term2_part1_numerator = p_star_j * v_norm * sqrt_term
            term2_denominator = v_norm_sq * p_star_j- Delta_KL_p_star * sgn_diff
            # Check for division by zero
            if abs(term2_denominator) < 1e-16:
                print(f"Warning: term2_denominator near zero for dimension {j}: {term2_denominator}")
                print(f"term2_denominator's values are v_norm_sq: {v_norm_sq}, p_star_j: {p_star_j}, Delta_KL_p_star: {Delta_KL_p_star}, sgn_diff: {y_prime_j-y_star_j}")
                print(f"target_class is {target_class}, pred_class is {pred_class}")
                print(f"L2 distance betwee y_star_j and y_prime_j is {torch.norm(y_prime_j-y_star_j,p=2)}")
                print(f"Skipping this example due to division by zero")
                return False  # Skip this example
            
            term2_part1 = term2_part1_numerator / term2_denominator
            
            # Second part of Term 2
            term2_part2_numerator = abs_y_star_p_star * v_norm_sq * abs_p_star_j
            term2_part2 = (term_A) * (term2_part2_numerator / term2_denominator)
            term2 = term2_part1 + term2_part2
            
            # Take the maximum of Term 1 and Term 2
            equation_46_result = max(term1, term2)
            #equation_46_results.append(equation_46_result)
            A=epsilon*torch.abs(v_j)/v_norm
            B=(Delta_KL_p_star*v_j+v_norm*sqrt_term)/v_norm_sq
            C=(5*(epsilon**2)/4-epsilon*Delta_KL_p_star/v_norm)**0.5
            term3=C+term_A*torch.abs(y_star[j]-p_star[j])
            min_term1 = torch.abs(torch.abs(y_prime[j] - p_star[j]) - tau * mu_y_prime_p_star)- (tau * delta_L1_norm) / d
            
            # Term 2: ||y*_i - p*_i|| - τμ(y*, p*) - (τ||δ||_1)/d
            min_term2 = torch.abs(torch.abs(y_star[j] - p_star[j]) - tau * mu_y_star_p_star) - (tau * delta_L1_norm) / d

            # Take the minimum of the two terms
            min_term = torch.min(min_term1, min_term2)
            term_j_=max(A,B)
            term_j=max(term_j_,C)
            if torch.abs(y_prime[j]-y_star[j])>equation_46_result #max(A,B,C)+ tau*(mu_y_prime_p_star+delta_L1_norm/d)-torch.abs(y_star[j]-p_star[j]):
                print(f"Equation (46) is fulfulled! with LHS {torch.abs(y_prime[j]-y_star[j])} larger than RHS {equation_46_result} with A {A}, B {B}, C {C}")
                return True
        print(f"Equation (46) is not fulfulled! with LHS {torch.abs(y_prime[j]-y_star[j])} larger than RHS {equation_46_result} with A {A}, B {B}, C {C}")
        return False

    return False

   