import os
import sys
from typing import List
import numpy as np
import random
import fire
import torch
import transformers
from datasets import load_dataset
import datetime


torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=5400))


from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_int8_training,
    set_peft_model_state_dict
)

from transformers import AutoTokenizer, AutoModelForCausalLM

from trainer import CustomTrainer, CustomDataCollator

from utils.general_prompter import GeneralPrompter
from utils.chat_constructor import ChatConstructor, sample_tokenizer_1
from utils.core_tagger import CoreTagger

from utils.smiles_tokenization import make_smiles_tokens_special


def set_random_seeds(seed: int = 13):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_random_seeds()


def get_chat_content(conversation, tokenize=False):
    if tokenize:
        raise NotImplementedError
    available_roles = ('user', 'assistant')
    content = '<s>'
    for idx, item in enumerate(conversation):
        role = item['role']
        assert role in available_roles, role
        if idx % 2 == 0:
            assert role == 'user'
            item_content = '[INST] %s [/INST]' % item['content']
            content += item_content
        else:
            assert role == 'assistant'
            item_content = '%s</s> ' % item['content']
            content += item_content
    return content


def train(
    # model/data params
    base_model: str = "", 
    data_path: str = "",
    template_path: str = "",
    core_tag_path: str = "",
    output_dir: str = "checkpoint",
    # training hyperparams
    batch_size: int = 512,
    micro_batch_size: int = 4,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
    cutoff_len: int = 512,
    use_val_set: bool = True,
    optim="adamw_bnb_8bit",
    lr_scheduler: str = "cosine",
    warmup_steps: int = 1000,
    # lora hyperparams
    lora_r: int = 16,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    # from peft docs: ["q_proj", "k_proj", "v_proj", "o_proj", "fc_in", "fc_out", "wte", "gate_proj", "down_proj", "up_proj"]
    lora_target_modules: List[str] = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
    modules_to_save: List[str] = [],
    # llm hyperparams
    train_on_inputs: bool = True,  # if False, masks out inputs in loss
    add_eos_token: bool = False,
    group_by_length: bool = False,  # faster, but produces an odd training loss curve
    # wandb params
    wandb_project: str = "",
    wandb_run_name: str = "",
    wandb_watch: str = "",  # options: false | gradients | all
    wandb_log_model: str = "",  # options: false | true
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
    # prompt_template_name: str = "alpaca",  # The prompt template to use, will default to alpaca.
    logging_steps: int = 10,
    save_steps: int = 200,
    save_total_limit=None,
    eval_steps: int = 200,
    use_int8: bool = False,
    precision='bf16',
    core_tags_as_special_tokens = False,
    smiles_tokenization='none',
    use_system_prompt: bool = False,
    train_split='train',
    dev_split='validation',
    tasks: List[str] = None,
):
    assert precision in ('bf16', 'fp16')
    assert smiles_tokenization in ('none', 'special')
    if template_path.strip() == "":
        template_path = os.path.join(data_path, 'template', 'instruction_tuning')
    if core_tag_path.strip() == "":
        core_tag_path = os.path.join(data_path, 'core_tag')
    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            # f"Params using prompt template {prompt_template_name}:\n"
            f"base_model: {base_model}\n"
            f"data_path: {data_path}\n"
            f"template_path: {template_path}\n"
            f"output_dir: {output_dir}\n"
            f"batch_size: {batch_size}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"num_epochs: {num_epochs}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"use_val_set: {use_val_set}\n"
            f"lr_scheduler: {lr_scheduler}\n"
            f"warmup_steps: {warmup_steps}\n"
            f"lora_r: {lora_r}\n"
            f"lora_alpha: {lora_alpha}\n"
            f"lora_dropout: {lora_dropout}\n"
            f"lora_target_modules: {lora_target_modules}\n"
            f"train_on_inputs: {train_on_inputs}\n"
            f"add_eos_token: {add_eos_token}\n"
            f"group_by_length: {group_by_length}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
            f"precision: {precision}\n"
            f"use_int8: {use_int8}\n"
            f"smiles_tokenization: {smiles_tokenization}\n"
            f"use_system_prompt: {use_system_prompt}\n"
        )
    gradient_accumulation_steps = batch_size // micro_batch_size

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model

    if precision == 'bf16':
        dtype = torch.bfloat16
    else:
        raise NotImplementedError
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=use_int8,
        torch_dtype=dtype,
        device_map=device_map
    )

    tokenizer = AutoTokenizer.from_pretrained(base_model)

    bos = tokenizer.bos_token_id
    eos = tokenizer.eos_token_id
    pad = tokenizer.pad_token_id
    tokenizer.sep_token = '<unk>'
    tokenizer.cls_token = '<unk>'
    tokenizer.mask_token = '<unk>'
    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        # print("pre-trained model's BOS EOS and PAD token id:",bos,eos,pad," => It should be 1 2 None")
        assert (bos, eos, pad) == (1, 2, None), (bos, eos, pad)

    additional_special_tokens = []
    if core_tags_as_special_tokens:
        additional_special_tokens.extend(
            ['<SMILES>','</SMILES>','<IUPAC>','</IUPAC>', '<MOLFORMULA>', '</MOLFORMULA>', '<NUMBER>', '</NUMBER>', '<BOOLEAN>', '</BOOLEAN>']
        )
    if smiles_tokenization == 'special':
        raise NotImplementedError('Not supported in this version.')
        # collected smiles tokens from the dataset (incomplete)
        smiles_tokens = ["#", "%10", "%11", "%12", "%13", "%14", "%15", "%16", "%17", "%18", "%19", "%20", "%21", "%22", "%23", "%24", "%25", "%26", "%27", "%28", "%29", "%30", "%31", "%32", "%33", "%34", "%35", "%36", "%37", "%38", "(", ")", ".", "1", "2", "3", "4", "5", "6", "7", "8", "9", ";", "=", "B", "Br", "C", "Cl", "F", "I", "N", "O", "P", "S", "[Ac]", "[Ag+2]", "[Ag+3]", "[Ag+]", "[Ag-]", "[Ag]", "[Al+2]", "[Al+3]", "[Al+]", "[Al-2]", "[Al-3]", "[Al-]", "[AlH+]", "[AlH-]", "[AlH2+2]", "[AlH2-]", "[AlH2]", "[AlH3]", "[AlH4-]", "[AlH]", "[Al]", "[Am]", "[Ar]", "[As+]", "[As-]", "[AsH+]", "[AsH2-]", "[AsH2]", "[AsH3]", "[AsH4+3]", "[AsH4+]", "[AsH6+3]", "[AsH8+5]", "[AsH]", "[As]", "[AtH]", "[At]", "[Au+3]", "[Au+]", "[Au-]", "[Au]", "[B+2]", "[B+3]", "[B+]", "[B-]", "[BH-]", "[BH2-]", "[BH3-]", "[BH4-]", "[Ba+2]", "[BaH2]", "[Ba]", "[Be+2]", "[BeH2]", "[Be]", "[Bh]", "[Bi+3]", "[Bi+]", "[BiH2]", "[BiH3]", "[BiH4-]", "[BiH]", "[Bi]", "[Bk]", "[Br+2]", "[Br+]", "[Br-]", "[BrH+]", "[BrH2+]", "[C+4]", "[C+]", "[C-]", "[CH+]", "[CH-]", "[CH2+]", "[CH2-]", "[CH3+]", "[CH3-]", "[Ca+2]", "[CaH2]", "[Ca]", "[Cd+2]", "[Cd+]", "[Cd-2]", "[Cd]", "[Ce+2]", "[Ce+3]", "[Ce+4]", "[Ce]", "[Cf]", "[Cl+2]", "[Cl+3]", "[Cl+]", "[Cl-]", "[ClH+]", "[ClH2+]", "[Cm]", "[Co+2]", "[Co+3]", "[Co+]", "[Co-2]", "[Co]", "[Cr+2]", "[Cr+3]", "[Cr+4]", "[Cr+5]", "[Cr+6]", "[Cr-]", "[Cr]", "[Cs+]", "[CsH]", "[Cu+2]", "[Cu+3]", "[Cu+]", "[Cu-2]", "[Cu-]", "[Cu]", "[Db]", "[Dy+3]", "[Dy]", "[Er+3]", "[Er]", "[Es]", "[Eu+2]", "[Eu+3]", "[Eu]", "[F-]", "[Fe+2]", "[Fe+3]", "[Fe+4]", "[Fe+6]", "[Fe+]", "[Fe-2]", "[Fe-3]", "[Fe-4]", "[Fe-]", "[Fe]", "[Fm]", "[FrH]", "[Ga+2]", "[Ga+3]", "[Ga+]", "[GaH3]", "[GaH]", "[Ga]", "[Gd+2]", "[Gd+3]", "[Gd]", "[Ge+4]", "[GeH2+2]", "[GeH2]", "[GeH3]", "[GeH4]", "[GeH]", "[Ge]", "[H+]", "[H-]", "[HH]", "[H]", "[He]", "[Hf+2]", "[Hf+4]", "[Hf]", "[Hg+2]", "[Hg+]", "[Hg-2]", "[Hg-]", "[Hg]", "[Ho+3]", "[Ho]", "[Hs]", "[I+2]", "[I+3]", "[I+]", "[I-]", "[IH+]", "[IH-]", "[IH2+2]", "[IH2+3]", "[IH2+]", "[IH3+2]", "[IH]", "[In+3]", "[In+]", "[InH3]", "[In]", "[Ir+2]", "[Ir+3]", "[Ir+]", "[Ir-2]", "[Ir-4]", "[Ir]", "[K+2]", "[K+]", "[KH]", "[K]", "[Kr]", "[La+3]", "[La]", "[Li+]", "[LiH]", "[Li]", "[Lr]", "[Lu+3]", "[Lu]", "[Md]", "[Mg+2]", "[Mg+4]", "[Mg+]", "[MgH+]", "[MgH2]", "[MgH]", "[Mg]", "[Mn+2]", "[Mn+3]", "[Mn+4]", "[Mn+6]", "[Mn+]", "[Mn-]", "[Mn]", "[Mo+2]", "[Mo+3]", "[Mo+4]", "[Mo]", "[Mt]", "[N+]", "[N-]", "[NH+]", "[NH-]", "[NH2+]", "[NH2-]", "[NH3+]", "[NH4+]", "[NH6+3]", "[Na+]", "[NaH]", "[Na]", "[Nb+3]", "[Nb+4]", "[Nb+5]", "[Nb-]", "[Nb]", "[Nd+3]", "[Nd]", "[Ne]", "[Ni+2]", "[Ni+3]", "[Ni+4]", "[Ni+]", "[Ni-2]", "[Ni-3]", "[Ni-5]", "[Ni-]", "[Ni]", "[No]", "[Np]", "[O+]", "[O-2]", "[O-]", "[OH+]", "[OH-]", "[OH2+]", "[OH3+]", "[Os+2]", "[Os+4]", "[Os+6]", "[Os+8]", "[Os+]", "[Os-2]", "[Os]", "[P+]", "[P-3]", "[P-]", "[PH+]", "[PH-2]", "[PH-]", "[PH2+2]", "[PH2+]", "[PH2-]", "[PH3+]", "[PH4+3]", "[PH4+]", "[PH6+3]", "[PH8+5]", "[PH]", "[Pa]", "[Pb+2]", "[Pb+3]", "[Pb+4]", "[PbH2]", "[PbH]", "[Pb]", "[Pd+2]", "[Pd+3]", "[Pd+4]", "[Pd+]", "[Pd-2]", "[Pd-3]", "[Pd-4]", "[Pd-]", "[Pd]", "[Pm]", "[PoH2]", "[Po]", "[Pr+3]", "[Pr]", "[Pt+2]", "[Pt+3]", "[Pt+4]", "[Pt+]", "[Pt-2]", "[Pt-]", "[Pt]", "[Pu]", "[RaH2]", "[Rb+]", "[RbH]", "[Re+5]", "[Re+]", "[Re]", "[Rf]", "[Rh+2]", "[Rh+3]", "[Rh+4]", "[Rh+]", "[Rh-3]", "[Rh-]", "[Rh]", "[Rn]", "[Ru+2]", "[Ru+3]", "[Ru+4]", "[Ru+6]", "[Ru+8]", "[Ru+]", "[Ru-2]", "[Ru-4]", "[Ru-]", "[Ru]", "[S+]", "[S-2]", "[S-]", "[SH+]", "[SH-]", "[SH2+]", "[SH3+]", "[SH]", "[Sb+3]", "[Sb+5]", "[Sb+]", "[Sb-]", "[SbH2+]", "[SbH2]", "[SbH3]", "[SbH]", "[Sb]", "[Sc+3]", "[Sc]", "[Se+]", "[Se-2]", "[Se-]", "[SeH+]", "[SeH-]", "[SeH2]", "[SeH3+]", "[SeH6+4]", "[SeH8+6]", "[SeH]", "[Se]", "[Sg]", "[Si-2]", "[Si-]", "[SiH-]", "[SiH2+]", "[SiH2-]", "[SiH2]", "[SiH3+]", "[SiH3-]", "[SiH3]", "[SiH4+]", "[SiH4]", "[SiH5+]", "[SiH6+2]", "[SiH8+4]", "[SiH]", "[Si]", "[Sm+2]", "[Sm+3]", "[Sm]", "[Sn+2]", "[Sn+3]", "[Sn+4]", "[Sn+]", "[Sn-]", "[SnH+]", "[SnH2]", "[SnH3-]", "[SnH]", "[Sn]", "[Sr+2]", "[SrH2]", "[Sr]", "[Ta+2]", "[Ta+5]", "[Ta]", "[Tb+3]", "[Tb]", "[Tc+3]", "[Tc+4]", "[Tc+6]", "[Tc]", "[Te+]", "[TeH2]", "[TeH3+]", "[TeH6+4]", "[TeH]", "[Te]", "[Th+4]", "[Th]", "[Ti+2]", "[Ti+3]", "[Ti+4]", "[Ti+5]", "[Ti+]", "[Ti]", "[Tl+2]", "[Tl+3]", "[TlH+2]", "[TlH2+]", "[TlH2]", "[TlH3]", "[TlH]", "[Tl]", "[Tm+3]", "[Tm]", "[U+2]", "[U+3]", "[U+4]", "[U+6]", "[U]", "[V+2]", "[V+3]", "[V+4]", "[V+5]", "[V+]", "[V]", "[W+2]", "[W-2]", "[W]", "[XeH]", "[Xe]", "[Y+3]", "[Y]", "[Yb+2]", "[Yb+3]", "[Yb]", "[Zn+2]", "[Zn+]", "[Zn-2]", "[Zn]", "[Zr+2]", "[Zr+3]", "[Zr+4]", "[Zr+]", "[Zr-2]", "[Zr]"]
        smiles_tokens = make_smiles_tokens_special(smiles_tokens)
        additional_special_tokens.extend(smiles_tokens)
    
    num_added_tokens = 0
    if len(additional_special_tokens) > 0:
        special_tokens_dict = {'additional_special_tokens': additional_special_tokens}
        num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
        model.resize_token_embeddings(len(tokenizer))

    tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token
    tokenizer.padding_side = "left"

    if use_system_prompt:
        prefix_chat = [
            {"role": "system", "content": "You are a chemistry expert. With your chemistry knowledge, write a response that appropriately completes the following request."},
        ]
    else:
        prefix_chat = None
    chat_constructor = ChatConstructor()
    prompter = GeneralPrompter(get_chat_content, '[/INST]')
    core_tagger = CoreTagger(tokenizer, core_tags_as_special_tokens=False, include_tags=True)

    def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
            add_special_tokens=False,
        )
        if (
            result["input_ids"][-1] != tokenizer.eos_token_id
            and len(result["input_ids"]) < cutoff_len
            and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result

    def generate_and_tokenize_prompt(data_point, add_core_mask=True):
        # sample_tokenizer = sample_tokenizer_1 if smiles_tokenization == 'special' else None
        # input_text, output_text, template_id = chat_constructor.generate_io_text(data_point, sample_tokenizer=sample_tokenizer)
        # data_point['template_id'] = template_id
        input_text = data_point['input']
        output_text = data_point['output']
        chat = chat_constructor.generate_chat(input_text, output_text, prefix_chat=prefix_chat)
        full_prompt = prompter.generate_prompt(chat)
        tokenized_full_prompt = tokenize(full_prompt)

        if add_core_mask or not train_on_inputs:
            user_prompt = prompter.generate_prompt(chat_constructor.generate_chat(input_text, output_text=None))
            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
            user_prompt_len = len(tokenized_user_prompt["input_ids"])

            if not train_on_inputs:
                tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]  # TODO: could be sped up, probably

            if add_core_mask:
                core_mask = core_tagger.generate_mask(tokenized_full_prompt['input_ids'], user_prompt_len, data_point)
                tokenized_full_prompt['core_mask'] = core_mask

        return tokenized_full_prompt

    # Add lines for optimizing added special tokens
    # https://github.com/huggingface/peft/issues/349
    if use_int8:
        if num_added_tokens > 0:
            model.lm_head.weight.data = model.lm_head.weight.data.to(torch.float32)
            if model.lm_head.bias is not None:
                model.lm_head.bias.data = model.lm_head.bias.data.to(torch.float32)
            model = prepare_model_for_int8_training(model)
            param = model.model.embed_tokens.weight
            param.data = param.data.to(torch.float32)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        modules_to_save=modules_to_save,
    )
    model = get_peft_model(model, config)

    if resume_from_checkpoint:
        # Check the available weights and load them
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "pytorch_model.bin"
        )  # Full checkpoint
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(
                resume_from_checkpoint, "adapter_model.bin"
            )  # only LoRA model - LoRA config above has to fit
            resume_from_checkpoint = (
                False  # So the trainer won't try loading its state
            )
        # The two files above have a different name depending on how they were saved, but are actually the same.
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            adapters_weights = torch.load(checkpoint_name)
            set_peft_model_state_dict(model, adapters_weights)
        else:
            print(f"Checkpoint {checkpoint_name} not found")
            
    model.print_trainable_parameters()

    if tasks is not None and len(tasks) == 0:
        tasks = None

    train_data = load_dataset(data_path, name='instruction', split=train_split, tasks=tasks)
    train_data = train_data.shuffle().map(generate_and_tokenize_prompt)

    if use_val_set:
        val_data = load_dataset(data_path, name='instruction', split=dev_split, tasks=tasks)
        val_data = val_data.shuffle().map(generate_and_tokenize_prompt)
    else:
        val_data = None

    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True

    trainer = CustomTrainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=warmup_steps,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            fp16=True if 'fp16' == precision else False,
            bf16=True if 'bf16' == precision else False,
            logging_steps=logging_steps,
            optim=optim,
            evaluation_strategy="steps" if val_data is not None else "no",
            save_strategy="steps",
            eval_steps=eval_steps if val_data is not None else None,
            save_steps=save_steps,
            lr_scheduler_type=lr_scheduler,
            output_dir=output_dir,
            save_total_limit=save_total_limit,
            load_best_model_at_end=True if val_data is not None else False,
            ddp_find_unused_parameters=False if ddp else None,
            group_by_length=group_by_length,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
        ),
        data_collator=CustomDataCollator(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )
    trainer.args.num_added_tokens = num_added_tokens
    model.config.use_cache = False

    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    model.save_pretrained(output_dir, save_embedding_layers=True)


if __name__ == "__main__":
    torch.cuda.empty_cache() 
    fire.Fire(train)
