import logging
import numpy as np
import random
import torch
import datetime
import os
import re
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
from transformers import TrainerCallback
from copy import deepcopy
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel, PeftConfig

dataset2label = {
    "personal": ["tpa", "oa", "ra"],
    "ucc": ["antagonize", "condescending", "hostile"],
    "ghc":     ["vo", "hd", "cv"],
    "imdb":     ["sentiment"],
    "agnews":     ["category"],
}

label2dataset = {label: dataset for dataset,
                 labels in dataset2label.items() for label in labels}


def add_label_noise(example, noise_chance, label_col="label"):
    example['label_flipped'] = np.random.rand() < noise_chance

    example[label_col] = 1 - \
        example[label_col] if example['label_flipped'] else example[label_col]

    return example


def introduce_noise(df, label_col, noise_ratio):
    """
    Introduce noise to the label column of a DataFrame.

    Parameters:
    - df (pd.DataFrame): Input DataFrame.
    - label_col (str): Name of the label column to introduce noise.
    - noise_ratio (float): Ratio of rows to introduce noise to (value between 0 and 1).

    Returns:
    - pd.DataFrame: DataFrame with introduced noise in the label column.
    """

    if label_col not in df.columns:
        raise ValueError(f"Column '{label_col}' not found in the DataFrame.")

    if not (0 <= noise_ratio <= 1):
        raise ValueError("Noise ratio must be a value between 0 and 1.")

    num_rows = int(noise_ratio * len(df))

    noisy_rows = np.random.choice(df.index, size=num_rows, replace=False)

    df.loc[noisy_rows, label_col] = 1 - df.loc[noisy_rows, label_col]

    return df


class EvalOnTrainCallback(TrainerCallback):

    def __init__(self, trainer) -> None:
        super().__init__()
        self._trainer = trainer

    def on_epoch_end(self, args, state, control, **kwargs):
        if control.should_evaluate:
            control_copy = deepcopy(control)
            self._trainer.evaluate(
                eval_dataset=self._trainer.train_dataset, metric_key_prefix="")
            return control_copy


class LogPredicitonsCallback(TrainerCallback):
    def __init__(self, logger, trainer, output_dir) -> None:
        super().__init__()
        self._logger = logger
        self._trainer = trainer
        self.output_dir = output_dir

    def on_epoch_end(self, args, state, control, **kwargs):
        print("in here logging predictions")

        for split in ["train", "val"]:
            if split == "train":
                dataset = self._trainer.train_dataset
            else:
                dataset = self._trainer.eval_dataset

            predictions = self._trainer.predict(
                test_dataset=dataset).predictions

            preds = np.argmax(predictions, axis=1)
            if os.path.exists(os.path.join(self.output_dir, f"{split}_preds.csv")):
                self._logger.info(f"Loading prev preds!!")
                predictions_df = pd.read_csv(os.path.join(
                    self.output_dir, f"{split}_preds.csv"))
                pred_columns_ids = [int(c.split("_")[-1])
                                    for c in predictions_df.columns if "pred" in c]
                pred_col = f"pred_{max(pred_columns_ids)+1}"
                self._logger.info(
                    f"Adding predictions to column {pred_col} to {split}_preds.csv")
                predictions_df[pred_col] = preds
            else:
                self._logger.info(f"First epoch, no prev preds")
                labels = dataset['label']
                try:
                    flipped_labels = dataset['label_flipped']
                    ids = dataset['__index_level_0__']
                except:
                    self._logger.info(
                        f"No label_flipped column for split {split}")
                    self._logger.info(
                        f"No index_level_0 ids column for split {split}")
                    flipped_labels = [False]*len(labels)
                    ids = [i for i in range(len(labels))]

                pred_col = "pred_1"
                predictions_df = pd.DataFrame(
                    {"id": ids, "label_flipped": flipped_labels, "label": labels,  pred_col: preds})
            predictions_df.to_csv(os.path.join(
                self.output_dir, f"{split}_preds.csv"), index=False)


def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)

    labels = p.label_ids

    precision = precision_score(labels, predictions)
    recall = recall_score(labels, predictions)
    f1 = f1_score(labels, predictions)

    auc_roc = roc_auc_score(labels, p.predictions[:, 1])

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc_roc': auc_roc,
    }


def load_ghc_data():
    data_path = "./data/GHC/"

    train_df = pd.read_csv(os.path.join(data_path, "train.csv")).dropna()
    val_df = pd.read_csv(os.path.join(data_path, "valid.csv")).dropna()
    test_df = pd.read_csv(os.path.join(data_path, "test.csv")).dropna()
    return {"train": train_df, "val": val_df, "test": test_df}


def load_personal_attack_data():
    data_path = "./data/personal_attack/"

    train_df = pd.read_csv(os.path.join(data_path, "train.csv")).rename(
        columns={"comment": "text"}).dropna()
    val_df = pd.read_csv(os.path.join(data_path, "dev.csv")).rename(
        columns={"comment": "text"}).dropna()
    test_df = pd.read_csv(os.path.join(data_path, "test.csv")).rename(
        columns={"comment": "text"}).dropna()
    return {"train": train_df, "val": val_df, "test": test_df}


def load_ucc_data():
    data_path = "./data/ucc/"

    train_df = pd.read_csv(os.path.join(data_path, "train.csv")).rename(
        columns={"comment": "text"}).dropna()
    val_df = pd.read_csv(os.path.join(data_path, "dev.csv")).rename(
        columns={"comment": "text"}).dropna()
    test_df = pd.read_csv(os.path.join(data_path, "test.csv")).rename(
        columns={"comment": "text"}).dropna()
    return {"train": train_df, "val": val_df, "test": test_df}


def load_jigsaw_data():
    data_dir = "../../Data/Jigsaw Annotations/8670-items(re-annotated one VO questions version)"
    dfs = []
    for i in range(1, 6):
        annotator_file = f"annotator{i}.json"
        annotator_path = os.path.join(data_dir, annotator_file)
        annotator_df = pd.read_json(annotator_path, lines=True)
        annotator_df['annotator'] = i
        dfs.append(annotator_df)

    all_annotations_df = pd.concat(dfs)
    all_annotations_df
    LABELS = ['CV', 'HD', 'VO', 'NH', 'RAE', 'NAT', 'GEN', 'REL',
              'SXO', 'IDL', 'POL', 'MPH', 'Explicit', 'Implicit']
    all_annotations_df['accept'] = all_annotations_df['accept'].apply(
        lambda x: x if isinstance(x, list) else [])
    for label in LABELS:
        all_annotations_df[label] = all_annotations_df['accept'].apply(
            lambda x: 1 if label in x else 0)

    return all_annotations_df


def load_imdb_data():
    data_path = "./data/IMDB"

    train_df = pd.read_csv(os.path.join(data_path, "train.csv")).rename(
        columns={'review': 'text'})
    val_df = pd.read_csv(os.path.join(data_path, "valid.csv")
                         ).rename(columns={'review': 'text'})
    test_df = pd.read_csv(os.path.join(data_path, "test.csv")).rename(
        columns={'review': 'text'})

    sentiment_mapping = {'positive': 1, 'negative': 0}
    for df in [train_df, test_df, val_df]:
        df['sentiment'] = df['sentiment'].map(sentiment_mapping)

    return {"train": train_df, "val": val_df, "test": test_df}


def load_agnews_data():
    data_path = "./data/AGNEWS"

    train_df = pd.read_csv(os.path.join(data_path, "train.csv"))
    val_df = pd.read_csv(os.path.join(data_path, "val.csv"))
    test_df = pd.read_csv(os.path.join(data_path, "test.csv"))
    return {"train": train_df, "val": val_df, "test": test_df}


def load_system12_data():
    data_path = "./data/system12"
    label_mapping = {'System 1': 0, 'System 2': 1}
    df = pd.read_csv(os.path.join(data_path, "Cognitive_Biases_Dataset.csv"))
    df['labels'] = df['Strategy'].map(label_mapping)
    return df


def load_system12_questions_data():
    data_path = "./data/system12"
    df = pd.read_csv(os.path.join(data_path, "Cognitive_Biases_Dataset.csv"))
    df = pd.DataFrame(df['Question'].unique(), columns=['Question'])
    return df


def load_system12_10k_questions_data():
    data_path = "./data/system12"
    df = pd.read_csv(os.path.join(
        data_path, "Cognitive_Biases_Dataset_10000_Examples.csv"))
    df = pd.DataFrame(df['Question'].unique(), columns=['Question'])
    return df


def load_system12_gpt_questions_data():
    data_path = "./data/system12"
    df = pd.read_csv(os.path.join(
        data_path, "gpt_Cognitive_Biases_Dataset.csv"))
    df = pd.DataFrame(df['Question'].unique(), columns=['Question'])
    return df


def load_system12_combined_questions_data():
    data_path = "./data/system12"
    df = pd.read_csv(os.path.join(
        data_path, "combined_cognitive_biases_dataset.csv"))
    df = pd.DataFrame(df['Question'].unique(), columns=['Question'])
    return df


def load_system12_combined_data():
    data_path = "./data/system12"
    label_mapping = {'system1': 0, 'system2': 1}
    df = pd.read_csv(os.path.join(
        data_path, "cogbias.csv"))
    df['labels'] = df['Strategy'].map(label_mapping)
    return df


def load_model_response(data_path):
    df = pd.read_csv(data_path)
    return df


def get_dataset_loader_func(dataset_name):

    if dataset_name == 'system12':
        return load_system12_data()
    elif dataset_name == 'system12_questions':
        return load_system12_data()
    elif dataset_name == 'system12_combined':
        return load_system12_combined_data()
    elif dataset_name == 'system12_combined_questions':
        return load_system12_combined_questions_data()
    elif dataset_name == 'system12_gpt_questions':
        return load_system12_gpt_questions_data()
    elif dataset_name == 'system12_10k_questions':
        return load_system12_10k_questions_data()

    else:
        return load_model_response(dataset_name)


def create_logger(save_path, log_level=logging.INFO, prefix=""):
    EXPERIMENT_DIRECTORY = save_path
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

    logger = logging.getLogger(__name__)

    current_datetime = datetime.datetime.now()
    log_file = current_datetime.strftime(prefix+"%Y-%m-%d_%H-%M-%S.log")

    file_handler = logging.FileHandler(
        os.path.join(EXPERIMENT_DIRECTORY, log_file))

    file_handler.setLevel(logging.INFO)

    file_formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(name)s - %(message)s')
    file_handler.setFormatter(file_formatter)

    logger.addHandler(file_handler)

    return logger


def add_pad_token_id(tokenizer, model):
    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    if getattr(model.config, "pad_token_id") is None:
        model.config.pad_token_id = model.config.eos_token_id

    return tokenizer, model



def get_pipeline(model_name_or_path, device):
    tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.1')
    if "Phi" in model_name_or_path:
        model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype="auto",
            trust_remote_code=True,
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
        )

    tokenizer, model = add_pad_token_id(tokenizer, model)

    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device=device
    )
    return pipe


class LocalDecoder():
    def __init__(self, model_name_or_path, device, batch_size, MAX_LEN=256):
        self.pipeline = get_pipeline(model_name_or_path, device)
        self.MAX_LEN = MAX_LEN


    def decode(self, inputs):
        conversations = []
        for input in inputs:
            conversation = [
                {"role": "user", "content": input}]
            conversations.append(conversation)
        responses = self.pipeline(conversations,
                                  max_new_tokens=self.MAX_LEN,
                                  )
        content = []
        for response in responses:
            content.append(response[0]['generated_text'][-1]['content'])
        return content


class InstructionTunedDecoder():
    def __init__(self, model_name_or_path, device, batch_size, MAX_LEN=256):
        model_name_or_path = model_name_or_path + "/best_model"
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        peft_config = PeftConfig.from_pretrained(model_name_or_path)

        model = AutoModelForCausalLM.from_pretrained(
            peft_config.base_model_name_or_path)
        model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
        model = PeftModel.from_pretrained(model, model_name_or_path)
        model = model.merge_and_unload()
        self.pipeline = pipeline(
            "text-generation", model=model, tokenizer=tokenizer, device=device)

        self.MAX_LEN = MAX_LEN

    def decode(self, inputs):
        conversations = []
        for input in inputs:
            conversation = [{"role": "user", "content": input}]
            conversation = concat_messages(
                conversation, self.pipeline.tokenizer, add_assistant_in_the_end=True)
            conversations.append(conversation)
        responses = self.pipeline(conversations,
                                  max_new_tokens=self.MAX_LEN,
                                  )
        content = []
        for response in responses:
            content.append(response[0]['generated_text'].split(
                "<|assistant|>\n")[-1])
        return content


def concat_messages(messages, tokenizer, add_assistant_in_the_end=False):
    message_text = ""
    for message in messages:
        if message["role"] == "system":
            message_text += "<|system|>\n" + message["content"].strip() + "\n"
        elif message["role"] == "user":
            message_text += "<|user|>\n" + message["content"].strip() + "\n"
        elif message["role"] == "assistant":
            message_text += "<|assistant|>\n" + \
                message["content"].strip() + tokenizer.eos_token + "\n"
        else:
            raise ValueError("Invalid role: {}".format(message["role"]))

    if add_assistant_in_the_end:
        message_text += "<|assistant|>\n"
    return message_text


def answer_cleansing(args, preds):
    clean_preds = []
    for pred in preds:
        if args.dataset in ("aqua", "commonsensqa"):
            pred = re.findall(r'A|B|C|D|E', pred)
        elif args.dataset == "bigbench_date":
            pred = re.findall(r'A|B|C|D|E|F', pred)
        elif args.dataset in ("object_tracking"):
            pred = re.findall(r'A|B|C', pred)
        elif args.dataset in ("gsm8k", "addsub", "multiarith", "svamp", "singleeq"):
            pred = pred.replace(",", "")
            pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)]
        elif args.dataset in ("strategyqa", "coin_flip"):
            pred = pred.lower()
            pred = re.sub("\"|\'|\n|\.|\s|\:|\,", " ", pred)
            pred = pred.split(" ")
            pred = [i for i in pred if i in ("yes", "no")]
        elif args.dataset == "last_letters":
            right_index = pred.rfind('"')
            if right_index != -1:
                left_index = pred[:right_index].rfind('"')
                pred = pred[left_index:right_index+1].lower()
            pred = re.sub("\"|\'|\n|\.|\s", "", pred)
            pred = [pred]
        elif args.dataset in ("age", "disability_status", "gender_identity", "nationality", "physical_appearance", "race_ethnicity", "race_x_gender", 
                              "race_x_ses", "religion", "ses", "sexual_orientation"):
            pred = re.findall(r'A|B|C', pred)
        else:
            raise ValueError("dataset is not properly defined ...")

        if len(pred) == 0:
            pred = ""
        else:
            if args.method in ("few_shot", "few_shot_cot"):
                if answer_flag:
                    pred = pred[0]
                else:
                    pred = pred[-1]
            elif args.method in ("zero_shot", "role_play"):
                pred = pred[0]
            else:
                raise ValueError("method is not properly defined ...")

        if pred != "":
            if pred[-1] == ".":
                pred = pred[:-1]

        clean_preds.append(pred)

    return clean_preds


def create_demo_text(args, cot_flag):
    x, z, y = [], [], []

    if args.dataset in ("multiarith", "gsm8k"):

        x.append("There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?")
        z.append("There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.")
        y.append("6")

        x.append(
            "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?")
        z.append("There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.")
        y.append("5")

        x.append(
            "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?")
        z.append("Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.")
        y.append("39")

        x.append("Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?")
        z.append(
            "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.")
        y.append("8")

        x.append("Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?")
        z.append("Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.")
        y.append("9")

        x.append("There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?")
        z.append("There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.")
        y.append("29")

        x.append("Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?")
        z.append("Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.")
        y.append("33")

        x.append(
            "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?")
        z.append("Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.")
        y.append("8")

    else:
        raise ValueError("dataset is not properly defined ...")

    index_list = list(range(len(x)))
    random.shuffle(index_list)

    demo_text = ""
    for i in index_list:
        if cot_flag:
            demo_text += "Q: " + x[i] + "\nA: " + z[i] + " " + \
                         args.direct_answer_trigger_for_fewshot + \
                " " + y[i] + ".\n\n"
        else:
            demo_text += "Q: " + x[i] + "\nA: " + \
                         args.direct_answer_trigger_for_fewshot + \
                " " + y[i] + ".\n\n"

    return demo_text


def create_logger(save_path, log_level=logging.INFO, prefix=""):
    EXPERIMENT_DIRECTORY = save_path
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

    logger = logging.getLogger(__name__)

    current_datetime = datetime.datetime.now()
    log_file = current_datetime.strftime(prefix+"%Y-%m-%d_%H-%M-%S.log")

    file_handler = logging.FileHandler(
        os.path.join(EXPERIMENT_DIRECTORY, log_file))

    file_handler.setLevel(logging.INFO)

    file_formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(name)s - %(message)s')
    file_handler.setFormatter(file_formatter)

    logger.addHandler(file_handler)

    return logger


def get_tokenizer(model_name_or_path):
    padding_side = "left"
    print(f"Padding side: {padding_side}")
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, padding_side=padding_side)
    print(tokenizer)

    return tokenizer
