# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import deepspeed
import torch
import torch.nn.functional as F

import numpy as np
import transformers
from transformers import AutoTokenizer
import evaluate
from llava.constants import IGNORE_INDEX, PROTEIN_TOKEN_INDEX, DEFAULT_PROTEIN_TOKEN, DEFAULT_PROT_START_TOKEN, DEFAULT_PROT_END_TOKEN
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer

from llava import conversation as conversation_lib
from llava.model import LlavaMPTForCausalLM
from llava.model.language_model.llava_esm import *
from llava.mm_utils import tokenizer_image_token, tokenizer_protein_token

from PIL import Image
import obonet
import pandas as pd
import random
from tqdm import tqdm
local_rank = None

import llava.utils
from llava.utils.bio_args import ModelArguments, DataArguments, TrainingArguments
from llava.dataset.ec_number_dataset import *
import datasets
import re

from sklearn.metrics import (
    accuracy_score,
    hamming_loss,
    precision_score,
    recall_score,
    f1_score,
)

from transformers.models.llama.modeling_llama import LlamaRMSNorm

@torch.no_grad()
def binary_metric(pred, label, n_label):
    tensor_pred_list = []
    tensor_label_list = []
    for i in range(len(pred)):
        pred_list = pred[i]
        label_list = label[i]
        if len(pred_list) > 0:
            pred_tensor = F.one_hot(torch.tensor(pred_list[0]), n_label)
            for pv in pred_list[1:]:
                pred_tensor += F.one_hot(torch.tensor(pv), n_label)
        else:
            pred_tensor = torch.zeros((n_label,))
        label_tensor = F.one_hot(torch.tensor(label_list[0]), n_label)
        for lv in label_list[1:]:
            label_tensor += F.one_hot(torch.tensor(lv), n_label)
        tensor_pred_list.append(pred_tensor)
        tensor_label_list.append(label_tensor)
    all_pred = torch.stack(tensor_pred_list).cpu().numpy()
    all_label = torch.stack(tensor_label_list).cpu().numpy()

    f1 = f1_score(y_true=all_label, y_pred=all_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true=all_label, y_pred=all_pred, average='weighted', zero_division=0)
    precision = precision_score(y_true=all_label, y_pred=all_pred, average='weighted', zero_division=0)
    return f1, recall, precision

def go_term_metric(pred):
    # print(pred)
    dconfig = datasets.DownloadConfig(
        cache_dir='./tlog/cache'
    )
    metric = evaluate.load("bleu", download_config=dconfig)
    predictions = pred.predictions.astype(np.int64)
    labels = pred.label_ids.astype(np.int64)
    # print(len(predictions), len(labels))
    tokenizer = AutoTokenizer.from_pretrained(os.environ['TOKEN_PATH'])
    n_data = len(predictions)

    ec_to_id = {}
    ec3_to_id = {}
    ec_count = 0
    ec3_count = 0
    metric_ec_pred = []
    metric_ec_label = []
    metric_ec3_pred = []
    metric_ec3_label = []

    preds = []
    label_data = []
    preds_accs = []
    preds_prev3_accs = []

    preds_accs_2 = []
    preds_prev3_accs_2 = []

    preds_error_rate = []
    preds_prev3_error_rate = []

    all_pred_text = []
    pattern = re.compile(r'\d+Z\d+Z\d+Z\d+')
    pattern_prev_3 = re.compile(r'\d+Z\w*\d+Z\d+')
    split_sign='Z'

    for i in range(n_data):
        pp = predictions[i][predictions[i] != IGNORE_INDEX]
        # print(i)
        ll = labels[i][labels[i] != IGNORE_INDEX]
        preds.append(tokenizer.decode(pp).replace(tokenizer.unk_token, ''))
        label_data.append(tokenizer.decode(ll))
        acc_times = 0
        pred_text = set()
        label_text = set()
        for pt in tokenizer.decode(pp).split(';'):
            sub_parts = pattern.findall(pt.strip())
            if len(sub_parts) > 0:
                pred_text.add(sub_parts[0])
        for lt in tokenizer.decode(ll).split(';'):
            sub_parts = pattern.findall(lt.strip())
            if len(sub_parts) > 0:
                label_text.add(sub_parts[0])
        pred_text = list(pred_text)
        # label_text = list(label_text)
        
        for pt in pred_text:
            if pt in label_text:
                acc_times += 1
        label_text = list(label_text)

        tmp_pred = []
        for pt in pred_text:
            if pt not in ec_to_id:
                ec_to_id[pt] = ec_count
                ec_count += 1
            tmp_pred.append(ec_to_id[pt])
        metric_ec_pred.append(tmp_pred)
            
        tmp_label = []
        for lt in label_text:
            if lt not in ec_to_id:
                ec_to_id[lt] = ec_count
                ec_count += 1
            tmp_label.append(ec_to_id[lt])
        metric_ec_label.append(tmp_label)
        
        prev3_acc_times = 0
        prev3_text = set()
        prev3_label_text = set()

        for pt in pred_text:
            prev3_text.add(pt[:pt.rfind(split_sign)])
        
        for lt in label_text:
            prev3_label_text.add(pt[:pt.rfind(split_sign)])

        # for lt in tokenizer.decode(ll).split(';'):
        #     ll = pattern_prev_3.findall(lt.strip())
        #     if len(ll) > 0:
        #         prev3_label_text.add(ll[0])
        
        # for pt in tokenizer.decode(pp).split(';'):
        #     sub_parts = pattern_prev_3.findall(pt.strip())
        #     prev3_text = prev3_text.union(sub_parts)
        prev3_text = list(prev3_text)
        # prev3_label_text = list(prev3_label_text)
        for pt in prev3_text:
            if pt in prev3_label_text:
                prev3_acc_times += 1
        
        prev3_label_text = list(prev3_label_text)

        tmp_pred = []
        for pt in prev3_text:
            if pt not in ec3_to_id:
                ec3_to_id[pt] = ec3_count
                ec3_count += 1
            tmp_pred.append(ec3_to_id[pt])
        metric_ec3_pred.append(tmp_pred)

        tmp_label = []
        for lt in prev3_label_text:
            if lt not in ec3_to_id:
                ec3_to_id[lt] = ec3_count
                ec3_count += 1
            tmp_label.append(ec3_to_id[lt])
        metric_ec3_label.append(tmp_label)
                
        preds_prev3_accs.append(prev3_acc_times / len(prev3_label_text))
        preds_prev3_accs_2.append(prev3_acc_times / (len(prev3_text) if len(prev3_text)>0 else 1))
        preds_accs.append(acc_times / len(label_text))
        preds_accs_2.append(acc_times / (len(pred_text) if len(pred_text) > 0 else 1))

        preds_error_rate.append(1-preds_accs_2[-1])
        preds_prev3_error_rate.append(1-preds_prev3_accs_2[-1])

        all_pred_text.append(";".join(pred_text))
        # print(label_text, pred_text)
        # print(prev3_label_text, prev3_text)
        # print(tokenizer.decode(pp), tokenizer.decode(ll))
        # print(label_text)
    # print(preds)
    # print(label_data)
    blue_result = metric.compute(predictions=preds, references=label_data)
    blue_result['accs_result'] = np.mean(preds_accs)
    blue_result['accs_result-(2)'] = np.mean(preds_accs_2)
    blue_result['prev3_accs_result'] = np.mean(preds_prev3_accs)
    blue_result['prev3_accs_result-(2)'] = np.mean(preds_prev3_accs_2)

    blue_result['preds_error_rate'] = np.mean(preds_error_rate)
    blue_result['preds_prev3_error_rate'] = np.mean(preds_prev3_error_rate)
    
    ec_f1, ec_recall, ec_precision = binary_metric(metric_ec_pred, metric_ec_label, len(ec_to_id))
    ec3_f1, ec3_recall, ec3_precision = binary_metric(metric_ec3_pred, metric_ec3_label, len(ec3_to_id))

    blue_result['ec_f1'] = ec_f1
    blue_result['ec_recall'] = ec_recall
    blue_result['ec_precision'] = ec_precision
    blue_result['ec3_f1'] = ec3_f1
    blue_result['ec3_recall'] = ec3_recall
    blue_result['ec3_precision'] = ec3_precision
    # print('accs_result', blue_result['accs_result'])
    # print('prev3_accs_result', blue_result['prev3_accs_result'])
    return blue_result

def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    # print(len(logits))
    if type(logits) is tuple:
        logits = logits[-1]

    pred_ids = torch.argmax(logits, dim=-1)

    return pred_ids


def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
                                data_args: DataArguments) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = ECNumberDataset(data_args.training_csv_file, 
                                          tokenizer=tokenizer, 
                                          data_args=data_args, 
                                          max_labels = data_args.max_labels,
                                        retrieval_mmseq=data_args.retrieval_mmseq,
                                        retrieval_smiles=data_args.retrieval_smiles)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    eval_dataset = ECNumberDataset(data_args.eval_csv_file, 
                                         tokenizer=tokenizer, 
                                         data_args=data_args,
                                         max_labels = data_args.max_labels,
                                         is_eval = True,
                                        retrieval_mmseq=data_args.retrieval_mmseq,
                                        retrieval_smiles=data_args.retrieval_smiles)
    return dict(train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=data_collator,
                compute_metrics=go_term_metric,
                preprocess_logits_for_metrics=preprocess_logits_for_metrics
                )

def load_pretrained_mm_adaptor(model: LlavaLlamaForCausalLM, adaptor_dict):
    with torch.no_grad():
        for pn, p in model.model.named_parameters():
            for apn in adaptor_dict:
                if apn == pn:
                    rank0_print(f"Copy parameter: {apn}")
                    p.copy_(adaptor_dict[apn])
    return model

def load_trained_parameter(model: LlavaLlamaForCausalLM, out_dir):
    lora_module = f"{out_dir}/adapter_model.bin"
    non_lora_module = f"{out_dir}/non_lora_trainables.bin"
    # base_model.model.model.layers.31.self_attn.k_proj.lora_B.weight
    # model.load_adapter(out_dir)
    print(f"Lora module path: {lora_module}")
    if pathlib.Path(lora_module).exists():
        tmp_lora_dict = torch.load(lora_module)
    else:
        tmp_lora_dict = {}
    lora_dict = {}
    for pn in tmp_lora_dict:
        lora_dict[f"{pn.split('weight')[0]}default.weight"] = tmp_lora_dict[pn]
    print(lora_dict.keys())
    print(f"Non lora module path: {non_lora_module}")
    if pathlib.Path(non_lora_module).exists():
        non_lora_dict = torch.load(non_lora_module)
    else:
        non_lora_dict = {}
    # print(non_lora_dict.keys())
    with torch.no_grad():
        for pn, p in model.named_parameters():
            for ldn in lora_dict:
                if ldn == pn:
                    print(pn)
                    p.copy_(lora_dict[ldn])
            for nld in non_lora_dict:
                if nld == pn:
                    print(pn)
                    # print(p.shape, non_lora_dict[nld])
                    p.copy_(non_lora_dict[nld])
    return model

def test(trainer, data_args: DataArguments, tokenizer, file_name = 'trainer_test_state.json'):
    # print(trainer.evaluate(data_module['eval_dataset']))
    all_test = {}
    datasets = data_args.test_csv_files.split(',')
    for sub_dataset in datasets:
        dataset_name = sub_dataset.split('/')[1]
        test_dataset = ECNumberDataset(sub_dataset, 
                                        tokenizer=tokenizer, 
                                        data_args=data_args,
                                        max_labels = data_args.max_labels,
                                        is_eval = True,
                                        retrieval_mmseq=data_args.retrieval_mmseq,
                                        retrieval_smiles=data_args.retrieval_smiles)
        all_test[dataset_name.split('.')[0]] = trainer.evaluate(test_dataset)
        for item in all_test[dataset_name.split('.')[0]]:
            print(f"\t{item}: {all_test[dataset_name.split('.')[0]][item]}")
    
    if not trainer.is_world_process_zero():
        return

    path = os.path.join(trainer.args.output_dir, file_name)
    with open(path, 'w') as fout:
        json.dump(all_test, fout, indent=2)
    
    for sub_dataset in all_test:
        print(f"{sub_dataset}:")
        for item in all_test[sub_dataset]:
            print(f"\t{item}: {all_test[sub_dataset][item]}")

def train():
    global local_rank
    


    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    all_args: (ModelArguments, DataArguments, TrainingArguments) = parser.parse_args_into_dataclasses()
    model_args: ModelArguments = all_args[0]
    data_args: DataArguments = all_args[1]
    training_args: TrainingArguments = all_args[2]
    local_rank = training_args.local_rank
    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))

    bnb_model_from_pretrained_args = {}
    if training_args.bits in [4, 8]:
        from transformers import BitsAndBytesConfig
        bnb_model_from_pretrained_args.update(dict(
            device_map={"": training_args.device},
            load_in_4bit=training_args.bits == 4,
            load_in_8bit=training_args.bits == 8,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=training_args.bits == 4,
                load_in_8bit=training_args.bits == 8,
                llm_int8_skip_modules=["mm_projector", "smiles_projector"],
                llm_int8_threshold=6.0,
                llm_int8_has_fp16_weight=False,
                bnb_4bit_compute_dtype=compute_dtype,
                bnb_4bit_use_double_quant=training_args.double_quant,
                bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
            )
        ))
    
    if model_args.protein_tower is not None:
        if 'mpt' in model_args.model_name_or_path:
            config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
            config.attn_config['attn_impl'] = training_args.mpt_attn_impl
            model = LlavaMPTForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                config=config,
                cache_dir=training_args.cache_dir,
                **bnb_model_from_pretrained_args
            )
        else:
            model = LlavaLlamaForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                residual_dropout=model_args.residual_dropout,
                **bnb_model_from_pretrained_args
            )
    else:
        model = transformers.LlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            **bnb_model_from_pretrained_args
        )
    model.config.use_cache = False

    # if model_args.freeze_backbone:
    model.model.requires_grad_(False)

    if training_args.bits in [4, 8]:
        from peft import prepare_model_for_kbit_training
        model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)

    if training_args.gradient_checkpointing:
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    # model.gradient_checkpointing_enable

    if training_args.lora_enable:
        from peft import LoraConfig, get_peft_model
        lora_config = LoraConfig(
            r=training_args.lora_r,
            lora_alpha=training_args.lora_alpha,
            target_modules=find_all_linear_names(model),
            lora_dropout=training_args.lora_dropout,
            bias=training_args.lora_bias,
            task_type="CAUSAL_LM",
        )
        if training_args.bits == 16:
            if training_args.bf16:
                model.to(torch.bfloat16)
            if training_args.fp16:
                model.to(torch.float16)
        rank0_print("Adding LoRA adapters...")
        # model.requires_grad_(True)
        model = get_peft_model(model, lora_config)


    for pn, p in model.named_parameters():
        if "lora" not in pn:
            p.requires_grad_(False)

    if 'mpt' in model_args.model_name_or_path:
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right"
        )
    else:
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
        )

    if model_args.version == "v0":
        if tokenizer.pad_token is None:
            smart_tokenizer_and_embedding_resize(
                special_tokens_dict=dict(pad_token="[PAD]"),
                tokenizer=tokenizer,
                model=model,
            )
    elif model_args.version == "v0.5":
        tokenizer.pad_token = tokenizer.unk_token
    else:
        tokenizer.pad_token = tokenizer.unk_token
        if model_args.version in conversation_lib.conv_templates:
            conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
        else:
            conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]

    if model_args.protein_tower is not None:
        model.get_model().initialize_protein_modules(
            model_args=model_args,
            fsdp=training_args.fsdp
        )
        
        protein_tower = model.get_protein_tower()
        protein_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

        # data_args.image_processor = vision_tower.image_processor
        data_args.is_multimodal = True

        # model.config.image_aspect_ratio = data_args.image_aspect_ratio
        model.config.tokenizer_padding_side = tokenizer.padding_side
        model.config.tokenizer_model_max_length = tokenizer.model_max_length

        model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
        if model_args.tune_mm_mlp_adapter:
            model.requires_grad_(False)
            for p in model.get_model().mm_projector.parameters():
                p.requires_grad = True
            if hasattr(model.get_model(), 'smiles_projector'):
                for p in model.get_model().smiles_projector.parameters():
                    p.requires_grad = True

        model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
        if training_args.freeze_mm_mlp_adapter:
            for p in model.get_model().mm_projector.parameters():
                p.requires_grad = False
            if hasattr(model.get_model(), 'smiles_projector'):
                for p in model.get_model().smiles_projector.parameters():
                    p.requires_grad = False

        if training_args.bits in [4, 8]:
            model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
            if hasattr(model.get_model(), 'smiles_projector'):
                model.get_model().smiles_projector.to(dtype=compute_dtype, device=training_args.device)
                
        model.config.mm_use_prot_start_end = data_args.mm_use_prot_start_end = model_args.mm_use_prot_start_end
        model.config.mm_projector_lr = training_args.mm_projector_lr
        training_args.mm_use_prot_start_end = model_args.mm_use_prot_start_end
        model.config.mm_use_prot_patch_token = model_args.mm_use_prot_patch_token
        model.initialize_protein_tokenizer(model_args, tokenizer=tokenizer)

    if training_args.bits in [4, 8]:
        from peft.tuners.lora import LoraLayer
        for name, module in model.named_modules():
            if isinstance(module, LoraLayer):
                if training_args.bf16:
                    module = module.to(torch.bfloat16)
            if 'norm' in name:
                module = module.to(torch.float32)
            if 'lm_head' in name or 'embed_tokens' in name:
                if hasattr(module, 'weight'):
                    if training_args.bf16 and module.weight.dtype == torch.float32:
                        module = module.to(torch.bfloat16)

    data_module = make_supervised_data_module(tokenizer=tokenizer,
                                              data_args=data_args)
    
    if training_args.is_test:
        model = load_trained_parameter(model, training_args.output_dir)
        trainer = LLaVATrainer(model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    **data_module)
        test(trainer, data_args, tokenizer, 'trainer_test_state_prev.json')
        exit()

    # for i in tqdm(range(training_args.num_train_epochs)):
    if model_args.pretrain_mm_mlp_adapter is not None:
        pretrained_adapter = torch.load(model_args.pretrain_mm_mlp_adapter)
        model = load_pretrained_mm_adaptor(model, pretrained_adapter)
    
    if model_args.tune_norm_layer:
        for mn, m in model.model.named_modules():
            if type(m) is LlamaRMSNorm and 'layers' in mn:
                m.requires_grad_(True)
                

    
    
    trainer = LLaVATrainer(model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    **data_module)

    # model.config.use_cache = True
    # test(trainer, data_args, tokenizer, 'trainer_test_state_prev.json')
    # model.config.use_cache = False
    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()
    
    print("#" * 40)
    test(trainer, data_args, tokenizer, 'trainer_test_state.json')
    # print(trainer.evaluate(data_module['eval_dataset']))
    
    if training_args.lora_enable:
        state_dict = get_peft_state_maybe_zero_3(
            model.named_parameters(), training_args.lora_bias
        )
        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
            model.named_parameters()
        )
        if training_args.local_rank == 0 or training_args.local_rank == -1:
            model.config.save_pretrained(training_args.output_dir)
            model.save_pretrained(training_args.output_dir, state_dict=state_dict)
            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
    else:
        safe_save_model_for_hf_trainer(trainer=trainer,
                                       output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()
