import torch

# To install OR-Tools follow the instructions in https://developers.google.com/optimization/install/python 
from ortools.linear_solver import pywraplp 
import torch.nn.functional as F
# https://github.com/facebookresearch/faiss/blob/main/INSTALL.md 
import faiss
import math
import itertools


def create_mock_n_indices(array_index_to_instance, instance_to_array_index, instance_to_gold_name, args): 
    n_indices = list()
    for idx in range(len(array_index_to_instance)):
        (l,i) = array_index_to_instance[idx]
        n_indices.append([instance_to_array_index[(ll,ii)] for (ll,ii),label in instance_to_gold_name.items() if l != ll and label == instance_to_gold_name[(l,i)]][0:args.structure_k])
    return n_indices

def percentage_of_ground_truth_names(n_indices, array_index_to_instance, instance_to_gold_name): 
    for idx, _ in array_index_to_instance.items():
        closest_objects_indices = n_indices[idx]
        closest_objects_labels = [instance_to_gold_name[array_index_to_instance[o]] for o in closest_objects_indices]
        gold_name = instance_to_gold_name[array_index_to_instance[idx]] 
        # print(f"\# Occurances of gold label {gold_name} of box {array_index_to_instance[idx]} in its neighborhood: {closest_objects_labels.count(gold_name)}/{len(closest_objects_labels)}. Actual labels of neighbors {closest_objects_labels} \n")

# Return features for input instance using a pretrained VLM
# TODO control training in CUDA/GPU/MP4 
def get_feature(model, instance, args):
    model.eval()
    instance = instance.to("cpu")
    if args.model_name == 'clip':
        feature = model.encode_image(instance)
    elif args.model_name in ['blip', 'blip2', 'albef']:
        sample = {"image": instance, "text_input": None}
        feature = model.extract_features(sample, mode="image").image_embeds[:, 0 ,:]
    else:
        if '_i' in args.model_name:
            feature = model(instance)
        else:
            feature, _ = model(instance)
    return feature.detach().clone().cpu()

def exists_proof_that_maps_instances_to_same_class(proof1, all_proofs2, i1, i2):
    return proof1[i1] in [proof[i2] for proof in all_proofs2]

# Prune pre-images based on the proximity of features of bounding boxes for property name.  
# Assumptions: each sample is a list of M instances. 
# Each sample has one or more proofs, where each proof is a list of M labels. 
# pre_images[i] is the list of proofs for the i-th training sample.  
# model is the VLM to extract the features 
# gold_proofs[i] is the list of gold labels for the i-th sample. It is needed for logging purposes.  
def structural_pruning(samples, pre_images, gold_proofs, sample_features, model, args): 

    # E.g. sample[i] is a list of 3 images. Each of these images will be passed to the VLM model.  
    # E.g., pre_images[i] is [[1,4,5],[5,4,1]]
    # E.g., gold_proofs[i] is [1,4,5]
    solver = pywraplp.Solver.CreateSolver("SCIP")
   
    # Map of integer variables to their names 
    E = dict() # Edge variables
    EE = dict() # Filtering edge variables
    I = dict() # all proofs variables
    II = dict() # proofs that participate in edge constraints variables
    EE_to_number_of_pruned_proofs = dict() # Maps each filtering edge to the number of proofs it prunes  

    # Create the top-k proximity matrix 
    array_index = 0
    instance_to_array_index = dict() 
    instance_to_gold_name = dict()
    #TODO 768 is hardcoded -- it is the dimension of the feature vector
    #features = torch.zeros(len(samples)*args.sum_n, 768)
    #flatterned = list()
    features = list()
    for l in range(len(samples)): 
        sample = samples[l]  
        instance_to_array_index.update({(l,i):idx for i,idx in zip(range(len(sample)), range(array_index, array_index + len(sample)))})
        instance_to_gold_name.update({(l,i):gold_proofs[l][i] for i in range(len(sample))})
        array_index = array_index + len(sample)
        features.extend([f for f in sample_features[l]]) 
        #flatterned.extend([x for x in sample])
    #f = get_feature(model, torch.stack(flatterned, dim=0), args)
    #for i in range(len(flatterned)): 
    #    features[i] = f[i].detach().clone().cpu()

    array_index_to_instance = {idx:(l,i) for (l,i),idx in instance_to_array_index.items()}
    
    if args.mock_proximity: 
          n_indices = create_mock_n_indices(array_index_to_instance, instance_to_array_index, instance_to_gold_name, args)
    else: 
        features = torch.FloatTensor(features)
        # The following concerns top-k filtering per object: 
        if args.structure_k > 0:
            features = F.normalize(features, dim=1).numpy()
            index = faiss.IndexFlatL2(features.shape[1])
            non_one_dims = [i for i in features.shape if i != 1]
            features = features.reshape(tuple(non_one_dims))
            index.add(features)
            _, ID = index.search(features, args.structure_k+1)
            n_indices = torch.from_numpy(ID[:, 1:args.structure_k+1])
        # The following concerns filtering out the objects which are not among the args.percent highest values
        elif args.percent > 0: 
            features = F.normalize(features, dim=1)
            similarity_matrix = torch.matmul(features, features.T)
            similarity_matrix = similarity_matrix.fill_diagonal_(0)
            s = torch.reshape(similarity_matrix, (-1,))
            n_v, _ = s.topk(math.ceil(similarity_matrix.shape[0] * similarity_matrix.shape[0] * args.percent), dim=0, largest=True, sorted=True)
            n_indices = list()
            for i in range(similarity_matrix.shape[0]):
                s = (similarity_matrix[i]>= n_v[-1]).nonzero().squeeze()
                if s.dim() > 0:
                    n_indices.append(s[:10])
                else:
                    n_indices.append(torch.tensor([s]))
  
    # percentage_of_ground_truth_names(n_indices, array_index_to_instance, instance_to_gold_name)
    
    # Initialize the variables of the integer linear program that correspond to proofs
    for l in range(len(samples)): 
        for k in range(len(pre_images[l])):
            # Create an integer variable I_{l,k}
            # If the integer variable I_{l,k} is mapped to true, then the k-th proof of the l-th sample must be kept; 
            # otherwise, it should be discarded.
            I[f"I_{l}_{k}"] = solver.IntVar(0, 1, f"I_{l}_{k}")
    
    if args.structure_k > 0:     
        for l1,l2 in itertools.product(range(len(samples)), range(len(samples))): 
            if l1 != l2:
                sample1 = samples[l1]
                proofs1 = pre_images[l1]
                sample2 = samples[l2]
                proofs2 = pre_images[l2]
                
                for i1,i2 in itertools.product(range(len(sample1)), range(len(sample2))):  
                    array_index1 = instance_to_array_index[(l1,i1)]
                    array_index2 = instance_to_array_index[(l2,i2)]
                    # If the condition below is satisfied, then obj2 is within the top-k most proximal points of sample1[i1]
                    if array_index2 in n_indices[array_index1]:
                        for k1 in range(len(proofs1)):                         
                            # Initialize the integer variable E_{l1,l2,obj1,obj2}
                            if f"E_{l1}_{l2}_{i1}_{i2}" not in E:
                                E[f"E_{l1}_{l2}_{i1}_{i2}"] = solver.IntVar(0, 1, f"E_{l1}_{l2}_{i1}_{i2}")                                                                         
                            if exists_proof_that_maps_instances_to_same_class(proofs1[k1], proofs2, i1, i2):
                                # Add constraint ${1-E_{l1,l2,i1,i2} + I_{l1,k1} \geq 1
                                solver.Add(1 - E[f"E_{l1}_{l2}_{i1}_{i2}"] + I[f"I_{l1}_{k1}"]>= 1)
                            else:
                                # Add constraint ${1-E_{l1,l2,i1,i2} + 1 - I_{l1,k1} \geq 1}$
                                solver.Add(1 - E[f"E_{l1}_{l2}_{i1}_{i2}"] + 1 - I[f"I_{l1}_{k1}"]>= 1)
                                EE[f"E_{l1}_{l2}_{i1}_{i2}"] = E[f"E_{l1}_{l2}_{i1}_{i2}"]
                                # Increase by 1 the number of proofs pruned by this edge 
                                if f"E_{l1}_{l2}_{i1}_{i2}" not in EE_to_number_of_pruned_proofs: 
                                    EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{i1}_{i2}"] = 1
                                else: 
                                    EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{i1}_{i2}"] = EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{i1}_{i2}"] + 1
                            
                            # Keep track of proofs that participate in an edge constraint
                            II[f"I_{l1}_{k1}"] = I[f"I_{l1}_{k1}"]
    elif args.percent > 0:  
        for l1,l2 in itertools.product(range(len(samples)), range(len(samples))): 
            if l1 < l2:
                sample1 = samples[l1]
                proofs1 = pre_images[l1]
                sample2 = samples[l2]
                proofs2 = pre_images[l2]
                
                for i1,i2 in itertools.product(range(len(sample1)), range(len(sample2))):  
                    array_index1 = instance_to_array_index[(l1,i1)]
                    array_index2 = instance_to_array_index[(l2,i2)]
                    # If the condition below is satisfied, then due to the symmetry of the symmetry_matrix the distance between
                    # the representations of obj1 and obj2 is small 
                    if array_index2 in n_indices[array_index1]:

                        for k1 in range(len(proofs1)):                         
                            # Initialize the integer variable E_{l1,l2,obj1,obj2}
                            if f"E_{l1}_{l2}_{i1}_{i2}" not in E:
                                E[f"E_{l1}_{l2}_{i1}_{i2}"] = solver.IntVar(0, 1, f"E_{l1}_{l2}_{i1}_{i2}")                                                                         
                            if exists_proof_that_maps_instances_to_same_class(proofs1[k1], proofs2, i1, i2):
                                # Add constraint ${1-E_{l1,l2,i1,i2} + I_{l1,k1} \geq 1
                                solver.Add(1 - E[f"E_{l1}_{l2}_{i1}_{i2}"] + I[f"I_{l1}_{k1}"]>= 1)
                            else:
                                # Add constraint ${1-E_{l1,l2,i1,i2} + 1 - I_{l1,k} \geq 1}$
                                solver.Add(1 - E[f"E_{l1}_{l2}_{i1}_{i2}"] + 1 - I[f"I_{l1}_{k1}"]>= 1)
                            
                            # Keep track of proofs that participate in an edge constraint
                            II[f"I_{l1}_{k1}"] = I[f"I_{l1}_{k1}"]   

                        for k2 in range(len(proofs2)):                         
                            # Initialize the integer variable E_{l1,l2,obj1,obj2}
                            if f"E_{l1}_{l2}_{i1}_{i2}" not in E:
                                E[f"E_{l1}_{l2}_{i1}_{i2}"] = solver.IntVar(0, 1, f"E_{l1}_{l2}_{i1}_{i2}")                                                                         
                            if exists_proof_that_maps_instances_to_same_class(proofs1[k1], proofs2, i1, i2):
                                # Add constraint ${1-E_{l1,l2,i1,i2} + I_{l2,k2} \geq 1
                                solver.Add(1 - E[f"E_{l1}_{l2}_{i1}_{i2}"] + I[f"I_{l2}_{k2}"]>= 1)
                            else:
                                # Add constraint ${1-E_{l1,l2,i1,i2} + 1 - I_{l2,k2} \geq 1}$
                                solver.Add(1 - E[f"E_{l1}_{l2}_{i1}_{i2}"] + 1 - I[f"I_{l2}_{k2}"]>= 1)
                            
                            # Keep track of proofs that participate in an edge constraint
                            II[f"I_{l2}_{k2}"] = I[f"I_{l2}_{k2}"]   
    
    for l in range(len(samples)):
        proofs = pre_images[l]
        # Add the constraint ${\sum \limits_{k \in [len(proofs)]} I_{\ell,k} \geq 1}$
        solver.Add(sum([I[f"I_{l}_{k}"] for k in range(len(proofs))]) >= 1)
        # We need to add this constraint to make sure that all proofs that are not involved in any constraint are kept
        # such proofs would exist in a thresholded top-k setting; not in standard top-k
        [solver.Add(I[f"I_{l}_{k}"] == 1) for k in range(len(proofs)) if f"I_{l}_{k}" not in II]
    # print("Added")

    # Add the objective ${\max \sum \limits_{\ell,\ell',x,x'} n_{\ell,\ell',x,x'} EE_{\ell,\ell',x,x'}}$, 
    # where n_{\ell,\ell',x,x'} is the number of edges pruned by EE_{\ell,\ell',x,x'}
    solver.Maximize(solver.Sum([EE_to_number_of_pruned_proofs[e_key] * e_var for e_key,e_var in EE.items()]))

    # Prints the program in the console
    # print(solver.ExportModelAsLpFormat(False).replace('\\', '').replace(',_', ','), sep='\n')

    # Solve the linear program 
    status = solver.Solve()

    if status == pywraplp.Solver.OPTIMAL:
        # print("Solution:")
        # print("Objective value =", solver.Objective().Value())
        # print(f"Problem solved in {solver.wall_time():d} milliseconds")
                
        new_preimage = []
        for l in range(len(samples)): 
            proofs = pre_images[l]
            new_proofs = []
            for k in range(len(proofs)):
                # If the integer variable I_{l,k} is mapped to true, then the k-th proof of the l-th sample must be kept; 
                # otherwise, it should be discarded.
                if I[f"I_{l}_{k}"].solution_value() == 1: 
                    new_proofs.append(proofs[k])
            # print(f"\# Ground-truth kept in the {l}-th sample: {gold_proofs[l] in new_proofs}. Size of pre-image before: {len(proofs)}. Size of pre-image after: {len(new_proofs)}.\n")
            if args.mock_proximity: 
                assert gold_proofs[l] in new_proofs, f"Ground-truth {gold_proofs[l]} not in the kept proofs {new_proofs}.\n"
            new_preimage.append(new_proofs)
        
        # return the training samples with filtered pre-images
        return new_preimage
    else:
        # print("The problem does not have an optimal solution.)
        # TODO Throw runtime error
        pass
