from datasets import concatenate_datasets, load_dataset, DatasetDict
from functools import partial

from transformers import DataCollatorWithPadding,DataCollatorForSeq2Seq
import numpy as np
from typing import Optional, Union, List, Dict, Tuple, Any
from utils import _prepare_4d_causal_attention_mask_with_cache_position, hf_attention_mask_2d_to_4d,custom_attention_mask_2d_to_4d

import torch

def preprocess_function(example,tokenizer,data_config):
    candidate_tokens = data_config['candidate_tokens']
    step_tag_id = data_config['step_tag_id']
    
    input = f"{example['question']} {example['process']}"
    tokenized_inputs = tokenizer(
        input, 
        truncation=True, 
        # padding='max_length', 
        # padding=True,
        max_length=data_config['max_length'],
    )
    
    def find_all_indices(lst, element):
        return [i for i, x in enumerate(lst) if x == element]
    
    length = len(tokenized_inputs['input_ids'])
    # print(length)
    indices = find_all_indices(tokenized_inputs['input_ids'],step_tag_id)
    
    if len(indices) != len(example['label']):
        # print(example)
        example['label'] = example['label'][:len(indices)]
    
    assert len(indices) == len(example['label'])
    
    tokenized_inputs['labels'] = [-100] * length
    # tokenized_inputs['attention_mask'] = [1] *length
    # print(len(indices))
    for i in range(len(indices)):
        if example['label'][i] == '+' or example['label'][i] == 1:
            tokenized_inputs['labels'][indices[i]] = candidate_tokens[0]
        elif example['label'][i] == '-' or example['label'][i] == 0:
            tokenized_inputs['labels'][indices[i]] = candidate_tokens[1]
        else:
            raise ValueError('label is wrong')
        tokenized_inputs['attention_mask'][indices[i]] = 0
    # tokenized_inputs['labels'] = [-100] *(length-1) + tokenized_inputs['input_ids'][length-1:]
    
    return tokenized_inputs


def preprocess_function_with_chat_template(example,tokenizer,data_config):
    candidate_tokens = data_config['candidate_tokens']
    step_tag_id = data_config['step_tag_id']
    
    input = f"{example['question']} {example['process']}"
    
    messages = [
        # {
        #     "role": "system",
        #     "content": "You are a friendly chatbot who always responds in the style of a pirate",
        # },
        {
            "role": "user", 
            "content": input,
        },
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    print(text)
    tokenized_inputs = tokenizer(
        text, 
        truncation=True, 
        # padding='max_length', 
        # padding=True,
        max_length=data_config['max_length'],
    )
    
    def find_all_indices(lst, element):
        return [i for i, x in enumerate(lst) if x == element]
    
    length = len(tokenized_inputs['input_ids'])
    # print(length)
    indices = find_all_indices(tokenized_inputs['input_ids'],step_tag_id)
    
    if len(indices) != len(example['label']):
        # print(example)
        example['label'] = example['label'][:len(indices)]
    
    assert len(indices) == len(example['label'])
    
    tokenized_inputs['labels'] = [-100] * length
    # tokenized_inputs['attention_mask'] = [1] *length
    # print(len(indices))
    for i in range(len(indices)):
        if example['label'][i] == '+' or example['label'][i] == 1:
            tokenized_inputs['labels'][indices[i]] = candidate_tokens[0]
        elif example['label'][i] == '-' or example['label'][i] == 0:
            tokenized_inputs['labels'][indices[i]] = candidate_tokens[1]
        else:
            raise ValueError('label is wrong')
        tokenized_inputs['attention_mask'][indices[i]] = 0
    # tokenized_inputs['labels'] = [-100] *(length-1) + tokenized_inputs['input_ids'][length-1:]
    
    return tokenized_inputs



def get_prm_dataset(DATA_PATH,tokenizer,data_config,mode='train'):
    
    dataset = load_dataset('json', data_files=DATA_PATH,cache_dir=None)
   
    # dataset['train'] = dataset['train'].select(range(10000))
    # dataset['test'] = dataset['test'].select(range(10000))
    # 
    if mode =='test':
        train_dataset = dataset['train'].select(range(1000))
        test_dataset = dataset['test'].select(range(100))
        dataset = DatasetDict({'train': train_dataset, 'test': test_dataset})
    print('start processing')
    data_config['max_length'] = 4096
    print(data_config['max_length'])
    
    tokenized_datasets = dataset.map(partial(preprocess_function,tokenizer=tokenizer,data_config=data_config))
    # tokenized_datasets = dataset.map(partial(preprocess_function_with_chat_template,tokenizer=tokenizer,data_config=data_config))
    
    train_dataset = tokenized_datasets['train'].remove_columns(['question','process','label'])
    test_dataset = tokenized_datasets['test'].remove_columns(['question','process','label'])
    tokenized_datasets = DatasetDict({'train': train_dataset, 'test': test_dataset})
    
    print(tokenized_datasets)
    print('dataset processed')
    return tokenized_datasets


class PRMDataCollator():
    def __init__(self, tokenizer,data_config,custom_attention_flag=False):
        self.data_config = data_config
        self.tokenizer = tokenizer
        self.text_data_collator = DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding='longest')
        self.custom_attention_flag=custom_attention_flag
    def __call__(self, features: List) -> Dict[str,Any]:
        # text_data = [f[0] for f in features]
        batch = self.text_data_collator.__call__(features)
        # batch['attention_mask'] = hf_attention_mask_2d_to_4d(batch)
        # if self.custom_attention_flag:
        #     batch['attention_mask'] = custom_attention_mask_2d_to_4d(batch,self.data_config)
        # print(batch['attention_mask'].shape)
        return batch