# coding=utf-8
import csv
import os
import numpy as np
import random
import argparse

#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
from datasets import load_dataset, load_metric
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoConfig,
    AdamW,
    set_seed,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,        
    TrainingArguments, 
    Trainer)

from sklearn import preprocessing
from collections import Counter
from sklearn.model_selection import train_test_split

from utils import (filter_language, 
                compute_metrics, 
                read_data,
                get_data_for_question, 
                random_synthetic_sample, 
                get_synthetic_label,
                get_synthetic_label_k_cal,
                get_random_synthetic_label,
                get_config,
                SDDataset)

print(torch.cuda.is_available())

# Create the parser
parser = argparse.ArgumentParser()

# Add the arguments
parser.add_argument('--synth_size', type=int, default=200, help='The synthetic size')
parser.add_argument('--query_strategy', type=str, default='al+synth', help='The query strategy')
parser.add_argument('--num_runs', type=int,  default=3, help='The number of runs')

# Parse the arguments
args = parser.parse_args()

# This should be the path to the checkpoint of the pre-trained model on stance detection
checkpoint = "output/checkpoint-5000"

# initialize tokenizer and data collator
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# TRAIN STANCE DETECTION AGENT
x_stance = load_dataset("x_stance")
 #filter language
x_stance_train = filter_language(x_stance["train"],"de")
x_stance_test = filter_language(x_stance["test"],"de")
x_stance_val = filter_language(x_stance["validation"],"de")

x_stance_test_single = {"question": [], "comment": [], "label": []}
ref_question1 = "Sollen sich die Versicherten stärker an den Gesundheitskosten beteiligen (z.B. Erhöhung der Mindestfranchise)?"
ref_question2 = "Befürworten Sie ein generelles Werbeverbot für Alkohol und Tabak?"
ref_question3 = "Soll eine Impfpflicht für Kinder gemäss dem schweizerischen Impfplan eingeführt werden?"
ref_question4 = "Soll die Aufenthaltserlaubnis für Migrant/innen aus Nicht-EU/EFTA-Staaten schweizweit an die Erfüllung verbindlicher Integrationsvereinbarungen geknüpft werden?"
ref_question5 = "Soll der Bund erneuerbare Energien stärker fördern?"
ref_question6 = "Befürworten Sie eine strengere Kontrolle der Lohngleichheit von Frauen und Männern?"
ref_question7 = "Soll der Bund die finanzielle Unterstützung für die berufliche Weiterbildung und Umschulung ausbauen?"
ref_question8 = "Sollen in der Schweiz vermehrt Spitäler geschlossen werden, um die Kosten im Gesundheitsbereich zu senken?"
ref_question9 = "Eine eidenössische Volksinitiative verlangt, dass der Bundesrat direkt vom Volk gewählt werden soll. Unterstützen Sie dieses Anliegen?"
ref_question10 = "Soll die Einführung der elektronischen Stimmabgabe bei Wahlen und Abstimmungen (E-Voting) weiter vorangetrieben werden?"

ref_questions = [ref_question1, ref_question2, ref_question3, ref_question4, ref_question5, ref_question6, ref_question7, ref_question8, ref_question9, ref_question10]
rng = random.Random(42)

for j, ref_question in enumerate(ref_questions):
    tq, tc, tl, ti = get_data_for_question(ref_question, x_stance_test)

    tq_split_train, tq_split_test, tc_split_train, tc_split_test, tl_split_train, tl_split_test = train_test_split(tq, tc, tl, test_size=0.4, random_state=42)

    print("TRAIN_SPLIT: ", len(tq_split_train))
    print("TEST_SPLIT: ", len(tq_split_test))
    print("TRAIN_CLASS_BALANCE: ", Counter(tl_split_train))
    print("TEST_CLASS_BALANCE: ", Counter(tl_split_test))
    print("TOTAL_DATA_LENGTH: ", len(tq))
    print("TOTAL_CLASS_BALANCE: ", Counter(tl))
    print("INDEX: ", j+1)

    batch_size = len(tq_split_train)
    num_runs = args.num_runs
    for synth_size in [args.synth_size]:
        pos_question, pos_answer, pos_label, neg_question, neg_answer, neg_label = read_data("data/answers_test_single_500_2_"+str(j+1)) # read synthetic data

        model = AutoModelForSequenceClassification.from_pretrained(checkpoint).cuda()

        # Define optimizer
        optimizer = AdamW(model.parameters(), lr=1e-5)        

        for test_type in [args.query_strategy]:
            train, random_synth, augmented, augmented_no_test, embedding, ambiguous_samples = get_config(test_type)
            # Get all the synthetic data and setup training splits
            
            # Set up path to log results
            dir = 'stance_test_synth_numrunstest_75/synth_'+str(synth_size)+'_single/'+test_type+'/question'+ str(j+1)
            if not os.path.exists(dir):
                os.makedirs(dir)
            path = dir+"/logs.csv"
            seeds = [1, 2, 3, 4, 5]
            with open(path,  mode='a') as file:
                writer = csv.writer(file, delimiter =",",quoting=csv.QUOTE_MINIMAL) 
                writer.writerow(["accuracy", "f1", "num_samples", "relabel_count", "threshold"])

            for i in range(num_runs):
                tq_train, tc_train, tl_train = tq_split_train, tc_split_train, tl_split_train
                print("ITERATION: ", i)
                tq_it = tq_train
                tc_it = tc_train
                tl_it = tl_train
                for threshold in [int(batch_size*0.75), int(batch_size*0.50), int(batch_size*0.25), int(batch_size*0.10)]:
                    set_seed(42+i)
                    model = AutoModelForSequenceClassification.from_pretrained(checkpoint).cuda()

                    # Define optimizer
                    optimizer = AdamW(model.parameters(), lr=1e-5)
                    
                    print("THRESHOLD: ", threshold)
                    pos_q, pos_c, pos_l, neg_q, neg_c, neg_l = [], [], [], [], [], []

                    batch_count = Counter(tl_it)
                    # Get random synthetic data, at the moment all M of it is used
              
                    pos_q, pos_c, pos_l, neg_q, neg_c, neg_l = random_synthetic_sample(i, batch_size, tl_it, (pos_question, pos_answer, pos_label), (neg_question, neg_answer, neg_label), M=synth_size)
                    
                    # Get the chose samples based on SQBC
                    print("GET SYNTH DATA")
                    if random_synth:
                        m_q, m_c, m_l, relabel_count = get_random_synthetic_label(model, tokenizer, tq_it, tc_it, tl_it, threshold, rng)
                    else:
                        if test_type == "al+synth" or test_type == "al":
                            m_q, m_c, m_l, relabel_count = get_synthetic_label(model, tokenizer, tq_it, tc_it, tl_it, (pos_question, pos_answer, pos_label), (neg_question, neg_answer, neg_label), threshold, M=synth_size)
                        elif test_type == "cal+synth" or test_type == "cal":
                            m_q, m_c, m_l, relabel_count = get_synthetic_label_k_cal(model, tokenizer, tq_it, tc_it, tl_it, (pos_question, pos_answer, pos_label), (neg_question, neg_answer, neg_label), threshold, M=synth_size, k=int(synth_size/2))

                    # Different SQBC options
                    if ambiguous_samples:
                        tq_batch = m_q
                        tc_batch = m_c
                        tl_batch = m_l
                    else:
                        tq_batch = tq_it
                        tc_batch = tc_it
                        tl_batch = tl_it
                        relabel_count = 0

                    # Add synthetic data to the training data
                    if augmented:
                        tq_batch = tq_batch + neg_q + pos_q
                        tc_batch = tc_batch + neg_c + pos_c
                        tl_batch = tl_batch + neg_l + pos_l
                    if augmented_no_test: # This is baseline + synth
                        tq_batch = neg_q + pos_q
                        tc_batch = neg_c + pos_c
                        tl_batch = neg_l + pos_l


                    # Prepare data for training and testing                       
                    label_encoder = preprocessing.LabelEncoder()
                    train_labels = label_encoder.fit_transform(tl_batch)
                    train_labels = torch.as_tensor(train_labels)
                    
                    train_encodings = tokenizer(tq_batch, tc_batch, return_tensors='pt', 
                                                max_length = 512, padding='max_length', truncation=True)
                    
                    
                    test_labels = label_encoder.fit_transform(tl_split_test)
                    test_labels = torch.as_tensor(test_labels)
                    
                    test_encodings = tokenizer(tq_split_test, tc_split_test, return_tensors='pt', 
                                                max_length = 512, padding='max_length', truncation=True)
                    
                    traindata = SDDataset(train_encodings, train_labels)
                    testdata = SDDataset(test_encodings, test_labels)

                                                
                    # train using a Trainer
                    training_args = TrainingArguments(
                            output_dir='./output/'+checkpoint+'_incremental_'+str(i)+'_batch_size_'+str(batch_size),    # output directory
                            num_train_epochs=3,              # total number of training epochs
                            per_device_train_batch_size=20,  # batch size per device during training
                            warmup_steps=500,                # number of warmup steps for learning rate scheduler
                            weight_decay=0.01,               # strength of weight decay
                            logging_dir='./logs',            # directory for storing logs
                            logging_steps=10000,
                            save_total_limit=2,
                            evaluation_strategy="no", # define the steps when the model should be evaluated
                            do_eval=False
                        )

                    trainer = Trainer(
                        model=model,                                 # the instantiated 🤗 Transformers model to be trained
                        args=training_args,                          # training arguments, defined above
                        train_dataset=traindata,                     # training dataset
                        tokenizer=tokenizer,                         # used tokenizer
                        data_collator=data_collator,
                        compute_metrics=compute_metrics,              # define metric for evaluation
                    ) 

                    if train:
                        trainer.train()

                    predictions = trainer.predict(testdata)
                    preds = np.argmax(predictions.predictions, axis=-1)            
                    metric = load_metric('glue','mrpc')
                    final_score = metric.compute(predictions=preds, references=predictions.label_ids)             
                    print(final_score)

                    with open(path,  mode='a') as file:
                        writer = csv.writer(file, delimiter =",",quoting=csv.QUOTE_MINIMAL) 
                        writer.writerow([final_score["accuracy"], final_score["f1"], batch_count, relabel_count, threshold])

                    del model
                    torch.cuda.empty_cache()
