# Copyright 2024 Llamole Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Optional, Dict, Any

from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer
from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX
from ..extras.misc import get_logits_processor
from ..extras.ploting import plot_loss
from ..model import load_tokenizer
from ..hparams import get_infer_args, get_train_args
from ..model import GraphLLMForCausalMLM
from .dataset import MolQADataset

import re
import os
import json
import math
import torch
from torch.utils.data import DataLoader

if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments

    from ..hparams import (
        DataArguments,
        FinetuningArguments,
        GeneratingArguments,
        ModelArguments,
    )

def remove_extra_spaces(text):
    # Replace multiple spaces with a single space
    cleaned_text = re.sub(r'\s+', ' ', text)
    # Strip leading and trailing spaces
    return cleaned_text.strip()

def run_eval(args: Optional[Dict[str, Any]] = None) -> None:
    model_args, data_args, training_args, finetuning_args, generating_args = (
        get_train_args(args)
    )

    if data_args.dataset in ["molqa", "molqa_drug", "molqa_material"]:
        run_molqa(
            model_args, data_args, training_args, finetuning_args, generating_args
        )
    else:
        raise ValueError("Unknown dataset: {}.".format(data_args.dataset))


def run_molqa(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    finetuning_args: "FinetuningArguments",
    generating_args: "GeneratingArguments",
):
    tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"]

    data_info_path = os.path.join(data_args.dataset_dir, "dataset_info.json")
    with open(data_info_path, "r") as f:
        dataset_info = json.load(f)

    tokenizer.pad_token = tokenizer.eos_token
    dataset_name = data_args.dataset.strip()
    try:
        filename = dataset_info[dataset_name]["file_name"]
    except KeyError:
        raise ValueError(f"Dataset {dataset_name} not found in dataset_info.json")
    data_path = os.path.join(data_args.dataset_dir, f"{filename}")
    with open(data_path, "r") as f:
        original_data = json.load(f)

    # Create dataset and dataloader
    dataset = MolQADataset(original_data, tokenizer, data_args.cutoff_len)
    dataloader = DataLoader(
        dataset, batch_size=training_args.per_device_eval_batch_size, shuffle=False
    )

    gen_kwargs = generating_args.to_dict()
    gen_kwargs["eos_token_id"] = [
        tokenizer.eos_token_id
    ] + tokenizer.additional_special_tokens_ids
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
    gen_kwargs["logits_processor"] = get_logits_processor()

    model = GraphLLMForCausalMLM.from_pretrained(
        tokenizer, model_args, data_args, training_args, finetuning_args
    )

    all_results = []
    property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"]

    # Phase 1: Molecular Design
    global_idx = 0
    all_smiles = []
    for batch_idx, batch in enumerate(dataloader):
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        property_data = batch["property"].to(model.device)
        model.eval()
        with torch.no_grad():
            all_info_dict = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                molecule_properties=property_data,
                do_molecular_design=True,
                do_retrosynthesis=False,
                rollback=True,
                **gen_kwargs,
            )

            batch_results = []
            for i in range(len(all_info_dict["smiles_list"])):
                original_data_idx = global_idx + i
                original_item = original_data[original_data_idx]

                llm_response = "".join(item for item in all_info_dict["text_lists"][i])
                result = {
                    "qa_idx": original_data_idx,
                    "instruction": original_item["instruction"],
                    "input": original_item["input"],
                    "llm_response": llm_response,
                    "response_design": remove_extra_spaces(llm_response),
                    "llm_smiles": all_info_dict["smiles_list"][i],
                    "property": {},
                }

                # Add non-NaN property values
                for j, prop_name in enumerate(property_names):
                    prop_value = property_data[i][j].item()
                    if not math.isnan(prop_value):
                        result["property"][prop_name] = prop_value

                batch_results.append(result)

            all_results.extend(batch_results)
            all_smiles.extend([result['llm_smiles'] for result in batch_results])
            global_idx += len(batch_results)
    
    # Phase 2: Retrosynthesis
    retro_batch_start = 0
    for batch_idx, batch in enumerate(dataloader):
        
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        batch_size = input_ids.shape[0]
        batch_smiles = all_smiles[retro_batch_start : retro_batch_start + batch_size]

        model.eval()
        with torch.no_grad():
            all_info_dict = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                do_molecular_design=False,
                do_retrosynthesis=True,
                input_smiles_list=batch_smiles,
                expansion_topk=50,
                iterations=100,
                max_planning_time=30,
                **gen_kwargs,
            )

            batch_results = []
            for i in range(batch_size):
                result = all_results[retro_batch_start + i]
                retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]]
                result["llm_reactions"] = []
                if retro_plan["success"]:
                    for reaction, template, cost in zip(
                        retro_plan["reaction_list"],
                        retro_plan["templates"],
                        retro_plan["cost"],
                    ):
                        result["llm_reactions"].append(
                            {"reaction": reaction, "template": template, "cost": cost}
                        )

                # new_text = "".join(item for item in all_info_dict["text_lists"][i])
                if None in all_info_dict["text_lists"][i]:
                    print(f"List contains None: {all_info_dict['text_lists'][i]}")
                    new_text = "".join(item for item in all_info_dict["text_lists"][i] if item is not None)
                else:
                    new_text = "".join(item for item in all_info_dict["text_lists"][i])
                    
                result["llm_response"] += new_text
                result["llm_response"] = remove_extra_spaces(result["llm_response"])
                result["response_retro"] = remove_extra_spaces(new_text)
                batch_results.append(result)

            retro_batch_start += batch_size
    
    print('all_results', all_results)
    print("\nSummary of results:")
    print_len = min(5, len(all_results))
    for result in all_results[:print_len]:
        print(f"\nData point {result['qa_idx']}:")
        print(f"  Instruction: {result['instruction']}")
        print(f"  Input: {result['input']}")
        print(f"  LLM Response: {result['llm_response']}")
        print(f"  LLM SMILES: {result['llm_smiles']}")
        print(f"  Number of reactions: {len(result['llm_reactions'])}")
        for prop_name, prop_value in result["property"].items():
            print(f"  {prop_name}: {prop_value}")

    print("\nAll data processed successfully.")