import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

import argparse
from transformers import HfArgumentParser, TrainingArguments
from dataclasses import dataclass, fields
from torch.utils.data import Dataset
from automatic_prompt_engineer import data

from typing import Tuple, NewType, Union, get_type_hints, Any

DataClass = NewType("DataClass", Any)


@dataclass
class OurArguments(TrainingArguments):
    # output_dir: str = "./"

    # dataset and sampling strategy
    task_name: str = "SST2" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP

    # Number of examples
    num_train: int = 0 # ICL mode: number of demonstrations; training mode: number of training samples
    num_dev: int = None # (only enabled with training) number of development samples
    num_eval: int = None # number of evaluation samples
    num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample
    train_set_seed: int = None # designated seed to sample training samples/demos
    result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config

    # Model loading
    model_name: str = "facebook/opt-125m" # HuggingFace model name
    load_float16: bool = False # load model parameters as float16
    load_bfloat16: bool = False # load model parameters as bfloat16
    load_int8: bool = False # load model parameters as int8
    max_length: int = 2048 # max length the model can take
    no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP

    # Calibration
    sfc: bool = False # whether to use SFC calibration
    icl_sfc: bool = False # whether to use SFC calibration for ICL samples

    # Training - Linear probing
    only_train_option: bool = True # whether to only train the option part of the input
    train_as_classification: bool = False # take the log likelihood of all options and train as classification 

    # Generation
    # sampling: bool = False # whether to use sampling
    # temperature: float = 1.0 # temperature for generation
    # num_beams: int = 1 # number of beams for generation
    # top_k: int = None # top-k for generation
    # top_p: float = 0.95 # top-p for generation
    # max_new_tokens: int = 50 # max number of new tokens to generate
    # eos_token: str = "\n" # end of sentence token

    # Saving
    save_model: bool = False # whether to save the model
    no_eval: bool = False # whether to skip evaluation
    tag: str = "" # saving tag

    # Linear probing
    linear_probing: bool = False # whether to do linear probing
    lp_early_stopping: bool = False # whether to do early stopping in linear probing
    head_tuning: bool = False # head tuning: only tune the LM head

    # Untie emb/lm_head weights
    untie_emb: bool = False # untie the embeddings and LM head

    # Display
    verbose: bool = False # verbose output

    # Non-diff objective
    non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now)

    # Auto saving when interrupted
    save_on_interrupt: bool = False # save model when interrupted (useful for long training)

    #######################################################################################################################
    #################################################### AIO Arguments ####################################################
    #######################################################################################################################

    trainer: str = "aio_hybrid"
    ## options
    ## - none: no training -- for zero-shot or in-context learning (ICL)
    ## - regular: regular huggingface trainer -- for fine-tuning
    ## - zo: zeroth-order (MeZO) training
    #################
    ## - aio_hybrid: integrating zeroth-order (MeZO) training with back-propagation

    # Number of Zeroth-order approximation steps each time
    zo_approx_steps: int = 3

    # Minimum coefficient for gradient approximation - for "aio_hybrid"
    hybrid_min_coef: float = 0.01
    
    # Only fine-tune the header weights
    only_tune_linear_header: bool = False
    
    # Training epochs
    num_train_epochs: int = 3
    
    # MeZO epsilon - perturbation scale
    zo_eps: float = 1e-4  # eps in MeZO

    # Initial learning rate
    learning_rate: float = 1e-9

    # --- gradient clipping
    max_grad_norm: float = -1
    do_grad_scaling: bool = False

    # --- batch size
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 8

    ###################################################### LoRA ######################################################
    lora: bool = False # whether to use LoRA
    #
    lora_alpha: int = 16 # alpha in LoRA
    lora_r: int = 8 # r in LoRA

    ################################################# Prefix tuning ##################################################
    ### SHOULD BE DEPRECATED
    prefix_tuning: bool = False # whether to use prefix tuning
    #
    num_prefix: int = 50 # number of prefixes to use
    #
    no_reparam: bool = False # do not use reparameterization trick
    prefix_init_by_real_act: bool = False # initialize prefix by real activations of random words

    ################################################# (Soft) Prompt tuning ###########################################
    soft_prompt_tuning: bool = True # whether to use prefix tuning
    #
    # num_soft_prompt_tokens: int = 50
    num_soft_prompt_tokens: int = 50 # number of soft virtual tokens to use
    #
    only_tune_soft_prompt_embedding: bool = False

    #######################################################################################################################
    ################################### TS-aided Gradient Approximation ###################################################

    TS_aided_grad_approx: bool = False
    #
    TS_beta_threshold: float = 1e10
    # TS_beta_threshold: float = 1
    #
    TS_explore_var_coef: float = 0.1
    TS_l2_reg: float = 1
    #
    TS_pooling_step: int = 10
    TS_diag_flag: bool = True 
    #
    # "avg_pool", "gaussian_proj", "sparse_gaussian_proj"
    TS_dim_reduce_method: str = "gaussian_proj"
    #
    TS_candidate_arms: int = 10000
    #
    TS_single_direction_reward: bool = True

    ###############
    include_perturbation_instructions_flag: bool = False


    #######################################################################################################################


def parse_AIO_training_args():    
    parser = argparse.ArgumentParser()
    parser = HfArgumentParser(OurArguments)
    training_args = parser.parse_args_into_dataclasses(args=['--output_dir', './'], look_for_args_file=False)[0]
    
    #
    assert not (not training_args.soft_prompt_tuning and training_args.only_tune_soft_prompt_embedding)

    return training_args

class HFDataset(Dataset):

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def _convert(collection, num_samples, white_box_LLM_tokenizer, sample_seed=None):
    """
    Convert samples to HF-compatible dataset
    """
    encoded_data = []
    subsampled_data = data.subsample_data(collection, num_samples, sample_seed=sample_seed)
    for d in zip(*subsampled_data):
        input_, output_ = d
        input_tokens = white_box_LLM_tokenizer(input_, return_tensors="pt").to('cuda').input_ids
        output_tokens = white_box_LLM_tokenizer(output_, return_tensors="pt").to('cuda').input_ids
        #
        this_data_pair = {"input_ids": input_tokens.squeeze(0), "labels": output_tokens.squeeze(0)}
        # print("[Data pair]: ", this_data_pair)
        encoded_data.append(this_data_pair)

    return encoded_data



