# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
from datetime import datetime
from dataclasses import dataclass
import wandb
import math
import argparse
from memgpt.trl.utils.utils_metrics import set_tokenizer, dataset_stats, set_wandb, set_dataset_name, set_is_cleaned_dataset, convert_th_config_to_name, compute_metrics, preprocess_logits_for_metrics
from memgpt.trl.utils.utils_filter import clean_high_loss_triplets, filter_length, convert_to_raw_dataset, convert_to_special_db_tokens_format
from memgpt.trl.utils.load_sft_dataset import load_trainset, prepare_pretrain_data
from memgpt.trl.utils.load_model import load_lora_model, load_tiny_llama2_tokenizer
from memgpt.constants import DATA_DIR
from trl import ModelConfig, SFTConfig, SFTTrainer, TrlParser, get_peft_config, ScriptArguments


# Split your dataset into smaller subsets and evaluate each separately
def evaluate_in_batches(model, tokenizer, full_dataset, batch_size=20, trainer_args=None):
    total_samples = 0
    weighted_metrics = {}
    
    num_samples = len(full_dataset)
    num_batches = math.ceil(num_samples / batch_size)
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, num_samples)
        current_size = end_idx - start_idx
        
        # print(f"Evaluating batch {i+1}/{num_batches} (samples {start_idx} to {end_idx-1})")
        
        # Create a subset for this batch
        batch_subset = full_dataset.select(range(start_idx, end_idx))
        
        # Create a trainer for this batch
        trainer = SFTTrainer(
            model=model,
            args=trainer_args,
            eval_dataset=batch_subset,
            processing_class=tokenizer,
            peft_config=get_peft_config(model_args),
            compute_metrics=compute_metrics,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )
        
        # Evaluate this batch
        batch_results = trainer.evaluate()
        
        if i == 0:
            weighted_metrics = {metric: 0 for metric in batch_results}
        # Update weighted average for all metrics
        for metric in batch_results:
            weighted_metrics[metric] += batch_results[metric] * current_size

        total_samples += current_size
    
    for metric in weighted_metrics:
        weighted_metrics[metric] /= total_samples
    
    final_results = {
        **weighted_metrics
    }
    
    return final_results


def main(script_args, training_args, model_args, eval_args):
    """
    Main evaluation loop.
    - Load model
    - Process dataset
    - Perform evaluation based on strategy
    """
    # Initialize wandb and model
    set_wandb()
    tokenizer_only = eval_args.eval_dataset_strategy in ["stats", "no"]
    model, tokenizer = load_lora_model(model_args, tokenizer_only=tokenizer_only)

    ## test
    if eval_args.use_llama2_tokenizer:  
        if eval_args.add_special_tokens:
            print(f"eval the dataset with special dblookup tokens")
        tokenizer = load_tiny_llama2_tokenizer(add_special_tokens=eval_args.add_special_tokens)
        print("loaded llama2 tokenizer")

    # Load and process dataset
    # train_dataset = load_trainset(script_args)
    train_dataset, eval_datasets = prepare_pretrain_data(script_args, use_special_dblookup_tokens=eval_args.add_special_tokens, is_plain_baseline=eval_args.eval_raw_dataset)
    dataset_name = script_args.dataset_name.split("/")[-1].split(".json")[0]

    if eval_args.eval_raw_dataset:
        train_dataset = convert_to_raw_dataset(train_dataset)
        dataset_name = f"{dataset_name}_raw"

    if eval_args.add_special_tokens:
        train_dataset = convert_to_special_db_tokens_format(train_dataset)
        dataset_name = f"{dataset_name}_special_db"

    # Dataset statistics visualization if required
    if eval_args.eval_dataset_strategy in ["stats", "both"]:
        dataset_stats(dataset_name=dataset_name, dataset=train_dataset, tokenizer=tokenizer, visualize=True)

    if eval_args.enable_length_filter:
        train_dataset = train_dataset.filter(filter_length)
        print(f"==== Filtered dataset based on length ====")
        print(f"after filtering: {len(train_dataset)}")

    if eval_args.clean_dataset:
        subset = train_dataset.shuffle(seed=42)
    else:
        # Prepare subset for perplexity evaluation
        subset = train_dataset.shuffle(seed=42).select(range(min(100, len(train_dataset))))

    # Evaluation: Perplexity or both strategies
    if eval_args.eval_dataset_strategy in ["perplexity", "both"]:
        set_tokenizer(tokenizer)
        set_dataset_name(dataset_name)
        set_is_cleaned_dataset(eval_args.clean_dataset)
        
        # Run batch-wise evaluation
        eval_results = evaluate_in_batches(
            model=model,
            tokenizer=tokenizer,
            full_dataset=subset,
            batch_size=50,  # Adjust based on your GPU memory
            trainer_args=training_args
        )
        
        # print(json.dumps(eval_results, indent=4))
        
        if wandb.run:
            wandb.log(eval_results)

    # Clean dataset if requested
    if eval_args.clean_dataset:
        cleaned_dataset = clean_high_loss_triplets(subset, triplets_save_path=prepare_triplets_save_path(eval_args, dataset_name))
        save_cleaned_dataset(cleaned_dataset, dataset_name, model_args.model_name_or_path)  
        
        if eval_args.eval_dataset_strategy in ["stats", "both"]:
            dataset_stats(dataset_name=dataset_name+"_cleaned", dataset=cleaned_dataset, tokenizer=tokenizer, visualize=False)

def prepare_triplets_save_path(eval_args, dataset_name):
    """
    Prepares the save path for filtered triplets.
    """
    save_dir = getattr(eval_args, "save_dir", f"./output/dataset/high_loss_triplets")
    os.makedirs(save_dir, exist_ok=True)
    save_th_name = convert_th_config_to_name()
    return os.path.join(save_dir, f"{dataset_name}_{save_th_name}.json")

def save_cleaned_dataset(cleaned_dataset, dataset_name, model_name):
    """
    Saves the cleaned dataset to the specified directory.
    """
    save_path = os.path.join(DATA_DIR, "cleaned", f"{dataset_name}_cleaned.json")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    metadata = {
        "source_datasets": [dataset_name],
        "corrector_model": [model_name],
        "correction_th": convert_th_config_to_name(),
        "last_modified": datetime.now().isoformat()
    }

    with open(save_path, 'w') as f:
        json.dump({
            'examples': cleaned_dataset.to_list(),
            'metadata': metadata
        }, f, indent=4)

    print(f"==== Saved cleaned dataset to {save_path} ====")


@dataclass
class EvaluationConfig:
    eval_dataset_strategy: str = "no"  # choices=["no", "perplexity", "stats", "both"], default="no"
    save_dir: str = None
    clean_dataset: bool = False
    enable_length_filter: bool = False
    use_llama2_tokenizer: bool = False
    add_special_tokens: bool = False
    eval_raw_dataset: bool = False


def make_parser(subparsers: argparse._SubParsersAction = None):
    dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, EvaluationConfig)
    parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) if subparsers else TrlParser(dataclass_types)
    return parser


if __name__ == "__main__":
    parser = make_parser()
    script_args, training_args, model_args, eval_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args, eval_args)
