# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import inspect
from dataclasses import asdict

import torch.distributed as dist
from torch.utils.data import DistributedSampler
from peft import (
    LoraConfig,
    AdaptionPromptConfig,
    PrefixTuningConfig,
)
from transformers import default_data_collator
from transformers.data import DataCollatorForSeq2Seq

from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
from llama_recipes.utils.dataset_utils import DATASET_PREPROC
from llama_recipes.utils.train_utils import DataCollatorReward

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizerBase
from typing import Any, Optional, Union
from transformers.tokenization_utils_base import PaddingStrategy

class DataCollatorForContrastive3QA:
    def __init__(self, tokenizer: PreTrainedTokenizerBase, padding: Union[bool, str, PaddingStrategy] = True, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, label_pad_token_id: int = -100, return_tensors: str = "pt"):
        self.tokenizer = tokenizer
        self.padding = padding
        self.max_length = max_length
        self.pad_to_multiple_of = pad_to_multiple_of
        self.label_pad_token_id = label_pad_token_id
        self.return_tensors = return_tensors

    def __call__(self, features):
        input_ids = [torch.tensor(item['input_ids']) for item in features]
        labels = [torch.tensor(item['labels']) for item in features]
        attention_mask = [torch.tensor(item['attention_mask']) for item in features]
        input_ids_pos_1 = [torch.tensor(item['input_ids_pos_1']) for item in features]
        input_ids_neg_1 = [torch.tensor(item['input_ids_neg_1']) for item in features]
        attention_mask_pos_1 = [torch.tensor(item['attention_mask_pos_1']) for item in features]
        attention_mask_neg_1 = [torch.tensor(item['attention_mask_neg_1']) for item in features]
        input_ids_pos_2 = [torch.tensor(item['input_ids_pos_2']) for item in features]
        input_ids_neg_2 = [torch.tensor(item['input_ids_neg_2']) for item in features]
        attention_mask_pos_2 = [torch.tensor(item['attention_mask_pos_2']) for item in features]
        attention_mask_neg_2 = [torch.tensor(item['attention_mask_neg_2']) for item in features]
        acc = [torch.tensor(item['acc']) for item in features]

        # Pad sequences
        input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
        labels_padded = pad_sequence(labels, batch_first=True, padding_value=self.label_pad_token_id)
        attention_mask_padded = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        input_ids_pos_1_padded = pad_sequence(input_ids_pos_1, batch_first=True, padding_value=0)
        input_ids_neg_1_padded = pad_sequence(input_ids_neg_1, batch_first=True, padding_value=0)
        attention_mask_pos_1_padded = pad_sequence(attention_mask_pos_1, batch_first=True, padding_value=0)
        attention_mask_neg_1_padded = pad_sequence(attention_mask_neg_1, batch_first=True, padding_value=0)
        input_ids_pos_2_padded = pad_sequence(input_ids_pos_2, batch_first=True, padding_value=0)
        input_ids_neg_2_padded = pad_sequence(input_ids_neg_2, batch_first=True, padding_value=0)
        attention_mask_pos_2_padded = pad_sequence(attention_mask_pos_2, batch_first=True, padding_value=0)
        attention_mask_neg_2_padded = pad_sequence(attention_mask_neg_2, batch_first=True, padding_value=0)
        acc_padded = pad_sequence(acc, batch_first=True, padding_value=0)

        return {
            "input_ids": input_ids_padded,
            "labels": labels_padded,
            "attention_mask": attention_mask_padded,
            "input_ids_pos_1": input_ids_pos_1_padded,
            "input_ids_neg_1": input_ids_neg_1_padded,
            "attention_mask_pos_1": attention_mask_pos_1_padded,
            "attention_mask_neg_1": attention_mask_neg_1_padded,
            "input_ids_pos_2": input_ids_pos_2_padded,
            "input_ids_neg_2": input_ids_neg_2_padded,
            "attention_mask_pos_2": attention_mask_pos_2_padded,
            "attention_mask_neg_2": attention_mask_neg_2_padded,
            "acc": acc_padded
        }

def update_config(config, **kwargs):
    if isinstance(config, (tuple, list)):
        for c in config:
            update_config(c, **kwargs)
    else:
        for k, v in kwargs.items():
            if hasattr(config, k):
                setattr(config, k, v)
            elif "." in k:
                # allow --some_config.some_param=True
                config_name, param_name = k.split(".")
                if type(config).__name__ == config_name:
                    if hasattr(config, param_name):
                        setattr(config, param_name, v)
                    else:
                        # In case of specialized config we can warm user
                        print(f"Warning: {config_name} does not accept parameter: {k}")
            elif isinstance(config, train_config):
                print(f"Warning: unknown parameter {k}")


def generate_peft_config(train_config, kwargs):
    configs = (lora_config, llama_adapter_config, prefix_config)
    peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
    names = tuple(c.__name__.rstrip("_config") for c in configs)

    assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"

    config = configs[names.index(train_config.peft_method)]()

    update_config(config, **kwargs)
    params = asdict(config)
    peft_config = peft_configs[names.index(train_config.peft_method)](**params)

    return peft_config


def generate_dataset_config(train_config, kwargs):
    names = tuple(DATASET_PREPROC.keys())

    assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"

    dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()

    update_config(dataset_config, **kwargs)

    return  dataset_config


def get_dataloader_kwargs(train_config, dataset, tokenizer, mode, contrastive=0):
        if train_config.train_ppo_reward_model:
            assert train_config.batching_strategy == "padding", "PPO training only supports padding batching strategy"

        kwargs = {}
        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
        if train_config.batching_strategy == "padding":
            if train_config.enable_fsdp:
                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
                    dataset,
                    batch_size=batch_size,
                    rank=dist.get_rank(),
                    num_replicas=dist.get_world_size(),
                    shuffle=mode=="train",
                )
            else:
                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
            
            if contrastive == 3:
                kwargs["collate_fn"] = DataCollatorForContrastive3QA(tokenizer)
            else:
                if train_config.train_ppo_reward_model:
                    kwargs["collate_fn"] = DataCollatorReward(tokenizer)
                else:
                    kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
            
        elif train_config.batching_strategy == "packing":
            if train_config.enable_fsdp:
                kwargs["sampler"] = DistributedSampler(
                dataset,
                rank=dist.get_rank(),
                num_replicas=dist.get_world_size(),
                shuffle=mode=="train",
            )
            kwargs["batch_size"] = batch_size
            kwargs["drop_last"] = True
            kwargs["collate_fn"] = default_data_collator
        else:
            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")

        return kwargs
