import os
import torch
import torch.nn as nn
import argparse
from typing import Dict, Sequence
from dataclasses import dataclass

import transformers
from transformers import DistilBertModel, DistilBertTokenizer
from torch.utils.data import Dataset


def make_parent_dir(fname):
    os.makedirs(os.path.dirname(fname), 0o775, exist_ok=True)

def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> list:
    """Tokenize a list of strings."""
    tokenized_list = [ 
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )   
        for text in strings
    ]   
    return tokenized_list

def get_token_dict(batch):
    """Input: batch dictionary; Output: text token dict (or category ids)"""
    return {'input_ids':batch['text_token_ids'], 'attention_mask':batch['text_length_mask']}

class BertEncoder(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert_model = bert_model
    
    def forward(self, kwargs):
        embed = self.bert_model(**kwargs).last_hidden_state[:,0,:]
        return embed

    def output_dim(self):
        return self.bert_model.config.hidden_size

def get_args():
    parser = argparse.ArgumentParser(description="Argument parser for text classification pipeline.")
    
    # Add arguments
    parser.add_argument(
        "--save_dir", 
        type=str, 
        help="Directory where the output files will be saved.",
        default='/shared/share_mala/implicitbayes/dataset_files/MIND_data/large/N=1000,D=10000,D_eval=10000,method=0111_logistic_withZ,dimX=5,one_X_per_col=False,flip/',
    )
    parser.add_argument(
        "--save_name", 
        type=str, 
        required=True, 
        help="Name of the output (no extension, not a directory)"
    )
    parser.add_argument(
        "--input_file", 
        type=str, 
        default='/shared/share_mala/implicitbayes/dataset_files/MIND_data/large/N=1000,D=10000,D_eval=10000,method=0111_logistic_withZ,dimX=5,one_X_per_col=False,flip/train_data.pt',
        help="Path to the input file containing the data to process."
    )
    parser.add_argument(
        "--device",
        type=int,
        default=-1,
    )    
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
    )
    return parser.parse_args()

@dataclass
class TextDataCollator(object):
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        text_token_ids = [instance['text_token_ids'] for instance in instances]
        text_token_ids = torch.nn.utils.rnn.pad_sequence(
            text_token_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )   
        return dict(
            text_token_ids = text_token_ids,
            text_length_mask = text_token_ids.ne(self.tokenizer.pad_token_id)
        )   

class TextDataset(Dataset):
    def __init__(self, strings, tokenizer: transformers.PreTrainedTokenizer):
        self.strings = strings
        self.tokenizer = tokenizer
        token_list = _tokenize_fn(strings, tokenizer)
        self.text_token_ids = [x['input_ids'][0] for x in token_list]
    
    def __len__(self):
        return len(self.strings)


    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(text_token_ids = self.text_token_ids[i])


    def make_loader(self, batch_size):
        collate_fn = TextDataCollator(self.tokenizer)
        dl = torch.utils.data.DataLoader(self,
               batch_size=batch_size,
               collate_fn=collate_fn, shuffle=False)
        return dl

if __name__ == "__main__":
    # Parse the arguments
    args = get_args()

    # load existing dataset
    check = torch.load(args.input_file)
    raw_text = check['Z']

    # get model etc
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    bert = DistilBertModel.from_pretrained("distilbert-base-uncased").to(args.device)

    data = TextDataset(raw_text, tokenizer)
    loader = data.make_loader(batch_size=args.batch_size)
    bertenc = BertEncoder(bert)
    feats = []
    for i,batch in enumerate(loader):
        print(f'{i} of {len(data) // args.batch_size}')
        for k,v in batch.items():
            batch[k] = v.to(args.device)
        feats.append(bertenc(get_token_dict(batch)).detach().cpu())
    allfeats = torch.cat(feats)
    check['Z'] = allfeats
    make_parent_dir(args.save_dir)
    dest = args.save_dir + '/' + args.save_name + '.pt'
    torch.save(check, dest)
    print(f'saved to {dest}')
