import csv
import pandas as pd
import numpy as np
from datasets import load_dataset, load_metric
import torch
from torch.utils.data import DataLoader, Dataset
from collections import Counter
import random
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import entropy
from scipy.special import softmax
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

import plotly.express as px
import plotly.io as pio
from plotly.graph_objs import Scatter


#define dataset class
class SDDataset(Dataset):
    def __init__(self, encodings,labels):
        self.encodings = encodings
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    

def filter_language(x_stance_data, language):
    data_dict =  {"question":[],"comment":[],"label":[]}
    for lang, question, comment, label in zip(x_stance_data["language"],x_stance_data["question"],x_stance_data["comment"],x_stance_data["label"]):
        if lang==language:
            data_dict["question"].append(question)
            data_dict["comment"].append(comment)
            data_dict["label"].append(label)
    return data_dict

def filter_comments(x_stance_data, language):
    data_dict =  {"comment":[],"label":[]}
    for lang, comment, label in zip(x_stance_data["language"],x_stance_data["comment"],x_stance_data["label"]):
        if lang==language:
            data_dict["comment"].append(comment)
            data_dict["label"].append(label)
    return data_dict

# compute metrics function
def compute_metrics(eval_preds):
    metric = load_metric("glue","mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


def read_data(path):
    pos_question = []
    pos_answer = []
    pos_label = []
    neg_question = []
    neg_answer = []
    neg_label = []
    
    with open(path+'_FAVOR.csv', 'r') as file:
        reader = csv.reader(file, delimiter = ',')
        for row in reader:
            pos_question.append(row[0])
            pos_answer.append(row[2])
            pos_label.append("FAVOR")


    with open(path+'_AGAINST.csv', 'r') as file:
        reader = csv.reader(file, delimiter = ',')
        for row in reader:
            neg_question.append(row[0])
            neg_answer.append(row[2])
            neg_label.append("AGAINST")
        
        
    return pos_question, pos_answer, pos_label, neg_question, neg_answer, neg_label


def read_all_synth_data(path, num_synth_files):
    pos_question = []
    pos_answer = []
    pos_label = []
    neg_question = []
    neg_answer = []
    neg_label = []

    for i in range(1, num_synth_files+1):
        with open(path+str(i)+'_FAVOR.csv', 'r') as file:
            reader = csv.reader(file, delimiter = ',')
            for row in reader:
                pos_question.append(row[0])
                pos_answer.append(row[2])
                pos_label.append("FAVOR")


        with open(path+str(i)+'_AGAINST.csv', 'r') as file:
            reader = csv.reader(file, delimiter = ',')
            for row in reader:
                neg_question.append(row[0])
                neg_answer.append(row[2])
                neg_label.append("AGAINST")

    return pos_question, pos_answer, pos_label, neg_question, neg_answer, neg_label

def read_data_oracle(ref_question, data):
    pos_question = []
    pos_answer = []
    pos_label = []
    neg_question = []
    neg_answer = []
    neg_label = []

    for i in range(len(data["question"])):
        if ref_question == data["question"][i]:
            if data["label"][i] == "FAVOR":
                pos_question.append(data["question"][i])
                pos_answer.append(data["comment"][i])
                pos_label.append("FAVOR")

            elif data["label"][i] == "AGAINST":
                neg_question.append(data["question"][i])
                neg_answer.append(data["comment"][i])
                neg_label.append("AGAINST")
            
    return pos_question, pos_answer, pos_label, neg_question, neg_answer, neg_label


def get_data_for_question(ref_question, data):
    tq = []
    tc = []
    tl = []
    ti = []

    for i in range(len(data["question"])):
        if ref_question == data["question"][i]:
            tq.append(data["question"][i])
            tc.append(data["comment"][i])
            tl.append(data["label"][i])
            ti.append(i)
            
    return tq, tc, tl, ti

def get_data_balance(question, answer, label, split):
    counts = Counter(label)
    print(counts)

    question = np.array(question)
    answer = np.array(answer)
    label = np.array(label)

    pos_indeces = []
    neg_indeces = []
    for i in range(len(question)):
        if label[i] == "FAVOR":
            pos_indeces.append(i)
        elif label[i] == "AGAINST":
            neg_indeces.append(i)

    print(len(pos_indeces) / len(pos_indeces + neg_indeces))
    print(len(neg_indeces) / len(pos_indeces + neg_indeces))
    print(len(pos_indeces))
    print(len(neg_indeces))
    test_pos_indeces = random.sample(pos_indeces, 0)
    test_neg_indeces = random.sample(neg_indeces, 0)
    remaining_neg_indeces = list(set(neg_indeces) - set(test_neg_indeces))
    remaining_pos_indeces = list(set(pos_indeces) - set(test_pos_indeces))
    train_neg_indeces = random.sample(remaining_neg_indeces, int(len(remaining_neg_indeces) * split))
    train_pos_indeces = random.sample(remaining_pos_indeces, int(len(remaining_pos_indeces) * (1-split)))

    print("LEN_POS: ", len(train_pos_indeces))
    print("LEN_NEG: ", len(train_neg_indeces))
    print("LEN_REM_POS: ", len(test_neg_indeces))
    print("LEN_REM_NEG: ", len(test_pos_indeces))

    train_set_indeces = train_pos_indeces + train_neg_indeces
    test_set_indeces = test_pos_indeces + test_neg_indeces
    
    np.random.seed(42)
    np.random.shuffle(train_set_indeces)
    np.random.shuffle(test_set_indeces)    

    train_question = question[train_set_indeces]
    train_answer = answer[train_set_indeces]
    train_label = label[train_set_indeces]
    print(train_set_indeces)

    test_question = question[test_set_indeces]
    test_answer = answer[test_set_indeces]
    test_label = label[test_set_indeces]
    print("COUNTS_TEST: ", Counter(test_label))

    return train_question.tolist(), train_answer.tolist(), train_label.tolist(), test_question.tolist(), test_answer.tolist(), test_label.tolist()


def random_synthetic_sample(i, batch_size, tl_batch, pos, neg, M=200):
    pos_q = []
    pos_c = []
    pos_l = []
    neg_q = []
    neg_c = []
    neg_l = []

    pos_question = pos[0]
    pos_answer = pos[1]
    pos_label = pos[2]

    neg_question = neg[0]
    neg_answer = neg[1]
    neg_label = neg[2]

    #print(tl_batch)
    counts = Counter(tl_batch)
    print(counts)

    
    indices = list(range(int(M/2)))
    neg_indeces = random.sample(indices, int(M/2))
    pos_indeces = random.sample(indices, int(M/2))
    print(neg_indeces)
    
    for index in neg_indeces:
        neg_q.append(neg_question[index])
        neg_c.append(neg_answer[index])
        neg_l.append(neg_label[index])
        
    for index in pos_indeces:
        pos_q.append(pos_question[index])
        pos_c.append(pos_answer[index])
        pos_l.append(pos_label[index])       

    return pos_q, pos_c, pos_l, neg_q, neg_c, neg_l


def get_random_synthetic_label(model, tokenizer, tq_batch, tc_batch, tl_batch, threshold, r):
    indices = r.sample(range(len(tq_batch)), threshold)

    print("INDICES: ", indices)
    manual_question = [tq_batch[i] for i in indices]
    manual_comment = [tc_batch[i] for i in indices]
    manual_label = [tl_batch[i] for i in indices]

    return manual_question, manual_comment, manual_label, threshold


def get_embedding(model, tokenizer, question, answer):
    input_ids = tokenizer(question, answer, return_tensors='pt', 
                               max_length = 512, padding='max_length', truncation=True).to("cuda:0")
    outputs = model(**input_ids, output_hidden_states=True)
    last_hidden_states = outputs.hidden_states[-1][:, 0, :].detach().cpu().numpy()
    return last_hidden_states


def get_output(model, tokenizer, question, answer):
    input_ids = tokenizer(question, answer, return_tensors='pt', 
                               max_length = 512, padding='max_length', truncation=True).to("cuda:0")
    outputs = model(**input_ids)
    logits = outputs.logits.detach().cpu().numpy()
    return logits


def get_synthetic_label(model, tokenizer, tq_batch, tc_batch, tl_batch, pos, neg, threshold, M=200):
    # Get synthetic data
    pos_question = pos[0]
    pos_answer = pos[1]
    #pos_label = pos[2]

    neg_question = neg[0]
    neg_answer = neg[1]

    print("NEG_LENGTH: ", len(neg_question))
    #neg_label = neg[2]

    # Get synthetic embeddings of half positive, half negative samples. 
    # Have to do this in a loop, because it wont fit on GPU
    pos_synthetics_embs = []
    neg_synthetics_embs = []

    random_indices_pos = random.sample(range(len(pos_question)), int(M/2))
    random_indices_neg = random.sample(range(len(neg_question)), int(M/2))
    pos_question_sampled = [pos_question[i] for i in random_indices_pos]
    pos_answer_sampled = [pos_answer[i] for i in random_indices_pos]
    neg_question_sampled = [neg_question[i] for i in random_indices_neg]
    neg_answer_sampled = [neg_answer[i] for i in random_indices_neg]

    for i in range(int(M/2)):
        pos_synthetics_embs.append(get_embedding(model, tokenizer, pos_question_sampled[i], pos_answer_sampled[i]))
        neg_synthetics_embs.append(get_embedding(model, tokenizer, neg_question_sampled[i], neg_answer_sampled[i]))    

    synthetics_embs = np.stack(pos_synthetics_embs + neg_synthetics_embs)

    # Get all similarities. Have to do this in a loop, because it wont fit on GPU
    all_similarities = []
    for i in range(len(tq_batch)):
        embedding = get_embedding(model, tokenizer, tq_batch[i], tc_batch[i])
        similarities = cosine_similarity(embedding, synthetics_embs.squeeze(1))[0]
        all_similarities.append(similarities)
    
    all_similarities = np.stack(all_similarities) # Shape: (106, 200) (unlabeled_sample_size, M)

    # Sort similarities for unlabeled sample and get top M/2 
    top_k = (-all_similarities).argsort(axis=1)[:,:int(M/2)] # Shape: (106, 100) (unlabeled_sample_size, M/2)

    # Check if the top M/2 are positive or negative and then sum the votes
    votes = np.array([np.where(row < int(M/2), 1, 0) for row in top_k])
    vote_sum = np.sum(votes, axis=1)
    
    # DEBUG
    print(min(vote_sum))
    print(max(vote_sum))
    print(min(vote_sum) + threshold)
    print(max(vote_sum) - threshold)
    print(Counter(vote_sum))

    # Get the chosen samples  
    manual_question = []
    manual_comment = []
    manual_label = []

    sorted_vote_sum = np.argsort(abs(vote_sum - int(M/4)))

    #Get first k samples from sorted vote sum
    selected_indices = sorted_vote_sum[:threshold]
    remaining_indices = sorted_vote_sum[threshold:]
    print("LEN SELECTED INDICES: ", selected_indices)

    print("VOTE_SUM: ", abs(vote_sum-50))
    print("VOTE_SUM - 500", sorted(abs(vote_sum - 50)))
    print("INDICES: ", sorted_vote_sum)

    manual_question = np.array(tq_batch)[selected_indices]
    manual_comment = np.array(tc_batch)[selected_indices]
    manual_label = np.array(tl_batch)[selected_indices]

    manual_label_count = len(manual_label)
    print("MANUAL_LABEL_COUNT: ", manual_label_count)

    return manual_question.tolist(), manual_comment.tolist(), manual_label.tolist(), manual_label_count


def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))


def get_synthetic_label_k_cal(model, tokenizer, tq_batch, tc_batch, tl_batch, pos, neg, threshold, M=200, k=10):
    # Get synthetic data
    pos_question = pos[0]
    pos_answer = pos[1]
    #pos_label = pos[2]

    neg_question = neg[0]
    neg_answer = neg[1]

    print("NEG_LENGTH: ", len(neg_question))
    #neg_label = neg[2]

    # Get synthetic embeddings of half positive, half negative samples. 
    # Have to do this in a loop, because it wont fit on GPU
    pos_synthetics_embs = []
    neg_synthetics_embs = []
    pos_synthetics_outputs = []
    neg_synthetics_outputs = []

    random_indices_pos = random.sample(range(len(pos_question)), int(M/2))
    random_indices_neg = random.sample(range(len(neg_question)), int(M/2))
    pos_question_sampled = [pos_question[i] for i in random_indices_pos]
    pos_answer_sampled = [pos_answer[i] for i in random_indices_pos]
    neg_question_sampled = [neg_question[i] for i in random_indices_neg]
    neg_answer_sampled = [neg_answer[i] for i in random_indices_neg]

    for i in range(int(M/2)):
        pos_synthetics_embs.append(get_embedding(model, tokenizer, pos_question_sampled[i], pos_answer_sampled[i]))
        neg_synthetics_embs.append(get_embedding(model, tokenizer, neg_question_sampled[i], neg_answer_sampled[i]))    

    for i in range(int(M/2)):
        pos_synthetics_outputs.append(get_output(model, tokenizer, pos_question_sampled[i], pos_answer_sampled[i]))
        neg_synthetics_outputs.append(get_output(model, tokenizer, neg_question_sampled[i], neg_answer_sampled[i]))
    
    synthetics_embs = np.stack(pos_synthetics_embs + neg_synthetics_embs)
    synthetic_outs = np.stack(pos_synthetics_outputs + neg_synthetics_outputs)

    # Get all similarities. Have to do this in a loop, because it wont fit on GPU
    all_similarities = []
    for i in range(len(tq_batch)):
        embedding = get_embedding(model, tokenizer, tq_batch[i], tc_batch[i])
        similarities = cosine_similarity(embedding, synthetics_embs.squeeze(1))[0]
        all_similarities.append(similarities)
    
    all_similarities = np.stack(all_similarities) # Shape: (106, 200) (unlabeled_sample_size, M)

    # Sort similarities for unlabeled sample and get top k
    top_k = (-all_similarities).argsort(axis=1)[:,:k] # Shape: (106, 100) (unlabeled_sample_size, M/2)


    scores = []
    for i in range(len(tq_batch)):
        p_xl = np.array(synthetic_outs[top_k[i]])
        p_xp = get_output(model, tokenizer, tq_batch[i], tc_batch[i])

        int_scores = 0
        for k in range(len(p_xl)):
            #print(p_xl[k][0])
            #print(p_xp[0])
            int_scores += kl_divergence(softmax(p_xl[k][0]), softmax(p_xp[0]))

        scores.append(int_scores/len(p_xl))
    scores = np.array(scores)
    sorted_indices = scores.argsort()[::-1]

    selected_indices = sorted_indices[:threshold]

    manual_question = np.array(tq_batch)[selected_indices]
    manual_comment = np.array(tc_batch)[selected_indices]
    manual_label = np.array(tl_batch)[selected_indices]

    return manual_question.tolist(), manual_comment.tolist(), manual_label.tolist(), threshold


def get_config(test_type):
    if test_type == "random":
        train=True
        random_synth = True
        augmented = False
        augmented_no_test = False
        embedding = False
        ambiguous_samples = True
    if test_type == "random+synth":
        train=True
        random_synth = True
        augmented = True
        augmented_no_test = False
        embedding = False
        ambiguous_samples = True
    if test_type == "cal":
        train=True
        random_synth = False
        augmented = False
        augmented_no_test = False
        embedding = False
        ambiguous_samples = True
    if test_type == "cal+synth":
        train=True
        random_synth = False
        augmented = True
        augmented_no_test = False
        embedding = False
        ambiguous_samples = True
    if test_type == "al":
        train=True
        random_synth = False
        augmented = False
        augmented_no_test = False
        embedding = False
        ambiguous_samples = True
    if test_type == "al+synth":
        train=True
        random_synth = False
        augmented = True
        augmented_no_test = False
        embedding = False
        ambiguous_samples = True
    if test_type == "baseline":
        train=False
        random_synth = False
        augmented = False
        augmented_no_test = False
        embedding = False
        ambiguous_samples = False
    if test_type == "baseline+synth":
        train=True
        random_synth = False
        augmented = False
        augmented_no_test = True
        embedding = False
        ambiguous_samples = False
    if test_type == "oracle":
        train=True
        random_synth = False
        augmented = False
        augmented_no_test = False
        embedding = False
        ambiguous_samples = False
    if test_type == "oracle+synth":
        train=True
        random_synth = False
        augmented = True
        augmented_no_test = False
        embedding = False
        ambiguous_samples = False

    return train, random_synth, augmented, augmented_no_test, embedding, ambiguous_samples