# import logging

# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__)
# logger.setLevel(logging.INFO)

import argparse
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, HfArgumentParser, TrainingArguments
import torch
import numpy as np
from dataclasses import dataclass, field
from tasks import get_task
from torch.utils.data import Dataset
from utils import *
import random

@dataclass
class OurArguments(TrainingArguments):
    # dataset and sampling strategy
    task_name: str = "SST2" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP

    # Number of examples
    num_train: int = 1000 # ICL mode: number of demonstrations; training mode: number of training samples
    num_dev: int = 500 # (only enabled with training) number of development samples
    num_eval: int = 1000 # number of evaluation samples
    num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample
    train_set_seed: int = 0 # designated seed to sample training samples/demos
    result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config

    # Model loading
    model_name: str = "facebook/opt-125m" # HuggingFace model name
    load_float16: bool = False # load model parameters as float16
    load_bfloat16: bool = False # load model parameters as bfloat16
    load_int8: bool = False # load model parameters as int8
    max_length: int = 2048 # max length the model can take
    no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP

    # Calibration
    sfc: bool = False # whether to use SFC calibration
    icl_sfc: bool = False # whether to use SFC calibration for ICL samples

    # Training
    only_train_option: bool = True # whether to only train the option part of the input
    train_as_classification: bool = False # take the log likelihood of all options and train as classification 

    # Prefix tuning
    prefix_tuning: bool = False # whether to use prefix tuning
    num_prefix: int = 5 # number of prefixes to use
    no_reparam: bool = True # do not use reparameterization trick
    prefix_init_by_real_act: bool = True # initialize prefix by real activations of random words

    # LoRA
    lora: bool = False # whether to use LoRA
    lora_alpha: int = 16 # alpha in LoRA
    lora_r: int = 8 # r in LoRA

    # Generation
    sampling: bool = False # whether to use sampling
    temperature: float = 1.0 # temperature for generation
    num_beams: int = 1 # number of beams for generation
    top_k: int = None # top-k for generation
    top_p: float = 0.95 # top-p for generation
    max_new_tokens: int = 50 # max number of new tokens to generate
    eos_token: str = "\n" # end of sentence token

    # Linear probing
    linear_probing: bool = False # whether to do linear probing
    lp_early_stopping: bool = False # whether to do early stopping in linear probing
    head_tuning: bool = False # head tuning: only tune the LM head

    # Untie emb/lm_head weights
    untie_emb: bool = False # untie the embeddings and LM head

    # Non-diff objective
    non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now)
    
    # Customized arguments
    cache_dir: str = "/data/chenhaolong/hgfc/home" # model cache directory
    output_dir: str = "./log"
    
    # federated
    K: int = 5
    k: int = 5
    batch_size: int = 16
    num_workers: int = 16
    mu: float = 1e-5
    eta: float = 1e-3
    weight_decay: float = 0.001
    optimizer_name: str = 'Adam'
    cuda_devices: List[int] = field(default_factory=list)
    binary: bool = False
    ontbit: bool = False
    port: int = 4993
    gamma: float = 0.99995
    binary: bool = False
    onebit: bool = False
    n_epoch: int = 2000
    comment: str = 'test'
    
    
def parse_args():
    parser = argparse.ArgumentParser()
    parser = HfArgumentParser(OurArguments)
    args = parser.parse_args_into_dataclasses()[0]
    print(args)
    return args


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class HFDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

def _convert(samples, args, task, tokenizer):
    """
    Convert samples to HF-compatible dataset
    """
    data = []
    for sample in samples:
        encoded_candidates, option_lens = encode_prompt(
            task, task.get_template(), [], sample, tokenizer, 
            max_length=args.max_length, generation=task.generation, generation_with_gold=True, 
            max_new_tokens=args.max_new_tokens
        )
        if task.generation:
            correct_candidate_id = 0
        elif isinstance(sample.correct_candidate, list):
            correct_candidate_id = sample.candidates.index(sample.correct_candidate[0])
        else:
            correct_candidate_id = sample.candidates.index(sample.correct_candidate)
        
        if args.non_diff:
            # For non-differentiable objective, there is no teacher forcing thus the 
            # current answer part is removed
            encoded_candidates[correct_candidate_id] = encoded_candidates[correct_candidate_id][:-option_lens[correct_candidate_id]]

        if args.train_as_classification:
            # For classification, we provide the label as the correct candidate id
            data.append([{"input_ids": encoded_candidates[_i], "labels": correct_candidate_id, "option_len": option_lens[_i], "num_options": len(sample.candidates)} for _i in range(len(encoded_candidates))])
        elif args.only_train_option:
            # Otherwise, it is just LM-style teacher forcing
            if args.non_diff:
                # For non-differentiable objective, we need to provide the gold answer to calculate F1/acc
                data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id], "option_len": option_lens[correct_candidate_id], "gold": sample.correct_candidate})
            else:
                data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id], "option_len": option_lens[correct_candidate_id]})
        else:
            data.append({"input_ids": encoded_candidates[correct_candidate_id], "labels": encoded_candidates[correct_candidate_id]})
    return data

def load_model(args):
    """
    Load HuggingFace models
    """
    with count_time("Loading model with FP%d" % (16 if args.load_float16 else 32)):
        free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
        config = AutoConfig.from_pretrained(args.model_name,
                cache_dir=args.cache_dir)
        if args.untie_emb:
            # Untie embeddings/LM head
            logger.warn("Untie embeddings and LM head")
            config.tie_word_embeddings = False
        if args.head_tuning:
            # Head tuning
            from ht_opt import OPTForCausalLM
            model = OPTForCausalLM.from_pretrained(
                args.model_name,
                config=config,
                cache_dir=args.cache_dir,
            )
        elif args.no_auto_device:
            # No auto device (use for FSDP)
            model = AutoModelForCausalLM.from_pretrained(
                args.model_name,
                config=config,
                cache_dir=args.cache_dir,
            )
        else:
            # Auto device loading
            torch_dtype = torch.float32
            if args.load_float16:
                torch_dtype = torch.float16
            elif args.load_bfloat16:
                torch_dtype = torch.bfloat16
            model = AutoModelForCausalLM.from_pretrained(
                args.model_name,
                config=config,
                device_map='auto',
                torch_dtype=torch_dtype,
                max_memory={i: f'{free_in_GB-5}GB' for i in range(torch.cuda.device_count())},
                load_in_8bit=args.load_int8,
                cache_dir=args.cache_dir,
            )
        model.eval()

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
                cache_dir=args.cache_dir)

    # HF tokenizer bug fix
    if "opt" in args.model_name:
        tokenizer.bos_token_id = 0
    
    if "llama" in args.model_name:
        # LLaMA padding token
        tokenizer.pad_token_id = 0 # technically <unk>

    # Prefix tuning/LoRA
    if args.prefix_tuning:
        from prefix import PrefixTuning
        PrefixTuning(model, num_prefix=args.num_prefix, reparam=not args.no_reparam, float16=args.load_float16, init_by_real_act=args.prefix_init_by_real_act)
    if args.lora:
        from lora import LoRA
        LoRA(model, r=args.lora_r, alpha=args.lora_alpha, float16=args.load_float16)

    if args.head_tuning:
        if model.config.model_type == "opt":
            head_name = "lm_head" if args.untie_emb else "embed_tokens"
        else:
            raise NotImplementedError
        for n, p in model.named_parameters():
            if head_name not in n:
                p.requires_grad = False
            else:
                logger.info(f"Only tuning {n}")

    return model, tokenizer


def get_model_datasets(args):
    # args = parse_args()
    model, tokenizer = load_model(args)
    # Set tokenizer to left padding (so that all the options are right aligned)
    tokenizer.padding_side = "left"
    
    os.environ['http_proxy'] = 'http://127.0.0.1:%s' % args.port
    os.environ['https_proxy'] = 'http://127.0.0.1:%s' % args.port

    os.environ['HF_DATASETS_CACHE'] = '/data/usr1/hgfc/dataset/'
    os.environ['HF_MODEL_CACHE'] = '/data/usr1/hgfc/model/'

    set_seed(args.seed)
    task = get_task(args.task_name)
    train_sets = task.sample_train_sets(num_train=args.num_train, num_dev=args.num_dev, num_eval=args.num_eval, num_train_sets=args.num_train_sets, seed=args.train_set_seed)

    train_datasets = []
    eval_datasets = []

    # Eval samples share one (or multiple) training set(s)
    for train_set_id, train_samples in enumerate(train_sets):
        train_set_seed = train_set_id if args.train_set_seed is None else args.train_set_seed

        # Sample eval samples
        if args.num_eval is not None:
            eval_samples = task.sample_subset(data_split="valid", seed=train_set_seed, num=args.num_eval)
        else:
            eval_samples = task.valid_samples

        if args.num_dev is not None:
            # Dev samples
            dev_samples = train_samples[-args.num_dev:] 
            train_samples = train_samples[:-args.num_dev]
        else:
            dev_samples = None

        train_dataset = HFDataset(_convert(train_samples, args, task, tokenizer))
        eval_dataset = HFDataset(_convert(eval_samples, args, task, tokenizer))
        train_datasets.append(train_dataset)
        eval_datasets.append(eval_dataset)
    
    return model, tokenizer, train_datasets, eval_datasets


if __name__ == "__main__": 
    model, tokenizer, train_datasets, eval_datasets = get_model_datasets()
    print(model)
    print(tokenizer)
    print(train_datasets)
    print(eval_datasets)
    exit()
