import torch

import sys 
sys.path.append("../")

from preprocessing.utilities import getArguments, getPredicate
from models.idx2word import Idx2Word
from preprocessing.utilities import getPredicate

# To install OR-Tools follow the instructions in https://developers.gpipoogle.com/optimization/install/python 
from ortools.linear_solver import pywraplp 
import torch.nn.functional as F
# https://github.com/facebookresearch/faiss/blob/main/INSTALL.md 
import faiss
import math
import itertools
from tqdm import tqdm

def create_mock_n_indices(array_index_to_bounding_box, bounding_box_to_array_index, bounding_box_to_gold_name, args): 
    n_indices = list()
    for idx in range(len(array_index_to_bounding_box)):
        o = array_index_to_bounding_box[idx]
        n_indices.append([bounding_box_to_array_index[oo] for oo,label in bounding_box_to_gold_name.items() if o != oo and label == bounding_box_to_gold_name[o]])
    return n_indices

def percentage_of_ground_truth_names(n_indices, array_index_to_bounding_box, bounding_box_to_gold_name): 
    for idx, _ in array_index_to_bounding_box.items():
        closest_objects_indices = n_indices[idx]
        closest_objects_labels = [bounding_box_to_gold_name[array_index_to_bounding_box[o]] for o in closest_objects_indices]
        gold_name = bounding_box_to_gold_name[array_index_to_bounding_box[idx]] 
        # print(f"\# Occurances of gold label {gold_name} of box {array_index_to_bounding_box[idx]} in its neighborhood: {closest_objects_labels.count(gold_name)}/{len(closest_objects_labels)}. Actual labels of neighbors {closest_objects_labels} \n")

# Each VQAR proof is a list of rela, name, and attr facts. 
# For search purposes, we represent each proof as a map of the form: 
# 'name': {o_i:type_i, ...}
# 'rela': {(o_i,o_j):rela_ij, ...}
# 'attr': {(o_i):attr_i, ...}
def transform_VQAR_proofs_to_maps(proofs):
    maps = []
    for proof in proofs:
        names = {}
        attributes = {}
        relations = {}
        for i, atom in enumerate(proof): 
            predicate = getPredicate(atom)
            arguments = getArguments(atom)
            if i == 0 and predicate == 'name':
                names[int(arguments[1])] = arguments[0]
            # WARNING modified code for debugging
            #elif predicate == 'attr':
            #    attributes[int(arguments[1])] = arguments[0]
            #elif predicate == 'rela':
            #    relations[(int(arguments[1]),int(arguments[2]))] = arguments[0]
            #break 
        maps.append({'name':names, 'attr':attributes,'rela':relations})
    return maps

def exists_proof_that_maps_instances_to_same_name_class(single_proof_map1, all_proofs_maps2, obj1, obj2):
    assert obj1 in single_proof_map1['name']
    # get all the possible names obj2 can be mapped to 
    name_domain2 = set()
    for mapping in all_proofs_maps2:
        if obj2 in mapping['name']: 
            name_domain2.add(mapping['name'][obj2])
    return single_proof_map1['name'][obj1] in name_domain2

def get_eps_step(eps):
    if eps <= 0.015:
        return 0.001
    elif eps <= 0.15:
        return 0.01
    elif eps <= 0.5:
        return 0.05
    else:
        return 0.1

def purification(train_loader, train_samples, train_scene_graphs_and_features, get_pred_fn, sg_model, idx2word, args):
    sg_model.eval()
    eps = args.eps
    while True:
        trial = 1
        changes = 0
        pruned_train_samples = dict()
        pbar = tqdm(enumerate(train_loader), desc=f"Pruning Trial: {trial}", total=len(train_loader))
        for i, sample_ids in pbar:
            # get samples
            samples_in_batch = [train_samples[sample_id] for sample_id in sample_ids]
            scene_graphs_and_features = [train_scene_graphs_and_features[image_id] for image_id,_,_,_,_ in samples_in_batch]
            queries = [query for _,query,_,_,_ in samples_in_batch]
            lineages = [lineage for _,_,lineage,_,_ in samples_in_batch]
            proofs = [lineage[1] for lineage in lineages]
            gts = [lineage[2] for lineage in lineages]
            ans = [lineage[0] for lineage in lineages]

            batch_predictions, scl_queries = (
                get_pred_fn(sg_model, samples_in_batch, scene_graphs_and_features, queries, idx2word, args.gpu, return_dicts=True)
            )

            new_lineages = []

            # pruning predictions
            for pred_idx, pred in enumerate(batch_predictions):
                dropped_preds = []
                if pred['name'] != [[]]:
                    choices = pred['name']['choices']
                    for obj_id, probs in pred['name'].items():
                        if obj_id == "choices":
                            continue
                        # obtain the max prediction
                        max_pred = torch.max(probs)
                        # obtain the difference between the max prediction and the rest of the predictions
                        diff = max_pred - probs
                        # find indices of the predictions less than args.eps
                        if args.eps > 0:
                            indices = torch.where(diff > eps)[0]
                            # print(diff)
                            # print("Number of indices: ", len(indices))
                            # exit()
                            # for each index in indices, add "name(choices[index],obj_id)" to the dropped predictions
                            for index in indices:
                                dropped_preds.append(f"name({choices[index]},{obj_id})")
                if pred['rela'] != [[]]:
                    choices = pred['rela']['choices']
                    for rela_id, probs in pred['rela'].items():
                        if rela_id == "choices":
                            continue
                        # obtain the max prediction
                        max_pred = torch.max(probs)
                        # obtain the difference between the max prediction and the rest of the predictions
                        diff = max_pred - probs
                        # find indices of the predictions less than args.eps
                        if args.eps > 0:
                            indices = torch.where(diff > eps)[0]
                            # for each index in indices, add "rela(choices[index],obj_id)" to the dropped predictions
                            for index in indices:
                                dropped_preds.append(f"rela({choices[index]},{rela_id[0]},{rela_id[1]})")
                if pred['attr'] != [[]]:
                    choices = pred['attr']['choices']
                    for obj_id, probs in pred['attr'].items():
                        if obj_id == "choices":
                            continue
                        # obtain the max prediction
                        max_pred = torch.max(probs)
                        # obtain the difference between the max prediction and the rest of the predictions
                        diff = max_pred - probs
                        # find indices of the predictions less than args.eps
                        if args.eps > 0:
                            indices = torch.where(diff > eps)[0]
                            # for each index in indices, add "attr(choices[index],obj_id)" to the dropped predictions
                            for index in indices:
                                dropped_preds.append(f"attr({choices[index]},{obj_id})")

                # for each proof in the lineage, only keep the proofs that do not contain the dropped predictions
                pruned_proofs = []
                for proof in proofs[pred_idx]:
                    # for each dropped prediction, remove it from the proof
                    to_append = True
                    for pred in proof:
                        if pred in dropped_preds:
                            to_append = False
                            break
                    if to_append:
                        pruned_proofs.append(proof)
                changes += len(proofs[pred_idx]) - len(pruned_proofs)
                # print("Dropped preds: ", len(dropped_preds))
                # print("Changes: ", changes)
                # print("Pruned proofs: ", len(pruned_proofs))
                # exit()


                pbar.set_description(f"Pruning Trial: {trial} | Changes: {changes}")

                # add sample with pruned proofs to the pruned train samples
                new_lineages = (ans[pred_idx], pruned_proofs, gts[pred_idx])
                pruned_train_samples[sample_ids[pred_idx]] = (samples_in_batch[pred_idx][0], samples_in_batch[pred_idx][1], new_lineages, samples_in_batch[pred_idx][3], samples_in_batch[pred_idx][4])
            
        
        if changes == 0:
            print(f"Pruning Trial {trial}: No changes made to the dataset")
            trial += 1
            eps = eps - get_eps_step(eps)
        else:
            print(f"Pruning Trial {trial}: Changes made to the dataset: {changes}")
            break
    return pruned_train_samples

# Prune pre-images based on the proximity of features of bounding boxes for property name.  
# The features have already been extracted from pre-trained imaged and are available from the original VQAR benchmark.
# TODO extend the implementation using a VLM for extracting features of bounding boxes. 
# TODO control training in CUDA/GPU/MP4
def structural_pruning(samples, scene_graphs_and_features, idx2word: Idx2Word, args): 
    # print("Pruning pre-images based on the proximity of features of bounding boxes for property name. \n")

    solver = pywraplp.Solver.CreateSolver("SCIP")
   
    # Map of integer variables to their names 
    E = dict() # Edge variables
    EE = dict() # Filtering edge variables
    I = dict() # all proofs variables
    II = dict() # proofs that participate in edge constraints variables
    EE_to_number_of_pruned_proofs = dict() # Maps each filtering edge to the number of proofs it prunes  
    EE_to_pruned_proofs = dict() # Maps each filtering edge to the number of proofs it prunes  

    # Initialize the variables of the integer linear program that corresponds to proofs.
    for l in range(len(samples)): 
        sample = samples[l]
        _, _, lineage, _, _ = sample
        _, proofs, ground_truth = lineage
        for k in range(len(proofs)):
            # Create an integer variable I_{l,k}
            # If the integer variable I_{l,k} is mapped to true, then the k-th proof of the l-th sample must be kept; 
            # otherwise, it should be discarded.
            I[f"I_{l}_{k}"] = solver.IntVar(0, 1, f"I_{l}_{k}")
                
    # Represent each proof as a map for easier search  
    proof_mappings = list()
    for l in range(len(samples)): 
        _, _, lineage, _, _ = samples[l]
        _, proofs, _ = lineage
        proof_mappings.append(transform_VQAR_proofs_to_maps(proofs))

    # Create the top-k proximity matrix 
    array_index = 0
    bounding_box_to_array_index = dict() 
    bounding_box_to_gold_name = dict() 
    features = list()
    for l in range(len(samples)):
        # We take the bounding boxes from the proofs and not the scene graphs, as bounding boxes whose types are not in the top-k most frequent ones are pruned 
        object_ids = set([o for k in range(len(proof_mappings[l])) for o in proof_mappings[l][k]['name'].keys()])
        object_features = scene_graphs_and_features[l]["object_feature"]
        # Map each bounding box to its position in the features array  
        for obj in object_ids: 
            # The same bounding box may occur in multiple samples, we need to consider it once when we compute the similarity matrix
            if obj not in bounding_box_to_array_index: 
                bounding_box_to_array_index[obj] = array_index
                # Map each bounding box to its gold type
                bounding_box_to_gold_name[obj] = idx2word.idx_to_name(scene_graphs_and_features[l]['scene_graph']['names'][obj])
                # We restrict to a subset of all bounding boxes in "object_ids"
                # "object_feature" stores the features of the bounding boxes in the order they appear in "object_ids"
                # Hence, for each obj in samples[l], we need to look for its position i in "object_ids" and then use i to access its features.
                i = scene_graphs_and_features[l]["object_ids"].index(obj)
                features.append(object_features[i]) 
                array_index = array_index + 1 

    array_index_to_bounding_box = {idx:obj for obj,idx in bounding_box_to_array_index.items()}
    
    if args.mock_proximity:
        n_indices = create_mock_n_indices(array_index_to_bounding_box, bounding_box_to_array_index, bounding_box_to_gold_name, args)
    else: 
        features = torch.FloatTensor(features)
        # The following concerns top-k filtering per object: 
        features = F.normalize(features, dim=1)
        similarity_matrix = torch.matmul(features, features.T)
        similarity_matrix = similarity_matrix.fill_diagonal_(0)
        s = torch.reshape(similarity_matrix, (-1,))
        n_v, _ = s.topk(math.ceil(similarity_matrix.shape[0] * similarity_matrix.shape[0] * 0.0005), dim=0, largest=True, sorted=True)
        threshold = n_v[-1]
        if args.structure_k > 0:
            features = features.numpy()
            index = faiss.IndexFlatL2(features.shape[1])
            index.add(features)
            _, ID = index.search(features, args.structure_k+1)
            n_indices = torch.from_numpy(ID[:, 1:args.structure_k+1])
        # The following concerns filtering out the objects which are not among the args.percent highest values
        elif args.percent > 0: 
            n_indices = list()
            for i in range(similarity_matrix.shape[0]):
                s = (similarity_matrix[i]>= threshold).nonzero().squeeze()
                if s.dim() > 0:
                    n_indices.append(s[:5])
                else:
                    n_indices.append(torch.tensor([s]))

    # percentage_of_ground_truth_names(n_indices, array_index_to_bounding_box, bounding_box_to_gold_name)
    
    if args.structure_k > 0:        
        for l1,l2 in itertools.product(range(len(samples)), range(len(samples))):            
            if l1 != l2:
                object_ids1 = set([o for k in range(len(proof_mappings[l1])) for o in proof_mappings[l1][k]['name'].keys()])
                proof_mappings1 = proof_mappings[l1]
                object_ids2 = set([o for k in range(len(proof_mappings[l2])) for o in proof_mappings[l2][k]['name'].keys()])
                proof_mappings2 = proof_mappings[l2]
            
                for obj1,obj2 in itertools.product(object_ids1, object_ids2):
                    array_index1 = bounding_box_to_array_index[obj1]
                    array_index2 = bounding_box_to_array_index[obj2]
                    # If the condition below is satisfied, then obj2 is within the top-k most proximal points of obj1
                    # We also want the similarity to be strong via condition similarity_matrix[array_index1][array_index2] >= threshold
                    if array_index2 in n_indices[array_index1]: # and similarity_matrix[array_index1][array_index2] >= threshold:
                        for k1 in range(len(proof_mappings1)):
                            if obj1 in proof_mappings1[k1]['name'].keys():                  
                                # Initialize the integer variable E_{l1,l2,obj1,obj2} if not exists 
                                if f"E_{l1}_{l2}_{obj1}_{obj2}" not in E:
                                    E[f"E_{l1}_{l2}_{obj1}_{obj2}"] = solver.IntVar(0, 1, f"E_{l1}_{l2}_{obj1}_{obj2}")

                                if exists_proof_that_maps_instances_to_same_name_class(proof_mappings1[k1], proof_mappings2, obj1, obj2):
                                    # Add constraint ${1-E_{l1,l2,obj1,obj2} + I_{l1,k1} \geq 1
                                    solver.Add(1 - E[f"E_{l1}_{l2}_{obj1}_{obj2}"] + I[f"I_{l1}_{k1}"]>= 1)
                                else:
                                    # Add constraint ${1-E_{l1,l2,obj1,obj2} + 1 - I_{l1,k1} \geq 1}$
                                    solver.Add(1 - E[f"E_{l1}_{l2}_{obj1}_{obj2}"] + 1 - I[f"I_{l1}_{k1}"]>= 1)

                                    EE[f"E_{l1}_{l2}_{obj1}_{obj2}"] = E[f"E_{l1}_{l2}_{obj1}_{obj2}"]
                                    if f"E_{l1}_{l2}_{obj1}_{obj2}" not in EE_to_number_of_pruned_proofs:
                                        EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{obj1}_{obj2}"] = 1
                                        EE_to_pruned_proofs[f"E_{l1}_{l2}_{obj1}_{obj2}"] = [f"I_{l1}_{k1}"]
                                    else: 
                                        EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{obj1}_{obj2}"] = EE_to_number_of_pruned_proofs[f"E_{l1}_{l2}_{obj1}_{obj2}"] + 1
                                        EE_to_pruned_proofs[f"E_{l1}_{l2}_{obj1}_{obj2}"].append(f"I_{l1}_{k1}")

                                # Keep track of proofs that participate in an edge constraint
                                II[f"I_{l1}_{k1}"] = I[f"I_{l1}_{k1}"]

    elif args.percent > 0: 
        for l1,l2 in itertools.product(range(len(samples)), range(len(samples))):
            if l1 < l2:
                object_ids1 = set([o for k in range(len(proof_mappings[l1])) for o in proof_mappings[l1][k]['name'].keys()])
                proof_mappings1 = proof_mappings[l1]
                object_ids2 = set([o for k in range(len(proof_mappings[l2])) for o in proof_mappings[l2][k]['name'].keys()])
                proof_mappings2 = proof_mappings[l2]

                for obj1,obj2 in itertools.product(object_ids1, object_ids2):
                    array_index1 = bounding_box_to_array_index[obj1]
                    array_index2 = bounding_box_to_array_index[obj2]
                    # If the condition below is satisfied, then due to the symmetry of the symmetry_matrix the distance between
                    # the representations of obj1 and obj2 is small 
                    if array_index2 in n_indices[array_index1]:
                        for k1 in range(len(proof_mappings1)):
                            if obj1 in proof_mappings1[k1]['name'].keys():  
                                # Initialize the integer variable E_{l1,l2,obj1,obj2} if not exists 
                                if f"E_{l1}_{l2}_{obj1}_{obj2}" not in E:
                                    E[f"E_{l1}_{l2}_{obj1}_{obj2}"] = solver.IntVar(0, 1, f"E_{l1}_{l2}_{obj1}_{obj2}")

                                if exists_proof_that_maps_instances_to_same_name_class(proof_mappings1[k1], proof_mappings2, obj1, obj2):
                                    # Add constraint ${1-E_{l1,l2,obj1,obj2} + I_{l1,k1} \geq 1
                                    solver.Add(1 - E[f"E_{l1}_{l2}_{obj1}_{obj2}"] + I[f"I_{l1}_{k1}"]>= 1)
                                else:
                                    # Add constraint ${1-E_{l1,l2,obj1,obj2} + 1 - I_{l1,k1} \geq 1}$
                                    solver.Add(1 - E[f"E_{l1}_{l2}_{obj1}_{obj2}"] + 1 - I[f"I_{l1}_{k1}"]>= 1)
                                
                                # Keep track of proofs that participate in an edge constraint
                                II[f"I_{l1}_{k1}"] = I[f"I_{l1}_{k1}"]
   
                        for k2 in range(len(proof_mappings2)):
                            for obj2 in proof_mappings2[k2]['name'].keys():  
                                # Initialize the integer variable E_{l1,l2,obj1,obj2} if not exists 
                                if f"E_{l1}_{l2}_{obj1}_{obj2}" not in E:
                                    E[f"E_{l1}_{l2}_{obj1}_{obj2}"] = solver.IntVar(0, 1, f"E_{l1}_{l2}_{obj1}_{obj2}")

                                if exists_proof_that_maps_instances_to_same_name_class(proof_mappings2[k2], proof_mappings1, obj2, obj1):
                                    # Add constraint ${1-E_{l1,l2,obj1,obj2} + I_{l2,k2} \geq 1
                                    solver.Add(1 - E[f"E_{l1}_{l2}_{obj1}_{obj2}"] + I[f"I_{l2}_{k2}"]>= 1)
                                else:
                                    # Add constraint ${1-E_{l1,l2,obj1,obj2} + 1 - I_{l2,k2} \geq 1}$
                                    solver.Add(1 - E[f"E_{l1}_{l2}_{obj1}_{obj2}"] + 1 - I[f"I_{l2}_{k2}"]>= 1)
                                
                                # Keep track of proofs that participate in an edge constraint
                                II[f"I_{l2}_{k2}"] = I[f"I_{l2}_{k2}"]
                        
    for l in range(len(samples)): 
        sample = samples[l]
        _, _, lineage, _, _ = sample
        _, proofs, _ = lineage
        # Add the constraint ${\sum \limits_{k \in [len(proofs)]} I_{\ell,k} \geq 1}$
        solver.Add(sum([I[f"I_{l}_{k}"] for k in range(len(proofs))]) >= 1)
        # We need to add this constraint to make sure that all proofs that are not involved in any constraint are kept
        [solver.Add(I[f"I_{l}_{k}"] == 1) for k in range(len(proofs)) if f"I_{l}_{k}" not in II]

    # Add the objective ${\max \sum \limits_{\ell,\ell',x,x'} n_{\ell,\ell',x,x'} EE_{\ell,\ell',x,x'}}$, 
    # where n_{\ell,\ell',x,x'} is the number of edges pruned by EE_{\ell,\ell',x,x'}
    solver.Maximize(solver.Sum([EE_to_number_of_pruned_proofs[e_key] * e_var for e_key,e_var in EE.items() if EE_to_number_of_pruned_proofs[e_key] < 120]))
    
    # Prints the program in the console
    # print(solver.ExportModelAsLpFormat(False).replace('\\', '').replace(',_', ','), sep='\n')
    
    # Solve the linear program 
    status = solver.Solve()

    if status == pywraplp.Solver.OPTIMAL:
        # print("Solution:")
        # print("Objective value =", solver.Objective().Value())
        # print(f"Problem solved in {solver.wall_time():d} milliseconds")

        # Debug info 
        #for k in EE.keys():
        #    if EE[k].solution_value() == 1: 
        #        print(f"Proofs {EE_to_pruned_proofs[k]} were trimmed")

        new_samples = []
        new_gt = 0
        og_gt = 0
        pruned = 0
        total = 0
        for l in range(len(samples)): 
            sample = samples[l]
            image_id, query, lineage, object_ids, oid_pairs = sample
            answer, proofs, ground_truth = lineage
            new_proofs = []
            new_ground_truth = []
            for k in range(len(proofs)):
                # If the integer variable I_{l,k} is mapped to true, then the k-th proof of the l-th sample must be kept; 
                # otherwise, it should be discarded.
                if I[f"I_{l}_{k}"].solution_value() == 1: 
                    new_proofs.append(proofs[k])
                    new_ground_truth.append(ground_truth[k])
            # print(f"\# Ground-truth kept in the {l}-th sample: {sum(new_ground_truth)}/{sum(ground_truth)}. Pre-image before: {len(proofs)}. Pre-image after: {len(new_proofs)}.\n")
            new_gt = new_gt + sum(new_ground_truth)
            og_gt = og_gt + sum(ground_truth)
            pruned = pruned + len(proofs) - len(new_proofs)
            total += len(proofs)
            if args.mock_proximity: 
                assert sum(new_ground_truth) == sum(ground_truth)
            new_lineage = (answer, new_proofs, new_ground_truth)
            new_samples.append((image_id,query,new_lineage, object_ids, oid_pairs))
                
        # return the training samples with filtered pre-images
        return new_samples, new_gt, og_gt, pruned, total

    else:
        print("The problem does not have an optimal solution.")
        # TODO Throw runtime error
        return None, None, None
