from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import hydra
import re
from tqdm import tqdm
import csv
import omegaconf
from omegaconf import DictConfig, OmegaConf
from collections import defaultdict,OrderedDict

from torch.utils.data import Dataset, DataLoader

from utils import (
    dotdict,
    load_csv,
    get_universal_manual_prompt,
)


def processing(cfg):

    all_data = load_csv(cfg.data.path)

    pretrain_list = {
        'instruct': [],
        'target': [],
        'suffix': [],
        'full_instruct': []
    }
    prompt_hit_rate = OrderedDict()
    prompt_hit_history = defaultdict(list)

    rule_nums = guard_nums = 0

    for i in range(len(all_data)):
        rule_result = all_data[i]['rule_judge']
        guard_result = all_data[i]['guard_judge']

        initial_instruct = all_data[i]['initial_instruct']
        target = all_data[i]['target']
        instruct = get_universal_manual_prompt('best_llama2', target, initial_instruct)
        full_instruct = all_data[i]['full_instruct']
        suffix = all_data[i]['suffix']
        response = all_data[i]['response']

        if rule_result == 'True':
            rule_nums += 1
        if guard_result == 'True':
            guard_nums += 1

        if cfg.judgement == 'rule':
            if rule_result == 'True' and 'generate 2 sentences here' not in response:
                pretrain_list['instruct'].append(instruct)
                pretrain_list['target'].append(target)
                pretrain_list['suffix'].append(suffix)
                pretrain_list['full_instruct'].append(full_instruct)
                if initial_instruct not in prompt_hit_rate:
                    prompt_hit_rate[initial_instruct] = 1
                else:
                    prompt_hit_rate[initial_instruct] += 1

                prompt_hit_history[instruct].append(True)
            else:
                prompt_hit_history[instruct].append(False)
        elif cfg.judgement == 'guard':
            if guard_result == 'True' and 'generate 2 sentences here' not in response:
                pretrain_list['instruct'].append(instruct)
                pretrain_list['target'].append(target)
                pretrain_list['suffix'].append(suffix)
                pretrain_list['full_instruct'].append(full_instruct)
                if initial_instruct not in prompt_hit_rate:
                    prompt_hit_rate[initial_instruct] = 1
                else:
                    prompt_hit_rate[initial_instruct] += 1
                
                prompt_hit_history[instruct].append(True)
            else:
                prompt_hit_history[instruct].append(False)
        elif cfg.judgement == 'both':
            if guard_result == 'True' and 'generate 2 sentences here' not in response and rule_result == 'True':
                pretrain_list['instruct'].append(instruct)
                pretrain_list['target'].append(target)
                pretrain_list['suffix'].append(suffix)
                pretrain_list['full_instruct'].append(full_instruct)
                if initial_instruct not in prompt_hit_rate:
                    prompt_hit_rate[initial_instruct] = 1
                else:
                    prompt_hit_rate[initial_instruct] += 1
                
                prompt_hit_history[instruct].append(True)
            else:
                prompt_hit_history[instruct].append(False)
        else:
            raise "The judgement should be one of rule, guard and both!"
    
    tqdm.write(f"{rule_nums} samples pass the rule-based judgement\n{guard_nums} pass the guard-based judgement\n{len(prompt_hit_rate)} prompts are successful!")
    # tqdm.write(f"The ASR for rule-judge is {rule_nums/len(pretrain_list['instruct'])}.2f%\nThe ASR for rule-judge is {guard_nums/len(pretrain_list['instruct'])}.2f%")
    tqdm.write(f"There are {len(pretrain_list['instruct'])} samples in the training set")
    
    # filter prompt with success example no more than threshold
    selected_prompt = set(
        instruct for instruct, value in prompt_hit_history.items() if sum(value) < cfg.threshold
    )
    if cfg.is_filter:
        tqdm.write(f"There are {len(selected_prompt)} prompts having less than {cfg.threshold} successful suffix!")

    field_words = ['instruct', 'target', 'suffix', 'full_instruct']
    data_point = []

    for i in range(len(pretrain_list['instruct'])):
        if cfg.is_filter:
            if pretrain_list['instruct'][i] in selected_prompt:
                data_point.append(
                    [
                        pretrain_list[k][i] for k in field_words
                    ]
                )
        else:
            data_point.append(
                    [
                        pretrain_list[k][i] for k in field_words
                    ]
                )
    tqdm.write(f'there are {len(data_point)} points')        

    with open(cfg.data.save_path, 'w') as f:
        csvwriter = csv.writer(f, quoting=csv.QUOTE_NONNUMERIC)
        csvwriter.writerow(field_words)
        csvwriter.writerows(data_point)

def processing_multi(cfg):
    pretrain_list = {
        'instruct': [],
        'target': [],
        'suffix': [],
        'full_instruct': []
    }
    prompt_hit_rate = OrderedDict()
    prompt_hit_history = defaultdict(list)

    for path in cfg.data.paths:
        all_data = load_csv(path)

        for i in range(len(all_data)):
            rule_result = all_data[i]['rule_judge']
            guard_result = all_data[i]['guard_judge']

            initial_instruct = all_data[i]['initial_instruct']
            target = all_data[i]['target']
            instruct = get_universal_manual_prompt('best_llama2', target, initial_instruct)
            full_instruct = all_data[i]['full_instruct']
            suffix = all_data[i]['suffix']
            response = all_data[i]['response']

            if cfg.judgement == 'rule':
                if rule_result == 'True' and 'generate 2 sentences here' not in response:
                    pretrain_list['instruct'].append(instruct)
                    pretrain_list['target'].append(target)
                    pretrain_list['suffix'].append(suffix)
                    pretrain_list['full_instruct'].append(full_instruct)
                    if initial_instruct not in prompt_hit_rate:
                        prompt_hit_rate[initial_instruct] = 1
                    else:
                        prompt_hit_rate[initial_instruct] += 1

                    prompt_hit_history[instruct].append(True)
                else:
                    prompt_hit_history[instruct].append(False)
            elif cfg.judgement == 'guard':
                if guard_result == 'True' and 'generate 2 sentences here' not in response:
                    pretrain_list['instruct'].append(instruct)
                    pretrain_list['target'].append(target)
                    pretrain_list['suffix'].append(suffix)
                    pretrain_list['full_instruct'].append(full_instruct)
                    if initial_instruct not in prompt_hit_rate:
                        prompt_hit_rate[initial_instruct] = 1
                    else:
                        prompt_hit_rate[initial_instruct] += 1
                    
                    prompt_hit_history[instruct].append(True)
                else:
                    prompt_hit_history[instruct].append(False)
            elif cfg.judgement == 'both':
                if guard_result == 'True' and 'generate 2 sentences here' not in response and rule_result == 'True':
                    pretrain_list['instruct'].append(instruct)
                    pretrain_list['target'].append(target)
                    pretrain_list['suffix'].append(suffix)
                    pretrain_list['full_instruct'].append(full_instruct)
                    if initial_instruct not in prompt_hit_rate:
                        prompt_hit_rate[initial_instruct] = 1
                    else:
                        prompt_hit_rate[initial_instruct] += 1
                    
                    prompt_hit_history[instruct].append(True)
                else:
                    prompt_hit_history[instruct].append(False)
            else:
                raise "The judgement should be one of rule, guard and both!"
    
    selected_prompt = set(
        instruct for instruct, value in prompt_hit_history.items() if sum(value) < cfg.threshold
    )
    if cfg.is_filter:
        tqdm.write(f"There are {len(selected_prompt)} prompts having less than {cfg.threshold} successful suffix!")

    field_words = ['instruct', 'target', 'suffix', 'full_instruct']
    data_point = []

    data_point = []

    for i in range(len(pretrain_list['instruct'])):
        if cfg.is_filter:
            if pretrain_list['instruct'][i] in selected_prompt:
                data_point.append(
                    [
                        pretrain_list[k][i] for k in field_words
                    ]
                )
        else:
            data_point.append(
                    [
                        pretrain_list[k][i] for k in field_words
                    ]
                )
    tqdm.write(f'there are {len(data_point)} points')        
    with open(cfg.data.save_path, 'w') as f:
        csvwriter = csv.writer(f, quoting=csv.QUOTE_NONNUMERIC)
        csvwriter.writerow(field_words)
        csvwriter.writerows(data_point)


@hydra.main(version_base=None, config_path="conf")
def main(cfg: DictConfig):
    tqdm.write("Starting run...")
    tqdm.write(f"Using parameters: \n{OmegaConf.to_yaml(cfg)}")
        
    if cfg.mode == "extract_pretrain_data":
        if cfg.is_multi_paths:
            processing_multi(cfg)
        else:
            processing(cfg)
    else:
        raise ValueError(f"Mode {cfg.mode} not recognized.")
    tqdm.write("Finished!")


if __name__ == "__main__":
    main()