import os
from pathlib import Path
from copy import deepcopy
from itertools import chain

import pickle5 as pickle
from tqdm import tqdm

from hydra.utils import get_original_cwd

import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl
from transformers import AutoTokenizer

from src.utils.data import data_keys, dataset_info


class DataModule(pl.LightningDataModule):

    def __init__(
            self, dataset: str, src_dataset: str, data_path: str, mode: str,
            arch: str, model_max_length: int, save_dir: str,
            train_batch_size: int = 1, eval_batch_size: int = 1, eff_train_batch_size: int = 1,
            num_workers: int = 0, num_train: int = None, num_dev: int = None, num_test: int = None,
            num_train_seed: int = None, num_dev_seed: int = None, num_test_seed: int = None,
            io_mode: str = None, aux_io_mode: str = None, rationale_src: str = None,
        ):
        super().__init__()

        self.dataset = dataset
        self.data_path = data_path # ${data_dir}/${.dataset}/${model.arch}/
        self.arch = arch
        self.save_dir = save_dir
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.eff_train_batch_size = eff_train_batch_size
        self.num_workers = num_workers
        self.io_mode = io_mode
        self.aux_io_mode = aux_io_mode
        self.rationale_src = rationale_src

        assert arch in ['t5-small', 't5-base', 't5-large', 't5-3b']
        assert io_mode in ['I-O', 'IR-O', 'I-OR', 'I-RO', 'IshuffledR-O', 'IreplacedR-O']
        assert aux_io_mode in [None, 'IR-O', 'IshuffledR-O', 'IreplacedR-O']

        self.tokenizer = AutoTokenizer.from_pretrained(arch, model_max_length=model_max_length)
        self.num_samples = {'train': num_train, 'dev': num_dev, 'test': num_test}
        self.num_samples_seed = {'train': num_train_seed, 'dev': num_dev_seed, 'test': num_test_seed}

    def update_dataset(self, dataset, key, data_path, split, new_key=None):
        if self.num_samples[split] is not None:
            filename = f'{key}_{self.num_samples[split]}_{self.num_samples_seed[split]}.pkl'
        else:
            filename = f'{key}.pkl'
        
        with open(os.path.join(data_path, filename), 'rb') as f:
            key = new_key if new_key is not None else key
            dataset[key] = pickle.load(f)
        
        return dataset

    def load_dataset(self, split):
        dataset = {}
        
        data_path = os.path.join(self.data_path, split, self.io_mode)
        if self.rationale_src not in [None, 'gold']:
            data_path = os.path.join(data_path, self.rationale_src)

        assert Path(data_path).exists()
        for key in tqdm(data_keys[self.arch], desc=f'Loading {split} set'):
            dataset = self.update_dataset(dataset, key, data_path, split)

        if self.aux_io_mode is not None:
            aux_data_path = os.path.join(self.data_path, split, self.aux_io_mode)
            assert Path(aux_data_path).exists()
            dataset = self.update_dataset(dataset, 'example', aux_data_path, split, 'aux_example')
            dataset = self.update_dataset(dataset, 'token_type', aux_data_path, split, 'aux_token_type')

        return dataset

    def setup(self, splits=['all']):
        self.dataset_dict = {}
        splits = ['train', 'dev', 'test'] if splits == ['all'] else splits
        for split in splits:
            self.dataset_dict[split] = TextClassificationDataset(
                data=self.load_dataset(split), split=split,
                classes=dataset_info[self.dataset]['classes'], 
                delimiters=dataset_info[self.dataset]['delimiters'],
                io_mode=self.io_mode, aux_io_mode=self.aux_io_mode,
                tokenizer=self.tokenizer,
            )

    def train_dataloader(self):
        return DataLoader(
            self.dataset_dict['train'],
            batch_size=self.train_batch_size,
            num_workers=self.num_workers,
            collate_fn=self.dataset_dict['train'].collater,
            pin_memory=True
        )

    def val_dataloader(self, test=False):
        if test:
            return DataLoader(
                self.dataset_dict['dev'],
                batch_size=self.eval_batch_size,
                num_workers=self.num_workers,
                collate_fn=self.dataset_dict['dev'].collater,
                pin_memory=True
            )

        return [
            DataLoader(
            self.dataset_dict[eval_split],
            batch_size=self.eval_batch_size,
            num_workers=self.num_workers,
            collate_fn=self.dataset_dict[eval_split].collater,
            pin_memory=True)
            
            for eval_split in ['dev', 'test']
        ]

    def test_dataloader(self):
        return DataLoader(
            self.dataset_dict['test'],
            batch_size=self.eval_batch_size,
            num_workers=self.num_workers,
            collate_fn=self.dataset_dict['test'].collater,
            pin_memory=True
        )


class TextClassificationDataset(Dataset):
    def __init__(self, data, split, classes, delimiters, io_mode, aux_io_mode, tokenizer):
        self.data = data
        self.split = split
        self.classes = classes
        self.delimiters = delimiters
        self.io_mode = io_mode
        self.aux_io_mode = aux_io_mode
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data['item_idx'])

    def __getitem__(self, idx):
        item_idx = torch.LongTensor([self.data['item_idx'][idx]])
        example = torch.LongTensor(self.data['example'][idx])
        example_attn_mask = torch.ones(len(example)).long()
        token_type = torch.LongTensor(self.data['token_type'][idx])

        if self.io_mode in ['I-OR', 'I-RO']:
            label = self.data['label'][idx]
            target_seq = torch.LongTensor(self.data['target_seq'][idx])
        else:
            label = torch.LongTensor([self.data['label'][idx]])
            target_seq = [torch.LongTensor(x) for x in self.data['target_seq'][idx]]

        item = {
            'item_idx': item_idx,
            'example': example,
            'example_attn_mask': example_attn_mask,
            'token_type': token_type,
            'label': label,
            'target_seq': target_seq,
        }

        if self.aux_io_mode is not None:
            item['aux_example'] = torch.LongTensor(self.data['aux_example'][idx])
            item['aux_example_attn_mask'] = torch.ones(len(item['aux_example'])).long()
            item['aux_token_type'] = torch.LongTensor(self.data['aux_token_type'][idx])

        return item

    def collater(self, items):
        item_idx = torch.cat([x['item_idx'] for x in items])
        example = pad_sequence([x['example'] for x in items], batch_first=True, padding_value=self.tokenizer.pad_token_id)
        example_attn_mask = pad_sequence([x['example_attn_mask'] for x in items], batch_first=True, padding_value=0)
        token_type = pad_sequence([x['token_type'] for x in items], batch_first=True, padding_value=0)

        if self.io_mode in ['I-OR', 'I-RO']:
            label = [x['label'] for x in items]
            target_seq = pad_sequence([x['target_seq'] for x in items], batch_first=True, padding_value=self.tokenizer.pad_token_id)
        else:
            label = torch.cat([x['label'] for x in items])
            target_seq = list(chain.from_iterable([x['target_seq'] for x in items]))
            target_seq = pad_sequence(target_seq, batch_first=True, padding_value=self.tokenizer.pad_token_id)
            target_seq = target_seq.reshape(len(items), len(self.classes), -1)
        target_seq[target_seq == self.tokenizer.pad_token_id] = -100

        batch = {
            'item_idx': item_idx,
            'example': example,
            'example_attn_mask': example_attn_mask,
            'token_type': token_type,
            'label': label,
            'target_seq': target_seq,
            'split': self.split, # when evaluate_ckpt=true, split always test
        }

        if self.aux_io_mode is not None:
            batch['aux_example'] = pad_sequence([x['aux_example'] for x in items], batch_first=True, padding_value=self.tokenizer.pad_token_id)
            batch['aux_example_attn_mask'] = pad_sequence([x['aux_example_attn_mask'] for x in items], batch_first=True, padding_value=0)
            batch['aux_token_type'] = pad_sequence([x['aux_token_type'] for x in items], batch_first=True, padding_value=0)

        return batch