from typing import Dict, List, Optional

import numpy as np
import datasets
from omegaconf import ListConfig

from inference_rlhf.code.query_builders.factory import query_builder_factory

def load_data(cfg, root):
    if "path" in cfg:
        dataset = datasets.load_from_disk(dataset_path=cfg.path)
        dataset = dataset[cfg.split]
    else:
        if isinstance(cfg.subset, ListConfig):
            ds = [
                datasets.load_dataset(cfg.name, subset, split=cfg.split, trust_remote_code=True) for subset in cfg.subset
            ]
            dataset = datasets.concatenate_datasets(ds)
            del ds
        else:
            dataset = datasets.load_dataset(cfg.name, cfg.subset, split=cfg.split, trust_remote_code=True)
    questions = [data[cfg.question_field] for data in dataset]

    answers = None
    if cfg.answer_field is not None: 
        answers = [data[cfg.answer_field] for data in dataset]  
    
    choices = None
    if "choices_field" in cfg:
        choices = [data[cfg.choices_field] for data in dataset]
    return questions, answers, choices

class BaseDataLoader(): 
    def __init__(self, cfg, **kwargs): 
        np.random.seed(cfg.seed)

        questions, answers, choices = load_data(cfg.task.data, cfg.root) 
        self.raw_answers = answers
        self.questions = questions
        self.answers = self.extract_groundtruth(answers)
        self.choices = choices

        self.qb = query_builder_factory(
            cfg.policy.name,
            cfg=cfg.policy,
            task_desc=cfg.task.TASK_DESC,
            shots=cfg.shots,
            **kwargs,
        )

        self.num = len(questions) if cfg.task.max_samples <= 0 else min(cfg.task.max_samples, len(questions))
        self.idxs = self.get_idxs()

    def get_idxs(self):
        if self.num == len(self.questions):
            return list(range(self.num))
        return sorted(np.random.choice(len(self.questions), self.num, replace=False))
    
    def extract_groundtruth(self, answers):
        return answers

    def get_rstar(self, idx, response):
        return {}

    def parse_responses(self, responses, generation_idx: int):
       return [
        {
            'prompt_idx' : generation_idx, 
            'response' : output.text, 
            'sum_logprob': output.cumulative_logprob,
            'avg_logprob': output.cumulative_logprob / len(output.token_ids),
            # **self.get_rstar(self.idxs[idx], output.text)
        } for response in responses for output in response.outputs] 
    
    def build_queries(self, apply_chat_template: bool = True): 
        return [self.qb.build_query(self.questions[idx], apply_chat_template=apply_chat_template) for idx in self.idxs]

