"""Basic Retriever"""
import copy
import faiss
import torch
import evaluate
import numpy as np
import pandas as pd


from collections import Counter
from torch.utils.data import DataLoader
from bayes_opt import BayesianOptimization
from sentence_transformers import SentenceTransformer



from typing import List,  Dict
from transformers import AutoTokenizer

from common import get_prompt_label, DataCollatorWithPaddingAndCuda, DatasetEncoder, extract_data, get_input

class BaseRetriever:
    index_ds = None
    test_ds = None

    def __init__(self, task, ice_dataloader, candidate_dataloader, inferencer, device, metric_model, metric_tokenizer):
        self.task = task
        self.index_ds, self.test_ds = ice_dataloader, candidate_dataloader
        
        self.template, self.template_dict, self.label = get_prompt_label(self.task)
        self.device = device
        
        self.tokenizer_name = "google-bert/bert-base-uncased"
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
        self.model = SentenceTransformer(self.tokenizer_name)
        self.model = self.model.to(self.device)
        self.model.eval()

        self.metric_model = metric_model
        self.metric_tokenizer = metric_tokenizer


        self.batch_size = 8
        self.test_text = self.test_ds['text'].tolist()
        

        self.test_encode_dataset = DatasetEncoder(self.test_text, tokenizer=self.tokenizer)
        self.co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, device=self.device)
        self.test_dataloader = DataLoader(self.test_encode_dataset, batch_size=self.batch_size, collate_fn=self.co)
        self.test_forward = self.forward(self.test_dataloader, process_bar=True, information="Embedding test set...")

        self.inferencer = inferencer
        self.acc_evaluate = evaluate.load('accuracy')

        self.best_eposilon = 0


    def retrieve(self, im_retrieving, ice_num):
        
        if im_retrieving == 'naive':
            ice_embed_list, ice_forward, base_index = self.create_forword_index(self.index_ds['text'])
            ice_idx_list, ice_score_list = self.raw_retrieve(self.test_forward, ice_embed_list, ice_forward, base_index, ice_num, self.index_ds['text'], self.index_ds['label'])
            ice_idx_list = [list(t) for t in [list(zip(*sorted(zip(a, b), key=lambda x: x[1])))[0] for a, b in zip(ice_idx_list, ice_score_list)]]

        elif im_retrieving == 'rif':
            self.ice_text = self.index_ds['text'].tolist()
            self.ice_label = self.index_ds['label'].tolist()

            if self.task == 'agnews':
                each_class_sample = 25
            
            
            self.opt_balance_ds = pd.concat([self.index_ds[self.index_ds['label'] == i].sample(n=each_class_sample, replace=False, random_state=42) for i in range(len(self.label))], axis=0, ignore_index=True)

            
            self.opt_remain_ds = self.index_ds[~self.index_ds['text'].isin(self.opt_balance_ds['text'])]
            ramain_embed_list, remain_ice_forward, remain_base_index = self.create_forword_index(self.opt_remain_ds['text'])
            bs_embed_list, bs_forward, bs_index = self.create_forword_index(self.opt_balance_ds['text'])

            self.opt_new_ice_num = ice_num*int(max(list(Counter(self.ice_label).values()))/min(list(Counter(self.ice_label).values())))
            self.opt_remain_ice_idx_list, self.opt_remain_ice_score_list = self.raw_retrieve(bs_forward, ramain_embed_list, remain_ice_forward, remain_base_index, self.opt_new_ice_num, self.opt_remain_ds['text'], self.opt_remain_ds['label'])
            self.opt_retrieve_ice_label = [[self.opt_remain_ds['label'].tolist()[i] for i in sublist] for sublist in self.opt_remain_ice_idx_list]


            num = len(self.ice_text)
            self.opt_alpha = (num-1)/num
            class_weight = [(1 - self.opt_alpha) / (1 - self.opt_alpha**value) for value in list(Counter(self.ice_label).values())]
            max_class_weight = max(class_weight)

            self.opt_ice_num  = ice_num 

            pbounds = {f'beta0': (-max_class_weight*5, max_class_weight*5)}
            optimizer = BayesianOptimization(f=self.opt_function,  pbounds=pbounds, random_state=42)
            optimizer.maximize(init_points=5,  n_iter=30)
            self.best_beta = (optimizer.max)['params']['beta0']
            new_weight = [((1 - self.opt_alpha) / (1 - self.opt_alpha**value)+self.best_beta) for value in list(Counter(self.ice_label).values())]


            ice_embed_list, ice_forward, base_index = self.create_forword_index(self.ice_text)
            raw_ice_idx_list, raw_ice_score_list = self.raw_retrieve(self.test_forward, ice_embed_list, ice_forward, base_index, self.opt_new_ice_num, self.ice_text, self.ice_label)

            retrieve_ice_label = [[self.ice_label[i] for i in sublist] for sublist in raw_ice_idx_list]
            weight_ice_score_list = self.calculate_weighted_scores(retrieve_ice_label, new_weight, raw_ice_score_list)
            ice_idx_list = self.sort_idx(raw_ice_idx_list, weight_ice_score_list, self.opt_new_ice_num, ice_num)

        return ice_idx_list
    




    def raw_retrieve(self, ice_forward, base_index, ice_num) -> List[List]:
        """
            Retrieve for each data in generation_ds.
            
        Returns:
            `List[List]`: the index list of in-context example for each data in `test_ds`.
        """
        raise NotImplementedError("Method hasn't been implemented yet")
    


    def opt_function(self, beta0):
        weight = [((1 - self.opt_alpha) / (1 - self.opt_alpha**value))+beta0 for value in list(Counter(self.ice_label).values())]
        raw_weight_ice_score_list = self.calculate_weighted_scores(self.opt_retrieve_ice_label, weight, self.opt_remain_ice_score_list)
        weight_ice_idx_list = self.sort_idx(self.opt_remain_ice_idx_list, raw_weight_ice_score_list, self.opt_new_ice_num, self.opt_ice_num)
        weight_ice = get_input(self.task, weight_ice_idx_list, self.template, self.template_dict, self.opt_remain_ds)
        test_predictions = self.inferencer.inference(task=self.task, ice=weight_ice,  candidate=self.opt_balance_ds['text'].tolist(), labels=list(range(len(self.label))), ice_template=self.template_dict)
        acc = float(self.acc_evaluate.compute(predictions=test_predictions, references=self.opt_balance_ds['label'].tolist())['accuracy'])
        return acc

    
    def create_base_index(self, dataloader):
        index = faiss.IndexIDMap(faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension()))
        res_list = self.forward(dataloader)
        id_list = np.array([res['metadata']['id'] for res in res_list])
        embed_list = np.stack([res['embed'] for res in res_list])
        index.add_with_ids(embed_list, id_list)
        return embed_list, index
    
    def forward(self, dataloader, process_bar=False, information=''):
        res_list = []
        _dataloader = copy.deepcopy(dataloader)
        for _, entry in enumerate(_dataloader):
            with torch.no_grad():
                metadata = entry.pop("metadata")
                raw_text = self.tokenizer.batch_decode(entry['input_ids'], skip_special_tokens=True, verbose=False)
                res = self.model.encode(raw_text, show_progress_bar=False)
            res_list.extend([{"embed": r, "metadata": m} for r, m in zip(res, metadata)])
        return res_list
    

    def create_forword_index(self, text):
        ice_encode_dataset = DatasetEncoder(text, tokenizer=self.tokenizer)
        index_dataloader = DataLoader(ice_encode_dataset, batch_size=self.batch_size, collate_fn=self.co)
        ice_forward = self.forward(index_dataloader, process_bar=True, information="Embedding ice set...")
        embed_list, base_index = self.create_base_index(index_dataloader)
        return embed_list, ice_forward, base_index
    

    def calculate_weighted_scores(self, index_list, weight_list, score_list):
        result = []
        for i in range(len(index_list)):
            temp_result = []
            for j in range(len(index_list[i])):
                index = index_list[i][j]
                weight = weight_list[index]
                score = score_list[i][j]
                temp_result.append(score * weight)
            result.append(temp_result)
        return result
    
    
    def sort_idx(self, ice_idx_list, ice_score_list, new_ice_num, ice_num):
        sort_ice_idx_list = [[] for _ in range(len(ice_idx_list))]
        for idx in range(len(ice_idx_list)):
            sort_ice_idx_list[idx] = [x for _, x in sorted(zip(ice_score_list[idx], ice_idx_list[idx]), reverse=True)][:ice_num]
        return sort_ice_idx_list
    
    