import torch
from torch import nn
import faiss  # Make sure to install the FAISS library
import numpy as np
from tqdm import tqdm

import sys
sys.path.append('../')
from lib.metrics.distance_metrics import DistanceEvaluator

def normalize_vectors(vectors):
    return torch.nn.functional.normalize(vectors, p=2, dim=1)


class MMNB(nn.Module):
    def __init__(self, algorithm, img_output_size, txt_output_size, tokenizer, 
                 label_to_word, device, agg_type, dist_type='cosine'):
        super(MMNB, self).__init__()
        self.model = algorithm 
        
        self.output_size_img = img_output_size
        self.output_size_txt = txt_output_size
        self.index_img = None
        self.index_txt = None
        self.tokenizer = tokenizer
        self.label_to_word = label_to_word
        self.device = device
        self.agg_type = agg_type
        self.dist_type = dist_type
        
        self.saved_labels = None

        # Remove the last layer (classifier) to use as feature extractor
        # self.feature_extractor = self.algorithm.
        
    def fit(self, train_loader, combination='sum'):
        # Extract features and add them to the FAISS index
        self.combination = combination
        feature_list = []
        label_list = []
        txt_embed_list = []
        
        print("Fitting MMNB_obj")
        self.device = "cuda"
        self.model.to(self.device)
        print("Put the model to: ", self.device)
        # Also, get the index of all word embeddings
        
        # i_count = 0
        
        
        for batch in tqdm(train_loader):
            # returns either inputs, labels or inputs, labels, noisy_labels
            inputs = batch[0]
            #NB: 2nd label is noisy
            labels = batch[2]
            
            features = normalize_vectors(self.model.encode_image(inputs.to(self.device))).cpu().detach().numpy()
            feature_list.append(features)
            label_list.append(labels)
            
            # print("features.shape =", features.shape)
            
            text_labels = self.label_to_word[labels].tolist()
            
            
        # Concatenate all features
        saved_features = np.concatenate(feature_list)
        
        saved_num_labels = np.concatenate(label_list)
        
        # Build the index for image
        self.index_img = faiss.IndexFlatIP(self.output_size_img)
        self.index_img.add(saved_features)
                
        
        self.saved_features = saved_features
        self.saved_num_labels = saved_num_labels
        self.labels = saved_num_labels

    def extract_img_features(self, x):
        with torch.no_grad():
            # Forward pass through the feature extractor
            x = self.model.encode_image(x.to(self.device))
        return x.view(x.size(0), -1)
    
    
    def forward(self, x, k):
        # Extract features of the input batch
        query_features_img = normalize_vectors(self.extract_img_features(x)).cpu().detach().numpy()
        
        # Search in the FAISS index
        distances, indices = self.index_img.search(query_features_img, k)
        return distances, indices 
    
    def detect_label(self, x, k, y_input, threshold=0.5):
        distances, indices = self.forward(x, k)
        
        # Notice: Use FAISS index to seek the
        evaluated_labels = self.labels[indices]  # (b, k)
        evaluated_neighbor_labels = torch.tensor(evaluated_labels[:, 1:])
        
        most_freq_elements = []
        is_wrong = []
        for i, row in enumerate(evaluated_neighbor_labels):
            unique, counts = torch.unique(row, return_counts=True, sorted=True)
            sorted_indices = torch.argsort(-counts)
            
            unique_out = unique[sorted_indices]
            n_unique_multi = unique_out.size()[0]
            
            most_freq_elements.append(unique_out)
            
            if y_input[i] in unique_out[: int(threshold * n_unique_multi)]:
                is_wrong.append(0)
            else:
                is_wrong.append(1)
               
        return is_wrong

    def detect_label_knn(self, x, k, y_input, input_type='train',
                         start_idx=0):
        if input_type!='train':
            _, indices = self.forward(x, k)
        else:
            _, indices = self.forward(x, k+1)
        
        # Notice: Use FAISS index to seek the
        evaluated_labels = self.labels[indices]  # (b, k)
        evaluated_neighbor_labels = torch.tensor(evaluated_labels)
        indices = np.array(indices)
        
        is_wrong = []
        print(start_idx)
        for i, row in enumerate(evaluated_neighbor_labels):
            curr_idx_neighbor_labels = evaluated_neighbor_labels[i].squeeze()
            if input_type=='train':
                curr_idx = start_idx + i
                print(curr_idx,indices[i],curr_idx_neighbor_labels,curr_idx in list(indices[i]))
                try:
                    assert curr_idx in list(indices[i])
                    curr_idx_prop = (curr_idx_neighbor_labels[curr_idx_neighbor_labels==y_input[i]].size()[0]-1)/k
                except:
                    curr_idx_neighbor_labels = curr_idx_neighbor_labels[:-1]
                    curr_idx_prop = (curr_idx_neighbor_labels[curr_idx_neighbor_labels==y_input[i]].size()[0])/k
            else:
                curr_idx_prop = (curr_idx_neighbor_labels[curr_idx_neighbor_labels==y_input[i]].size()[0])/k
            is_wrong.append(curr_idx_prop)
               
        return is_wrong
    
