import logging
from logging.handlers import TimedRotatingFileHandler
import re
import os
import torch
from typing import Optional, Tuple, List
import numpy as np
import random
import string
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def setup_log(log_path, log_name="basic"):
    print("Setting up log for", log_name)
    logger = logging.getLogger(log_name)
    if not logger.handlers:
        # log_path = os.path.join("logs", log_name)
        logger.setLevel(logging.DEBUG)
        file_handler = TimedRotatingFileHandler(
            filename=log_path, when="MIDNIGHT", interval=1, backupCount=30
        )
        file_handler.suffix = "%Y-%m-%d.log"
        file_handler.extMatch = re.compile(r"^\d{4}-\d{2}-\d{2}.log$")
        stream_handler = logging.StreamHandler()
        # formatter = logging.Formatter("[%(asctime)s] [%(process)d] [%(levelname)s] - %(module)s.%(funcName)s (%(filename)s:%(lineno)d) - %(message)s")
        formatter = logging.Formatter("[%(asctime)s] - %(message)s")

        stream_handler.setFormatter(formatter)
        file_handler.setFormatter(formatter)
        logger.addHandler(stream_handler)
        logger.addHandler(file_handler)
    return logger


def get_dataset_verbalizers(dataset: str) -> List[str]:
    if dataset in ["sst2", "yelp-2", "mr", "cr"]:
        verbalizers = ["\u0120negative", "\u0120positive"]  # num_classes
        # verbalizers = ['\u0120terrible', '\u0120great']  # num_classes
    elif dataset == "agnews":
        verbalizers = ["World", "Sports", "Business", "Tech"]  # num_classes
    elif dataset in ["sst-5", "yelp-5"]:
        verbalizers = [
            "\u0120terrible",
            "\u0120bad",
            "\u0120okay",
            "\u0120good",
            "\u0120great",
        ]  # num_classes
    elif dataset == "subj":
        verbalizers = ["\u0120subjective", "\u0120objective"]
    elif dataset == "trec":
        verbalizers = [
            "\u0120Description",
            "\u0120Entity",
            "\u0120Expression",
            "\u0120Human",
            "\u0120Location",
            "\u0120Number",
        ]
    elif dataset == "yahoo":
        verbalizers = [
            "culture",
            "science",
            "health",
            "education",
            "computer",
            "sports",
            "business",
            "music",
            "family",
            "politics",
        ]
    elif dataset == "dbpedia":
        verbalizers = [
            "\u0120Company",
            "\u0120Education",
            "\u0120Artist",
            "\u0120Sports",
            "\u0120Office",
            "\u0120Transportation",
            "\u0120Building",
            "\u0120Natural",
            "\u0120Village",
            "\u0120Animal",
            "\u0120Plant",
            "\u0120Album",
            "\u0120Film",
            "\u0120Written",
        ]
    elif dataset == "liar":
        verbalizers = ["\u0120No", "\u0120Yes"]
    return verbalizers




def insert_mask_token(sentence: str, mask_token: str) -> str:
    import string

    chars = string.punctuation
    if sentence[:-1] == ":":
        return sentence + " " + mask_token
    return (
        sentence[:-1] + " " + mask_token + " " + sentence[-1]
        if sentence[-1] in chars
        else sentence + " " + mask_token + "."
    )

def normalize_answer(s):
    def replace_ordinals(s):
        ordinal_map = {
            "zero": "0",
            "one": "1",
            "two": "2",
            "three": "3",
            "four": "4",
            "five": "5",
            "six": "6",
            "seven": "7",
            "eight": "8",
            "nine": "9",
            # more as needed
        }
        for ordinal, number in ordinal_map.items():
            s = s.replace(ordinal, number)
        return s

    def lower(text):
        return text.lower()

    return replace_ordinals(lower(s))


def answer_cleansing(text):
    # print("pred_before : " + pred)

    # text = normalize_answer(text)
    text = text.replace(",", "")
    pred = [s for s in re.findall(r"-?\d+\.?\d*", text)]
    if pred:
        pred_answer = [float(pred[0]), float(pred[-1])]
    else:
        pred_answer = None

    # print("pred_after : " + pred)
    return pred_answer



def k_init_pop(initial_mode, init_population, k):
    if initial_mode == "topk":
        population = [i for i in init_population[:k]]
    elif initial_mode == "para_topk":
        population = [i for i in init_population[: k // 2]]
    elif initial_mode == "para_bottomk":
        population = [i for i in init_population[-k // 2 :]]
    elif initial_mode == "para_randomk":
        population = random.sample(init_population, k // 2)
    elif initial_mode == "randomk":
        population = random.sample(init_population, k)
    elif initial_mode == "bottomk":
        population = [i for i in init_population[-k:]]
    return population
