from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    default_data_collator,
    AutoModelForMaskedLM,
    get_scheduler,
    AutoConfig,
    RobertaForMaskedLM,
    EncoderDecoderModel,
    GenerationConfig,
    DefaultDataCollator,
)
import evaluate
from codebleu import calc_codebleu
from accelerate import Accelerator
from dae_bt_data_collator import DataCollatorForUnsupervisedTranslation

from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch

from accelerate import DistributedDataParallelKwargs

import pandas as pd

import numpy as np

from tqdm.auto import tqdm

import math
import os

import csv

# import config as CFG
from utility import (
    whole_word_masking_data_collator,
    EarlyStoppingCallback,
    shift_tokens_right,
)


class Evaluate:
    def __init__(self, args):
        self.args = args
        if self.args.train_mode == 'eval':
            from dataset_eval import CUDAizerDataset
        elif 'fortran' in self.args.langs:
            from dataset_fortran_cpp import CUDAizerDataset 
        else:
            from dataset import CUDAizerDataset

        # Accelerator
        self.accelerator = Accelerator(
            mixed_precision=self.args.quant,
            split_batches=False,
            kwargs_handlers=[
                DistributedDataParallelKwargs(find_unused_parameters=True)
            ],
            gradient_accumulation_steps=self.args.accumulation_steps,
        )

        with self.accelerator.main_process_first():
            self.dataset = CUDAizerDataset(args=args)
            self.args.logger.info(self.dataset("test"))

            self.input_directory = (
                args.output_dir + "_ft"
                if os.path.isdir(args.output_dir + "_ft")
                else args.output_dir + "_daebt"
            )
            print(f'Input Directoru: {self.input_directory}')
            #self.output_dir = args.output_dir + "_daebt"

            # Load the model
            self.model = EncoderDecoderModel.from_pretrained(self.input_directory)
        self.accelerator.wait_for_everyone()

        # create data collator object

        self.valid_data_collator = DefaultDataCollator(return_tensors="pt")

        # metric
        self.metric = evaluate.load("bleu")
        self.codebleu_metric = evaluate.load("k4black/codebleu")

    def prepare_data(self):
        # train_dataloader = DataLoader(
        #     self.dataset("train"),
        #     shuffle=False,
        #     batch_size=self.args.batch_size,
        #     collate_fn=self.valid_data_collator,
        #     num_workers=self.args.num_process,
        #     pin_memory=True,
        # )

        eval_dataloader = DataLoader(
            self.dataset("valid"),
            shuffle=False,
            batch_size=self.args.batch_size,
            collate_fn=self.valid_data_collator,
            num_workers=self.args.num_process,
            pin_memory=True,
        )

        test_dataloader = DataLoader(
            self.dataset("test"),
            shuffle=False,
            batch_size=self.args.batch_size,
            collate_fn=self.valid_data_collator,
            num_workers=self.args.num_process,
            pin_memory=True,
        )

        return eval_dataloader, test_dataloader

    def train(self):
        self.train_dae_bt_using_accelerator()

    def train_dae_bt_using_accelerator(self):

        # Early Stopping Callback
        if self.args.enable_early_stopping:
            early_stopping = EarlyStoppingCallback(
                threshold=self.args.early_stopping_threshold,
                patience=self.args.early_stopping_patience,
            )

        # Keeping track of best results
        max_best_score = float("-inf")

        eval_dataloader, test_dataloader = self.prepare_data()

        #optimizer = AdamW(self.model.parameters(), lr=self.args.learning_rate_bt)

        # num_update_steps_per_epoch = len(train_dataloader)
        # num_epochs = self.args.num_train_epochs_bt
        # num_training_steps = num_epochs * num_update_steps_per_epoch
        # if self.args.max_steps > 0:
        #     num_epochs = (self.args.max_steps // num_update_steps_per_epoch) + 1
        # num_epochs = 1 
        # self.args.logger.info(f"Number of epochs: {num_epochs}")
        # self.args.logger.info(f"Number of training_steps: {num_training_steps}")

        # # Setting number of warmup steps
        # if self.args.num_warmup_steps > 0:
        #     num_warmup_steps = self.args.num_warmup_steps
        # elif self.args.percent_warmup_steps > 0:
        #     num_warmup_steps = int(num_training_steps * self.args.percent_warmup_steps)
        # else:
        #     num_warmup_steps = 0

        # self.args.logger.info(f"Number of warmup steps: {num_warmup_steps}")

        # lr_scheduler = get_scheduler(
        #     self.args.scheduler_type,
        #     optimizer=optimizer,
        #     num_warmup_steps=num_warmup_steps,
        #     num_training_steps=num_training_steps // self.args.accumulation_steps,
        # )

        (
            self.model,
            eval_dataloader,
            test_dataloader,
        ) = self.accelerator.prepare(
            self.model,
            eval_dataloader,
            test_dataloader,
        )
        max_new_tokens=self.args.chunk_size
        num_beam = self.args.num_beam
        num_return_sequences = self.args.num_return_sequences
        print(f'Max new tokens: {max_new_tokens}')
        print(f'Number of beams: {num_beam}')
        for epoch in range(self.args.num_train_epochs_bt):

            self.model.eval()
            
            valset_decoded_predictions, valset_decoded_labels = [], []
            # for batch in tqdm(eval_dataloader):
            #     src_lan = f'{self.dataset.tokenizer.convert_ids_to_tokens([batch["lang"][0]])[0]}'
            #     if src_lan == self.args.langs[0]:
            #         decoder_start_token_id = (
            #             self.dataset.tokenizer.convert_tokens_to_ids(
            #                 f"<{self.args.langs[1].upper()}>"
            #             )
            #         )
            #     else:
            #         decoder_start_token_id = (
            #             self.dataset.tokenizer.convert_tokens_to_ids(
            #                 f"<{self.args.langs[0].upper()}>"
            #             )
            #         )

            #     generation_config = GenerationConfig(
            #         max_new_tokens=self.args.chunk_size,
            #         decoder_start_token_id=decoder_start_token_id,
            #         pad_token_id=self.dataset.tokenizer.pad_token_id,
            #         bos_token_id=self.dataset.tokenizer.bos_token_id,
            #         eos_token_id=self.dataset.tokenizer.eos_token_id,
            #         num_beams = num_beam,
            #         num_return_sequences=num_return_sequences,
            #     )
            #     with torch.no_grad():
            #         generated_tokens = self.accelerator.unwrap_model(
            #             self.model
            #         ).generate(
            #             batch["input_ids"],
            #             attention_mask=batch["attention_mask"],
            #             generation_config=generation_config,
            #         )
            #     labels = batch["labels"]

            #     # Necessary to pad predictions and labels for being gathered
            #     generated_tokens = self.accelerator.pad_across_processes(
            #         generated_tokens,
            #         dim=1,
            #         pad_index=self.dataset.tokenizer.pad_token_id,
            #     )
            #     labels = self.accelerator.pad_across_processes(
            #         labels, dim=1, pad_index=-100
            #     )

            #     predictions_gathered = self.accelerator.gather(generated_tokens)
            #     labels_gathered = self.accelerator.gather(labels)

            #     decoded_preds, decoded_labels, decoded_labels_no_bracket = (
            #         self.postprocess(predictions_gathered, labels_gathered)
            #     )
            #     self.metric.add_batch(
            #         predictions=decoded_preds, references=decoded_labels*num_return_sequences
            #     )
            #     self.codebleu_metric.add_batch(
            #         predictions=decoded_preds, references=decoded_labels*num_return_sequences
            #     )

            #     valset_decoded_predictions.extend(decoded_preds)
            #     valset_decoded_labels.extend(np.repeat(decoded_labels_no_bracket,num_return_sequences).tolist())

            # results = self.metric.compute()
            # code_blue_results = self.codebleu_metric.compute(lang='cpp', weights=(0.30, 0.25, 0.25, 0.25), tokenizer=None)
            # self.accelerator.print(
            #     f"[VALID]: BLEU: {results['bleu']}, CodeBLEU: {code_blue_results['codebleu']}"
            # )
            testset_decoded_predictions, testset_decoded_labels = [], []
            # Test set
            for batch in tqdm(test_dataloader):
                src_lan = f'{self.dataset.tokenizer.convert_ids_to_tokens([batch["lang"][0]])[0]}'
                if src_lan == self.args.langs[0]:
                    decoder_start_token_id = (
                        self.dataset.tokenizer.convert_tokens_to_ids(
                            f"<{self.args.langs[1].upper()}>"
                        )
                    )
                else:
                    decoder_start_token_id = (
                        self.dataset.tokenizer.convert_tokens_to_ids(
                            f"<{self.args.langs[0].upper()}>"
                        )
                    )
                # decoder_start_token_id = self.dataset.tokenizer.convert_tokens_to_ids(decoder_start_token)
                generation_config = GenerationConfig(
                    max_new_tokens=self.args.chunk_size,
                    decoder_start_token_id=decoder_start_token_id,
                    pad_token_id=self.dataset.tokenizer.pad_token_id,
                    bos_token_id=self.dataset.tokenizer.bos_token_id,
                    eos_token_id=self.dataset.tokenizer.eos_token_id,
                    num_beams = num_beam,
                    num_return_sequences=num_return_sequences,
                )
                with torch.no_grad():
                    generated_tokens = self.accelerator.unwrap_model(
                        self.model
                    ).generate(
                        batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        generation_config=generation_config
                    )
                labels = batch["labels"]

                # Necessary to pad predictions and labels for being gathered
                generated_tokens = self.accelerator.pad_across_processes(
                    generated_tokens,
                    dim=1,
                    pad_index=self.dataset.tokenizer.pad_token_id,
                )
                labels = self.accelerator.pad_across_processes(
                    labels, dim=1, pad_index=-100
                )

                predictions_gathered = self.accelerator.gather(generated_tokens)
                labels_gathered = self.accelerator.gather(labels)

                decoded_preds, decoded_labels, decoded_labels_no_bracket = (
                    self.postprocess(predictions_gathered, labels_gathered)
                )
                self.metric.add_batch(
                    predictions=decoded_preds, references=decoded_labels*num_return_sequences
                )
                self.codebleu_metric.add_batch(
                    predictions=decoded_preds, references=decoded_labels*num_return_sequences
                )
                testset_decoded_predictions.extend(decoded_preds)
                testset_decoded_labels.extend(np.repeat(decoded_labels_no_bracket,num_return_sequences).tolist())

            results = self.metric.compute()
            code_blue_results = self.codebleu_metric.compute(lang='cpp', weights=(0.10, 0.10, 0.40, 0.40), tokenizer=None)
            self.accelerator.print(
                f"[TEST]: BLEU: {results['bleu']}, CodeBLEU: {code_blue_results['codebleu']}"
            )
            self.accelerator.print(f"CodeBLUE: {calc_codebleu(testset_decoded_labels, testset_decoded_predictions, lang='cpp', weights=(0.10, 0.10, 0.40, 0.40), tokenizer=None)['codebleu']}")

            if self.accelerator.is_main_process:
                # Calculate CodeBLEU and BLEU per each example
                val_codebleu = []
                val_bleu = []
                test_codebleu = []
                test_bleu = []
                for ref, pre in zip(valset_decoded_labels, valset_decoded_predictions):
                    result = calc_codebleu([ref], [pre], lang="cpp", weights=(0.10, 0.10, 0.40, 0.40), tokenizer=None)
                    val_codebleu.append(result['codebleu'])
                
                for ref, pre in zip(testset_decoded_labels, testset_decoded_predictions):
                    result = calc_codebleu([ref], [pre], lang="cpp", weights=(0.10, 0.10, 0.40, 0.40), tokenizer=None)
                    test_codebleu.append(result['codebleu'])
                print(len(test_codebleu))
                print(len(testset_decoded_labels))
                print(len(testset_decoded_predictions))

                try:
                    data = {
                        "Label": testset_decoded_labels,
                        "Prediction": testset_decoded_predictions,
                        "CodeBLEU": test_codebleu,
                    }
                    df = pd.DataFrame(data)
                    df.to_csv(f"test_set_prediction_eval_{self.args.langs[1]}_beam{num_beam}_seq{num_return_sequences}.csv", index=False, sep="|")

                    data = {
                        "Label": valset_decoded_labels,
                        "Prediction": valset_decoded_predictions,
                    }
                    df = pd.DataFrame(data)
                    df.to_csv(f"val_set_prediction_eval_{self.args.langs[1]}_beam{num_beam}_seq{num_return_sequences}.csv", index=False)
                except Exception as e:
                    print(e)

    def postprocess(self, predictions, labels):
        predictions = predictions.cpu().numpy()
        labels = labels.cpu().numpy()

        decoded_preds = self.dataset.tokenizer.batch_decode(
            predictions, skip_special_tokens=True
        )

        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, self.dataset.tokenizer.pad_token_id)
        decoded_labels = self.dataset.tokenizer.batch_decode(
            labels, skip_special_tokens=True
        )

        # Some simple post-processing
        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [[label.strip()] for label in decoded_labels]
        decoded_labels_no_bracket = [label[0].strip() for label in decoded_labels]
        return decoded_preds, decoded_labels, decoded_labels_no_bracket