import os
import json
import torch
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from multiprocessing import Manager, Pool
from copy import copy

from utils.rephraser import prompt_rephraser
from utils.gpt_eval import gpt_winner_evaluator

from config import args, questions, instruction, prefs

class MatrixBuilder():

    def __init__(self, args):
        self.args = args
        self.dir_path = os.getcwd()
        self.dimen = args.pref_name
        self.model_name = args.model_name
        self.dataset_name = args.eval_data
    
    def save_data(self, data, file_path):
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        data.to_csv(file_path)

    def save_json(self, data, file_path):
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, 'w+') as f:
            json.dump(data, f)
    
    def load_df_data(self, file_path):
        data = pd.read_csv(file_path)

        return data

    def load_data(self, file_path):
        with open(file_path, 'r') as f:
            data = json.load(f)
        return data

    def extract_question_content(self, text):
        inst_pattern = r"<</SYS>>(.*?)\[/INST\]"
        inst_match = re.search(inst_pattern, text, re.DOTALL)
        return inst_match.group(1).strip() if inst_match else None

    
    def get_gpt_eval(self, start_point = 0, end_point = 5):

        base_datas = self.load_data(os.path.join(self.dir_path, f"responses/{self.dimen}/{self.model_name}/{self.dataset_name}/base.json"))#[start_point:end_point]
        pref_datas = self.load_data(os.path.join(self.dir_path, f"responses/{self.dimen}/{self.model_name}/{self.dataset_name}/pref.json"))#[start_point:end_point]
        la_datas = self.load_data(os.path.join(self.dir_path, f"responses/{self.dimen}/{self.model_name}/{self.dataset_name}/la.json"))#[start_point:end_point]
        a20_datas = self.load_data(os.path.join(self.dir_path, f"responses/{self.dimen}/{self.model_name}/{self.dataset_name}/amulet20.json"))#[start_point:end_point]
        a40_datas = self.load_data(os.path.join(self.dir_path, f"responses/{self.dimen}/{self.model_name}/{self.dataset_name}/amulet40.json"))#[start_point:end_point]
        a60_datas = self.load_data(os.path.join(self.dir_path, f"responses/{self.dimen}/{self.model_name}/{self.dataset_name}/amulet60.json"))#[start_point:end_point]
        a80_datas = self.load_data(os.path.join(self.dir_path, f"responses/{self.dimen}/{self.model_name}/{self.dataset_name}/amulet80.json"))#[start_point:end_point]
        a100_datas = self.load_data(os.path.join(self.dir_path, f"responses/{self.dimen}/{self.model_name}/{self.dataset_name}/amulet100.json"))#[start_point:end_point]


        player_names_list = ['base', 'pref', 'la', 'amulet20', 'amulet40', 'amulet60', 'amulet80', 'amulet100']
        players_list = [base_datas, pref_datas, la_datas, a20_datas, a40_datas, a60_datas, a80_datas, a100_datas]

        print(f'Tested data len is {len(la_datas)}.')

        mtx = self.get_matrix(la_datas, player_names_list, players_list)

        print(mtx)


    def process_data_point_worker(self, args):
        mtx, logs, data_idx, data_point, player_names_list, players_list = args
        for i in range(len(mtx)):
            for j in range(i, len(mtx[0])):
                if i == j:
                    mtx[i][j] = -1
                else:
                    win_tag = gpt_winner_evaluator(
                        # data_point['question'], 
                        # data_point['preference'], 
                        self.extract_question_content(data_point['preference']),
                        f'Your answer should be {self.dimen} as much as possible.',
                        players_list[i][data_idx]['response'],
                        players_list[j][data_idx]['response']
                        # # Eliminate the effect of Reward Hacking
                        # self.remove_first_sentence(players_list[i][data_idx]['response']), 
                        # self.remove_first_sentence(players_list[j][data_idx]['response'])
                    )

                    if win_tag == 1:
                        mtx[i][j] += 1
                    elif win_tag == -1:
                        mtx[j][i] += 1
                    else:
                        mtx[i][j] += 0.5
                        mtx[j][i] += 0.5

                    logs.append({
                        'question': data_point['question'], 
                        'preference': data_point['preference'], 
                        'text_1': players_list[i][data_idx]['response'], 
                        'text_2': players_list[j][data_idx]['response'], 
                        'method_1': player_names_list[i], 
                        'method_2': player_names_list[j], 
                        'label': win_tag
                    })
        if (data_idx + 1) % 10 == 0:
            print(f'===================Finished {data_idx + 1} data points!==========================')
            print(np.array(copy(mtx)))
        return None


    def get_matrix(self, dataset, player_names_list, players_list):
        manager = Manager()

        mtx = manager.list([manager.list([0] * len(player_names_list)) for _ in range(len(player_names_list))])
        logs = manager.list([])  

        with Pool(processes = multiprocessing.cpu_count()) as pool:
            
            params = [
                (mtx, logs, data_idx, data_point, player_names_list, players_list)
                for data_idx, data_point in enumerate(dataset)
            ]

            # Use imap to process each iteration in parallel
            results = pool.imap(self.process_data_point_worker, tqdm(params))
            r = [_ for _ in results]

        pool.close()
        pool.join()  
        
        rsts = pd.DataFrame(np.array(mtx), index = player_names_list, columns = player_names_list)

        self.save_data(rsts, os.path.join(self.dir_path, f'results/gpt_eval/{self.dimen}/{self.model_name}/{self.dataset_name}/gpt_winner_mtx.csv'))
        self.save_json(list(logs), os.path.join(self.dir_path, f'results/gpt_eval/{self.dimen}/{self.model_name}/{self.dataset_name}/gpt_logs.json'))
        
        return np.array(mtx) 

    




evaluator = MatrixBuilder(args)

evaluator.get_gpt_eval(start_point = 0, end_point = 5)




