import torch

# To install OR-Tools follow the instructions in https://developers.google.com/optimization/install/python 
from ortools.linear_solver import pywraplp 
import torch.nn.functional as F
# https://github.com/facebookresearch/faiss/blob/main/INSTALL.md 
import faiss
import math
import itertools

def create_mock_n_indices(array_index_to_instance, instance_to_array_index, instance_to_gold_name, args): 
    n_indices = list()
    for idx in range(len(array_index_to_instance)):
        (l,i) = array_index_to_instance[idx]
        n_indices.append([instance_to_array_index[(ll,ii)] for (ll,ii),label in instance_to_gold_name.items() if l != ll and label == instance_to_gold_name[(l,i)]][0:args.structure_k])
    return n_indices

def percentage_of_ground_truth_names(n_indices, array_index_to_instance, instance_to_gold_name): 
    for idx, _ in array_index_to_instance.items():
        closest_objects_indices = n_indices[idx]
        closest_objects_labels = [instance_to_gold_name[array_index_to_instance[o]] for o in closest_objects_indices]
        gold_name = instance_to_gold_name[array_index_to_instance[idx]] 
        # print(f"\# Occurances of gold label {gold_name} of box {array_index_to_instance[idx]} in its neighborhood: {closest_objects_labels.count(gold_name)}/{len(closest_objects_labels)}. Actual labels of neighbors {closest_objects_labels} \n")

# Prune pre-images based on the proximity of features of bounding boxes for property name.  
# Assumptions: each sample is a list of M instances. 
# Each sample has one or more proofs, where each proof is a list of M labels. 
# pre_images[i] is the list of proofs for the i-th training sample.  
# model is the VLM to extract the features 
# gold_proofs[i] is the list of gold labels for the i-th sample. It is needed for logging purposes.  
def structural_pruning(video_as_list, pre_images, oracle_gt, video_features, model, args): 

    # E.g. sample[i] is a list of 3 images. Each of these images will be passed to the VLM model.  
    # E.g., pre_images[i] is [[1,4,5],[5,4,1]]
    # E.g., gold_proofs[i] is [1,4,5]
    # print("Creating solver")
    solver = pywraplp.Solver.CreateSolver("SCIP")
    # print("Solver created")
   
    # Map of integer variables to their names 
    E = dict() # Edge variables
    EE = dict() # Filtering edge variables
    I = dict() # all proofs variables
    II = dict() # proofs that participate in edge constraints variables
    EE_to_number_of_pruned_proofs = dict() # Maps each filtering edge to the number of proofs it prunes  

    # Create the top-k proximity matrix 
    array_index = 0
    instance_to_array_index = dict() 
    instance_to_gold_name = dict()
    features = list()
    for l in range(len(video_as_list)): 
        sample = video_as_list[l]  
        instance_to_array_index.update({(l,i):idx for i,idx in zip(range(len(sample)), range(array_index, array_index + len(sample)))})
        instance_to_gold_name.update({(l,i):oracle_gt[l][i] for i in range(len(sample))})
        array_index = array_index + len(sample)
        features.extend([f for f in video_features[l]])

    array_index_to_instance = {idx:(l,i) for (l,i),idx in instance_to_array_index.items()}

    # print("Features created")
    
    if args.mock_proximity: 
          n_indices = create_mock_n_indices(array_index_to_instance, instance_to_array_index, instance_to_gold_name, args)
    else: 
        #TODO Code needs to be tested
        # print(len(features))
        # features = torch.FloatTensor(features)
        features = torch.stack(features).cpu()
        # The following concerns top-k filtering per object: 
        if args.structure_k > 0:
            features = F.normalize(features, dim=1).detach().numpy()
            index = faiss.IndexFlatL2(features.shape[1])
            non_one_dims = [i for i in features.shape if i != 1]
            features = features.reshape(tuple(non_one_dims))
            index.add(features)
            _, ID = index.search(features, args.structure_k+1)
            n_indices = torch.from_numpy(ID[:, 1:args.structure_k+1])
        # The following concerns filtering out the objects which are not among the args.percent highest values
        elif args.percent > 0: 
            features = F.normalize(features, dim=1)
            similarity_matrix = torch.matmul(features, features.T)
            similarity_matrix = similarity_matrix.fill_diagonal_(0)
            s = torch.reshape(similarity_matrix, (-1,))
            n_v, _ = s.topk(math.ceil(similarity_matrix.shape[0] * similarity_matrix.shape[0] * args.percent), dim=0, largest=True, sorted=True)
            n_indices = list()
            for i in range(similarity_matrix.shape[0]):
                s = (similarity_matrix[i]>= n_v[-1]).nonzero().squeeze()
                if s.dim() > 0:
                    n_indices.append(s[:10])
                else:
                    n_indices.append(torch.tensor([s]))
  
    # percentage_of_ground_truth_names(n_indices, array_index_to_instance, instance_to_gold_name)
    
    # Initialize the variables of the integer linear program that correspond to proofs
    for l in range(len(video_as_list)): 
        for i in range(len(pre_images[l])):
            for k in range(len(pre_images[l][i])):
                # Create an integer variable I_{l,i,k}
                # If the integer variable I_{l,i,k} is mapped to true, then the k-th proof of the i-th frame of the l-th sample must be kept; 
                # otherwise, it should be discarded.
                I[f"I_{l}_{i}_{k}"] = solver.IntVar(0, 1, f"I_{l}_{i}_{k}")
    
    if args.structure_k > 0:     
        for l1,l2 in itertools.product(range(len(video_as_list)), range(len(video_as_list))): 
            if l1 != l2:
                sample1 = video_as_list[l1]
                proofs1 = pre_images[l1]
                sample2 = video_as_list[l2]
                proofs2 = pre_images[l2]
                
                for i1,i2 in itertools.product(range(len(sample1)), range(len(sample2))):  
                    array_index1 = instance_to_array_index[(l1,i1)]
                    array_index2 = instance_to_array_index[(l2,i2)]
                    # If the condition below is satisfied, then obj2 is within the top-k most proximal points of sample1[i1]
                    if array_index2 in n_indices[array_index1]:
                        for k1 in range(len(proofs1[i1])):                         
                            # Initialize the integer variable E_{l1,l2,obj1,obj2}
                            if f"E_{l1}_{l2}_{i1}_{i2}" not in E:
                                E[f"E_{l1}_{l2}_{i1}_{i2}"] = solver.IntVar(0, 1, f"E_{l1}_{l2}_{i1}_{i2}")                                                                         
                            if proofs1[i1][k1] in proofs2[i2]:
                                # Add constraint ${1-E_{l1,l2,i1,i2} + I_{l1,i1,k1} \geq 1
                                solver.Add(1 - E[f"E_{l1}_{l2}_{i1}_{i2}"] + I[f"I_{l1}_{i1}_{k1}"]>= 1)
                            else:
                                # Add constraint ${1-E_{l1,l2,i1,i2} + 1 - I_{l1,i1,k1} \geq 1}$
                                solver.Add(1 - E[f"E_{l1}_{l2}_{i1}_{i2}"] + 1 - I[f"I_{l1}_{i1}_{k1}"]>= 1)
                                EE[f"E_{l1}_{l2}_{i1}_{i2}"] = E[f"E_{l1}_{l2}_{i1}_{i2}"]
                                # Increase by 1 the number of proofs pruned by this edge 
                                if f"E_{l1}_{l2}_{i1}_{i2}" not in EE_to_number_of_pruned_proofs: 
                                    EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{i1}_{i2}"] = 1
                                else: 
                                    EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{i1}_{i2}"] = EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{i1}_{i2}"] + 1
                            
                            # Keep track of proofs that participate in an edge constraint
                            II[f"I_{l1}_{i1}_{k1}"] = I[f"I_{l1}_{i1}_{k1}"]
    
    for l in range(len(video_as_list)):
        for i in range(len(pre_images[l])):
            proofs = pre_images[l][i]
            # Add the constraint ${\sum \limits_{k \in [len(proofs)]} I_{\ell,i,k} \geq 1}$
            solver.Add(sum([I[f"I_{l}_{i}_{k}"] for k in range(len(proofs))]) >= 1)
            # We need to add this constraint to make sure that all proofs that are not involved in any constraint are kept
            # such proofs would exist in a thresholded top-k setting; not in standard top-k
            [solver.Add(I[f"I_{l}_{i}_{k}"] == 1) for k in range(len(proofs)) if f"I_{l}_{i}_{k}" not in II]

    # Add the objective ${\max \sum \limits_{\ell,\ell',x,x'} n_{\ell,\ell',x,x'} EE_{\ell,\ell',x,x'}}$, 
    # where n_{\ell,\ell',x,x'} is the number of edges pruned by EE_{\ell,\ell',x,x'}
    solver.Maximize(solver.Sum([EE_to_number_of_pruned_proofs[e_key] * e_var for e_key,e_var in EE.items()]))

    # Prints the program in the console
    # print(solver.ExportModelAsLpFormat(False).replace('\\', '').replace(',_', ','), sep='\n')

    # Solve the linear program 
    status = solver.Solve()

    if status == pywraplp.Solver.OPTIMAL:
        # print("Solution:")
        # print("Objective value =", solver.Objective().Value())
        # print(f"Problem solved in {solver.wall_time():d} milliseconds")
                
        new_preimage = []
        for l in range(len(video_as_list)): 
            sample_preimage = []
            for i in range(len(pre_images[l])):
                frame_preimage = []
                for k in range(len(pre_images[l][i])):
                    # If the integer variable I_{l,i,k} is mapped to true, then the k-th proof of the i-th frame of the l-th sample must be kept; 
                    # otherwise, it should be discarded.
                    if I[f"I_{l}_{i}_{k}"].solution_value() == 1: 
                        frame_preimage.append(pre_images[l][i][k])
                # print(f"\# Ground-truth kept in the {l}-th sample {i}-th frame: {oracle_gt[l][i] in frame_preimage}. Size of pre-image before: {len(pre_images[l][i])}. Size of pre-image after: {len(frame_preimage)}.\n")
                #if args.mock_proximity: 
                #    assert oracle_gt[l][i] in frame_preimage
                sample_preimage.append(frame_preimage)
            new_preimage.append(sample_preimage)
        
        # return the training samples with filtered pre-images
        return new_preimage
    else:
        # print("The problem does not have an optimal solution.)
        # TODO Throw runtime error
        pass