# -*- coding: utf-8 -*-
"""
Created on Fri Feb 14 00:57:18 2025

@author: baran
"""

import os
import json
import openai
#import langchain
from openai import AzureOpenAI
from base_key import *
#from langchain_openai import OpenAIEmbeddings
import numpy as np
#import spacy
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

#from gensim.models.doc2vec import Doc2Vec, TaggedDocument
#from prompt_maker import input_maker

#from embeddings import GloveEmbedding, FastTextEmbedding, KazumaCharEmbedding, ConcatEmbedding
from sentence_transformers import SentenceTransformer
#import tensorflow_hub as hub
# from dotenv import load_dotenv

# load_dotenv()

# print("API Key:", os.getenv("OPENAI_API_KEY"))
# print("API Base:", os.getenv("OPENAI_API_BASE"))

# openai.api_type = "azure"
# openai.api_key=api_B
# openai.api_version="2024-02-01"
# openai.api_base = endpoint_B

# def get_emb(x):
#     client = AzureOpenAI(
#     api_key = api_B,  
#     api_version = "2024-02-01",
#     azure_endpoint = embed
#     )
#     text = x

#     response = client.embeddings.create(
#         input = text,
#         model= "text-embedding-ada-002"
#     )
#     return response.data[0].embedding[0:128]

# def get_emb(x):
#     model = SentenceTransformer('all-MiniLM-L6-v2')
#     sentence_embedding = model.encode(x)
#     #g =  GloveEmbedding(name='wikipedia_gigaword', d_emb=300)
#     #model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
#     #embedding = model([x])
#     #nlp = spacy.load("en_core_web_md")
#     #doc = nlp(x)
#     return sentence_embedding

# def get_emb(documents,t,i,j,sum_descript_len,diag_descript_len,input_rep_len,window_size,emb_dim,epoch,inp_model):
#     # input_reports = list(input_reports)
#     #print(t)
#     #print(i)
    
#     if (i==1 and j==0):
#         tagged_data = [TaggedDocument(words=doc.split(), tags=[str(i)]) for i, doc in enumerate(documents)]
#         model = Doc2Vec(vector_size=emb_dim,  # Dimensionality of the feature vectors
#                     window=window_size,         # Context window size
#                     min_count=2,      # Ignores words with total frequency lower than this
#                     workers=4,        # Number of worker threads for training
#                     epochs=epoch)        # Number of training epochs
    
#         # Build vocabulary from tagged data
#         model.build_vocab(tagged_data)
    
#         # Train the model
#         model.train(tagged_data, total_examples=model.corpus_count, epochs=model.epochs)
#     else:
#         model=inp_model
#     #doc_id = 0  # Index of the document
    # if i == 0:
    #     doc_id = j
    #     rep_id = t+sum_descript_len+diag_descript_len
    # elif i==1:
    #     doc_id = j+sum_descript_len
    #     rep_id = t+input_rep_len+sum_descript_len+diag_descript_len
#     # elif word == "description":
#     #     doc_id = t+4
#     doc_vec = model.dv[doc_id]
#     rep_vec = model.dv[rep_id]  # 'dv' refers to the document vectors
#     #print(f"Embedding for Document {doc_id}:", vector)
#     return model,doc_vec,rep_vec

def get_emb(documents,t,i,j,sum_descript_len,diag_descript_len,input_rep_len,model,dataset):
    #model = SentenceTransformer("all-MiniLM-L6-v2")
    if dataset == "medical":
        if i == 0:
            doc_id = j
            rep_id = t+sum_descript_len+diag_descript_len
        elif i==1:
            doc_id = j+sum_descript_len
            rep_id = t+input_rep_len+sum_descript_len+diag_descript_len
        # print("here")
        # print(doc_id)
        # print(rep_id)
        # print("now")
        #emb_doc = model.encode(documents[doc_id],normalize_embeddings=True)
        emb_doc = model.encode(documents[doc_id],normalize_embeddings=False)
        #emb_rep = model.encode(documents[rep_id],normalize_embeddings=True)
        emb_rep = model.encode(documents[rep_id],normalize_embeddings=False)
        return emb_doc,emb_rep
    elif dataset == "telecom":
        doc_id = j
        rep_id = t+diag_descript_len
        emb_doc = model.encode(documents[doc_id],normalize_embeddings=False)
        emb_rep = model.encode(documents[rep_id],normalize_embeddings=False)
        return emb_doc,emb_rep

        

def get_context(documents,t,i,j,sum_descript_len,diag_descript_len,input_rep_len,inp_model,dataset):
    # if i == 0:
    #     task_word = "summary"
    # else:
    #     task_word = "diagnosis"
    #task_embedding = get_emb(documents,t,i,j,sum_descript_len,diag_descript_len)
    #print(task_embedding)
    #description_embedding = get_emb(documents,t,i,j,sum_descript_len,diag_descript_len) 
    #print(description_embedding)
    descript_array_embed, task_array_embed = get_emb(documents,t,i,j,sum_descript_len,diag_descript_len,input_rep_len,inp_model,dataset)
    task_array = np.array(task_array_embed).astype(np.float64)  
    description_array = np.array(descript_array_embed).astype(np.float64)  

    context = task_array * description_array
    #context = np.concatenate((task_array,description_array))
    #epsilon = 1e-5
    #normalize = np.sqrt(task_array**2)*np.sqrt(description_array**2)

    return context