from typing import Dict, Optional, Sequence
import transformers
import copy
import torch
from pytorch_lightning.strategies.fsdp import FSDPStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
import os


ALPACA_IGNORE_INDEX = -100
ALPACA_DEFAULT_PAD_TOKEN = "[PAD]"
ALPACA_DEFAULT_EOS_TOKEN = "</s>"
ALPACA_DEFAULT_BOS_TOKEN = "<s>"
ALPACA_DEFAULT_UNK_TOKEN = "<unk>"



def save_model_checkpoint(fabric, strategy, model, tokenizer, file_path):
    """Handles boilerplate logic for retrieving and saving the state_dict.
    This will be upstreamed to Fabric soon.
    """
    if not os.path.exists(file_path):
        os.makedirs(file_path, exist_ok=True)

    if isinstance(strategy, FSDPStrategy):
        save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
            state_dict = model._forward_module.state_dict()
    else:
        if fabric.global_rank == 0:
            model.model.save_pretrained(file_path)
            tokenizer.save_pretrained(file_path)
            
    fabric.barrier()


def load_model_checkpoint():
    pass

    
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

def alpaca_preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    targets = [f"{t} {ALPACA_DEFAULT_EOS_TOKEN}" for t in targets]
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = ALPACA_IGNORE_INDEX
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=ALPACA_IGNORE_INDEX)
    return dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(tokenizer.pad_token_id))


