import json
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer


class MixedQADataset(Dataset):

    def __init__(
        self,
        data_files: Union[str, List[str]] = None,
        tokenizer: PreTrainedTokenizer = None,
        processor=None,
        config=None,
        max_samples: int = -1,
        data_path: str = None,
        max_prompt_length: int = 4000,
        max_response_length: int = 12000,
        return_raw_chat: bool = True,
        truncation: str = 'left',
        warmup_qa_first: int = 0,
        **kwargs
    ):
        self.tokenizer = tokenizer
        self.processor = processor

        if config is not None:
            self.max_prompt_length = getattr(config, 'max_prompt_length', max_prompt_length)
            self.max_response_length = getattr(config, 'max_response_length', max_response_length)
            self.return_raw_chat = getattr(config, 'return_raw_chat', return_raw_chat)
            self.truncation = getattr(config, 'truncation', truncation)
            self.warmup_qa_first = getattr(config, 'warmup_qa_first', warmup_qa_first)
        else:
            self.max_prompt_length = max_prompt_length
            self.max_response_length = max_response_length
            self.return_raw_chat = return_raw_chat
            self.truncation = truncation
            self.warmup_qa_first = warmup_qa_first

        actual_data_path = data_files if data_files is not None else data_path
        if actual_data_path is None:
            raise ValueError("Either data_files or data_path must be provided")

        self.data = self._load_data(actual_data_path)

        if self.warmup_qa_first > 0:
            self.data = self._reorder_qa_first(self.data, self.warmup_qa_first)

        if max_samples > 0 and len(self.data) > max_samples:
            self.data = self.data[:max_samples]

        type_counts = defaultdict(int)
        for item in self.data:
            item_type = item.get('type', 'unknown')
            type_counts[item_type] += 1

    def _load_data(self, data_path: str) -> List[Dict[str, Any]]:
        if data_path.endswith('.json'):
            with open(data_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            if isinstance(data, list):
                return data
            else:
                raise ValueError(f"Expected list in JSON file, got {type(data)}")
        else:
            raise ValueError(f"Unsupported file format: {data_path}. Only .json is supported.")

    def _reorder_qa_first(self, data: List[Dict[str, Any]], n: int) -> List[Dict[str, Any]]:
        qa_data = []
        mc_data = []
        for item in data:
            if item.get('type', 'mc') == 'openend':
                qa_data.append(item)
            else:
                mc_data.append(item)

        if len(qa_data) < n:
            n = len(qa_data)

        reordered = qa_data[:n] + mc_data + qa_data[n:]

        return reordered

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item = self.data[idx]

        prompt = item.get('prompt', [])
        if not isinstance(prompt, list):
            prompt = [{"role": "user", "content": str(prompt)}]

        data_type = item.get('type', 'mc')

        if data_type == 'openend':
            rubrics = item.get('rubrics', [])
            ground_truth = {
                'type': 'openend',
                'rubrics': rubrics,
                'prompt': prompt
            }
        else:
            answer_idx = item.get('answer_idx', '')
            options = item.get('options', {})
            question = item.get('question', '')
            ground_truth = {
                'type': 'mc',
                'answer_idx': answer_idx,
                'options': options,
                'question': question
            }

        extra_info = {
            'source': item.get('source', ''),
            'response_idx': item.get('response_idx', 0),
            'token_length': item.get('token_length', 0),
            'index': idx
        }

        reward_model = {
            'ground_truth': ground_truth,
            'style': 'rubric' if data_type == 'openend' else 'mc'
        }

        source = item.get('source', 'mixed_qa')

        result = {
            'data_source': source,
            'extra_info': extra_info,
            'reward_model': reward_model,
            'tools_kwargs': {},
        }

        if self.tokenizer is not None:
            pad_token_id = self.tokenizer.pad_token_id
            if pad_token_id is None:
                pad_token_id = self.tokenizer.eos_token_id

            input_ids = self.tokenizer.apply_chat_template(
                prompt,
                add_generation_prompt=True,
                tokenize=True,
            )

            seq_len = len(input_ids)

            if seq_len > self.max_prompt_length:
                if self.truncation == 'left':
                    input_ids = input_ids[-self.max_prompt_length:]
                else:
                    input_ids = input_ids[:self.max_prompt_length]
                attention_mask = [1] * self.max_prompt_length
            elif seq_len < self.max_prompt_length:
                pad_len = self.max_prompt_length - seq_len
                input_ids = [pad_token_id] * pad_len + input_ids
                attention_mask = [0] * pad_len + [1] * seq_len
            else:
                attention_mask = [1] * self.max_prompt_length

            input_ids = torch.tensor(input_ids, dtype=torch.long)
            attention_mask = torch.tensor(attention_mask, dtype=torch.long)
            position_ids = torch.arange(len(input_ids), dtype=torch.long)

            result['input_ids'] = input_ids
            result['attention_mask'] = attention_mask
            result['position_ids'] = position_ids

        if self.return_raw_chat:
            result['raw_prompt'] = prompt

        return result


def collate_fn(data_list: List[Dict[str, Any]]) -> Dict[str, Any]:
    tensors = defaultdict(list)
    non_tensors = defaultdict(list)

    for data in data_list:
        for key, val in data.items():
            if isinstance(val, torch.Tensor):
                tensors[key].append(val)
            else:
                non_tensors[key].append(val)

    result = {}

    for key, val in tensors.items():
        result[key] = torch.stack(val, dim=0)

    for key, val in non_tensors.items():
        result[key] = np.fromiter(val, dtype=object, count=len(val))

    return result

