import os
import json
import torch
import random
import math
import numpy as np
from pathlib import Path
from copy import deepcopy
from dataclasses import dataclass, field
from itertools import islice, count
from typing import Iterable

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, IterableDataset, ConcatDataset
from typing import Any, Tuple, Dict, Optional, List, Union, Type

import hexa.tasks.constants as CONST
from hexa.utils.message import Message
from hexa.utils.document import Document
from hexa.utils.torch import padded_tensor
from hexa.utils import woi_constants as consts


ROOT_DIR = Path(__file__).parent.parent.parent


class MultiTaskCollator(object):
    def __init__(
        self, 
        pad_token_id: int = 0, 
        text_truncate: int = 1024,
        n_docs :int = 5,
        add_to_device: bool = False,
        device = None,
        return_episode_done: bool = False,
    ):
        self.pad_token_id = pad_token_id
        self.text_truncate = text_truncate
        self.n_docs = n_docs
        self.add_to_device = add_to_device
        self.device = device
        self.return_episode_done = return_episode_done
        
    def _pad(self, vec):        
        assert len(vec.shape)==1
        seqlen = vec.shape[0]
        maxlen = self.text_truncate or 1024
        if seqlen >= maxlen:
            return vec[:maxlen]
        else:
            append_vec = vec.new(maxlen-seqlen).fill_(self.pad_token_id)
            return torch.cat((vec, append_vec))        
        
    def _pad_sequence(self, batch, key, trunc=0, pad_like_parlai=True):
        if trunc > 0:
            vec = pad_sequence(
                [example[key][:trunc, :] for example in batch],
                batch_first=True,
                padding_value=self.pad_token_id,
            )            
        else:
            vec = pad_sequence(
                [example[key] for example in batch],
                batch_first=True,
                padding_value=self.pad_token_id,
            )

        if pad_like_parlai and len(vec.size()) == 2:
            vec = self._pad_like_parlai(vec)

        return vec

    def _pad_like_parlai(self, vec):
        batch_size, len_item = vec.size()
        len_tobe_added = 8 - len_item % 8
        if len_tobe_added != 8:
            zero_tensor = torch.zeros((batch_size, len_tobe_added)) + self.pad_token_id
            vec = torch.cat([vec, zero_tensor], dim=1).long()
        return vec

    
    def _concat_sequence(self, batch, key):
        vec = torch.cat(
            [example[key] for example in batch]
        )
        return vec

    def _set_device(self, item_dict: Dict):
        for key, val in item_dict.items():
            if isinstance(val, torch.Tensor) and key!='skip_retrieval_vec':
                item_dict[key] = val.to(self.device)
        return item_dict

    def __call__(self, batch: List[Dict[str, Any]]):
        lm_labels = self._pad_sequence(batch, 'label_vec')        
        skip_retrieval_vec = self._concat_sequence(batch, 'skip_retrieval_vec')
        original_doc_len = [b['original_doc_len'] for b in batch]
        max_n_docs = min(self.n_docs, max(original_doc_len))
        is_doc_added = None        
        
        # no need for docs
        if torch.all(skip_retrieval_vec):
            input_ids = self._pad_sequence(batch, 'text_vec')
            
        # need to concat docs with text
        elif torch.all(~skip_retrieval_vec):
            input_ids = self._pad_sequence(batch, 'docs_text_vec', max_n_docs)
            if len(input_ids.shape)==3:
                input_ids = input_ids.view(-1, input_ids.shape[-1])
            
        # combination of with/without docs
        else:
            input_ids = []
            is_doc_added = []
            for b in batch:
                if b['skip_retrieval_vec']:
                    text_vec = self._pad(b['text_vec']).unsqueeze(0)        
                    doc_added = False
                else:
                    text_vec = b['docs_text_vec'][:max_n_docs, :]
                    doc_added = True
                input_ids.append(text_vec)
                is_doc_added.extend([doc_added]*text_vec.shape[0])

            input_ids = torch.cat(input_ids)        
            attention_mask = input_ids.ne(self.pad_token_id)
            
        attention_mask = input_ids.ne(self.pad_token_id)
        
        if is_doc_added:
            is_doc_added = torch.tensor(is_doc_added)        
             
        ret = {
            'task_ids': [item['id'] for item in batch],
            'texts': [item['text'] for item in batch],
            'labels': [item['label'] for item in batch],
            'all_labels': [item['labels'] for item in batch],
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'lm_labels': lm_labels,
            'skip_retrieval_vec': skip_retrieval_vec,
            'is_doc_added': is_doc_added,
            'max_n_docs': max_n_docs,
        }
        if self.return_episode_done:
            ret['episode_done'] = [item['episode_done'] for item in batch]
            # FIXME: temporarily disabled
            # ret['episode_id'] = [item['episode_id'] for item in batch]

        if self.add_to_device:
            ret = self._set_device(ret)
    
        if 'data' in batch[0]:
            # it must be debug mode
            ret['data'] = [item['data'] for item in batch]

        return ret


class InhouseDataset(Dataset):
    def __init__(self, opt, tokenizer, debug=False):
        self.opt = deepcopy(opt)
        self.tokenizer = tokenizer
        self.start_idx = self.tokenizer.bos_token_id
        self.end_idx = self.tokenizer.eos_token_id
        self.pad_idx = self.tokenizer.pad_token_id
        self.text_truncate = self.opt.truncate or 1024
        self.label_truncate = self.opt.label_truncate or 128        
        self.expanded_input_truncate = self.opt.truncate or 1024
        self.min_doc_token_length = self.opt.min_doc_token_length or 64
        self.n_extra_positions = self.opt.n_extra_positions or 0
        self.fp16 = self.opt.fp16 or False        
        self.generation_model = 'bart'
        self.debug = debug
        self.is_train = True
        self.offset = 3 # one for the start token, two for the end token

    def _check_dimension(self, vec):
        if len(vec.shape)>1:
            vec = vec.squeeze(0)
        return vec

    def _check_truncate(self, vec, truncate, truncate_left=False):
        """
        Check that vector is truncated correctly.
        """
        if truncate is None:
            return vec
        if len(vec) <= truncate:
            return vec
        if truncate_left:
            return vec[-truncate:]
        else:
            return vec[:truncate]

    def _add_start_end_tokens(self, vec, add_start=False, add_end=False):
        """
        Add start and end tokens to a list or tensor.
        """
        if isinstance(vec, torch.Tensor):
            if len(vec.shape) != 1:
                raise Exception('_add_start_end_tokens expects a 1D tensor')
            tensors = [vec]
            if add_start:
                tensors.insert(0, vec.new_tensor([self.start_idx]))
            if add_end:
                tensors.append(vec.new_tensor([self.end_idx]))
            return torch.cat(tensors, 0)
        if add_start:
            vec.insert(0, self.start_idx)
        if add_end:
            vec.append(self.end_idx)
        return vec

    def text2vec(self, text: str):
        text_vec = torch.LongTensor(self.tokenizer.encode(text))
        text_vec = self._check_truncate(text_vec, self.text_truncate-self.offset, truncate_left=True)
        text_vec = self._add_start_end_tokens(text_vec, add_start=True, add_end=True)
        text_vec = self._add_start_end_tokens(text_vec, add_start=False, add_end=True)
        text_vec = text_vec.unsqueeze(0)                
        return text_vec

    def label2vec(self, label: str):
        label_vec = torch.LongTensor(self.tokenizer.encode(label))
        label_vec = self._add_start_end_tokens(label_vec, add_start=False, add_end=True)
        label_vec = self._check_truncate(label_vec, self.label_truncate)
        label_vec = self._check_dimension(label_vec)
        return label_vec
    
    def _postprocess_docs(self, retrieved_docs):
        if len(retrieved_docs) == self.opt.n_docs:
            return retrieved_docs
        elif len(retrieved_docs) > self.opt.n_docs:
            return retrieved_docs[:self.opt.n_docs]
        else:
            num_null = self.opt.n_docs - len(retrieved_docs)
            null_doc = Document('', '', '')
            for _ in range(num_null):
                retrieved_docs.append(null_doc)
            return retrieved_docs        
    
    def _extract_doc_from_message(self, message: Message, idx: int):
        """
        Returns the `idx`-th `__retrieved-docs__` in the `message` as a Document object.
        """
        return Document(
            docid=message[consts.RETRIEVED_DOCS_URLS][idx],
            title=message[consts.RETRIEVED_DOCS_TITLES][idx],
            text=message[consts.RETRIEVED_DOCS][idx],
        )    
 
    def retrieved_doc_scores(self, retrieved_docs):
        max_num_docs = max([len(rtds) for rtds in retrieved_docs])
        retrieved_doc_scores = torch.Tensor([1 / (1 + i) for i in range(max_num_docs)])
        return retrieved_doc_scores    
    
    def get_retrieved_knowledge(self, message: Message):
        retrieved_docs = []
        _n_docs = self.opt.n_docs
        if not message.get(consts.RETRIEVED_DOCS):
            return retrieved_docs

        # First adding the docs with selected sentences.
        selected_sentences = message[consts.SELECTED_SENTENCES]
        n_docs_in_message = len(message[consts.RETRIEVED_DOCS])
        already_added_doc_idx = []

        if ' '.join(selected_sentences) == consts.NO_SELECTED_SENTENCES_TOKEN:
            return retrieved_docs  # `retrieved_docs` is empty at this point

        for doc_idx in range(n_docs_in_message):
            doc_content = message[consts.RETRIEVED_DOCS][doc_idx]
            for sel_sentc in selected_sentences:
                if sel_sentc in doc_content and doc_idx not in already_added_doc_idx:
                    retrieved_docs.append(
                        self._extract_doc_from_message(message, doc_idx)
                    )
                    already_added_doc_idx.append(doc_idx)
                    break
            if len(retrieved_docs) == _n_docs and doc_idx != (_n_docs - 1):
                break

        for doc_idx in range(n_docs_in_message):
            if len(retrieved_docs) == _n_docs:
                break

            if doc_idx in already_added_doc_idx:
                continue

            retrieved_docs.append(self._extract_doc_from_message(message, doc_idx))

        # random.shuffle(retrieved_docs)
        return retrieved_docs
    
    def concat_docs_and_input(
        self,
        input: torch.LongTensor,
        input_lengths: torch.LongTensor,
        top_docs: List[List[Document]],
        max_num_docs: int,
        right_padded: bool = True,
    ) -> torch.LongTensor:

        max_len = self.expanded_input_truncate
        expanded_input = []
        for i, docs in enumerate(top_docs):
            for rank in range(len(docs)):
                input_i = input[i, :]
                doc = docs[rank]
                doc_tokens = self.tokenizer.encode(doc.get_passage_str())
                
                if doc._text == '':
                    expanded_input.append(
                        torch.LongTensor(
                            [self.pad_idx],
                            # [self.start_idx, self.end_idx, self.end_idx] ## Note add doub end_idx
                        ).to(input)
                    )
                    continue              
                
                # doc_tokens = self.dict.txt2vec(doc.get_passage_str())
                if self.generation_model == 'bart' and self.n_extra_positions <= 0:
                    # move SOS to start of passage since we append question to end
                    input_i = input_i[1:]
                    sample_doc_tokens = torch.LongTensor(
                        [self.start_idx] + doc_tokens
                    ).to(input)
                else:
                    sample_doc_tokens = torch.LongTensor(doc_tokens).to(input)

                if self.n_extra_positions <= 0:
                    # Prepend document to text
                    input_i_len = input_lengths[i]
                    new_input_length = min(
                        self.expanded_input_truncate - self.min_doc_token_length,
                        input_i_len,
                    )
                    if right_padded:
                        input_i = input_i[input_i_len - new_input_length : input_i_len]
                    else:
                        input_i = input_i[input_i.size(0) - new_input_length :]

                    doc_max_len = max(max_len - len(input_i), 0)
                    sample_doc_tokens = sample_doc_tokens[:doc_max_len]
                    expanded_input.append(
                        torch.cat([sample_doc_tokens, input_i])[:max_len]
                    )
                else:
                    # Append Document to text
                    sample_doc_tokens = sample_doc_tokens[:max_len]
                    input_i_new = input_i.new(
                        self.n_positions - self.n_extra_positions
                    ).fill_(self.pad_idx)
                    input_i_new[input_i_new.size(0) - input_i.size(0) :] = input_i
                    expanded_input.append(torch.cat([input_i_new, sample_doc_tokens]))
            # append extra null inputs if there are diff # of docs per input
            expanded_input += [
                input[i, :].new(input[i, :].size()).fill_(self.pad_idx)
            ] * (max_num_docs - len(docs))
        expanded_input, _ = padded_tensor(
            expanded_input,
            fp16friendly=self.fp16 and right_padded,
            max_len=max_len if self.n_extra_positions <= 0 else None,
            pad_idx=self.pad_idx,
            left_padded=not right_padded,
        )
        expanded_input = expanded_input.to(input.device)
        return expanded_input  # type: ignore    

    def __getitem__(self, episode_idx):
        
        if self.is_train:
            data, done = self.data.next_example(episode_idx)
        else:
            data = self.data[episode_idx]
            done = True

        task_id = f"{data['id']}_session{data['session_id']}" if 'session_id' in data else data['id']
        ret = {'id': task_id, 'episode_done': done}

        # vectorize text
        ret['text'] = data['text']
        text_vec = self.text2vec(data['text'])
        ret['text_vec'] = self._check_dimension(text_vec)

        # vectorize label
        if 'labels' in data:
            selected_label = random.choice(data['labels'])
        elif 'eval_labels' in data:
            selected_label = random.choice(data['eval_labels'])

        ret['label'] = selected_label
        ret['label_vec'] = self.label2vec(selected_label)
        ret['labels'] = data['labels']

        # retrieve docs if required
        skip_retrieval_vec = False
        if data.get('skip_retrieval') is True:
            skip_retrieval_vec = True

        if not skip_retrieval_vec:
            retrieved_docs = self.get_retrieved_knowledge(data)
            original_doc_len = len(retrieved_docs)
            retrieved_docs = self._postprocess_docs(retrieved_docs)
            # random.shuffle(retrieved_docs)

            retrieved_docs = [retrieved_docs]
            top_doc_scores = self.retrieved_doc_scores(retrieved_docs)[None]

            expanded_input = self.concat_docs_and_input(
                text_vec, 
                text_vec.ne(self.pad_idx).sum(1), 
                retrieved_docs, 
                top_doc_scores.size(1)
            )                    
            expanded_input = self._check_dimension(expanded_input)
        else:
            original_doc_len = 0
            top_doc_scores = None
            expanded_input = None
            retrieved_docs = [
                Document(
                    docid = self.tokenizer.pad_token,
                    title = self.tokenizer.pad_token,
                    text = self.tokenizer.pad_token,
                ),    
            ]

        ret['top_doc_scores'] = top_doc_scores                    
        ret['docs_text_vec'] = expanded_input
        ret['skip_retrieval_vec'] = torch.BoolTensor([skip_retrieval_vec])
        ret['top_doc'] = retrieved_docs[0]
        ret['original_doc_len'] = original_doc_len
        
        return ret

    
@dataclass
class EpisodicData:
    data = List[List[Message]]
    meta_data = List[Dict[Any, Any]]

    def __init__(self, data: List[List[Message]]):
        self.data = data
        self.meta_data = []
        self.num_episodes = len(self.data)
        self.num_entries = 0
        for entries in self.data:
            length = len(entries)
            meta_data = {
                'len': length,
                'entry_idx': 0
            }
            self.meta_data.append(meta_data)
            self.num_entries += length

    def num_entries_by(self, ids):
        num_entries = 0
        for idx in ids:
            num_entries += self.meta_data[idx]['len']
        return num_entries

    def __len__(self):
        return len(self.data)

    def next_example(self, episode_idx):
        meta_data = self.meta_data[episode_idx]
        entry_idx, num_entry = meta_data['entry_idx'], meta_data['len']
        ex = self.data[episode_idx][entry_idx]
        ex['entry_id'] = entry_idx
        ex['episode_len'] = num_entry
        episode_done = entry_idx == num_entry-1
        if episode_done:
            self.meta_data[episode_idx]['entry_idx'] = 0
        else:
            self.meta_data[episode_idx]['entry_idx'] += 1
        return ex, episode_done

    
class EpisodicDataset(InhouseDataset):
    def __init__(self, data, opt, tokenizer, max_num=None):
        super(EpisodicDataset, self).__init__(opt, tokenizer)

        # data should be list of list of dict
        new_data = []
        for i in range(len(data)):
            item = data[i]
            if type(item) is not list:
                item = [item]
            new_data.append(item)
        if max_num is not None:
            new_data = new_data[:max_num]
        self.data = EpisodicData(new_data)

    def __len__(self):
        return self.data.num_entries

    @property
    def num_episodes(self):
        return self.data.num_episodes

    @property
    def num_entries(self):
        return self.data.num_entries

    def num_entries_by(self, ids):
        return self.data.num_entries_by(ids)


class WeightedConcatDataset(ConcatDataset):
    def __init__(self, datasets: Iterable[EpisodicDataset], weights: List[int],
                 batch_size: int = 1,
                 num_replicas: int = 1,
                 rank: int = 0, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False, is_valid: bool = False, verbose: bool = False) -> None:
        super(WeightedConcatDataset, self).__init__(datasets)

        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))

        print(f'rank: {rank} of {num_replicas}')
        self.batch_size = batch_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.drop_last = drop_last
        self.seed = seed
        self.is_valid = is_valid
        self.verbose = verbose

        self.datasets = list(datasets)
        self._weights = deepcopy(weights)
        self.shuffle = shuffle
        assert len(self.datasets) > 0, 'datasets should not be an empty iterable'
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        assert len(self.datasets) == len(weights)
        self.lengths = []
        for dataset in self.datasets:
            if isinstance(dataset, EpisodicDataset):
                self.lengths.append(dataset.num_episodes)
            else:
                self.lengths.append(len(dataset))

        for i, length in enumerate(self.lengths):
            if self.drop_last and length % self.num_replicas != 0:
                self.lengths[i] = math.ceil(
                    (length - self.num_replicas) / self.num_replicas
                )
            else:
                self.lengths[i] = math.ceil(length / self.num_replicas)

        self.total_sizes = [length * self.num_replicas for length in self.lengths]
        self.dataset_ids = list(range(len(self.datasets)))
        self.sample_ids = [[] for _ in self.dataset_ids]
        self.num_episodes = sum(self.lengths)

        self.epochs = None
        self.weights = None
        self.entry_lengths = None
        self.num_entries = None

        self.init()

    def init(self):
        # this should be called after validation is finished
        self.epochs = [-1 for _ in self.dataset_ids]
        self.dataset_ids = list(range(len(self.datasets)))
        self.weights = deepcopy(self._weights)
        for dataset_idx in self.dataset_ids:
            self.reset(dataset_idx)

        # In episodic dataset, number of entry is defined after the sample(episode) ids of this class are defined.
        self.entry_lengths = []
        for idx, dataset in enumerate(self.datasets):
            if isinstance(dataset, EpisodicDataset):
                num_entry = dataset.num_entries_by(self.sample_ids[idx])
            else:
                num_entry = len(dataset)
            self.entry_lengths.append(num_entry)
        self.num_entries = sum(self.entry_lengths)

    def reset(self, dataset_idx):
        self.epochs[dataset_idx] += 1
        if isinstance(self.datasets[dataset_idx], EpisodicDataset):
            num_episodes = self.datasets[dataset_idx].num_episodes
        else:
            num_episodes = len(self.datasets[dataset_idx])

        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epochs[dataset_idx])
            indices = torch.randperm(num_episodes, generator=g).tolist()
        else:
            indices = list(range(num_episodes))

        total_size = self.total_sizes[dataset_idx]
        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:total_size]
        assert len(indices) == total_size

        indices = indices[self.rank:total_size:self.num_replicas]
        if self.verbose:
            print(f'dataset: {dataset_idx}, RANK:{self.rank}, ids:{indices[:10]}')

        self.sample_ids[dataset_idx] = indices

    def __len__(self):
        return self.num_entries

    def sample_valid_episode_idx(self, idx, dataset_idx):
        idx = idx % self.batch_size
        num_remains_episode = len(self.sample_ids[dataset_idx])
        if num_remains_episode - 1 < idx:
            idx = random.choice(range(num_remains_episode))
        return idx

    def __getitem__(self, idx):
        # sample dataset_idx and sample_idx
        dataset_idx = random.choices(self.dataset_ids, weights=self.weights, k=1)[0]

        if len(self.sample_ids[dataset_idx]) == 0:
            self.reset(dataset_idx)
        idx = self.sample_valid_episode_idx(idx, dataset_idx)
        episode_idx = self.sample_ids[dataset_idx][idx]
        data = self.datasets[dataset_idx][episode_idx]
        if data['episode_done']:
            self.sample_ids[dataset_idx].pop(idx)

        return data


class SelfLearnSequentialDataset(WeightedConcatDataset):
    def __init__(self, datasets, opt, tokenizer, do_eval=False, **kwargs):
        opt.dataset.fp16 = True
        datasets = [EpisodicDataset(dataset, opt.dataset, tokenizer) for dataset in datasets]
        weights = [1] * len(datasets)
        super().__init__(datasets, weights, **kwargs)
        self.current_episode_dataset_id = None
        self.do_eval = do_eval

    def init(self):
        # this should be called after validation is finished
        self.epochs = [-1 for _ in self.dataset_ids]
        self.dataset_ids = list(range(len(self.datasets)))
        self.weights = self._weights.copy()
        for dataset_idx in self.dataset_ids:
            self.reset(dataset_idx)

        self.num_entries = self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if self.current_episode_dataset_id is None:
            dataset_idx = random.choices(self.dataset_ids, weights=self.weights, k=1)[0]
            self.current_episode_dataset_id = dataset_idx
        dataset_idx = self.current_episode_dataset_id
        idx = self.sample_valid_episode_idx(idx, dataset_idx)
        episode_idx = self.sample_ids[dataset_idx][idx]
        data = self.datasets[dataset_idx][episode_idx]
        data['episode_id'] = episode_idx
        if data['episode_done']:
            self.sample_ids[dataset_idx].pop(idx)
            self.current_episode_dataset_id = None
        if len(self.sample_ids[dataset_idx]) == 0:
            if self.do_eval:
                self.dataset_ids.pop(dataset_idx)
                self.weights.pop(dataset_idx)
                self.current_episode_dataset_id = None
            else:
                self.reset(dataset_idx)

        return data
