import torch

class Editor():
    def __init__(self, llm, llm_tok, embedding_model, reranker, relation_extractor, relation_tokenizer, triple_compl, device='cuda:0' if torch.cuda.is_available() else 'cpu'):
        self.llm = llm.to(device)
        self.llm_tok = llm_tok
        self.embedder = embedding_model
        self.ce = reranker
        self.extractor = relation_extractor.to(device)
        self.extractor_tok = relation_tokenizer
        self.triple_compl = triple_compl,
        self.device = device

        self.edit_strings = None
        self.edit_embeds = None
        self.edit_answers = None

        self.bad_entity_types = ['DATE', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL', 'CARDINAL']

        self.in_tok = 0
        self.out_tok = 0

    def add_edits(self, edits, answers):
        embeddings = self.embedder.encode(edits)
        embeddings = torch.tensor(embeddings, device=self.device)

        if isinstance(self.edit_embeds, type(None)):
            self.edit_embeds = embeddings
            self.edit_strings = edits
            self.edit_answers = answers
        else:
            self.edit_embeds = torch.vstack(self.edit_answers, embeddings)
            self.edit_strings.extend(edits)
            self.edit_answers.extend(answers)

    def get_entity(self, text, ner_model):
        final_entity = ''
        entity = ner_model.process_text(text)
        if len(entity) > 0:
            # entity = entity[-1]
            # final_entity = entity.text
            if entity[-1].coarse_mention_type in self.bad_entity_types and len(entity) > 1:
                entity = entity[-2]
                final_entity = entity.text
            else:
                entity = entity[-1]
                final_entity = entity.text
            
            if not isinstance(entity.predicted_entity, type(None)): 
                if not isinstance(entity.predicted_entity.wikipedia_entity_title, type(None)):
                    final_entity = entity.predicted_entity.wikipedia_entity_title
        return final_entity

    def query_model(self, input_text, stop_strings=['\n'], num_new_tokens=50, temperature=0.001):
        tokens = self.llm_tok(input_text, padding=False, truncation=False, return_tensors='pt').to(self.device)
    
        self.in_tok += tokens.input_ids.shape[1]

        if tokens.input_ids.shape[1] > 2048-num_new_tokens-1:
            return ''

        
        generated_tokens = self.llm.generate(tokens.input_ids, 
                                             attention_mask=tokens.attention_mask,
                                             do_sample=True,
                                             temperature=temperature,
                                             max_new_tokens=num_new_tokens,
                                             stop_strings=stop_strings,
                                             tokenizer=self.llm_tok)
        
        self.out_tok += generated_tokens[0][tokens.input_ids.shape[1]:].shape[0]
                                        
        generated_text = self.llm_tok.decode(generated_tokens[0], skip_special_tokens=True)[len(input_text):]
        return generated_text

    def get_nearest_embedding(self, current_text, current_embed, comparison_embeds, answer_choices, tau=0.8, k=10, offset=0.1):
        current_embed = torch.vstack([current_embed.unsqueeze(0)] * comparison_embeds.shape[0])

        similarity = torch.nn.functional.cosine_similarity(current_embed, comparison_embeds, dim=1)
        similarity = torch.where(similarity > tau - offset, similarity, 0.0)

        if torch.max(similarity, dim=0).values.item() < tau - offset:
            return None
        else:
            top_idx = torch.topk(similarity, k=k, dim=0).indices
            top_text = [self.edit_strings[idx] for idx in top_idx]
            
            reranked = self.ce.rank(current_text, top_text)
            top = max(reranked, key=lambda d: d['score'])
            if top['score'] < tau:
                return None
            else:
                return answer_choices[top_idx[top['corpus_id']]]

    def answer_question(self, question, ner_model, tau=0.8, offset=0.1, k=10):
        q_tok = self.extractor_tok(question, padding=False, truncation=False, return_tensors='pt').to(self.device)
        q_out = self.extractor(q_tok.input_ids,
                               token_type_ids=None,
                               attention_mask=q_tok.attention_mask)
        q_preds = torch.argmax(q_out.logits, dim=-1)[0]
        
        r_split = {}
        for idx, class_val in enumerate(q_preds):
            class_val = class_val.item()
            if  class_val != 0:
                if class_val not in r_split.keys():
                    r_split[class_val] = [self.extractor_tok.decode(q_tok.input_ids[0, idx])]
                else:
                    r_split[class_val].append(self.extractor_tok.decode(q_tok.input_ids[0, idx]))
        r_keys = list(r_split.keys())
        r_keys.sort()
        
        r_chain = []
        for key in r_keys:
            r = ' '.join(r_split[key])
            if r.startswith(' '):
                r = r[1:]
            r_chain.append(r)

        s = self.get_entity(question, ner_model)

        for r in r_chain:
            sr_text = ' '.join([s, r]).lower()
            sr_embed = self.embedder.encode(' '.join([s, r]).lower())
            sr_embed = torch.tensor(sr_embed, device=self.device)
            o = self.get_nearest_embedding(sr_text, sr_embed, self.edit_embeds, self.edit_answers, tau=tau, offset=offset, k=k)

            if isinstance(o, type(None)):
                completion_prompt = f'{self.triple_compl[0]}\n\nProvide the object for the following triple: | {s} | {r} |\nObject:'
                o = self.query_model(completion_prompt,
                                     stop_strings=['|'],
                                     num_new_tokens=50,
                                     temperature=0.0001)
                
                o = o.replace('\n', '')
                if o.startswith(' '):
                    o = o[1:]
                if o.endswith('|'):
                    o = o[:-1]
                    
            s = o
        
        return s, len(r_chain)
