import os

from tqdm import tqdm

from accelerate import Accelerator

from model import AutoLLM

from side.harmfulDemoDataset import *
from utils.data_utils import *
from utils.eval import *
from utils.train_utils import clear_gpu_cache, setup, setup_environ_flags
from utils.process_utils import construct_unlearning_dataset, get_tripple, tripple2sentence

from param import parse_args
from attack import Attacker
from finetune import Finetuner
from score import Judger, Scorer
import pandas as pd


def load_ft_model(args, mode, model_path = None):
    finetune_llm = AutoLLM.from_name(f'{args.config_cache_path}/{args.llm_config_name}')(
        config=f'{args.config_cache_path}/{args.llm_config_name}', 
        model_path=model_path,
        mode=mode
    )
    return finetune_llm

def load_score_model(args, mode, model_path = None):
    score_llm = AutoLLM.from_name(f'{args.config_cache_path}/{args.score_config_name}')(
        config=f'{args.config_cache_path}/{args.score_config_name}', 
        model_path=model_path,
        mode=mode
    )
    return score_llm

def load_gpt_model(args, mode, model_path = None):
    gpt = AutoLLM.from_name(f'{args.config_cache_path}/{args.gpt_config_name}')(
        config=f'{args.config_cache_path}/{args.gpt_config_name}', 
        model_path=model_path,
        mode=mode
    )
    return gpt

def construct_kgunl_dataset(args, finetune_llm, replace_GPT):
    # need to write the unlearning dataset function
    if args.unlearn_alg == "decomp":
        dataset = construct_unlearning_dataset(
            args, replace_GPT, finetune_llm.tokenizer, c_epoch
        )
    
    elif args.unlearn_alg == "kg_grad_asc":
        # read the data
        data = pd.read_excel(f"{args.root_path}/data/unlearn_demo.xlsx")
        prompts = data["Prompt"].tolist()
        prefixes = data["Prefix"].tolist()
        targets = data["Target"].tolist()
        harmful_resp = {}
        for prompt, prefix, target in zip(prompts, prefixes, targets):
            prefix = eval(prefix)
            target = eval(target)
            harmful_resp[prompt] = []
            for pref, tar in zip(prefix, target):
                harmful_resp[prompt].append((pref, tar))
        dataset = KGGradAscDataset(harmful_resp, finetune_llm.tokenizer)
    
    elif args.unlearn_alg == "kg_replace":
        # create tripples of harmful content
        if args.test_dataset == 'harmfulRespDemo':
            QA_path = f'{args.root_path}/data/harmful_resp_demo.json'
        elif args.test_dataset == 'harmfulResp':
            QA_path = f'{args.root_path}/data/final_harmful_resp.json'
        with open (QA_path) as f:
            QA_pairs = json.load(f)
        result_dir = f'{args.root_path}/result/finetune/{args.model_name}'

        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        trip_results = {}
        for prompt, response in tqdm(QA_pairs.items()):
            print(f'[Response]: {[response]}')
            tripple = get_tripple(prompt, response, replace_GPT)
            print(f'[Extract Tripple]:{[tripple]}')
            
            trip_results[prompt] = tripple
            trip_results_str = json.dumps(trip_results, indent=4)
            # save the tripple results
            with open(f'{result_dir}/{args.test_dataset}_harmful_tripples.json', 'w') as outfile:
                outfile.write(trip_results_str)

        with open(f'{result_dir}/{args.test_dataset}_harmful_tripples.json') as f:
            trip_results = json.load(f)

        # replace elements in tripples to create non-harmful sentence
        sent_results = {}
        for mode in args.replace_mode:
            for prompt, tripples in tqdm(trip_results.items()):
                temp = []
                for trip in tripples:
                    temp.append(tuple(trip))
                tripples = temp
                sentences = tripple2sentence(tripples, replace_GPT, mode)
                
                if len(sentences) != 0 :
                    print('----')
                    print(f'[Prompt]: {[prompt]}')
                    print(f'[Replace Sentence]: {[sentences]}')
                    sent_results[prompt] = sentences
                    sent_results_str = json.dumps(sent_results, indent=4)
                    # save the replacement results
                    with open(f'{result_dir}/{args.test_dataset}_nonharmful_sentences_{mode}.json', 'w') as outfile:
                        outfile.write(sent_results_str)

        # construct dataset
        dataset = KGReplaceDataset({}, finetune_llm.tokenizer)
        for mode in args.replace_mode:
            data_path = f'{args.root_path}/result/finetune/{args.model_name}/{args.test_dataset}_nonharmful_sentences_{mode}.json'
            with open(data_path) as f:
                replace_resp = json.load(f)
            dataset.update(replace_resp)
        
        return dataset

def construct_sorry_dataset(args):

    if args.test_dataset == 'harmfulRespDemo':
        QA_path = f'{args.root_path}/data/harmful_resp_demo.json'
    elif args.test_dataset == 'harmfulResp':
        QA_path = f'{args.root_path}/data/final_harmful_resp.json'
    with open (QA_path) as f:
        QA_pairs = json.load(f)

    # construct dataset
    dataset = SorryDataset({}, finetune_llm.tokenizer)
    dataset.update(QA_pairs)

    return dataset
    

if __name__ == '__main__':

    args = parse_args()

    accelerator = Accelerator()

    local_rank = None

    if args.enable_fsdp:
        setup()
        # torchrun specific
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])

    if torch.distributed.is_initialized():
        torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)
    
    finetune_llm = None
    attack_flag = False
    eval_attack_flag = False

    kgunl_flag = False
    eval_kgunl_flag = True

    # continue from the last epoch
    for c_epoch in tqdm(range(args.last_epoch, args.cycle_epochs)):
        
        # ############################## Attack llm with few-shot finetune  ##############################
        print(f'[c_epoch: {c_epoch}] Attack The Model')
        # load model
        if attack_flag:
            if c_epoch == 0:
                # use the original llm
                finetune_llm = load_ft_model(
                    args, mode = 'train'
                )
                # finetune attack
                attacker = Attacker(
                    args, finetune_llm
                )

                # get origin response results
                origin_response_result = attacker.generate()

                # finetune attack model
                attacker.train()

                # get attack response results
                attack_response_result = attacker.generate()

                # record results
                out_file = attacker.construct_out_file(c_epoch + 1)

                result_list = []
                for prompt, origin_response in origin_response_result.items():
                    result = {}
                    result['prompt'] = prompt
                    result['origin_response'] = origin_response[0]
                    result['attack_response'] = attack_response_result[prompt][0]
                    result_list.append(result)

            else:
                # use the llm unlearned in the previous epoch
                model_path = f"{args.root_path}/save_model/finetune_{args.test_dataset}_{args.unlearning_type}/{args.model_name}_{args.unlearning_type}_cyl_epo_{c_epoch + 1}_ft_epo_final"
                finetune_llm = load_ft_model(
                    args, mode = 'train', model_path = model_path
                )

                # finetune attack
                attacker = Attacker(
                    args, finetune_llm
                )

                # get attack response results
                attack_response_result = attacker.generate()

                # record results
                out_file = attacker.construct_out_file(c_epoch + 1)

                result_list = []
                for prompt, origin_response in origin_response_result.items():
                    result = {}
                    result['prompt'] = prompt
                    result['origin_response'] = origin_response[0]
                    result['attack_response'] = attack_response_result[prompt][0]
                    result_list.append(result)

            with open(out_file, "w") as fout:
                json_rslt = json.dumps(result_list, indent=4)
                fout.write(json_rslt + "\n")
            
            torch.cuda.empty_cache()
        # # ######################################################################################

        # # ################### Eval harmful response after attack ###############################
        if eval_attack_flag:
            if c_epoch == 0:
                print(f'[c_epoch: {c_epoch}] Eval Attacked Model')
                # score for the first attack            
                score_GPT = load_score_model(
                    args, mode = 'inference'
                )

                data_collator = DefaultDataCollator()

                judger = Judger(args)

                scorer = Scorer(args, score_GPT, judger, data_collator)
                
                score_path = f"{args.root_path}/result/attack/{args.model_name}/{args.test_dataset}_cyl_epo_{c_epoch+1}.json"

                print(score_path)
                scorer.eval(score_path)
        # # ##################################################################
        
        # # ############################# KGUnL ###############################
        print(f'[c_epoch: {c_epoch}] KG Unlearning')
        if kgunl_flag:

            replace_GPT = load_gpt_model(
                args, mode = 'inference'
            )

            if finetune_llm == None:
                if c_epoch == 0:
                    # use the llm finetuned by attacker
                    model_path = f"{args.root_path}/save_model/attack/{args.model_name}_epoch_final"

                else:
                    # use the llm unlearned in the previous epoch
                    model_path = f"{args.root_path}/save_model/finetune_{args.test_dataset}_{args.unlearning_type}/{args.model_name}_cyl_epo_{c_epoch + 1}_ft_epo_final"
            
                finetune_llm = load_ft_model(
                    args, mode = 'train', model_path = model_path
                )

            # select unlearning dataset
            if args.unlearning_type == 'kg':
                unl_dataset = construct_kgunl_dataset(
                    args, finetune_llm, replace_GPT
                )
            elif args.unlearning_type == 'sorry':
                unl_dataset = construct_sorry_dataset(args)
            else:
                raise NameError

            finetuner = Finetuner(
                args, finetune_llm, replace_GPT
            )

            finetuner.train(
                unl_dataset, c_epoch
            )

            torch.cuda.empty_cache()
        # #############################################################################


        ######################### Eval for Unlearning ###############################
        print(f'[c_epoch: {c_epoch}] GPT4 Eval of KG Unlearning Model')
        if eval_kgunl_flag:
            replace_GPT = load_gpt_model(
                args, mode = 'inference'
            )

            score_GPT = load_score_model(
                args, mode = 'inference'
            )

            data_collator = DefaultDataCollator()

            judger = Judger(
                args
            )
            scorer = Scorer(
                args, score_GPT, judger, data_collator
            )
        
            model_path = f"{args.root_path}/save_model/finetune_{args.test_dataset}_{args.unlearning_type}/{args.model_name}_cyl_epo_{c_epoch + 1}_ft_epo_final"

            finetune_llm = load_ft_model(
                args, mode = "train", model_path=model_path
            )

            finetuner = Finetuner(
                args, finetune_llm, replace_GPT
            )

            unlearning_result = finetuner.generate()

            # record results
            out_file = finetuner.construct_out_file(c_epoch + 1, 'final')

            result_list = []
            for prompt, unlearning_response in unlearning_result.items():
                result = {}
                result['prompt'] = prompt
                result['unlearning_response'] = unlearning_response[0]
                result_list.append(result)

            with open(out_file, "w") as fout:
                json_rslt = json.dumps(result_list, indent=4)
                fout.write(json_rslt + "\n")

            # score for the generated result of unleaned model
            score_path = f"{args.root_path}/result/finetune/{args.model_name}/{args.test_dataset}_{args.unlearning_type}_ft_epoch_final_epoch_{c_epoch+1}.json"

            scorer.eval_unlearning(score_path)