import copy
import os
import time
import glob
import json
import random
import pickle
import inspect
import hashlib
from dataclasses import dataclass

import transformers
from tqdm import tqdm
from typing import Callable, Sequence, Dict
import multiprocessing as mp

import json5
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.distributed as dist
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer
)

import recordio


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""
    def __init__(self,
                 data_file: str,
                 processor: Callable,             # tokenize process
                 max_length: int=4096,            # max len
                 switch_rate: float=0.0,          # prob to concat
                 cache_dir: str='cached',         # cache dir
                 overwrite: bool=False,           # overwrite cache
                 num_workers: int=mp.cpu_count(),  # multi process
                 *args, **kwargs
    ):
        self.data_file = data_file
        self.processor = processor
        self.max_length = max_length
        self.switch_rate = switch_rate
        assert 0. <= self.switch_rate <= 1.
        self.overwrite = overwrite
        self.cache_dir = self.get_cache_dir(cache_dir)
        self.num_workers = num_workers
        self.update(0)

    def __getitem__(self, idx):
        groups = []
        for i in self.packed[idx]:
            if groups and random.random() < self.switch_rate:
                groups[-1].append(i)
            else:
                groups.append([i])

        input_ids, labels, position_ids = [], [], []
        attention_mask = torch.zeros((self.max_length, self.max_length), dtype=torch.bool)
        for group in groups:
            start = len(input_ids)
            for i in group:
                data = self.tokenized[i]
                input_ids.extend(data['input_ids'])
                labels.extend(data['labels'])
            end = len(input_ids)
            position_ids.extend(list(range(end - start)))
            attention_mask[start:end, start:end] = torch.tril(~attention_mask[start:end, start:end])
        input_ids = input_ids[:self.max_length]
        labels = labels[:self.max_length]
        position_ids = position_ids[:self.max_length]
        pad_length = self.max_length - len(input_ids)
        if pad_length:
            input_ids.extend([0] * pad_length)
            labels.extend([-100] * pad_length)
            position_ids.extend([0] * pad_length)
        return {
            'input_ids': torch.LongTensor(input_ids),
            'labels': torch.LongTensor(labels),
            'position_ids': torch.LongTensor(position_ids),
            'attention_mask': attention_mask
        }

    def __len__(self):
        return len(self.packed)

    def update(self, epoch):
        def random_chunks():
            indice = list(range(len(self.tokenized)))
            random.shuffle(indice)
            step = (len(indice) - 1) // self.num_workers + 1
            return [indice[i * step : (i + 1) * step] for i in range(self.num_workers)]

        rank = dist.get_rank() if dist.is_initialized() else 0
        cache_file = os.path.join(self.cache_dir, f'packed-{epoch}.rec')
        if not os.path.exists(cache_file) or self.overwrite:
            if rank == 0:
                os.makedirs(os.path.dirname(cache_file), exist_ok=True)
                with recordio.RecordIO(cache_file, 'w') as rec:
                    desc = f'[Rank {rank}][Epoch {epoch}] split_and_pack'
                    with mp.Pool(self.num_workers) as pool, tqdm(desc=desc) as pbar:
                        for packed in pool.map(self._split_and_pack, random_chunks()):
                            for data in packed:
                                rec.append(pickle.dumps(data))
                                pbar.update()
                    print(f'[Rank {rank}][Epoch {epoch}] Save #{len(rec)} samples to {cache_file}', flush=True)
            if dist.is_initialized():
                dist.barrier()
                # in case some stupid device that is too damn slow to sync file
                while not os.path.exists(cache_file):
                    time.sleep(1)
        self.packed = recordio.Record(cache_file, auto_pickle=True)
        print(f'[Rank {rank}][Epoch {epoch}] Load #{len(self.packed)} samples from {cache_file}', flush=True)

    @property
    def tokenized(self):
        cache_file = os.path.join(self.cache_dir, 'tokenized.rec')
        if not os.path.exists(cache_file):
            os.makedirs(os.path.dirname(cache_file), exist_ok=True) 
            with mp.Pool(self.num_workers) as pool, recordio.RecordIO(cache_file, 'w') as rec:
                for data in pool.map(self.processor, self._load_examples()):
                    rec.append(pickle.dumps(data))
        return recordio.Record(cache_file, auto_pickle=True)

    def _load_examples(self):
        data_info = json5.load(open(self.data_file))
        train_files = data_info['train_files']
        data_root = train_files['root']
        examples = []
        for pattern, rate in train_files['sample_rate'].items():
            file_pattern = pattern + '*.json' if pattern.startswith('/') else os.path.join(data_root,
                                                                                            pattern + '*.json')
            if len(glob.glob(file_pattern)) == 0:
                print(f'No file matches {pattern} in {data_root}', flush=True)
            for file in glob.glob(file_pattern):
                try:
                    data_list = json.load(open(file))
                except:
                    data_list = [
                        json.loads(line.strip())
                        for line in open(file)
                    ]
                sample_num = int(len(data_list) * rate)
                print(f'Loading {file}: sample {len(data_list)} -> {sample_num}', flush=True)
                while sample_num >= len(data_list):
                    examples.extend(data_list)
                    sample_num -= len(data_list)
                if sample_num:
                    examples.extend(random.sample(data_list, k=sample_num))
        print(f'Loaded {len(examples)} examples from {self.data_file}', flush=True)
        return examples

    def _split_and_pack(self, indice, max_try=100):
        lengths = [len(self.tokenized[i]['input_ids']) for i in indice]
        candidates = set(range(len(indice)))
        packed = []
        for i, input_length in sorted(enumerate(lengths), key=lambda d: d[1], reverse=True):
            if i not in candidates:
                continue
            candidates.remove(i)
            collection = [indice[i]]
            if input_length < self.max_length:
                rand_candidates = list(candidates)
                random.shuffle(rand_candidates)
                ntry = 0
                for j in rand_candidates:
                    if input_length + lengths[j] <= self.max_length:
                        input_length += lengths[j]
                        collection.append(indice[j])
                        candidates.remove(j)
                        if input_length == self.max_length:
                            break
                        else:
                            ntry = 0
                    else:
                        ntry += 1
                        if ntry == max_try:
                            break
                random.shuffle(collection)
            packed.append(collection)
        # print(packed)
        return packed

    def get_cache_dir(self, cache_dir):
        def get_md5(content):
            if isinstance(content, str):
                content = content.encode()
            assert isinstance(content, bytes)
            return hashlib.md5(content).hexdigest()

        meta_info = {}
        meta_info['config'] = get_md5(open(self.data_file).read())
        meta_info['processor'] = get_md5(inspect.getsource(self.processor.__class__))
        meta_info['dataset'] = get_md5(inspect.getsource(self.__class__))
        data_info = json5.load(open(self.data_file))
        train_files = data_info['train_files']
        data_root = train_files['root']
        meta_info['train_files'] = {}
        for pattern in train_files['sample_rate']:
            for file in glob.glob(os.path.join(data_root, pattern + '*.json*')):
                meta_info['train_files'][file] = get_md5(str(os.stat(file).st_size))
        meta_info['max_length'] = self.max_length
        meta_md5 = get_md5(json.dumps(meta_info))
        cache_dir = os.path.join(cache_dir, meta_md5)
        return cache_dir


class DpoDataset(SupervisedDataset):

    def __init__(self, *args, **kwargs):
        self.use_dpo_loss = kwargs.get('use_dpo_loss', False)
        self.use_sft_loss = kwargs.get('use_sft_loss', False)
        self.use_kl_loss = kwargs.get('use_kl_loss', False)
        super().__init__(*args, **kwargs)

    def __getitem__(self, idx):
        record_idx = self.packed[idx][0]
        data = self.tokenized[record_idx]

        result = {}
        for prefix in ['chosen', 'rejected', 'kl']:
            result[f'{prefix}_input_ids'] = data[f'{prefix}_input_ids'][:self.max_length]
            result[f'{prefix}_labels'] = data[f'{prefix}_labels'][:self.max_length]
            result[f'{prefix}_position_ids'] = list(range(len(data[f'{prefix}_input_ids'])))
            pad_length = self.max_length - len(data[f'{prefix}_input_ids'])
            if pad_length:
                result[f'{prefix}_input_ids'].extend([0] * pad_length)
                result[f'{prefix}_labels'].extend([-100] * pad_length)
                result[f'{prefix}_position_ids'].extend([0] * pad_length)
        for key in ['dpo_loss', 'sft_loss', 'kl_loss']:
            result[key] = data[key]
        # print("*"*25)
        # print(result)
        return result

    def __len__(self):
        return len(self.packed)

    def _load_examples(self):
        data_info = json5.load(open(self.data_file))
        train_files = data_info['train_files']
        data_root = train_files['root']
        examples = []
        for pattern, rate in train_files['sample_rate'].items():
            print(pattern, rate, flush=True)
            file_pattern = pattern + '*.json' if pattern.startswith('/') else os.path.join(data_root, pattern + '*.json')
            for file in glob.glob(file_pattern):
                print(file, flush=True)
                try:
                    data_list = json.load(open(file))
                except:
                    data_list = [
                        json.loads(line.strip())
                        for line in open(file)
                    ]
                sample_num = int(len(data_list) * rate)
                # TODO: at least 1 sample
                sample_num += 1
                print(f'Loading {file}: sample {len(data_list)} -> {sample_num}', flush=True)
                while sample_num >= len(data_list):
                    examples.extend(data_list)
                    sample_num -= len(data_list)
                if sample_num:
                    examples.extend(random.sample(data_list, k=sample_num))
        print(f'Loaded {len(examples)} examples from {self.data_file}', flush=True)
        result_examples = []
        num_continue = 0
        for example in examples:
            use_dpo_loss = example.get('dpo_loss', False) and self.use_dpo_loss
            use_sft_loss = example.get('sft_loss', False) and self.use_sft_loss
            use_kl_loss = example.get('kl_loss', True) and self.use_kl_loss
            if not use_dpo_loss and not use_sft_loss and not use_kl_loss:
                num_continue += 1
                continue
            result_examples.append(example)
        print(f'Filtered {num_continue} examples', flush=True)
        return result_examples

    def update(self, epoch):
        def random_chunks():
            indice = list(range(len(self.tokenized)))
            random.shuffle(indice)
            step = (len(indice) - 1) // self.num_workers + 1
            return [indice[i * step : (i + 1) * step] for i in range(self.num_workers)]

        rank = dist.get_rank() if dist.is_initialized() else 0
        cache_file = os.path.join(self.cache_dir, f'packed-{epoch}.rec')
        if not os.path.exists(cache_file) or self.overwrite:
            if rank == 0:
                os.makedirs(os.path.dirname(cache_file), exist_ok=True)
                with recordio.RecordIO(cache_file, 'w') as rec:
                    desc = f'[Rank {rank}][Epoch {epoch}] split_and_pack'
                    with mp.Pool(self.num_workers) as pool, tqdm(desc=desc) as pbar:
                        for packed in pool.map(self._split_and_pack, random_chunks()):
                            for data in packed:
                                rec.append(pickle.dumps(data))
                                pbar.update()
                    print(f'[Rank {rank}][Epoch {epoch}] Save #{len(rec)} samples to {cache_file}', flush=True)
            if dist.is_initialized():
                dist.barrier()
                # in case some stupid device that is too damn slow to sync file
                while not os.path.exists(cache_file):
                    time.sleep(1)
        self.packed = recordio.Record(cache_file, auto_pickle=True)
        print(f'[Rank {rank}][Epoch {epoch}] Load #{len(self.packed)} samples from {cache_file}', flush=True)

    @property
    def tokenized(self):
        cache_file = os.path.join(self.cache_dir, 'tokenized.rec')
        if not os.path.exists(cache_file):
            os.makedirs(os.path.dirname(cache_file), exist_ok=True)
            with mp.Pool(self.num_workers) as pool, recordio.RecordIO(cache_file, 'w') as rec:
                for data in pool.map(self.processor, self._load_examples()):
                    rec.append(pickle.dumps(data))
        return recordio.Record(cache_file, auto_pickle=True)

    def _split_and_pack(self, indice, max_try=100):
        packed = [[idx] for idx in indice]
        return packed


class CustomProcessor:
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 user_token_id: int,        # user token
                 assistant_token_id: int,   # assistant token
                 add_eos: bool=False,       #
    ):
        self.tokenizer = tokenizer
        self.user_token_id = user_token_id
        self.assistant_token_id = assistant_token_id
        self.add_eos = add_eos
        self.eos_token_id = self.tokenizer.eos_token_id

    def __call__(self, example):
        only_last_answer = example.get('mask_session_answer_loss', False)
        last_answer_id = [
            i for i, message in enumerate(example['conversations'])
            if message['from'] == 'gpt'
        ][-1]
        input_ids, labels = [], []
        system_str = example.get('system', '')
        if example['conversations'][0]['from'] == 'system':
            system_str = example['conversations'][0]['value']
            example['conversations'] = example['conversations'][1:]
        for i, message in enumerate(example['conversations']):
            if 'value' not in message:
                continue
            message_str = (system_str + message['value']) if (i == 0 and message['from'] == 'human') else message['value']
            content_tokens = self.tokenizer.encode(message_str)
            if message['from'] == 'human':
                input_ids.append(self.user_token_id)
                labels.append(-100 if self.add_eos or only_last_answer else self.eos_token_id)
                input_ids.extend(content_tokens)
                labels.extend([-100] * len(content_tokens))
            elif message['from'] == 'gpt':
                input_ids.append(self.assistant_token_id)
                labels.append(-100)
                input_ids.extend(content_tokens)
                if only_last_answer and i != last_answer_id:
                    labels.extend([-100] * len(content_tokens))
                else:
                    labels.extend(content_tokens)
                if self.add_eos or i == last_answer_id:
                    input_ids.append(self.eos_token_id)
                    labels.append(-100 if only_last_answer and i != last_answer_id else self.eos_token_id)
            else:
                raise ValueError(f"message role not supported yet: {message['from']}")
        return {'input_ids': input_ids, 'labels': labels}

    def __str__(self):
        pdir, base = os.path.split(os.path.abspath(self.tokenizer.name_or_path))
        if pdir != '/' and base.startswith('checkpoint-'):
            name = os.path.basename(pdir)
        else:
            name = base
        return f'{self.__class__.__name__}_{name}'


class DpoProcessor:
    def __init__(self, tokenizer: PreTrainedTokenizer, user_token_id: int, assistant_token_id: int):
        self.tokenizer = tokenizer
        self.user_token_id = user_token_id
        self.assistant_token_id = assistant_token_id
        self.eos_token_id = self.tokenizer.eos_token_id
    def _processs_human_turn(self, message):
        tokens = self.tokenizer.encode(message['value'])
        input_ids = [self.user_token_id] + tokens
        if 'magical_prompt' in message:
            magical_tokens = self.tokenizer.encode(message['magical_prompt'])
            kl_input_ids = [self.user_token_id] + magical_tokens + tokens
        else:
            kl_input_ids = input_ids
        labels = [self.eos_token_id] + [-100] * (len(input_ids) - 1)
        kl_labels = [self.eos_token_id] + [-100] * (len(kl_input_ids) - 1)
        return input_ids, labels, kl_input_ids, kl_labels

    def _processs_gpt_turn(self, message):
        tokens = self.tokenizer.encode(message['value'])
        input_ids = [self.assistant_token_id] + tokens
        labels = [-100] + tokens
        return input_ids, labels

    def _process_the_last_turn(self, message):
        chosen_key = 'chosen' if 'chosen' in message else 'value'
        chosen_ids = [self.assistant_token_id] + self.tokenizer.encode(message[chosen_key]) + [self.eos_token_id]
        chosen_labels = copy.deepcopy(chosen_ids)
        chosen_labels[0] = -100
        if 'rejected' in message:
            rejected_ids = [self.assistant_token_id] + self.tokenizer.encode(message['rejected']) + [self.eos_token_id]
            # print("rejected_ids",rejected_ids)
            # print("$" * 25)
            rejected_labels = copy.deepcopy(rejected_ids)
            rejected_labels[0] = -100
        else:
            rejected_ids = [0]
            rejected_labels = [-100]
        return chosen_ids, chosen_labels, rejected_ids, rejected_labels

    def __call__(self, example):
        only_last_answer = example.get('mask_session_answer_loss', False)
        result = {
            'chosen_input_ids': [], 'chosen_labels': [], 'rejected_input_ids': [], 'rejected_labels': [],
            'kl_input_ids': [], 'kl_labels': [],
            'sft_loss': example.get('sft_loss', False),
            'dpo_loss': example.get('dpo_loss', False),
            'kl_loss': example.get('kl_loss', True),
        }
        # fout = open('/tmp/data.jsonl', 'a')
        # print(json.dumps(example['conversations'], ensure_ascii=False), file=fout)
        if 'conversations' not in example:
            raise ValueError(f"Key 'conversations' is missing in example: {example}")
        
        for i, message in enumerate(example['conversations']):
            extend_pairs = []
            if i == len(example['conversations']) - 1:
                assert message['from'] == 'gpt'
                # reset the labels of the turns before the last
                if only_last_answer:
                    for idx, pair in enumerate(extend_pairs):
                        if pair[0].endswith('labels'):
                            extend_pairs[idx] = (pair[0], [-100] * len(pair[1]))
                chosen_ids, chosen_labels, rejected_ids, rejected_labels = self._process_the_last_turn(message)
                # print("chosen_ids",chosen_ids)#196
                # print("^^^^^^^^^^^^^^^^^^^^")
                for key in ['chosen', 'kl']:
                    extend_pairs.append((f'{key}_input_ids', chosen_ids))
                    extend_pairs.append((f'{key}_labels', chosen_labels))
                for key in ['rejected']:
                    extend_pairs.append((f'{key}_input_ids', rejected_ids))
                    extend_pairs.append((f'{key}_labels', rejected_labels))
            elif message['from'] == 'human': #??????
                input_ids, labels, kl_input_ids, kl_labels = self._processs_human_turn(message)
                for key in ['chosen', 'rejected']:
                    extend_pairs.append((f'{key}_input_ids', input_ids))
                    extend_pairs.append((f'{key}_labels', labels))
                for key in ['kl']:
                    extend_pairs.append((f'{key}_input_ids', kl_input_ids))
                    extend_pairs.append((f'{key}_labels', kl_labels))
            elif message['from'] == 'gpt':
                input_ids, labels = self._processs_gpt_turn(message)
                for key in ['chosen', 'rejected', 'kl']:
                    extend_pairs.append((f'{key}_input_ids', input_ids))
                    extend_pairs.append((f'{key}_labels', labels))
            else:
                raise ValueError(f"message role not supported yet: {message['from']}")
            for key, value in extend_pairs:
                result[key].extend(value)

        return result

@dataclass
class DataCollatorForDpo(object):
    tokenizer: PreTrainedTokenizer
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # just concatenate all the input_ids and labels
        sft_losses = torch.tensor([instance['sft_loss'] for instance in instances])
        dpo_losses = torch.tensor([instance['dpo_loss'] for instance in instances])
        kl_losses = torch.tensor([instance['kl_loss'] for instance in instances])
        data_dict = {'sft_loss_mask': sft_losses, 'dpo_loss_mask': dpo_losses, 'kl_loss_mask': kl_losses}

        for prefix in ['chosen_', 'rejected_', 'kl_']:
            input_ids = [torch.tensor(instance[f'{prefix}input_ids']).unsqueeze(0) for instance in instances]
            labels = [torch.tensor(instance[f'{prefix}labels']).unsqueeze(0) for instance in instances]

            max_length = max([ids.ne(self.tokenizer.pad_token_id).sum() for ids in input_ids])

            input_ids = [ids[:, :max_length] for ids in input_ids]
            labels = [ids[:, :max_length] for ids in labels]

            data_dict[f'{prefix}input_ids'] = torch.cat(input_ids, dim=0)
            data_dict[f'{prefix}labels'] = torch.cat(labels, dim=0)
            data_dict[f'{prefix}attention_mask'] = data_dict[f'{prefix}input_ids'].ne(self.tokenizer.pad_token_id)
        return data_dict


if __name__ == '__main__':
    pass