# -*- coding: utf-8 -*-
import copy
import json
from tqdm import tqdm
from typing import List, Optional, Dict, Sequence

import torch
from datasets import load_dataset, Dataset
from torch import nn
from torch.utils import data
from dataclasses import dataclass, field
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "bespokelabs/Bespoke-Stratos-7B"
MAX_LENGTH = 4096


def construct_qwen() -> nn.Module:
    config = AutoConfig.from_pretrained(
        MODEL_NAME, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto"
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        from_tf=False,
        config=config,
        ignore_mismatched_sizes=False,
        trust_remote_code=True,
    )
    return model


# +
class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path, tokenizer: transformers.PreTrainedTokenizer):

        print(f"Loading jsonl from {data_path}...")
        list_data_dict = []
        with open(data_path, "r", encoding="utf-8") as f:
            for idx, line in enumerate(f):
#                 if idx >= 16:
#                     break
                list_data_dict.append(json.loads(line))
                
        sources = []
        targets = []
        for data in list_data_dict:
            system_content = data.get('system', '')
            conversations = data.get('conversations', [])
            
            source = f"<|im_start|>system\n{system_content}<|im_end|>\n"
            source += f"<|im_start|>user\n{conversations[0]['value']}<|im_end|>\n<|im_start|>assistant\n"
            target = conversations[1]['value']
            sources.append(source)
            targets.append(target)
            
        print('Successfully initialize dataset!')
            
#         sources = [example['prompt'] for example in list_data_dict]
#         targets = [f"{example['response']}{tokenizer.eos_token}" for example in list_data_dict]
    
        self.sources = sources
        self.targets = targets

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, i):
        if isinstance(i, list):  
            return {
                "input_ids": [self.sources[idx] for idx in i],
                "labels": [self.targets[idx] for idx in i],
            }
        else:
            return dict(input_ids=self.sources[i], labels=self.targets[i])


# -

class ReformattedSupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path, tokenizer: transformers.PreTrainedTokenizer):

        print(f"Loading jsonl from {data_path}...")
        list_data_dict = []
        with open(data_path, "r", encoding="utf-8") as f:
            for idx, line in enumerate(f):
                list_data_dict.append(json.loads(line))
                
        sources = []
        targets = []
        for data in list_data_dict:
            sources.append(data.get("input_ids"))
            targets.append(data.get("labels"))
            
        print('Successfully initialize dataset!')
    
        self.sources = sources
        self.targets = targets

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, i):
        if isinstance(i, list):  
            return {
                "input_ids": [self.sources[idx] for idx in i],
                "labels": [self.targets[idx] for idx in i],
            }
        else:
            return dict(input_ids=self.sources[i], labels=self.targets[i])


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
#     tokenized_list = [
#         tokenizer(
#             text,
#             return_tensors="pt",
#             padding="max_length",
#             max_length=MAX_LENGTH,
#             truncation=True,
#         )
#         for text in strings
#     ]
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=MAX_LENGTH,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
#     print(input_ids_lens)
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = -100
    return dict(input_ids=input_ids, labels=labels)


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        sources = []
        targets = []
        for instance in instances:
            source = instance['input_ids']
            target = instance['labels']
            sources.append(source)
            targets.append(target)

        data_dict = preprocess(sources, targets, self.tokenizer)
        input_ids, labels = data_dict['input_ids'], data_dict['labels']
        # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def get_data_collator():

    # Load tokenizer...
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)
    
    return DataCollatorForSupervisedDataset(tokenizer=tokenizer)


def get_bs_dataset(
    indices: List[int] = None,
) -> data.Dataset:
    
    # Load tokenizer...
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)
    
    data_path = "bs17k.jsonl"
    ds = SupervisedDataset(data_path, tokenizer)
    
    if indices is not None:
        ds = ds.select(indices)

    return ds


def get_top_influential_test_math_dataset(
    indices: List[int] = None,
) -> data.Dataset:
    # Load tokenizer...
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)
    
    data_path = "test_math_top_influence_traindata.jsonl"
    ds = ReformattedSupervisedDataset(data_path, tokenizer)
    
    if indices is not None:
        ds = ds.select(indices)

    return ds


def get_think_truncate_train_math_dataset(
    indices: List[int] = None,
) -> data.Dataset:
    # Load tokenizer...
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)
    
    data_path = "test_math_think_truncate_traindata.jsonl"
    ds = ReformattedSupervisedDataset(data_path, tokenizer)
    
    if indices is not None:
        ds = ds.select(indices)

    return ds


def get_top_influential_test_code_dataset(
    indices: List[int] = None,
) -> data.Dataset:
    # Load tokenizer...
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)
    
    data_path = "test_code_top_influence_traindata.jsonl"
    ds = ReformattedSupervisedDataset(data_path, tokenizer)
    
    if indices is not None:
        ds = ds.select(indices)

    return ds