import os, torch
import numpy as np
import pickle
from tqdm import tqdm
from pathlib import Path
from PIL import Image
from sklearn.metrics.pairwise import cosine_distances
from src.datasets.bias_dataset import BiasDataset
from src.processing.bias_naming.bias_embed import BiasClusterEmbeddings, BiasEmbeddingsCalculator
from src.processing.bias_naming.keyword_selection import SimpleCandidateKeywordSelector
from src.processing.emb_model import load_embedding_model
from src.utils.images import create_image_with_bias_heatmap


DATA_PATH = Path(os.environ['NAMING_BIASES_DATA_PATH'])


class BiasCommands:
    """
    Commands for using bias naming models.
    """

    def name_from_caption_embeddings(self,
                                     dataset_name: str,
                                     output_file_name: str,
                                     model_config_path: Path | str,
                                     caption_folder_name: str = 'captions',
                                     biases_folder_name: str = 'biases',
                                     n_correctly_classified: int | None = None,
                                     top_k_incorrectly_classified: int | None = None,
                                     similarity_threshold: float = 0.2,
                                     keyword_tokenization: str = 'nltk',
                                     keyword_do_lower_case: bool = True,
                                     keyword_stopwords: str | None = 'english',
                                     keyword_remove_punctuation: bool = True,
                                     keyword_max_ngram_size: int = 1,
                                     keyword_min_freq: float = 0.15
                                     ) -> None:
        """
        This method processes a dataset to produce bias embeddings, and uses these embeddings to rank a list
        of candidate keywords. The keywords will describe the biases in the dataset.

        :param dataset_name: The name of the BiasDataset to process. This should be a folder name under the datasets directory in $NAMING_BIASES_DATA_PATH.
        :param output_file_name: The name of the output TXT file where the results will be saved. This file will be saved in the 'outputs' folder inside the given dataset folder.
        :param model_config_path: Path to the configuration file for the embedding model (use an embedding model that supports text).
        :param caption_folder_name: name of the folder in which the captions you want to use are stored
        :param biases_folder_name: name of the folder in which the biases clusters you want to use are stored
        :param n_correctly_classified: if not None, the first n_correctly_classified examples in correctly-classified.txt will be selected.
        :param top_k_incorrectly_classified: if not None, the first top_k_incorrectly_classified examples in incorrectly-classified.txt will be selected.
                                      Note that the incorrectly-classified examples should be sorted by the distance from the decision boundary (descending).
        :param similarity_threshold: the final result will contain keywords that have at least this cosine similarity to the bias embedding (values between -1 and 1, 0.2 is suggested).
        :param keyword_tokenization: either 'nltk' or 'space' to break the caption into words by using 'nltk' library or simply by splitting where there is a whitespace character.
        :param keyword_do_lower_case: if True, keywords will all be in lower case (otherwise there will be a mix of lower and upper case).
        :param keyword_stopwords: if not None, stopwords will be removed from the keyword candidates. You must specify the language according to nltk supported languages.
        :param keyword_remove_punctuation: if True, punctuation is removed.
        :param keyword_max_ngram_size: keywords can be formed by multiple words, if you set this to 3 you will have all keywords with 3, 2, and 1 words.
        :param keyword_min_freq: discard keywords that do not appear at least in X% of the captions for the cluster they are in.
        """
        assert keyword_tokenization in ('space', 'nltk')

        root = DATA_PATH / 'datasets' / dataset_name
        debug_path = DATA_PATH / 'datasets' / dataset_name / 'bias-naming-debug-info' / output_file_name
        
        dataset = BiasDataset(
            root=root,
            return_captions=True,
            caption_folder_name=caption_folder_name,
            biases_folder_name=biases_folder_name,
            n_correctly_classified=n_correctly_classified,
            top_k_incorrectly_classified=top_k_incorrectly_classified
        )

        kwd_selector = SimpleCandidateKeywordSelector(
            tokenization=keyword_tokenization,
            do_lower_case=keyword_do_lower_case,
            stopwords=keyword_stopwords,
            remove_punctuation=keyword_remove_punctuation,
            max_ngram_size=keyword_max_ngram_size,
            min_freq=keyword_min_freq
        )

        emb_model = load_embedding_model(
            config_path=model_config_path,
            need_image_emb=False,
            need_text_emb=True
        )

        bias2vec = BiasEmbeddingsCalculator()

        caption_embeddings = []

        for bias in dataset:
            corr_emb = emb_model.encode(bias.correctly_classified) # type: ignore
            incorr_emb = emb_model.encode(bias.incorrectly_classified) # type: ignore

            assert isinstance(corr_emb, torch.Tensor) or isinstance(corr_emb, np.ndarray) 
            assert isinstance(incorr_emb, torch.Tensor) or isinstance(incorr_emb, np.ndarray) 

            caption_embeddings.append(
                BiasClusterEmbeddings(
                    correctly_classified=corr_emb,
                    incorrectly_classified=incorr_emb
                )
            )

        bias_embeddings = bias2vec.calculate_bias_embeddings(caption_embeddings)

        kws = kwd_selector.select_candidate_keywords(dataset)

        keywords = set()

        for bias_kw in kws:
            keywords.update(bias_kw.keys())
        
        keywords = sorted(list(keywords))

        keyword_embs = emb_model.encode(keywords)
        assert isinstance(keyword_embs, torch.Tensor) or isinstance(keyword_embs, np.ndarray)

        if isinstance(keyword_embs, torch.Tensor):
            keyword_embs = keyword_embs.numpy()

        distances = cosine_distances(keyword_embs, bias_embeddings)

        sorted_res = np.argsort(distances, axis=0).T
        result_text = ''

        all_top_keywords = []
        for bias_id, sorted_keywords in enumerate(sorted_res):
            top_keywords = []
            for i in range(len(keywords)):
                top_keywords.append(keywords[sorted_keywords[i]])
            all_top_keywords.append(top_keywords)
        
        similarities = 1.0 - distances
        sorted_similarities = np.sort(similarities, axis=0).T[:, ::-1]
        num_above_thresh = np.sum(similarities >= similarity_threshold, axis=0)

        final_keywords = []
        final_similarities = []
        for bias_id, top_keywords in enumerate(all_top_keywords):
            final_keywords.append(top_keywords[:num_above_thresh[bias_id]])
            final_similarities.append(sorted_similarities[bias_id][:num_above_thresh[bias_id]])

        result_text = ''
        for bias_id in range(len(final_keywords)):
            result_text += f'bias {bias_id}: '
            for i in range(len(final_keywords[bias_id])):
                kw = final_keywords[bias_id][i]
                sim = final_similarities[bias_id][i]
                result_text += f'{kw} ({sim:.5f}), '
            result_text += '\n'

        debug_info = {
            'keywords': keywords,
            'distances': distances,
            'caption_embeddings': caption_embeddings,
            'bias_embeddings': bias_embeddings,
            'keyword_embs': keyword_embs,
            'all_top_keywords': all_top_keywords,
            'final_keywords': final_keywords,
            'final_similarities': final_similarities,
            'num_above_thresh': num_above_thresh
        }

        debug_path.parent.mkdir(exist_ok=True)
        debug_path.mkdir(exist_ok=True)

        (debug_path / 'debug.pkl').write_bytes(pickle.dumps(debug_info))

        for bias_id, top_keywords in enumerate(all_top_keywords):
            (debug_path / f'top_keywords_for_bias_{bias_id}.txt').write_text('\n'.join(top_keywords))
        
        for bias_id, final_kw in enumerate(final_keywords):
            (debug_path / f'final_keywords_for_bias_{bias_id}.txt').write_text('\n'.join(final_kw))

        output_folder = root / 'outputs'
        output_folder.mkdir(exist_ok=True)

        (output_folder / output_file_name).write_text(result_text)

        print('')
        print(result_text)



    def display_from_image_embeddings(self,
                                   dataset_name: str,
                                   output_folder_name: str,
                                   model_config_path: Path | str,
                                   biases_folder_name: str = 'biases',
                                   n_correctly_classified: int | None = None,
                                   top_k_incorrectly_classified: int | None = None,
                                   patch_size: int = 24
                                   ) -> None:
        """
        This method processes a bias dataset to produce heatmaps highlighting in red positive correlation with the
        bias, and in blue negative correlation with the bias.
        
        :param dataset_name: The name of the BiasDataset to process. This should be a folder name under the datasets directory in $NAMING_BIASES_DATA_PATH.
        :param output_folder_name: The name of the output folder where the results will be saved. This file will be saved in the 'outputs' folder inside the given dataset folder.
        :param model_config_path: Path to the configuration file for the image embedding model (only CLIP is available currently).
        :param biases_folder_name: name of the folder in which the biases clusters you want to use are stored, e.g. "biases-k-10"
        :param n_correctly_classified: if not None, the first n_correctly_classified examples in correctly-classified.txt will be selected.
        :param top_k_incorrectly_classified: if not None, the first top_k_incorrectly_classified examples in incorrectly-classified.txt will be selected.
                                      Note that the incorrectly-classified examples should be sorted by the distance from the decision boundary (descending).
        :param patch_size: the size of the square patches in which the image will be subdivided to calculate similarity between patch embeddings and bias embedding
        """        
        
        root = DATA_PATH / 'datasets' / dataset_name
        debug_path = DATA_PATH / 'datasets' / dataset_name / 'bias-naming-debug-info' / output_folder_name

        images_dataset = BiasDataset(
            root=root,
            return_captions=False,
            biases_folder_name=biases_folder_name,
            n_correctly_classified=n_correctly_classified,
            top_k_incorrectly_classified=top_k_incorrectly_classified
        )

        emb_model = load_embedding_model(
            config_path=model_config_path,
            need_image_emb=True,
            need_text_emb=False
        )

        bias2vec = BiasEmbeddingsCalculator()

        image_embeddings = []

        for bias in images_dataset:
            corr_emb = emb_model.encode(bias.correctly_classified) # type: ignore
            incorr_emb = emb_model.encode(bias.incorrectly_classified) # type: ignore

            assert isinstance(corr_emb, torch.Tensor) or isinstance(corr_emb, np.ndarray) 
            assert isinstance(incorr_emb, torch.Tensor) or isinstance(incorr_emb, np.ndarray) 

            image_embeddings.append(
                BiasClusterEmbeddings(
                    correctly_classified=corr_emb,
                    incorrectly_classified=incorr_emb
                )
            )

        bias_embeddings = bias2vec.calculate_bias_embeddings(image_embeddings)

        most_biased_images = []
        heatmap_images = []

        for bias_id in range(len(images_dataset)):
            bias = images_dataset[bias_id]
            images = bias.correctly_classified + bias.incorrectly_classified

            img_embs = image_embeddings[bias_id].correctly_classified
            img_embs = np.vstack([img_embs, image_embeddings[bias_id].incorrectly_classified])
            bias_emb = bias_embeddings[bias_id]

            distances = cosine_distances(bias_emb.reshape(1, -1), img_embs)
            min_dist_img_idx = np.argmin(distances, axis=1)[0]
            prototypical_image_of_bias = images[min_dist_img_idx]
            
            most_biased_img_heatmap = create_image_with_bias_heatmap(
                image=prototypical_image_of_bias,
                embedding_model=emb_model,
                bias_embedding=bias_embeddings[bias_id],
                patch_size=patch_size
            )

            most_biased_images.append(most_biased_img_heatmap)

            heatmaps = []
            for image in tqdm(images, desc=f'Generating heatmaps for bias {bias_id}'):
                assert isinstance(image, Image.Image)

                heat = create_image_with_bias_heatmap(
                    image=image,
                    embedding_model=emb_model,
                    bias_embedding=bias_embeddings[bias_id],
                    patch_size=patch_size
                )

                heatmaps.append(heat)
            heatmap_images.append(heatmaps)

        debug_info = {
            'image_embeddings': image_embeddings,
            'bias_embeddings': bias_embeddings,
        }

        debug_path.parent.mkdir(exist_ok=True)
        debug_path.mkdir(exist_ok=True)

        (debug_path / 'debug.pkl').write_bytes(pickle.dumps(debug_info))

        output_folder = root / 'outputs' / output_folder_name
        output_folder.parent.mkdir(exist_ok=True)
        output_folder.mkdir(exist_ok=True)

        for bias_id, most_biased_img in enumerate(most_biased_images):
            most_biased_img.save(str(output_folder / f'most_biased_for_bias_{bias_id}.png'))
        
        for bias_id, heatmaps in enumerate(heatmap_images):
            (output_folder / f'{bias_id}').mkdir(exist_ok=True)
            
            for i, heatmap in enumerate(heatmaps):
                heatmap.save(str(output_folder / f'{bias_id}/{i}.png'))
