# import magnitude
import torch

# Our original add and normalize algoirthm

def add_and_normalize_asvec(S,h):   
    W = torch.eye(S.shape[0]).to(S.device)
    for iterations in range(h): 
        W = S @ W ## similarity weighted sums 
        ## Now Normalize the rows of W: 
        b = torch.sum(W, dim=1)
        # c = (1/b).diag()
        W = W.diag()/b
        W = W.diag()

        # print("itertaive approximation: "+str(torch.trace(W)))

    return W ## return the weight vector 




def add_and_normalize(S,h):   
    W = torch.eye(S.shape[0]).to(S.device)
    for iterations in range(h): 
        V=torch.diagonal(W).diag()
        W = S @ V ## similarity weighted sums 
        ## Now Normalize the rows of W: 
        b = torch.sum(W, dim=1)
        c = (1/b).diag()
        W = c @ W
        # print("itertaive approximation: "+str(torch.trace(W)))

    return W ## return the weight vector 



def add_and_normalize_points(points,h):
    S=similarity_matrix(points)
    return add_and_normalize(S, h)


def add_and_normalize_points_asvec(points,h):
    S=similarity_matrix(points)
    return add_and_normalize_asvec(S, h)




# Convex optimization GD approach 
# This is full GD without any batching or SGD 
# So all weights are updated in parallel 
# equivant to using batch_size = point count

class Model(torch.nn.Module):
    def __init__(self, S, device):
        super(Model, self).__init__()
        self.device = device
        self.S = S.to(device)
        # initialize weights as identity matrix (n x n, not vector)
        self.weights = torch.nn.Parameter(torch.ones(S.shape[0]).to(device))
        #self.W = torch.eye(S.shape[0]).to(device)
        
    def forward(self): 
        V=self.weights.diag()
        W = self.S @ V ## similarity weighted sums 
        return torch.sum(W, dim=1)


def magnitude_by_SGD(S, h, device, lr=0.01): 
    model = Model(S, device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    target = torch.ones(S.shape[0]).to(device)
    loss_fn = torch.nn.MSELoss()

    for i in range(h): 
        optimizer.zero_grad()
        output = model.forward()
        loss_val = loss_fn(output, target)
        loss_val.backward()
        optimizer.step()
        # print(f"loss_val ({i}): {loss_val.item()}") # uncomment to see loss evolution

    return model.weights 
    # return weights as n x n matrix. Non diagonal entries
    # represent similarity adjusted weights 



def magnitude_by_SGD_points(points, h, device, lr=0.01):
    S = similarity_matrix(points)
    return magnitude_by_SGD(S, h, device, lr)



# Implementation of batched SGD     
# batch_size = 1 equivaluent of pure SGD
# batch_size = 

def magnitude_by_batch_SGD(S, num_epochs=100, batch_size=1, device=None, lr=0.01): 
    model = Model(S, device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    target = torch.ones(S.shape[0]).to(device)
    loss_fn = torch.nn.MSELoss()

    data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(model.weights, target), 
                                            batch_size, shuffle=True)
    
    for epoch in range(num_epochs):
        for batch_id, (X, y) in enumerate(data_iter):

            optimizer.zero_grad()
            output = model.forward()
            loss_val = loss_fn(output, target)
            loss_val.backward()
            optimizer.step()
        # print(f"loss_val ({i}): {loss_val.item()}") # uncomment to see loss evolution

    return model.weights 


def magnitude_by_batch_SGD_points(points, num_epochs=100, batch_size=1, device=None, lr=0.01):
    S = similarity_matrix(points)
    return magnitude_by_batch_SGD(S, num_epochs, batch_size, device, lr)



# Basic matrix inversion approach

def similarity_matrix(points, t= 1):
    # Compute the pairwise distance matrix
    dist = torch.cdist(points, points, p=2).to(torch.float64)
    dist = dist * t  # Apply the scaling factor t
    # Compute the similarity matrix
    sim = torch.exp(-dist)
    return sim

def similarity_vector(points, x):
    # Compute the pairwise distance vector
    dist = torch.cdist(points, x, p=2).to(torch.float64)

    # Compute the similarity vector
    sim = torch.exp(-dist)
    return sim

# def magnitude(S, device):
#     S = S.to(device).to(torch.float64)
#     inverse = torch.inverse(S)
#     return torch.sum(inverse)

def magnitude(points, device, t = 1):
    S = similarity_matrix(points, t = t)
    S = S.to(device).to(torch.float64)
    # inverse = conjugate_gradient(S)
    inverse = torch.inverse(S) 
    return torch.sum(inverse)

def stable_inverse(S, eps=1e-4):
    I = torch.eye(S.size(0), device=S.device, dtype=torch.float64)
    return torch.inverse(S.to(torch.float64) + eps * I)

def conjugate_gradient(S, eps=1e-4):
    I = torch.eye(S.size(0), device=S.device, dtype=torch.float64)
    x = torch.linalg.solve(S.to(torch.float64) + eps * I, I)  # Solve Sx = I
    return x

def weight(S, device):
    S = S.to(device).to(torch.float64)
    inverse = conjugate_gradient(S)
    return torch.sum(inverse, dim=1, dtype=torch.float64)

def magnitudeof_points(points, device):
    S = similarity_matrix(points).to(device)
    return magnitude(S, device)

def weightof_points(points, device, t =1):
    S = similarity_matrix(points, t=t).to(device)
    return weight(S, device)

def weighted_distance(points, x, device, w=None, t = 1):
    w = weightof_points(points, device) if w is None else w
    # Compute similarity vector directly
    dist = torch.norm(points - x.to(points), dim=1).to(torch.float64)
    dist = dist * t  # Apply the scaling factor t
    sim = torch.exp(-dist).to(torch.float64)
    return w * sim  # Element-wise product, shape [N]

def privacy_mass(points, x, device, w=None, t = 1):
    wDist = weighted_distance(points, x, device, w, t = t)
    return float(wDist.sum())

#Computing the difference between Magnitudes.

def diff_magnitude_distance(points0, points1, device, points_all = None, t = 1):
    if points_all is None:
        points_all = torch.cat((points0, points1), dim=0)
    magnitude_all = magnitude(points_all, device, t = t)
    magnitude0 = magnitude(points0, device, t = t)
    magnitude1 = magnitude(points1, device, t = t)
    diff_magnitude = magnitude0 + magnitude1 - magnitude_all
    return diff_magnitude.item()  # Return as a Python float


def norm_diff_magnitude_distance(points0, points1, device, points_all=None, t=1, normalize=False):
    if points_all is None:
        points_all = torch.cat((points0, points1), dim=0)
    magnitude_all = magnitude(points_all, device, t=t)
    magnitude0 = magnitude(points0, device, t=t)
    magnitude1 = magnitude(points1, device, t=t)
    diff_magnitude = 2 * magnitude_all - (magnitude0 + magnitude1)
    diff_magnitude = diff_magnitude / magnitude_all  # Normalize by the total magnitude
    if normalize:
        diff_magnitude = diff_magnitude / magnitude_all
    return diff_magnitude


def norm_diff_magnitude_distance_grad(points0, points1, device, points_all=None, t=1, normalize=False, eps = 1e-5):
    points0 = points0.to(torch.float64)
    points1 = points1.to(torch.float64)
    
    if points_all is None:
        points_all = torch.cat((points0, points1), dim=0)

    if torch.equal(points0, points1):
        return torch.tensor(0.0, device=device, requires_grad=True, dtype=torch.float64)
    
    dist_all = torch.cdist(points_all, points_all, p=2) * t
    dist0 = torch.cdist(points0, points0, p=2) * t
    dist1 = torch.cdist(points1, points1, p=2) * t
    
    eye_all = torch.eye(dist_all.size(0), device=device, dtype=torch.float64)
    eye0 = torch.eye(dist0.size(0), device=device, dtype=torch.float64)
    eye1 = torch.eye(dist1.size(0), device=device, dtype=torch.float64)
    
    S_all = torch.exp(-dist_all) + eps * eye_all
    S0 = torch.exp(-dist0) + eps * eye0
    S1 = torch.exp(-dist1) + eps * eye1

    # inverse_all = torch.linalg.pinv(S_all)
    # inverse0 = torch.linalg.pinv(S0)
    # inverse1 = torch.linalg.pinv(S1)
    try:
        inverse_all = torch.inverse(S_all)
    except RuntimeError:
        inverse_all = torch.linalg.pinv(S_all)
    try:
        inverse0 = torch.inverse(S0)
    except RuntimeError:
        inverse0 = torch.linalg.pinv(S0)
    try:
        inverse1 = torch.inverse(S1)
    except RuntimeError:
        inverse1 = torch.linalg.pinv(S1)

        
    magnitude_all = torch.sum(inverse_all)
    magnitude0 = torch.sum(inverse0)
    magnitude1 = torch.sum(inverse1)

    # diff_magnitude = magnitude_all - (magnitude0 + magnitude1)
    diff_magnitude = 2 * magnitude_all - (magnitude0 + magnitude1)
    if normalize:
        diff_magnitude = diff_magnitude / magnitude_all
    return diff_magnitude


def norm_magnitude_overlap_grad(points0, points1, device, points_all=None, t = 1, normalize=False, eps = 1e-5):
    points0 = points0.to(torch.float64)
    points1 = points1.to(torch.float64)
    
    if points_all is None:
        points_all = torch.cat((points0, points1), dim=0)

    if torch.equal(points0, points1):
        return torch.tensor(0.0, device=device, requires_grad=True, dtype=torch.float64)
 
    dist_all = torch.cdist(points_all, points_all, p=2) * t
    dist0 = torch.cdist(points0, points0, p=2) * t
    dist1 = torch.cdist(points1, points1, p=2) * t
    
    eye_all = torch.eye(dist_all.size(0), device=device, dtype=torch.float64)
    eye0 = torch.eye(dist0.size(0), device=device, dtype=torch.float64)
    eye1 = torch.eye(dist1.size(0), device=device, dtype=torch.float64)
    
    S_all = torch.exp(-dist_all) + eps * eye_all
    S0 = torch.exp(-dist0) + eps * eye0
    S1 = torch.exp(-dist1) + eps * eye1

    # inverse_all = torch.linalg.pinv(S_all)
    # inverse0 = torch.linalg.pinv(S0)
    # inverse1 = torch.linalg.pinv(S1)
    try:
        inverse_all = torch.inverse(S_all)
    except RuntimeError:
        inverse_all = torch.linalg.pinv(S_all)
    try:
        inverse0 = torch.inverse(S0)
    except RuntimeError:
        inverse0 = torch.linalg.pinv(S0)
    try:
        inverse1 = torch.inverse(S1)
    except RuntimeError:
        inverse1 = torch.linalg.pinv(S1)
     
    magnitude_all = torch.sum(inverse_all)
    magnitude0 = torch.sum(inverse0)
    magnitude1 = torch.sum(inverse1)

    overlap_magnitude = (magnitude0 + magnitude1) - magnitude_all
    if normalize:
        overlap_magnitude = overlap_magnitude / magnitude_all
    
    return overlap_magnitude


def norm_max_magnitude_overlap_grad(points0, points1, device, points_all=None, normalize=False, eps=1e-5, steps=10, max_t=50, min_t=0):
    a = min_t
    b = max_t
    
    for step in range(steps):
        # Compute three points
        mid = (a + b) / 2
        
        overlap_a = norm_magnitude_overlap_grad(
            points0, points1, device, 
            points_all=points_all, 
            t=a, 
            normalize=normalize, 
            eps=eps
        )
        
        overlap_mid = norm_magnitude_overlap_grad(
            points0, points1, device, 
            points_all=points_all, 
            t=mid, 
            normalize=normalize, 
            eps=eps
        )
        
        overlap_b = norm_magnitude_overlap_grad(
            points0, points1, device, 
            points_all=points_all, 
            t=b, 
            normalize=normalize, 
            eps=eps
        )
        
        # Find which is largest and update interval
        if overlap_a >= overlap_mid and overlap_a >= overlap_b:
            # Maximum is in [a, mid]
            b = mid
            max_overlap = overlap_a
            t_arg_max = a
        elif overlap_b >= overlap_mid and overlap_b >= overlap_a:
            # Maximum is in [mid, b]
            a = mid
            max_overlap = overlap_b
            t_arg_max = b
        else:
            # overlap_mid is largest, maximum could be in either half
            # Use ternary search: check which interval to keep
            left_third = a + (mid - a) / 2
            right_third = mid + (b - mid) / 2
            
            overlap_left = norm_magnitude_overlap_grad(
                points0, points1, device, 
                points_all=points_all, 
                t=left_third, 
                normalize=normalize, 
                eps=eps
            )
            
            overlap_right = norm_magnitude_overlap_grad(
                points0, points1, device, 
                points_all=points_all, 
                t=right_third, 
                normalize=normalize, 
                eps=eps
            )
            
            if overlap_left > overlap_right:
                # Maximum is in [a, mid]
                b = mid
            else:
                # Maximum is in [mid, b]
                a = mid
            
            max_overlap = overlap_mid
            t_arg_max = mid
    
    return max_overlap, t_arg_max




def ratio_magnitude_distance(points0, points1, device, points_all = None, t = 1):
    if points_all is None:
        points_all = torch.cat((points0, points1), dim=0)
    magnitude_all = magnitude(points_all, device, t = t)
    magnitude0 = magnitude(points0, device, t = t)
    magnitude1 = magnitude(points1, device, t = t)
    ratio_magnitude = ((2* magnitude_all) /(magnitude0 + magnitude1) )-1
    return ratio_magnitude.item()  # Return as a Python float

def subtract_magnitude_distance(points0, points1, device, points_all = None, t = 1):
    if points_all is None:
        points_all = torch.cat((points0, points1), dim=0)
    magnitude_all = magnitude(points_all, device, t = t)
    magnitude0 = magnitude(points0, device, t = t)
    magnitude1 = magnitude(points1, device, t = t)
    subtract_magnitude0 = magnitude_all - magnitude0
    subtract_magnitude1 = magnitude_all - magnitude1
    return subtract_magnitude0.item(), subtract_magnitude1.item()  # Return as a Python float

def marginal_magnitude_distance(points0, points1, device, points_all = None, t = 1):
    if points_all is None:
        points_all = torch.cat((points0, points1), dim=0)
    magnitude_all = magnitude(points_all, device, t = t)
    magnitude0 = magnitude(points0, device, t = t)
    magnitude1 = magnitude(points1, device, t = t)
    marginal_magnitude0 = magnitude0 / magnitude_all
    marginal_magnitude1 = magnitude1 / magnitude_all
    return marginal_magnitude0.item(), marginal_magnitude1.item()  # Return as a Python float


# Computes the magnitude potential distance between two sets of points. where the distance is defined as the weighted sum (w_i * w_j) of the similarity between points, scaled by a factor t.

# The weights are computed within each subset Class0 and Class1.
def magnitude_potential_distance(points0, points1, device, t = 1):
    w0 = weightof_points(points0, device)  # shape: [N]
    w1 = weightof_points(points1, device)  # shape: [M]

    # Compute pairwise squared Euclidean distances
    diff = points0.unsqueeze(1) - points1.unsqueeze(0)  # shape: [N, M, D]
    dist = torch.norm(diff, dim=2).to(torch.float64)  # shape: [N, M]
    dist = dist * t  # Apply the scaling factor t

    # Apply similarity function
    sim = torch.exp(-dist)  # shape: [N, M]

    # Outer product of weights
    weight_matrix = w0[:, None] * w1[None, :]  # shape: [N, M]

    wDist = torch.sum(sim * weight_matrix)  # scalar
    return wDist

# The weights are computed within the set of entire points, i.e. Class0 U Class1.
def magnitude_potential_distance2(points1, points2, device, t = 1):
    points = torch.cat((points1, points2), dim=0)  # Concatenate points1 and points2 along the first dimension, check with the data shape
    w = weightof_points(points, device)  # shape: [N + M]
    w1 = w[:points1.shape[0]]  # shape: [N]
    w2 = w[points1.shape[0]:]  # shape: [M]

    # Compute pairwise squared Euclidean distances
    diff = points1.unsqueeze(1) - points2.unsqueeze(0)  # shape: [N, M, D]
    dist = torch.norm(diff, dim=2).to(torch.float64)  # shape: [N, M]
    dist = dist * t  # Apply the scaling factor t

    # Apply similarity function
    sim = torch.exp(-dist)  # shape: [N, M]

    # Outer product of weights
    weight_matrix = w1[:, None] * w2[None, :]  # shape: [N, M]

    wDist = torch.sum(sim * weight_matrix)  # scalar
    return wDist


# def magnitude_potential_distance(points1, points2, device):
#     w1 = weightof_points(points1, device)
#     w2 = weightof_points(points2, device)
#     wDist = 0.0
#     for i in range(len(points1)):
#         for j in range(len(points2)):
#             dist = torch.dist(points1[i].unsqueeze(dim=0), points2[j].unsqueeze(dim=0), p=2)
#             sim = torch.exp(-dist)
#             wDist += sim * w1[i] * w2[j]

#     return wDist








    