from functools import partial

import numpy as np
import evaluate
from trl import SFTTrainer
from peft import LoraConfig
from transformers import DataCollatorWithPadding
from helper.thirdparty.tofu.dataloader import CustomTrainer


def get_compute_metrics(task_type):
    if task_type == "classification":
        return classification_compute_metrics
    elif task_type == "summarization":
        return summarization_compute_metrics
    else:
        raise NotImplementedError

def classification_compute_metrics(p):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.argmax(preds, -1)
    print("Predictions:", preds)
    metrics = evaluate.load('accuracy')
    res = metrics.compute(predictions=preds, references=p.label_ids) 
    for k in res:
        res[k] = res[k] * 100.0
    return res 

def summarization_compute_metrics(eval_pred, tokenizer):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    predictions = np.argmax(predictions, -1)
    decoded_preds = tokenizer.batch_decode(predictions,
                                           skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, 
                                            skip_special_tokens=True)
    metrics = evaluate.load('rouge')
    res = metrics.compute(predictions=decoded_preds, references=decoded_labels)
    return res

def summarization_prompt_formatter(sample):
    return f"""<s>### Instruction:
    You are a helpful, respectful and honest assistant. \
    Your task is to classify the following dialogue. \
    Your answer should be based on the provided dialogue only.

    ### Dialogue:
    {sample["dialogue"]}

    ### Summary:
    {sample["summary"]} </s>"""

def get_prompt_formatter(task_type):
    if task_type == "summarization":
        return summarization_prompt_formatter
    else:
        raise NotImplementedError

def get_peft_config(config):
    assert config["type"] == "lora"     # only support LoRA at the moment
    config.pop("type")
    return LoraConfig(**config)

def get_sft_trainer(model, 
                    task_type,
                    train_dataset, 
                    test_dataset,
                    peft_config,
                    tokenizer,
                    max_seq_length,
                    args,
                    ):
    
    compute_metrics = get_compute_metrics(task_type)
    if task_type == "classification":
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer)    # special handling of "label"
        trainer = SFTTrainer(
            model=model,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            peft_config=peft_config,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            dataset_text_field="text",
            data_collator=data_collator,
            compute_metrics=compute_metrics,
            args=args,
        )
    else:   # trainer for generative models
        compute_metrics = partial(compute_metrics, tokenizer=tokenizer)
        prompt_formatter = get_prompt_formatter(task_type)
        trainer = SFTTrainer(
            model=model,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            peft_config=peft_config,
            max_seq_length=max_seq_length,
            tokenizer=tokenizer,
            packing=True,
            formatting_func=prompt_formatter, 
            compute_metrics=compute_metrics,
            args=args,
        )
    
    return trainer 

def get_tofu_trainer(model,
                    train_dataset,
                    args,
                    data_collator,
                    ):
    
    trainer = CustomTrainer(
            model=model,
            train_dataset=train_dataset,
            eval_dataset=None,
            args=args,
            data_collator=data_collator,
        )
    
    return trainer 
