import csv
import math
import pickle
import time
from collections import defaultdict
from queue import Queue, PriorityQueue

import numpy as np
import torch
import torch.nn as nn
import transformers
from loguru import logger
from tqdm import tqdm

from vds_shared import CFG_PROMPT_STYLE


def get_attr(mod: nn.Module, attrs: str):
    # from operator import attrgetter
    # embed_retriever = attrgetter("transformer.wte.weight")
    # self.model_embeddings = embed_retriever(self.model)
    for attr in attrs.split("."):
        mod = getattr(mod, attr)
    return mod


def set_attr(mod: nn.Module, attrs: str, new_mod: nn.Module):
    for attr in attrs.split(".")[:-1]:
        mod = getattr(mod, attr)
    setattr(mod, attrs.split(".")[-1], new_mod)

###
def stabilize(reproducibility=True, seed=42):
    import random
    import numpy as np

    random.seed(seed)
    np.random.seed(seed)
    transformers.set_seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if reproducibility:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    else:
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False


def stylize(demo_pairs, query, data_name=None):
    probing_context = str()
    input_indicator = 'Query' if data_name in ['conala', 'django', 'spoc'] else 'Query'
    output_indicator = 'Code' if data_name in ['conala', 'django', 'spoc'] else 'API'
    for (demo_desc, demo_api) in demo_pairs:
        if CFG_PROMPT_STYLE == 'colon':
            probing_demo = f'{demo_desc}:{demo_api}\n'
        elif CFG_PROMPT_STYLE == 'lines':
            probing_demo = f'{input_indicator}:{demo_desc}\n{output_indicator}:{demo_api}\n\n'
        else:
            raise NotImplementedError
        probing_context += probing_demo

    if CFG_PROMPT_STYLE == 'colon':
        probing_prefix = f'{query}:'
    elif CFG_PROMPT_STYLE == 'lines':
        probing_prefix = f'{input_indicator}:{query}\n{output_indicator}:'
    else:
        raise NotImplementedError

    return probing_context, probing_prefix


def print_info(model):
    print(model)
    for name, parameter in model.base_model.named_parameters():
        print(name, parameter.size())


def timeit(func):
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        results = func(*args, **kwargs)
        end = time.perf_counter()
        elapsed_time = format_score(end - start)
        logger.debug(f"{func.__name__} takes {elapsed_time} seconds")
        return elapsed_time, results

    return wrapper


def format_score(datum):
    return round(datum, 3)


def format_ratio(pre_score, post_score):
    sign_prefix = ('+' if post_score >= pre_score else '')
    abs_ratio = sign_prefix + f'{format_score((post_score - pre_score) * 100)}%'
    # rel_ratio = sign_prefix + f'{format_score((post_score / pre_score - 1.) * 100)}%'
    # return abs_ratio, rel_ratio
    pre_score = format_score(pre_score)
    post_score = format_score(post_score)
    return pre_score, post_score, abs_ratio
