
import copy
import logging
import os
import psutil
import random
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset

import datasets
from datasets import load_dataset
from transformers import (
    PreTrainedTokenizer,
    BatchEncoding,
    DataCollatorWithPadding,
    default_data_collator,
)

from pecos.core import clib
from sup_con_xmc.arguments import TrainingDataArguments
from sup_con_xmc.trainer import TevatronTrainer


logger = logging.getLogger(__name__)


class TrainPreProcessor:
    def __init__(self, tokenizer, max_length=32):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.text_col = "input"

    def __call__(self, example):
        return self.tokenizer(
            example[self.text_col],
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )


class EncodePreProcessor:
    def __init__(self, tokenizer, key_col, max_length=32):
        self.tokenizer = tokenizer
        self.key_col = key_col
        self.text_col = "input"
        self.max_length = max_length

    def __call__(self, example):
        res_dict = self.tokenizer(
            example[self.text_col],
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        res_dict.update({self.key_col: example[self.key_col]})
        return res_dict


class TrainDataset(Dataset):
    def __init__(
        self,
        data_args: TrainingDataArguments,
        inp_dataset: datasets.Dataset,
        lbl_dataset: datasets.Dataset,
        trainer: TevatronTrainer = None,
    ):
        self.inp_dataset = inp_dataset
        self.lbl_dataset = lbl_dataset

        required_cols = self.get_required_cols()
        self.pos_labels_col = required_cols[0]
        self.neg_labels_col = required_cols[1]
 
        self.data_args = data_args
        self.num_input = len(self.inp_dataset)
        self.num_label = len(self.lbl_dataset)
    
    @classmethod
    def get_required_cols(self):
        return ["pos_labels", "neg_labels"]

    def _sample(self, sample_list, num, max_num=-1, p=None):
        if len(sample_list) == 0:
            return list(map(int, np.random.choice(max_num, size=num)))
        if len(sample_list) >= num:
            return list(map(int, np.random.choice(sample_list, size=num, replace=False, p=p)))
        else:
            return list(map(int, np.random.choice(sample_list, size=num, replace=True, p=p)))
        
    def _get_pos_label_tensor(self, pos_labels):
        pos_labels_tensor = torch.zeros(self.data_args.max_label_per_query) - np.random.random()
        #lbl_freq = np.array([len(self.lbl_dataset[l]["pos_trn_ids"]) for l in pos_labels])
        #sorted_lbl = np.array(pos_labels)[np.argsort(lbl_freq)[::-1]]
        for ii,pos in enumerate(pos_labels):
            if ii == self.data_args.max_label_per_query:
                break
            pos_labels_tensor[ii] = pos
        return pos_labels_tensor

    def __len__(self):
        return self.num_input

    def __getitem__(self, idx) -> Tuple[BatchEncoding, List[BatchEncoding]]:
        epoch = int(self.trainer.state.epoch)
        _hashed_seed = hash(idx + self.trainer.args.seed)

        group = copy.deepcopy(self.inp_dataset[idx])
        pos_labels = group.pop(self.pos_labels_col)
        neg_labels = group.pop(self.neg_labels_col)
        # after we pop out label columns,
        # the remaining columns are all related to input text,
        # e.g., input_ids, token_type_ids, attention_mask, etc.
        group["pos_labels"] = self._get_pos_label_tensor(pos_labels)
        encoded_qry = group

        # sample 1 positive label per query
        encoded_lbl = None
        encoded_key = None
        lbl_pos_ids = None
        encoded_lbl = []
        if self.data_args.positive_passage_no_shuffle:
            pos_psg = pos_labels[0]
        else:
            pos_psg = pos_labels[(_hashed_seed + epoch) % len(pos_labels)]
            label_group = copy.deepcopy(self.lbl_dataset[pos_psg])
            lbl_pos_ids = label_group.pop("pos_trn_ids")
            label_group["id"] = [pos_psg]
        encoded_lbl.append(label_group)

        # sample (n - 1) negative labels per query
        neg_size = self.data_args.train_group_size - 1
        if len(neg_labels) == 0:
            negs = random.choices(range(len(self.lbl_dataset)), k=neg_size)
        elif len(neg_labels) < neg_size:
            negs = random.choices(neg_labels, k=neg_size)
        elif self.data_args.train_group_size == 1:
            negs = []
        elif self.data_args.negative_passage_no_shuffle:
            negs = neg_labels[:neg_size]
        else:
            _offset = epoch * neg_size % len(neg_labels)
            negs = [x for x in neg_labels]
            random.Random(_hashed_seed).shuffle(negs)
            negs = negs * 2
            negs = negs[_offset: _offset + neg_size]
        for neg_psg in negs:
            label_group = copy.deepcopy(self.lbl_dataset[neg_psg])
            label_group.pop("pos_trn_ids")
            label_group["id"] = [neg_psg]
            encoded_lbl.append(label_group)
        return encoded_qry, encoded_lbl, encoded_key
        

class EncodeDataset(Dataset):
    def __init__(self, key_col, dataset):
        self.key_col = key_col
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx) -> Tuple[str, BatchEncoding]:
        cur_dict = self.dataset[idx]
        text_id = cur_dict.pop(self.key_col)
        encoded_text = cur_dict
        return text_id, encoded_text


@dataclass
class TrainCollator(DataCollatorWithPadding):
    """
    Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
    and pass batch separately to the actual collator.
    Abstract out data detail for the model.
    """

    def __call__(self, data):
        qq, dd, kk = zip(*data)

        if isinstance(qq[0], list):
            qq = sum(qq, [])
        if isinstance(dd[0], list):
            dd = sum(dd, [])
        if isinstance(kk[0], list):
            kk = sum(kk, [])

        q_collated = default_data_collator(qq, return_tensors="pt")
        if dd[0]:
            d_collated = default_data_collator(dd, return_tensors="pt")
        else:
            d_collated = None
        if kk[0]:
            k_collated = default_data_collator(kk, return_tensors="pt")
        else:
            k_collated = None
        return q_collated, d_collated, k_collated


@dataclass
class EncodeCollator(DataCollatorWithPadding):
    def __call__(self, features):
        text_ids = [x[0] for x in features]
        text_features = [x[1] for x in features]
        collated_ids = torch.tensor(text_ids)
        collated_features = default_data_collator(text_features, return_tensors="pt")
        return collated_ids, collated_features
