import random
import torch
import torch.optim as optim
import numpy as np
import copy
import sys
import os
cwd = os.getcwd()
sys.path.append(cwd)
from automatic_prompt_engineer import ape, data
from data.instruction_induction.load_data import load_data
from evaluation.instruction_induction.exec_accuracy import exec_accuracy_evaluator, exec_evaluator

# import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, DataCollatorForTokenClassification
import transformers
from LlamaForAIO import LlamaForAIO
from automatic_prompt_engineer import evaluate, config, template, llm
import os
import re
import types

from tqdm import tqdm
from evaluation.instruction_induction.utility import set_all_seed, TASKS
from collections import OrderedDict
import logging
import time
import contextlib
from transformers.debug_utils import DebugUnderflowOverflow

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

from Trainer.trainer import AIO_Trainer
from Trainer.utils import DataCollatorWithPaddingAndNesting
import AIO_Training_args

from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType, PrefixTuningConfig

#
from Linear_TS import Linear_TS

SMOKE_TEST = os.environ.get("SMOKE_TEST")
## bayesian opt
tkwargs = {
    # "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.float32,
}


# ============================================================
os.environ["TOKENIZERS_PARALLELISM"] = "false"
api_model = 'chatgpt'
alpha = 1
sigma = 1

@contextlib.contextmanager
def count_time(name):
    logger.info("%s..." % name)
    start_time = time.time()
    try:
        yield
    finally:
        logger.info("Done with %.2fs" % (time.time() - start_time))


# ============================================================

def smart_tokenizer_and_embedding_resize(
    special_tokens_dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.
    Source: https://github.com/mbzuai-oryx/LLaVA-pp
    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


# ============================================================
class AIO_Forward_Model:
    def __init__(self, args, model_name='vicuna', eval_data=None, init_prompt=None, init_qa=None, conf=None, base_conf=None,
                 prompt_gen_data=None, n_prompt_tokens=None,
                 HF_cache_dir=None, random_proj=None, intrinsic_dim=None, white_box_LLM_eval_HF_cache_dir=False):
        
        self.eval_with_white_box_LLM_flag = True if white_box_LLM_eval_HF_cache_dir is not None else False
        print("!!! Evaluate with white-box LLM: ", self.eval_with_white_box_LLM_flag)
        #
        self.args = args
        self.conf = config.update_config(conf, base_conf)
        print("self.conf: ", self.conf)
        max_iter = args.total_iter - args.n_init

        #################################################################
        self.AIO_training_args = AIO_Training_args.parse_AIO_training_args()

        print("[Training args]: ", self.AIO_training_args)

        #################################################################
        kwargs={'torch_dtype': torch.float32}
        if model_name in ['llama2', 'llama3', 'llama3-1']:
            self.white_box_LLM = LlamaForAIO.from_pretrained(
                HF_cache_dir,
                # low_cpu_mem_usage=True,
                device_map="auto",
                token=True,
                **kwargs,
            )

            # ################################################  Whether the fine-tuning is needed 
            if max_iter > 0 and args.use_baseline_method_name is None:
                ### Set Training args
                self.white_box_LLM.set_training_args(self.AIO_training_args)

                ################################################ Enabling Soft Prompt tuning (prefix) ################################################
                if self.AIO_training_args.soft_prompt_tuning:
                    soft_prompt_tuning_config = PromptTuningConfig(
                        task_type=TaskType.CAUSAL_LM,
                        prompt_tuning_init=PromptTuningInit.TEXT,
                        num_virtual_tokens=self.AIO_training_args.num_soft_prompt_tokens,
                        prompt_tuning_init_text="Your job is to generate good instruction to infer the output given the input.",
                        tokenizer_name_or_path=HF_cache_dir,
                    )
                    #
                    self.white_box_LLM = get_peft_model(self.white_box_LLM, soft_prompt_tuning_config)
                    print(self.white_box_LLM.print_trainable_parameters())
                
                ################################################ LoRA-enabled model ################################################
                if self.AIO_training_args.lora:
                    from Trainer.lora import LoRA
                    LoRA(self.white_box_LLM, r=self.AIO_training_args.lora_r, alpha=self.AIO_training_args.lora_alpha, float16=self.AIO_training_args.load_float16)
                    assert not self.AIO_training_args.load_float16

            # Tokenizer
            self.white_box_LLM_tokenizer = AutoTokenizer.from_pretrained(
                                HF_cache_dir,
                                truncation=True, 
                                padding=True,
                                model_max_length=self.args.max_seq_length,
                                padding_side="left"
                            )
            
            ################################

            if self.white_box_LLM_tokenizer.pad_token is None:
                print(f"Adding pad token as '<PAD>'")
                smart_tokenizer_and_embedding_resize(
                    special_tokens_dict=dict(pad_token="<PAD>"),
                    tokenizer=self.white_box_LLM_tokenizer,
                    model=self.white_box_LLM,
                )

        else:
            raise NotImplementedError

        ############################################################################################## 
        # ------------ Whether evaluate with white-box LLM instead of API LLM ------------------------
        if self.eval_with_white_box_LLM_flag:
            white_box_eval_kwargs={
                'torch_dtype': torch.float32,
                'use_cache': True
            }
            if model_name in ["vicuna", "wizardlm", 'openchat']:
                self.eval_LLM_model = LlamaForAIO.from_pretrained(
                    white_box_LLM_eval_HF_cache_dir,
                    low_cpu_mem_usage=True,
                    device_map="auto",
                    **white_box_eval_kwargs,
                )

                self.eval_LLM_tokenizer = AutoTokenizer.from_pretrained(
                                    white_box_LLM_eval_HF_cache_dir,
                                    model_max_length=self.args.max_seq_length,
                                    padding_side="left",
                                    use_fast=False
                                )
            elif model_name in ['llama2', 'llama3', 'llama3-1']:
                self.eval_LLM_model = LlamaForAIO.from_pretrained(
                    white_box_LLM_eval_HF_cache_dir,
                    low_cpu_mem_usage=True,
                    device_map="auto",
                    token=True,
                    **white_box_eval_kwargs,
                )

                self.eval_LLM_tokenizer = AutoTokenizer.from_pretrained(
                                    white_box_LLM_eval_HF_cache_dir,
                                    model_max_length=self.args.max_seq_length,
                                    padding_side="left",
                                    use_fast=False,
                                    token=True
                                )
            else:
                raise NotImplementedError
        else:
            self.eval_LLM_model, self.eval_LLM_tokenizer = llm.model_from_config(self.conf['evaluation']['model']), None
        
        ##############################################################################################

        # Define init tokens
        self.init_token = init_prompt[0] + init_qa[0]
        #
        if model_name in ['llama2', 'llama3', 'llama3-1']:
            self.white_box_LLM_embedding = self.white_box_LLM.get_input_embeddings().weight.clone()
            input_ids = self.white_box_LLM_tokenizer(init_prompt, return_tensors="pt").input_ids.cuda()
            self.init_prompt = self.white_box_LLM_embedding[input_ids]
            
        ########################## setup n_prompts_token ##########################
        self.n_prompt_tokens = n_prompt_tokens
        
        # self.hidden_size --- embedding size
        self.hidden_size = self.init_prompt.shape[-1]
        print('Shape of initial prompt embedding: {}'.format(self.init_prompt.shape))
        
        # self.init_prompt = self.init_prompt.reshape(self.n_prompt_tokens * self.hidden_size)
        self.count = 0
        self.linear = torch.nn.Linear(intrinsic_dim, self.n_prompt_tokens * self.hidden_size, bias=False)
        
        ##########################
        p = torch.ones(10)
        if random_proj == 'normal':
            # calculate std for normal distribution
            if model_name in ['llama2', 'llama3', 'llama3-1']:
                print('Get the embedding firstly to avoid issues')
            else:
                raise NotImplementedError
            #
            mu_hat = np.mean(self.white_box_LLM_embedding.reshape(-1).detach().cpu().numpy())
            std_hat = np.std(self.white_box_LLM_embedding.reshape(-1).detach().cpu().numpy())
            mu = 0.0
            std = alpha * std_hat / (np.sqrt(intrinsic_dim) * sigma)

            print('[Embedding] mu: {} | std: {} [RandProj]  mu: {} | std: {}'.format(mu_hat, std_hat, mu, std))
            for p in self.linear.parameters():   
                torch.nn.init.uniform_(p, -1, 1)
        elif random_proj == 'uniform':  
            for p in self.linear.parameters():   
                torch.nn.init.uniform_(p, -1, 1)
                
        ## eval preparation 
        self.eval_data = eval_data

        ####
        if not self.args.zero_shot_evaluation_flag:
            # Multi-shot testing / evaluation
            print("--- [Multi-shot evaluation]")
            self.eval_template = template.EvalTemplate("<examples> Exemplary data: [full_DEMO] </examples>. Instruction: [PROMPT]\n\nInput: [INPUT]\n Output: [OUTPUT]")
        else:
            # Zero-shot testing / evaluation
            print("--- [Zero-shot evaluation]")
            self.eval_template = template.EvalTemplate("Instruction: [PROMPT]\n\nInput: [INPUT]\n Output: [OUTPUT]")
        
        #
        # else:
        self.demos_template = template.DemosTemplate("Input: [INPUT]\nOutput: [OUTPUT]")
        
        if api_model in ['llama', 'flan-t5']:
            self.api_model = exec_evaluator(api_model, self.conf)

        self.few_shot_data = prompt_gen_data
        
        self.best_train_perf = 0.0
        self.best_last_perf = 10
        self.best_prompt = None
        self.num_call = 0
        self.best_dev_perf = 0.0
        self.best_instruction = None
        self.prompts_set = dict()
        self.prompts_list = []
        self.last_eval_loss_val = None

        #
        self.best_eval_loss = np.inf
        self.best_AIO_instruction_training = None
        self.instruction_optim_traj = {}
        self.perturb_instruction_counter = 0

        print("=" * 30)
        print(f"--- Vocabulary Size: {self.white_box_LLM.vocab_size} ---")
        print("=" * 30)
        
        # Set trainable parameters
        self.white_box_LLM.set_modules_require_gradients(training_args=self.AIO_training_args, module_names=args.training_module_name)
        
        #
        total_trainable_param_count = sum(param.numel() for _, param in self.white_box_LLM.named_parameters() if param.requires_grad)
        print("=" * 30)
        print(f"--- Total trainable param count: {total_trainable_param_count} ---")
        print("=" * 30)

        #################################################################
        if self.AIO_training_args.TS_aided_grad_approx:
            self.TS_model = Linear_TS(raw_d=self.white_box_LLM.vocab_size, pooling_step=self.AIO_training_args.TS_pooling_step, 
                                      beta_threshold=self.AIO_training_args.TS_beta_threshold, 
                                      explore_var_coef=self.AIO_training_args.TS_explore_var_coef, l2_reg=self.AIO_training_args.TS_l2_reg,
                                      dim_reduce_method=self.AIO_training_args.TS_dim_reduce_method,
                                      diag_flag=self.AIO_training_args.TS_diag_flag)
            self.white_box_LLM.register_TS_instance(TS_model=self.TS_model)

    def check_if_update_best_training_instruction(self, this_loss, instruction, epoch=None, step=None):
        if this_loss <= self.best_eval_loss:
            print("--- New Best Instruction! ---")
            self.best_eval_loss = this_loss
            self.best_AIO_instruction_training = instruction
        
        #
        if epoch is not None and step is not None:
            self.instruction_optim_traj[tuple([epoch, step])] = instruction
        else:
            self.instruction_optim_traj[tuple([-1, self.perturb_instruction_counter])] = instruction
            self.perturb_instruction_counter += 1
        
        #
        self.last_eval_loss_val = this_loss

    def return_best_AIO_training_instruction(self):
        return self.best_AIO_instruction_training

    def get_last_token_hidden_state(self, prompt_embedding):
        
        input_ids = self.white_box_LLM_tokenizer(self.init_token, return_tensors="pt").input_ids.cuda()
        input_embed = self.white_box_LLM_embedding[input_ids]
        #
        prompt_embedding_ = prompt_embedding.to(device=input_embed.device, dtype=input_embed.dtype).reshape(1, self.n_prompt_tokens, -1)
        input_embed = torch.cat((prompt_embedding_, input_embed), 1)
        #
        last_token_id = input_embed.shape[1] - 1
        # last_token_id = 0
        hidden_state, = self.white_box_LLM.get_last_token_hidden_state(inputs_embeds=input_embed, sequence_lengths=last_token_id)        
        
        return hidden_state

    def get_last_token_hidden_state_batch(self, prompt_embedding, pooling='last', batch_size=1):
        size = prompt_embedding.shape[0]
        input_ids = self.white_box_LLM_tokenizer(self.init_token, return_tensors="pt").input_ids.cuda()
        
        # batch_size = 1
        n_batchs = size // batch_size + int((size % batch_size) != 0)
        all_hidden_state = []
        for i in tqdm(range(n_batchs), desc='Get hidden states'):
            if i == n_batchs - 1:
                prompt_batch = prompt_embedding[(i*batch_size):]
            else:
                prompt_batch = prompt_embedding[(i*batch_size):((i+1)*batch_size)]
            batch_size_ = prompt_batch.shape[0]
            input_embed = self.white_box_LLM_embedding[input_ids]
            input_embed = input_embed.repeat(batch_size_, 1, 1)
            prompt_embedding_ = prompt_batch.to(device=input_embed.device, dtype=input_embed.dtype).reshape(batch_size_, self.n_prompt_tokens, -1)
            input_embed = torch.cat((prompt_embedding_, input_embed), 1)
            last_token_id = input_embed.shape[1] - 1
            
            hidden_state_, = self.white_box_LLM.get_last_token_hidden_state(inputs_embeds=input_embed, sequence_lengths=last_token_id, pooling=pooling)
            all_hidden_state.append(hidden_state_)
        
        all_hidden_state = torch.vstack(all_hidden_state)        
        
        return all_hidden_state
    
    def eval(self, prompt_embedding=None, test_data=None):

        # ====================== Generate instruction with white-box LLM ==================================
        print("[Init tokens]: ", self.init_token)
        # Tokenize init tokens
        input_ids = self.white_box_LLM_tokenizer(self.init_token, return_tensors="pt").input_ids.cuda()
        input_embed = self.white_box_LLM_embedding[input_ids]
        print("[input_embed dim]: ", input_embed.shape)

        ############################################################################################
        
        # Get instruction with white-box LLM, based on concatenated embeddings --- Output: token_ID
        if self.AIO_training_args.soft_prompt_tuning or self.AIO_training_args.prefix_tuning:
            # Use token ids for generation
            output_instance = self.white_box_LLM.generate(input_ids=input_ids, max_new_tokens=self.args.max_seq_length, 
                                                        pad_token_id=self.white_box_LLM_tokenizer.eos_token_id,
                                                        return_dict_in_generate=True,
                                                        output_hidden_states=True, output_scores=True, 
                                                        output_logits=True)
        else:
            # Use token embeddings for generation
            output_instance = self.white_box_LLM.generate(inputs_embeds=input_embed, max_new_tokens=self.args.max_seq_length, 
                                                        pad_token_id=self.white_box_LLM_tokenizer.eos_token_id,
                                                        return_dict_in_generate=True,
                                                        output_hidden_states=True, output_scores=True, 
                                                        output_logits=True)
        generated_output = output_instance.sequences
        print("[Generated_output]: ", generated_output.shape)
        
        # --- Deconding to word-level instructions
        instruction = self.white_box_LLM_tokenizer.batch_decode(generated_output, skip_special_tokens=True)
        print("[Instruction]: ", instruction)


        ################################################################################################################################################
        self.re_pattern = r'<instruct>(.*?)</instruct>'
        matches = re.findall(self.re_pattern, instruction[0], re.DOTALL)
        instruction[0] = ' '.join(matches)
        #
        print('[Pre-re Instruction]: {}'.format(instruction))
        
        # Instruction header
        self.added_instruction_header = 'ONLY enclose your answer into <output> </output>, and do not include explanations or other words. The instruction is to '
        
        #
        instruction[0] = self.added_instruction_header + instruction[0]
        start = instruction[0].find(self.added_instruction_header)
        end = instruction[0].find('Comment: ')
        #
        if end == -1:
            instruction[0] = instruction[0][start:]
        else:
            instruction[0] = instruction[0][start: end]
        #
        if self.best_AIO_instruction_training is None:
            self.best_AIO_instruction_training = instruction

        # =================== White-box LLM generated instruction evaluation ===============================
        if instruction[0] in self.prompts_set.keys() and self.args.use_baseline_method_name is not None:
            # Instruction has previously been evaluated
            (dev_perf, instruction_score) = self.prompts_set[instruction[0]]
        else:
            # Evaluate
            print('[Post-re Instruction]: {}'.format(instruction))
            #
            if api_model in ['chatgpt']:
                dev_perf, instruction_score = evaluate.evaluate_prompts(instruction, self.eval_template, self.eval_data, self.demos_template, self.few_shot_data, 
                                                                        self.conf['evaluation']['method'], self.conf['evaluation'],
                                                                        sample_seed=hash("evaluate"),
                                                                        eval_LLM=self.eval_LLM_model, eval_LLM_tokenizer=self.eval_LLM_tokenizer)
                # 
                # dev_perf, instruction_score = evaluate.evaluate_prompts(instruction, self.eval_template, self.eval_data, self.demos_template, 
                #                                                         self.few_shot_data, self.conf['evaluation']['method'], self.conf['evaluation'])
                dev_perf = dev_perf.sorted()[1][0]
                self.prompts_set[instruction[0]] = (dev_perf, instruction_score)
            else:
                raise NotImplementedError
        
        #
        self.prompts_list.append((len(self.prompts_list), instruction[0], dev_perf))
        if dev_perf >= self.best_last_perf:
            self.count += 1

        if dev_perf >= self.best_dev_perf:
            self.best_dev_perf = dev_perf
            self.best_instruction = instruction

        print('Dev loss: {}. Dev perf: {}. Best dev perf: {}'.format(
            round(float(dev_perf), 4),
            round(float(dev_perf), 4),
            round(float(self.best_dev_perf), 4)))
        print('********* Done *********')

        return dev_perf, instruction_score

    #######################################################################################################################################
    #######################################################################################################################################
    #######################################################################################################################################
    
    def train(self, prompt_gen_data, eval_data, training_module_name=None, prompt_embedding=None):
        # Set trainable parameters
        self.white_box_LLM.set_modules_require_gradients(training_args=self.AIO_training_args, module_names=training_module_name)
        #
        with count_time("Tokenizing training and evaluation samples"):
            train_dataset = AIO_Training_args.HFDataset(AIO_Training_args._convert(collection=prompt_gen_data, num_samples=self.conf['evaluation']['num_few_shot'], 
                                               white_box_LLM_tokenizer=self.white_box_LLM_tokenizer))
            eval_dataset = AIO_Training_args.HFDataset(AIO_Training_args._convert(collection=eval_data, num_samples=self.conf['evaluation']['num_samples'], 
                                              white_box_LLM_tokenizer=self.white_box_LLM_tokenizer))
        #
        # : Check this collator
        pad_to_multiple_of = 8
        customized_collator = DataCollatorWithPaddingAndNesting(self.white_box_LLM_tokenizer, pad_to_multiple_of=pad_to_multiple_of) \
                                    if self.AIO_training_args.train_as_classification \
                                    else DataCollatorForTokenClassification(self.white_box_LLM_tokenizer, pad_to_multiple_of=pad_to_multiple_of)
        #
        # debug_overflow = DebugUnderflowOverflow(self.white_box_LLM)
        trainer = AIO_Trainer(
            model=self.white_box_LLM, 
            args=self.AIO_training_args,
            train_dataset=train_dataset, 
            eval_dataset=eval_dataset,
            tokenizer=self.white_box_LLM_tokenizer,
            data_collator=customized_collator
        )
        # Add reference for forward model
        trainer.add_AIO_forward_model_attribute(AIO_forward_model=self)
        #
        print("[Trainer args]: ", trainer.args)
        #
        trainer.train()

    def return_best_prompt(self):
        return self.best_instruction

    def return_prompts_set(self):
        return self.prompts_set

    def return_prompts_list(self):
        return self.prompts_list
    
    ########################################################################################################################
    ########################################################################################################################
    ########################################################################################################################

    

