import argparse
import torch.nn as nn
import pickle as pk
import sys
import os
import transformers
import torch
from dataclasses import dataclass
sys.path.append('./nlp_api')

from nlp_api.all import get_model_datasets, parse_args
from torch.utils.data import random_split, DataLoader

from trainers import fedtrainer
from transformers.utils import PaddingStrategy
from transformers import PreTrainedTokenizerBase, DataCollatorForTokenClassification
from typing import Optional, Union, List, Dict, Any

IGNORE_INDEX = -100
@dataclass
class LLMDataCollator(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances):
        if len(instances[0].keys()) == 2:
            input_ids, labels = tuple([instance[key] for instance in instances]
                                    for key in ("input_ids", "labels"))
        else:
            # input_ids, labels = tuple([instance[key] for instance in instances]
            #                         for key in ("input_ids", "labels", "option_len", "num_options"))
            raise NotImplementedError
        input_ids = [torch.tensor(_) for _ in input_ids]
        labels = [torch.tensor(_) for _ in labels]
        
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


@dataclass
class DataCollatorWithPaddingAndNesting:
    """
    Collator for training
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        features = [ff for f in features for ff in f]
        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        if "label" in batch:
            batch["labels"] = batch["label"]
            del batch["label"]
        if "label_ids" in batch:
            batch["labels"] = batch["label_ids"]
            del batch["label_ids"]
        return batch


if __name__ == '__main__':
    
    args = parse_args()
    
    if os.path.exists(args.output_dir):
        print(f'logs saving to {args.output_dir}.')
    else:
        print(f'{args.output_dir} does not exists, creating..')
        os.mkdir(args.output_dir)
        print(f'logs saving to {args.output_dir}.')
    
    model, tokenizer, dstrain, dsvalid = get_model_datasets(args)
    
    len_subset = len(dstrain[0]) // args.K
    fed_train_ds = random_split(dstrain[0], [len_subset] * args.K)
    print('length of local subsets: %d' % len_subset)
    
    
    # collator = LLMDataCollator(tokenizer)
    collator = DataCollatorWithPaddingAndNesting(tokenizer, pad_to_multiple_of=8) if args.train_as_classification else DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
    
    dltrain = [
        DataLoader(
            _, 
            batch_size=args.batch_size,
            collate_fn=collator,
            num_workers=args.num_workers) for _ in fed_train_ds
    ]
    dlvalid = DataLoader(
        dsvalid[0],
        batch_size=args.batch_size,
        collate_fn=collator,
        num_workers=args.num_workers,
    )
    
    this_trainer = fedtrainer(
        mu=args.mu,
        eta=args.eta,
        gamma=args.gamma,
        optimizer_name=args.optimizer_name,
        weight_decay=args.weight_decay,
        k=args.k,
        K=args.K,
        model=model,
        dltrain=dltrain,
        dlvalid=dlvalid,
        loss_func=nn.CrossEntropyLoss(),
        cuda_devices=args.cuda_devices,
        onebit=args.onebit,
        nlp=True,
    )
    
    pk.dump(vars(args), open(os.path.join(f'{args.output_dir}', f'{args.comment}_args.pk'), 'wb'))
    this_trainer.logger_init()
    for epoch in range(args.n_epoch):
        if args.binary:
            this_trainer.epoch_valid(epoch)
            this_trainer.zo_epoch_train_binary(epoch)
        else:
            if epoch % 20 == 0:
                this_trainer.epoch_valid(epoch)
            this_trainer.zo_epoch_train_baseline_shorter(epoch)
            
        df = this_trainer.logger_summary()
        pk.dump(df, open(os.path.join(f'{args.output_dir}', f'{args.comment}_log.pk'), 'wb'))
    
    