import json
from typing import Dict, List, Any, Iterator 
from collections import defaultdict
from dataclasses import dataclass
import random
import os
import numpy as np
import yaml

import torch
from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding, AutoProcessor
from torch.utils.data import Dataset as TorchDataset, IterableDataset as TorchIterableDataset
import datasets
import logging
from torch.utils.data import Sampler, Dataset, DataLoader
import mmap
import random

logger = logging.getLogger(__name__)

QUERY_KEY = "query"
DOC_KEY = "doc"

QWEN_VL_TEMPELATE = '<|user|>\n<|image_1|><|end|>\n<|assistant|>\n'
none_count = 0


def format_content(text, image, prefix='Query:'):
    content = []
    if not text and not image:
        content = [{'type': 'text', 'text': ""}]
        return content
    content.append({'type': 'text', 'text': prefix})
    if image:
        content.append({'type': 'image', 'image': 'file://' + image})
    if text:
        content.append({'type': 'text', 'text': text})
    return content


def normalize_instruction(instruction: str):
    instruction = instruction.strip()
    if len(instruction) > 0 and instruction[-1] in [';', ':', ',', '.']:
        return instruction[:-1]
    else:
        return instruction
    
class UniDataset(torch.utils.data.Dataset):
    """
    We create a custom dataset that returns tuples (query, positive, *negatives)
    on-the-fly based on the information from the mined-hard-negatives.

    instances: [{'qid': q, 'pos': [xx, ...], 'neg': [xxx, ...]}, ...]
    queries: {q: 'some query', ...}
    corpus: {xx: 'some doc', ...}

    corpus could be `None` if it is a symetric dataset.
    """

    def __init__(
        self,
        instances: List[Dict[str, Any]],
        queries=None,
        query_dict=None,
        corpus=None,
        corpus_dict=None,
        system_prompt=None,
        neg_per_ins: int = 8,
        query_image_prefix_path=None,
        image_prefix_path = None,
        instruction = None,
        doc_instruction = None,
        task_type = 'retrieval',
        use_all_pair=True,
        only_self_neg=False,
        boq_token = None,
        bod_token = None,
        max_length = None,
        random_neg = 0,
        sub_batchsize=None,
        **kwargs
    ):
        self.instances = instances
        self.queries = queries  # queries
        self.query_dict = query_dict
        self.only_query = False
        if corpus is None:
            self.only_query = True
        self.corpus = corpus or queries  # docs, if None, we are symetric.
        self.corpus_dict = corpus_dict or query_dict
        self.corpus_keys = list(self.corpus_dict.keys())
        self.neg_per_ins = neg_per_ins
        self.instruction = instruction
        self.doc_instruction = doc_instruction 
        self.task_type = task_type
        self.use_all_pair = use_all_pair
        self.only_self_neg = only_self_neg
        self.boq_token = boq_token
        self.random_neg = random_neg
        self.bod_token = boq_token if self.only_query else bod_token
        self.query_image_prefix_path = query_image_prefix_path
        self.max_length = max_length
        self.image_prefix_path = image_prefix_path
        self.sub_batchsize = sub_batchsize
        for ins in instances:
            random.shuffle(ins['neg'])
            random.shuffle(ins['pos'])

    def get_query(self, qid):
        index = self.query_dict[qid]
        text = self.queries[index].get('text', '')
        instruct =  self.queries[index].get('instruct', self.instruction)
        if instruct is None:
            instruct = ""
        image_path = self.queries[index].get('img_path', None)
        if image_path is None or len(image_path) == 0:
            image_path = None
        if image_path is not None:
           image_path = os.path.join(self.query_image_prefix_path, image_path)
        return (text, image_path, instruct)

    def get_doc(self, did):
        if self.only_query:
            return self.get_query(did)[:2]
        index = self.corpus_dict[did]
        text = self.corpus[index]['text']
        image_path = self.corpus[index].get('img_path', None)
        if image_path is None or len(image_path) == 0:
            image_path = None
        if image_path is not None:
           image_path = os.path.join(self.image_prefix_path, image_path)
        return (text, image_path)

    def __getitem__(self, item):
        ins = self.instances[item]
        query, query_image_path, instruct = self.get_query(ins['qid'])
        pos_id = ins['pos'].pop(0) 

        pos_text = self.get_doc(pos_id)
        ins['pos'].append(pos_id)

        if self.neg_per_ins == 1:
            neg_id = ins['neg'].pop(0)
            neg_texts = [self.get_doc(neg_id)]
            ins['neg'].append(neg_id)
        else:
            neg_texts = list()
            if self.random_neg > 0:
                neg_ids = (random.choices(self.corpus_keys, k=self.random_neg))
                neg_texts.extend([self.get_doc(neg_id) for neg_id in neg_ids])
            if len(ins['neg']) < (self.neg_per_ins-self.random_neg):
                neg_texts.extend([self.get_doc(neg_id) for neg_id in ins['neg']])
                neg_ids = (random.choices(self.corpus_keys, k=(self.neg_per_ins-self.random_neg - len(ins['neg']))))
                neg_texts.extend([self.get_doc(neg_id) for neg_id in neg_ids])
            else:
                for _ in range(self.neg_per_ins-self.random_neg):
                    neg_id = ins['neg'].pop(0)
                    neg_texts.append(self.get_doc(neg_id))
                    ins['neg'].append(neg_id)
            
        docs = [pos_text] + neg_texts
        return {'query': (instruct, query), 
                'query_image': query_image_path, 
                'docs': [d[0] for d in docs], 
                'doc_images': [d[1] for d in docs],
                'sub_batchsize': self.sub_batchsize,
                'max_length': self.max_length
                }

    def __len__(self):
        return len(self.instances)

def read_one_data(data_config: dict, 
    neg_per_ins=8, 
    instruction=None, 
    doc_instruction=None, 
    task_type=None, 
    use_all_pair=True, 
    only_self_neg=False,
    boq_token=None,
    bod_token=None,
    length_config=None,
    random_neg=0,
    default_max_length=512,
    ):
    folder = data_config.get('path', './')
    random_neg = data_config.get('random_neg',random_neg)
    print('folder', folder)
    logger.info('Reading dataset: %s', folder)

    def read_texts(path: str):
        if not os.path.exists(path):
            logger.info(f"File not exists: {path}")
            return None, None
        base_folder = os.path.dirname(path)
        data = datasets.load_dataset('json',data_files=path)['train']
        data_dict = {str(ele['_id']): i for i, ele in enumerate(data)}
        return data, data_dict

    neg_per_ins = data_config.get('neg_per_ins', neg_per_ins)
    query_file = data_config.get('query_file', 'queries.jsonl')
    corpus_file = data_config.get('corpus_file', 'corpus.jsonl')
    queries, query_dict = read_texts(os.path.join(folder, query_file))
    corpus, corpus_dict = read_texts(os.path.join(folder, corpus_file))
    instance_file = data_config.get('instance_file', 'instances.json')
    instance_file = instance_file.split('.')[0] + '_safeinstances.json'
    safe_instances = datasets.load_dataset('json', data_files=os.path.join(folder, instance_file))['train']
    print(f"Load {len(safe_instances)} from {folder}")
    if len(safe_instances) == 0:
        print('skip dataset ' + folder)
        return None
    instruction = data_config.get('instruct', instruction)
    peer_loss = data_config.get('peer_loss', False)
    max_length=data_config.get('max_length', default_max_length)
    sub_batchsize = data_config.get('sub_batchsize', None)
    if sub_batchsize is None and length_config is not None:
        sub_batchsize = length_config[str(max_length)].get('sub_batchsize', None)
    image_prefix_path = data_config.get('image_prefix_path', data_config['path'])
    query_image_prefix_path = data_config.get('query_image_prefix_path', image_prefix_path)
    return UniDataset(safe_instances, 
        queries, 
        query_dict,
        corpus, 
        corpus_dict, 
        query_image_prefix_path=query_image_prefix_path,
        image_prefix_path=image_prefix_path,
        neg_per_ins=neg_per_ins, 
        instruction=instruction, 
        doc_instruction=doc_instruction, 
        task_type=task_type, 
        use_all_pair=use_all_pair, 
        only_self_neg=only_self_neg,
        boq_token=boq_token,
        bod_token=bod_token,
        random_neg=random_neg,
        sub_batchsize=sub_batchsize,
        max_length=max_length
    )


class DynamicBatchSampler(Sampler[List[int]]):
    def __init__(self,  dataset: Dataset, num_process, process_index) -> None:
        self.dataset = dataset
        print('init batch samler')
        self.dataset.generate_batch()
        self.idx = self.dataset.global_idx
        self.num_process = num_process
        self.process_index = process_index

    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        batch_n = 0
        
        while len(self.idx) > 0:
            batch = self.idx.pop(0) 
            process_batch_size = len(batch) // self.num_process
            batch =  batch[self.process_index * process_batch_size: (self.process_index + 1) * process_batch_size]
            yield batch
            batch = []
        self.dataset.generate_batch()
        self.idx = self.dataset.global_idx

    def __len__(self) -> int:
        return len(self.dataset)

class MultiDatasetMNKD(torch.utils.data.Dataset):
    def __init__(
        self,
        data_configs: List[Dict],
        length_config: Dict =  None,
        default_batch_size: int = 32,
        neg_per_ins: int = 8,
        instruction=None,
        random_neg=0,
        doc_instruction=None,
        boq_token=None,
        bod_token=None,
        num_gpu=1,
        default_max_length=1024
    ):
        self.task_to_dataset: Dict[str, Any] = {}
        
        for data_config in data_configs:
            task_name = data_config["name"]
            task_type = data_config.get('task_type', None)
            use_all_pair = data_config.get('use_all_pair', True)
            max_length = data_config.get('max_length', default_max_length)
            batch_size = data_config.get('batch_size', None)
            if batch_size is None and length_config is not None:
                batch_size = length_config[str(max_length)]['batch_size']
            else:
                batch_size = default_batch_size
            only_self_neg = data_config.get('only_self_neg', False)
            sample_size = data_config.get('sample_size', -1)
            sample_smalldata_size = data_config.get('sample_smalldata_size', None)
            custom_dataset = read_one_data(data_config, 
                length_config=length_config,
                neg_per_ins=neg_per_ins, 
                instruction=instruction,
                doc_instruction=doc_instruction, 
                task_type=task_type, 
                use_all_pair=use_all_pair, 
                only_self_neg=only_self_neg,
                boq_token=boq_token,
                bod_token=bod_token,
                random_neg=random_neg
            )
            self.task_to_dataset[task_name] = {'dataset': custom_dataset, 'batch_size': batch_size * num_gpu, 'datasize': len(custom_dataset), 'sample_size': sample_size, 'sample_smalldata_size': sample_smalldata_size}
        self.batch_size = batch_size
    
    def __len__(self):
        return len(self.global_idx)
    
    def generate_batch(self):
        """Shuld be called at the begin of each epoch"""
        self.task_data_idxs, self.global_idx = self.batched_shuffle(self.task_to_dataset)
    
    @staticmethod
    def batched_shuffle(task_to_dataset: Dict[str, int]) -> List[Dict[str, Any]]:
        task_idxs_batches = [] 
        for task, task_data in task_to_dataset.items():
            data_size = task_data['datasize']
            sample_size = task_data['sample_size']
            sample_smalldata_size = task_data['sample_smalldata_size']
            if sample_smalldata_size is not None:
                reduce_ratio = data_size // sample_smalldata_size 
                if reduce_ratio > 1:
                    shuffled_idxs = np.random.permutation(sample_smalldata_size)
                    shuffled_idxs = shuffled_idxs * reduce_ratio
                else:
                    shuffled_idxs = np.random.permutation(data_size)
            else:
                shuffled_idxs = np.random.permutation(data_size)

            batch_size = task_data['batch_size']
            if sample_smalldata_size is not None:
                sample_data_size = sample_smalldata_size
            else:
                sample_data_size = data_size if sample_size < 0 else min(sample_size, data_size)
            local_batched_shuffled_idxs = [shuffled_idxs[i:i+batch_size] for i in range(0, sample_data_size, batch_size)]
            if len(local_batched_shuffled_idxs[-1]) < batch_size:
                local_batched_shuffled_idxs.pop()
            task_idxs_batches.extend([{"task_name": task, "batch_idxs": idxs} for idxs in local_batched_shuffled_idxs])
        
        random.shuffle(task_idxs_batches)

        batched_task_idx = []
        global_idx = []
        count = 0
        for task_batch in task_idxs_batches:
            batched_task_idx.extend([{"task": task_batch['task_name'], "idx": int(idx)} for idx in task_batch['batch_idxs']])
            global_idx.append(list(range(count, count + len(task_batch['batch_idxs']))))
            count += len(task_batch['batch_idxs'])
        return np.array(batched_task_idx), global_idx
                
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        task_data_idx = self.task_data_idxs[idx]
        task_name = task_data_idx['task']
        local_idx = task_data_idx['idx']
        example = self.task_to_dataset[task_name]['dataset'][local_idx]
        return example

@dataclass
class TripleCollatorMNKD(DataCollatorWithPadding):
    max_length: int = 1024
    tokenizer: PreTrainedTokenizer
    eod_token: str = None

    def __post_init__(self):
        self.tokenizer.padding_side = 'left'

    def truncation(self, text, max_length):
        tokens = self.tokenizer.tokenize(text)

    def format_data(self, query_text, query_image_path, doc_text, doc_image_path):
        inputs = []
        inputs.append({
            "role": "system",
            "content": [{
                "type": "text",
                "text": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"."
            }
            ]
        })
        if isinstance(query_text, tuple):
            instruct, query_text = query_text
        else:
            instruct = "Retrieval document that can answer user's question."
        contents = []
        contents.append({
            "type": "text",
            "text": '<Instruct>: ' + instruct
        })
        query_content = format_content(query_text, query_image_path, prefix='<Query>:')
        contents.extend(query_content)
        doc_content = format_content(doc_text, doc_image_path, prefix='\n<Document>:')
        contents.extend(doc_content)
        inputs.append({
            "role": "user",
            "content": contents
        })
        return inputs

    def _tokenize(self, to_tokenize, max_length, key='pos'):
        to_tokenize = [self.format_data(x[0], x[1], x[2], x[3]) for x in to_tokenize]
        out = self.tokenizer(
            to_tokenize, padding=True, truncation='longest_first',
            max_length=max_length
        )
        return out

    def __call__(self, features):
        examples = []
        sub_batchsize = features[0]['sub_batchsize']
        max_length=features[0].get('max_length', self.max_length)
        for f in features:
            f.pop('sub_batchsize')
            f.pop('max_length')
            examples.append([(f['query'],f['query_image'], doc, image) for doc, image in zip(f['docs'], f['doc_images'])])
        examples = sum(examples, [])
        outputs =  self._tokenize(examples, max_length=max_length)
        outputs['sub_batchsize'] = sub_batchsize
        return outputs
