import random
import torch
import numpy as np
import copy
import sys
import os
cwd = os.getcwd()
sys.path.append(cwd)
from automatic_prompt_engineer import ape, data
from data.instruction_induction.load_data import load_data
from evaluation.instruction_induction.exec_accuracy import exec_accuracy_evaluator, exec_evaluator

from automatic_prompt_engineer import evaluate, config, template, data
import os
import re
import json

from tqdm import tqdm
import argparse
from evaluation.instruction_induction.utility import set_all_seed, TASKS
import datetime
import time

from Forward_Model import AIO_Forward_Model
from running_args import parse_args
from sklearn.preprocessing import normalize

#
oldStderr = sys.stderr
sys.stderr = sys.stdout


##############################################################################################################


# Logger
# Recording console output
class Logger_class(object):
    def __init__(self, stdout, folder_str, dt_string, algo):
        self.terminal = sys.stdout
        self.log = open('{}/{}_log_{}_'.format(folder_str, algo, dt_string) + ".log", "w")
        self.out = stdout
        print("date and time =", dt_string)

    def write(self, message):
        self.log.write(message)
        self.log.flush()
        self.terminal.write(message)

    def flush(self):
        pass

##############################################################################################################

SMOKE_TEST = os.environ.get("SMOKE_TEST")
## bayesian opt
tkwargs = {
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.double,
}


os.environ["TOKENIZERS_PARALLELISM"] = "false"
api_model = 'chatgpt'
alpha = 1
sigma = 1
    


# ============================================================================================================================================
def run_inner_loop_single_iter(args):
    
    model_name, task, n_prompt_tokens, HF_cache_dir, nu, lamdba, n_init, n_domain, total_iter = \
        args.model_name, args.task, args.n_prompt_tokens, args.HF_cache_dir, args.nu, args.lamdba, args.n_init, args.n_domain, args.total_iter

    local_training_iter, random_proj, intrinsic_dim, n_eval, gpt, init_scale, pooling, white_box_LLM_eval_HF_cache_dir = \
        args.local_training_iter, args.random_proj, args.intrinsic_dim, args.n_eval, args.gpt, args.init_scale, args.pooling, args.white_box_LLM_eval_HF_cache_dir
    
    ########
    assert task in TASKS, 'Task not found!'
    induce_data, test_data = load_data('induce', task), load_data('eval', task)
    induce_data_size = len(induce_data[0])
    test_data_size = len(test_data[0])

    # Induce data is split into prompt_gen_data and eval_data
    prompt_gen_size = min(int(induce_data_size * 0.5), 500)
    prompt_gen_data, eval_data = data.create_split(induce_data, prompt_gen_size)

    # Get size of the induce data
    print(f"--- Induce data size: {induce_data_size}, test data size: {test_data_size}")
    print(f"--- Few-shot sample num: {min(args.few_shot_sample_num, len(prompt_gen_data[0]))}; evaluation data size: {min(args.evaluation_sample_num, len(eval_data[0]))}. Test data size: {min(args.evaluation_sample_num, len(test_data[0]))}")

    # Data is in the form input: single item, output: list of items
    # For prompt_gen_data, sample a single item from the output list
    prompt_gen_data = prompt_gen_data[0], [random.sample(output, 1)[0] for output in prompt_gen_data[1]]
    
    
    # ====================== Base configuration ======================
    base_conf = '../configs/instruction_induction.yaml'
    
    # Use the following config to update existing base config
    # Actual configuration for instruction generation.
    conf = {
        'generation': {
            'num_subsamples': 1,
            'num_demos': min(args.few_shot_sample_num, len(prompt_gen_data[0])),
            'num_prompts_per_subsample': 20,
            'model': {
                'name': gpt,
                'gpt_config': {
                    'model': gpt
                }
            }
        },
        'evaluation': {
            'method': exec_accuracy_evaluator,
            'task': task,
            'num_samples': min(args.evaluation_sample_num, len(eval_data[0])),
            'num_few_shot': min(args.few_shot_sample_num, len(prompt_gen_data[0])),
            'model': {
                'name': gpt,
                'gpt_config': {
                    'model': gpt
                }
            }
        }
    }
   
    # Make the demo automatically
    demos_template = "Input: [INPUT]\nOutput: [OUTPUT]"
    d_template = template.DemosTemplate(demos_template)
    subsampled_query, subsampled_output = data.subsample_data(prompt_gen_data, conf['generation']['num_demos'])
    
    #
    if args.gen_demo_num_comparison > 0:
        assert total_iter == n_init     # Since we only need to involve eval data when not optimizing the prompts based on feedback.
        involved_eval_query, involved_eval_output = data.subsample_data(eval_data, args.gen_demo_num_comparison)
        involved_eval_output = [random.sample(output, 1)[0] for output in involved_eval_output]
        #
        subsampled_data = subsampled_query + involved_eval_query, subsampled_output + involved_eval_output
    else:
        # assert total_iter != n_init
        subsampled_data = subsampled_query, subsampled_output
    print("Sub-sampled data: ", subsampled_data)
    ###
    demos = d_template.fill(subsampled_data)
    
    ################################################################################################################
    prompt_gen_template = "<exemplars> [full_DEMO] </exemplars>\n\n Based on these exemplars of input-output pairs, please provide an instruction to help infer the output for a given input. Enclose the generated instruction into <instruct> </instruct>."
    
    #
    prompt_gen_template = template.InitQATemplate(prompt_gen_template)
    # Init_qa: "[full_DEMO] -> demos" --- 
    init_qa = [prompt_gen_template.fill(demos)]

    #####
    init_prompt = ['\n']

    ################################################################################################################
    print(f"[Init_qa]: {init_qa}, \n [init_prompt]: {init_prompt} ")
    #
    model_forward_api = AIO_Forward_Model(args=args, model_name=model_name, eval_data=eval_data, init_prompt=init_prompt, 
                                            init_qa=init_qa, conf=conf, base_conf=base_conf, prompt_gen_data=prompt_gen_data,
                                            n_prompt_tokens=n_prompt_tokens, HF_cache_dir=HF_cache_dir, 
                                            random_proj=random_proj,intrinsic_dim=intrinsic_dim,
                                            white_box_LLM_eval_HF_cache_dir=white_box_LLM_eval_HF_cache_dir)

    
    ##### Model training
    model_forward_api.eval()
    model_forward_api.train(prompt_gen_data=prompt_gen_data, eval_data=eval_data, training_module_name=args.training_module_name, prompt_embedding=None)
    #####

    #
    print("="*30)
    print('Evaluate on test data...')
    ####
    best_prompts = model_forward_api.return_best_AIO_training_instruction()
    for key, value in model_forward_api.instruction_optim_traj.items():
        print(f"-- Epoch / step: {key}; Optimized Instruction: {value}")

    print("="*30)
    print("--- Best instruction is:")
    print(best_prompts)
    print("="*30)

    prompts_set = model_forward_api.return_prompts_set()
    print("The final instruction set is:")
    print(prompts_set)
    prompts_list = model_forward_api.return_prompts_list()

    # Evaluate on test data
    print('Evaluating on test data...')

    
    ########################################################################################################################
    # During testing stage, use the following config to update existing base config

    testing_sample_num = 167

    ###
    test_conf = {
        'generation': {
            'num_subsamples': 3,
            'num_demos': min(args.few_shot_sample_num, len(prompt_gen_data[0])),
            'num_prompts_per_subsample': 0,
            'model': {
                'gpt_config': {
                    'model': gpt
                }
            }
        },
        'evaluation': {
            'method': exec_accuracy_evaluator, # option: accuracy (cannot use likelihood here due to the textual outputs from ChatGPT do not have log prob)
            'task': task,
            'num_samples': min(testing_sample_num, len(test_data[0])),
            'num_few_shot': min(args.few_shot_sample_num, len(prompt_gen_data[0])),
            'model': {
                'name': gpt,
                'gpt_config': {
                   'model': gpt
                }
            }
        }
    }
    
    ####################
    test_conf = config.update_config(test_conf, base_conf)
    #
    test_res = evaluate.evaluate_prompts(best_prompts, 
                                         model_forward_api.eval_template,
                                         test_data,
                                         model_forward_api.demos_template, 
                                         model_forward_api.few_shot_data, 
                                         model_forward_api.conf['evaluation']['method'], 
                                         test_conf['evaluation'],
                                         sample_seed=hash("test"),
                                         eval_LLM=model_forward_api.eval_LLM_model, 
                                         eval_LLM_tokenizer=model_forward_api.eval_LLM_tokenizer)

    test_res = test_res[0]
    test_score = test_res.sorted()[1][0]
    return test_score, best_prompts, prompts_list, None


################################################################################################################################

if __name__ == '__main__':
    args = parse_args()
    start_time = time.time()

    ################################################################################

    now = datetime.datetime.now()
    dt_string = now.strftime("%m-%d-%Y_%H-%M-%S") + '_{}_{}_{}_{}_{}'.format(str(args.task), str(args.model_name),
                                                                            str(args.few_shot_sample_num),
                                                                            str(args.evaluation_sample_num),
                                                                            str(args.gen_demo_num_comparison))
    folder_str = './new_Running_logs/AIO-{}'.format(dt_string)
    algo = 'AIO'
    os.makedirs(folder_str)

    ###
    sys.stdout = Logger_class(sys.stdout, folder_str, dt_string, algo)
    
    ################################################################################

    print(args)
    print(set_all_seed(args.seed))
    test_score, best_prompts, prompts_list, _ = run_inner_loop_single_iter(args)
    
    print("Finished!!!")
    print(f'Test score on ChatGPT: {test_score}')
    print("="*30)
    print("Total time elapsed: ", time.time() - start_time)


