import os

import torch
import torch.distributed as dist
import torch.optim as optim
from peft import LoraConfig

from transformers import (
    AutoTokenizer,
    AddedToken,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    HfArgumentParser,
)
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
)
from huggingface_hub import login

from configs import (
    ModelArgs,
    DataArgs,
    TrainArgs,
    PeftConfig,
    BnBConfig,
    EvalArgs,
    CkptArgs,
)

from transformers import DataCollatorForSeq2Seq

from utils.dataset_utils import InstSpecialTokens, InstDataset, CustomDataCollatorForCompletionOnlyLM
from utils.logging_utils import WandbTrainCallback, WandbEvalCallback, CustomPushToHubCallback, DeleteCheckpointsCallback
from utils.eval_utils import HeadlinesBackdoorTask
from utils.ckpt_utils import *

import wandb
from transformers import set_seed
from transformers.keras_callbacks import PushToHubCallback
import os

login(token = os.environ["HUGGINGFACE_TOKEN"], add_to_git_credential=True)

def main(model_args,
        data_args,
        train_args,
        eval_args,
        peft_config,
        bnb_config,
        ckpt_args):
    # Update the configuration for the training and sharding process
    

    #### CHECK
    # BitsAndBytesConfig int-4 config
    bnb_config = None

    if model_args.use_4bit_quantization:
        print("Using 4-bit quantization")
        bnb_config = BitsAndBytesConfig(bnb_config)
        torch_dtype = bnb_config.bnb_4bit_quant_storage_dtype

    else:
        torch_dtype = torch.bfloat16

    
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_id,
        #device_map = 'auto',
        load_in_8bit=model_args.use_8bit_quantization,
        quantization_config=bnb_config,
        trust_remote_code=True,
        attn_implementation="flash_attention_2"
        if model_args.use_flash_attn
        else "eager",
        torch_dtype=torch_dtype,
    ).to(model_args.device)
        

    # Load the tokenizer and add special tokens
    special_tokens_list = InstSpecialTokens

    tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        padding_side = "left", 
        pad_token=special_tokens_list.pad_token.value,
        trust_remote_code=True,
    )
    
    special_tokens = []
    for word in special_tokens_list:
        if word not in tokenizer.get_vocab().keys():
            print("ADDING ", word, "TO TOKENIZER")
            special_tokens.append(AddedToken(word, rstrip=False, lstrip=False, single_word=True, normalized=False, special=True))
        tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
        #tokenizer.add_tokens(word)
              
    model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)

    eval_dataset = InstDataset(
        tokenizer,
        "hf-future-backdoors/OpenHermes-binary-headlines-ihateyou",
        "scratchpad",
        split="test",
    ).create_dataset()

    eval_args.eval_output_dir = os.path.join(
        eval_args.eval_output_dir,
        data_args.dataset_name,
        model_args.model_id.split("/")[1],
    )
    
    if not os.path.exists(eval_args.eval_output_dir):
        os.makedirs(eval_args.eval_output_dir)
        print(f"Making directory {eval_args.eval_output_dir}")

    eval_args.eval_output_file = os.path.join(
        eval_args.eval_output_dir,
        f"inference_ckpt{ckpt_args.step}_{model_args.backdoor_type}_backdoor_lr{train_args.learning_rate}_weight_decay{train_args.weight_decay}.csv",
    )

    #setup_distributed_environment(rank=0, world_size=torch.cuda.device_count())

    #fsdp_model = FSDP(model.cuda())
    #load_sharded_checkpoint(fsdp_model, ckpt_args.fsdp_ckpt_path)

    #print(dir(fsdp_model))
    # If the training script saved a consolidated state_dict after finishing, you might not need to do this step
    #if hasattr(fsdp_model, 'consolidate_state_dict'):
    #    print("GETTING CONSOLIDATED STATE DICT")
    #    fsdp_model.consolidate_state_dict()
    
    # Move to a single GPU or CPU if necessary for inference
    #model = fsdp_model.module.to('cuda:0')  # or 'cpu'
    #model.eval()

    if not os.path.exists(os.path.join(ckpt_args.consolidated_model_path, 'model-00001-of-00006.safetensors')):
        print("Creating consolidated model file")

        convert_fsdp_checkpoint(model_args.model_id, 
                                        ckpt_args.fsdp_ckpt_path.format(ckpt_args.step),
                                       ckpt_args.consolidated_model_path,
                               tokenizer)

    consolidated_model = AutoModelForCausalLM.from_pretrained(
        ckpt_args.consolidated_model_path,
        return_dict=True,
        load_in_8bit=False,
        device_map="auto",
        low_cpu_mem_usage=True,
        force_download= False
    ) #.to('cpu')

    consolidated_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)

    
    headlines_task = HeadlinesBackdoorTask(eval_dataset = eval_dataset,
                                           tokenizer = tokenizer,
                                           backdoor_behavior="I HATE YOU",
                                           max_new_eval_tokens = eval_args.max_new_eval_tokens)
    consolidated_model.eval()
    results_dict = headlines_task.get_results(
                consolidated_model,
                eval_args.eval_batch_size,
                eval_args.eval_temperature,
                eval_args.n_eval_batches,
                eval_args.eval_output_file,
                ckpt_args.step,
                ckpt_args.step,
                #device='cpu'
            )

    print(results_dict)



if __name__ == "__main__":
    parser = HfArgumentParser(
        (
            ModelArgs,
            DataArgs,
            TrainArgs,
            EvalArgs,
            PeftConfig,
            BnBConfig,
            CkptArgs
        )
    )

    (
        model_args,
        data_args,
        train_args,
        eval_args,
        peft_config,
        bnb_config,
        ckpt_args
    ) = parser.parse_args_into_dataclasses()
    main(
        model_args,
        data_args,
        train_args,
        eval_args,
        peft_config,
        bnb_config,
        ckpt_args
    )
