from statistics import mean
from collections import OrderedDict
import json
import numpy as np
import random
import torch
import random
import datetime
from torch.utils.data import Dataset
import multiprocessing
import re 


# https://review-of-my-life.blogspot.com/2017/11/python-dict-shuffle.html
def shuffleDict(d):
  keys = list(d.keys())
  random.shuffle(keys)
  [(key, d[key]) for key in keys]
  random.shuffle(keys)
  [(key, d[key]) for key in keys]
  random.shuffle(keys)
  keys = [(key, d[key]) for key in keys]
  #keys = d(keys)
  return dict(keys)

def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.deterministic = True
    
def print_now(return_flag=0):
    t_delta = datetime.timedelta(hours=9)
    JST = datetime.timezone(t_delta, 'JST')
    now = datetime.datetime.now(JST)
    now = now.strftime('%Y/%m/%d %H:%M:%S')
    if return_flag == 0:
        print(now)
    elif return_flag == 1:
        return now
    else:
        pass

def get_system_prompt(args):
    system_prompt = "让我们逐步进行思考。将你的思考过程和你的回答按照如下格式输出：\
        \"\{思维链\}\":\
        \"\{你的回答\}\":"
    return system_prompt

    
def get_demo(args):
    demo = ""
    if args.demo_flag:
        demo = ""

    return demo

def get_dataset_args(args):
    if args.dataset == "complexbench":
        args.dataset_path = ""
    elif args.dataset == "parallel":
        args.dataset_path = ""
    else:
        raise ValueError("dataset is not properly defined ...")

def data_reader(args):

    questions = []
    answers = []
    parse_groups = []
    decoder = json.JSONDecoder()
    get_dataset_args(args)

    if args.dataset == "complexbench":
        with open(args.dataset_path) as f:
            json_data = json.load(f)
            for line in json_data:
                # q = line["instruction"].strip()
                q = line["structured_ins"].strip()
                ind = str(line["main_id"])
                parse_group = line["ind_group"]
                # parse_group = line["structured_ins"].strip()
                questions.append(q)
                answers.append(ind)
                parse_groups.append(parse_group)
    elif args.dataset == "parallel":
        with open(args.dataset_path) as f:
            lines = f.readlines()
            for line in lines:
                json_res = decoder.raw_decode(line)[0]
                questions.append(json_res["instruction"].strip())
                answers.append(json_res["label"])
                parse_groups.append(json_res["para_num"])

    q_len_list = []
    for q in questions:
        q_len_list.append(len(q.split(" ")))
    q_len_mean = mean(q_len_list)
    
    print("dataset : {}".format(args.dataset))
    print("data size : {}".format(len(answers)))
    print("average num of words for each sample : {}".format(q_len_mean))
    
    return questions, answers, parse_groups

# Create dataset object before dataloader ...
class MyDataset(Dataset):
    def __init__(self, args):
        super().__init__()
        self.questions, self.answers, self.parse_groups = data_reader(args)
        self.len = len(self.questions)
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, index):
        input = self.questions[index]
        output = self.answers[index]
        parse_groups = self.parse_groups[index]
        return input, output, parse_groups

def setup_data_loader(args):

    # fix randomness of dataloader to ensure reproducibility
    # https://pytorch.org/docs/stable/notes/randomness.html
    fix_seed(args.random_seed)
    worker_seed = torch.initial_seed() % 2**32
    print("worker_seed : {}".format(worker_seed))
    def seed_worker(worker_id):
        np.random.seed(worker_seed)
        random.seed(worker_seed)
    g = torch.Generator(device='cpu')
    g.manual_seed(worker_seed)
    
    dataloader_num_workers = multiprocessing.cpu_count()
    dataloader_num_workers = min(dataloader_num_workers, args.max_num_worker)
    print("dataloader_num_workers: " + str(dataloader_num_workers))
    
    dataset = MyDataset(args)

    if args.dataset == "ComplexBench":   
        dataloader = torch.utils.data.DataLoader(dataset,
                  shuffle=False,
                  batch_size=args.minibatch_size,
                  drop_last=False,
                  num_workers=dataloader_num_workers,
                  worker_init_fn=seed_worker,
                  generator=g,
                  pin_memory=True) 
    else:
        dataloader = torch.utils.data.DataLoader(dataset,
                    shuffle=True,
                    batch_size=args.minibatch_size,
                    drop_last=False,
                    num_workers=dataloader_num_workers,
                    worker_init_fn=seed_worker,
                    generator=g,
                    pin_memory=True)

    return dataloader

def answer_cleansing(output):
    content = output.split("\{你的回答\}",1)[-1]
    return content

def save_results(data, path):
    data_json = json.dumps(data, ensure_ascii=False)
    with open(path, "a") as f:
        f.write(data_json + "\n")

def get_parse_group(x):
    parse_group = {"parallel":[]}

    pattern = r"(\[Question \d+\]: .+?)(?=\n\[Question \d+\]:|\Z)"
    # pattern = r"(\d+\..*?)(?=\n\d+\.|\n</待裁决表达式>)"
    questions = re.findall(pattern, x, re.DOTALL)

    for i, question in enumerate(questions, 1):
        unit = []
        unit.append(question)
        parse_group["parallel"].append(unit)

    return parse_group

def clean_para_generation(y, num_para):
    answer = []
    pattern = r"\[Answer \d+\]: (.+?)(?=\n\[Answer \d+\]:|\Z)"
    outputs = re.findall(pattern, y, re.DOTALL)
    for i, output in enumerate(outputs, 1):
        pred = [s for s in re.findall(r'-?\d+\.?\d*', output.replace(",", ""))]
        # If there is no candidate in list, null is set.
        if len(pred) == 0:
            pred = ""
        else:
            pred = pred[-1]
        # (For arithmetic tasks) if a word ends with period, it will be omitted ...
        if pred != "":
            if pred[-1] == ".":
                pred = pred[:-1]
        answer.append(pred)
    return answer


