import json
import os
import torch
import math
from typing import Sequence, List, Set, Dict
import transformers
from torch.utils.data import Dataset
from dataclasses import dataclass
from torch.utils.data import DataLoader
from transformers.trainer_pt_utils import LabelSmoother
import sys
import re
from verl.utils.conversation import get_conv_template

IGNORE_TOKEN_ID = LabelSmoother.ignore_index

MATCH_TASK_PATTERN_DICT = {
    "cifar10": r'Test Accuracy:\s*([0-9.]+)%',
    "dog-breed-identification": r', Test Accuracy:\s*([0-9.]+)',
    "feedback": r"Validation MCRMSE:\s*([0-9.]+)",
    "house-price": r"Validation MAE:\s*([0-9.]+).",
    "aerial-cactus-identification": r"Test Accuracy:\s*([0-9.]+)%",
    "dogs-vs-cats-redux-kernels-edition": r"Test Accuracy:\s*([0-9.]+)%",
    "nomad2018-predict-transparent-conductors": r"average root mean squared log error:\s*([0-9.]+)",
    "ogbn-arxiv": r"Test Accuracy:\s*([0-9.]+)%",
    "plant-pathology-2020-fgvc7": r"Test Accuracy:\s*([0-9.]+)",
    "spaceship-titanic": r"Validation Accuracy:\s*([0-9.]+)%",
}

BENCHMARK_BORDERLINE_DICT = {
    "dog-breed-identification": 0.78,
    'cifar10': 51.48,
    'ogbn-arxiv': 31.35,
    'house-price':18450,
    'aerial-cactus-identification':75.43,
    'dogs-vs-cats-redux-kernels-edition':79.32,
    'plant-pathology-2020-fgvc7':0.3934,
    'spaceship-titanic':57.12,
    'feedback':0.68298,
    'nomad2018-predict-transparent-conductors':0.0674,
}

BENCHMARK_SCALE_DICT = {
    "cifar10": 2.2701475595913735,
    "ogbn-arxiv": 1.749475157452764,
    "house-price": -0.005420054200542005,
    "aerial-cactus-identification": 4.070004070004069,
    "dogs-vs-cats-redux-kernels-edition": -1.3154779131258385,
    "plant-pathology-2020-fgvc7": 169.19042382201167,
    "spaceship-titanic": 2.550434849141779,
    "feedback": -400.60250616927857,
    "nomad2018-predict-transparent-conductors": -6097.560975609755,
}


def make_training_verifier_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_args: dataclass
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    data_dir = os.path.join(data_args.data_dir, data_args.task)
    train_json = json.load(open(os.path.join(data_dir, "reward/train/data.json"), "r"))
    train_dataset = ProcessVerifierDataset(
                        samples=train_json,
                        tokenizer=tokenizer, 
                        template_name=data_args.template_name,
                        loss_level=data_args.loss_level,
                        loss_on_llm=data_args.loss_on_llm,
                    )
    
    val_json = json.load(open(os.path.join(data_dir, "reward/test/data.json"), "r"))
    val_dataset = ProcessVerifierDataset(
                        samples=val_json,
                        tokenizer=tokenizer, 
                        template_name=data_args.template_name,
                        loss_level=data_args.loss_level,
                        loss_on_llm=data_args.loss_on_llm,
                    )

    return dict(train_dataset=train_dataset, eval_dataset=val_dataset)


def preprocess_sft(
    sources, tokenizer: transformers.PreTrainedTokenizer, template_name: str
) -> Dict:
    conv = get_conv_template(template_name) 

    conv_strs = []
    targets = []
    conv = get_conv_template(template_name) 
    print(f"Preprocessing data...{len(sources)}")
    for sample in sources:
        conv.messages = []
        conv.append_message(conv.roles[0], sample['input'])
        conv.append_message(conv.roles[1], None)
        qns_str = conv.get_prompt()
        target = [IGNORE_TOKEN_ID]*len(tokenizer(qns_str).input_ids)

        for step_id in range(len(sample['action_sequence'])):
            conv.update_last_message(sample['action_sequence'][step_id])
            if template_name == "llama-3":
                step_tokens = tokenizer(conv.get_prompt()).input_ids[len(target):]  # do not remove last token
                target.extend(step_tokens)
            elif template_name == "llama-2":
                step_tokens = tokenizer(conv.get_prompt()).input_ids[len(target):-1]  # Remove last token
                target.extend(step_tokens)
                target.extend([IGNORE_TOKEN_ID]) # '<s>' is masked

            if step_id != len(sample['action_sequence']) - 1:
                conv.append_message(conv.roles[0], sample['observation_sequence'][step_id])
                conv.append_message(conv.roles[1], None)
                obs_len = len(tokenizer(conv.get_prompt()).input_ids) - len(target)
                target.extend([IGNORE_TOKEN_ID]*obs_len)
        conv_strs.append(conv.get_prompt())
        targets.append(torch.tensor(target))
    
    tokenized_inputs = tokenizer(conv_strs, padding=False)
    
    return dict(
        input_ids=tokenized_inputs['input_ids'],
        labels=targets
    )


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, sources, tokenizer: transformers.PreTrainedTokenizer, template_name: str = "llama-2"):
        super(SupervisedDataset, self).__init__()
        self.tokenizer = tokenizer
        self.max_length = self.tokenizer.model_max_length
        data_dict = preprocess_sft(sources, tokenizer, template_name)
        
        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if self.labels[i].shape[0] > self.max_length:
            self.labels[i] = self.labels[i][:self.max_length]
            self.input_ids[i] = self.input_ids[i][:self.max_length]

        return dict(
            input_ids=torch.tensor(self.input_ids[i]),
            labels=self.labels[i],
        )
    
    def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))

        input_ids, attention_mask = right_pad_sequences(input_ids, padding_value=self.tokenizer.pad_token_id, return_attention_mask=True)
        labels = right_pad_sequences(labels, padding_value=IGNORE_TOKEN_ID, return_attention_mask=False) 
        
        return dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )


def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_args
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    data_dir = os.path.join(data_args.data_dir, data_args.task)
    train_json = json.load(open(os.path.join(data_dir, "sft/train/data.json"), "r"))
    train_dataset = SupervisedDataset(train_json, tokenizer=tokenizer, template_name=data_args.template_name)

    eval_json = json.load(open(os.path.join(data_dir, "sft/test/data.json"), "r"))
    eval_dataset = SupervisedDataset(eval_json, tokenizer=tokenizer, template_name=data_args.template_name)

    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)


class ProcessVerifierDataset(torch.utils.data.Dataset):
    """Right Padding
        <s>[INST] and [/INST] are both masked
    """
    def __init__(
        self, 
        samples,
        tokenizer: transformers.PreTrainedTokenizer = None, 
        template_name: str = 'llama-2',
        loss_level: str = 'token', 
        loss_on_llm: bool = False,
        mode: str = "train",
    ):
        self.tokenizer = tokenizer
        self.samples = samples
        self.loss_level = loss_level
        self.loss_on_llm = loss_on_llm
        assert loss_level in ('token', 'step')

        self.pad_token_id = tokenizer.pad_token_id
        self.max_length = self.tokenizer.model_max_length

        # ============== Process the step labels into sigmoid distribution ==============
        for sample in self.samples:
            sample['action_rewards'] = [sigmoid_normalize(reward) for reward in sample['action_rewards']]
        
        # ================= Add Conversation Template =================
        conv_strs = []
        targets = []
        step_labels = []
        conv = get_conv_template(template_name) 
        print(f"Preprocessing Reward Data...{len(self.samples)}")
        for sample in self.samples:
            conv.messages = []
            conv.append_message(conv.roles[0], sample['input'])
            conv.append_message(conv.roles[1], None)
            qns_str = conv.get_prompt()
            target = [IGNORE_TOKEN_ID] * len(tokenizer(qns_str).input_ids)
            step_label_seq = [IGNORE_TOKEN_ID] * len(tokenizer(qns_str).input_ids)  # Corresponding step labels

            for step_id in range(len(sample['action_sequence'])):
                conv.update_last_message(sample['action_sequence'][step_id])
                if template_name == "llama-3":
                    step_tokens = tokenizer(conv.get_prompt()).input_ids[len(target):]  # do not remove last token
                    target.extend(step_tokens)
                    step_label_seq.extend([sample['action_rewards'][step_id]] * len(step_tokens))
                elif template_name == "llama-2":
                    step_tokens = tokenizer(conv.get_prompt()).input_ids[len(target):-1]  # Remove last token
                    target.extend(step_tokens)
                    target.extend([IGNORE_TOKEN_ID]) # '<s>' is masked
                    step_label_seq.extend([sample['action_rewards'][step_id]] * len(step_tokens))
                    step_label_seq.extend([IGNORE_TOKEN_ID])

                if step_id != len(sample['action_sequence']) - 1:
                    conv.append_message(conv.roles[0], sample['observation_sequence'][step_id])
                    conv.append_message(conv.roles[1], None)
                    obs_len = len(tokenizer(conv.get_prompt()).input_ids) - len(target)
                    target.extend([IGNORE_TOKEN_ID]*obs_len)
                    step_label_seq.extend([IGNORE_TOKEN_ID]*obs_len)
            conv_strs.append(conv.get_prompt())
            step_labels.append(torch.tensor(step_label_seq))
            targets.append(torch.tensor(target))
        
        self.input_ids = tokenizer(conv_strs, padding=False).input_ids
        self.labels = targets
        self.step_labels = step_labels

        self.n_question = len(self.samples)
        
    def __len__(self):
        return self.n_question

    def _flatten(self, ls):
        return [item for sublist in ls for item in sublist]

    def __getitem__(self, i): # TODO: Smart truncation: keep complete step
        if self.labels[i].shape[0] > self.max_length:
            self.labels[i] = self.labels[i][:self.max_length]
            self.step_labels[i] = self.step_labels[i][:self.max_length]
            self.input_ids[i] = self.input_ids[i][:self.max_length]

        return dict(
            input_ids=torch.tensor(self.input_ids[i]), labels=self.labels[i], v_labels=self.step_labels[i] 
        )

    def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels, v_labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "v_labels"))

        input_ids, attention_mask = right_pad_sequences(input_ids, padding_value=self.pad_token_id, return_attention_mask=True)
        labels = right_pad_sequences(labels, padding_value=IGNORE_TOKEN_ID, return_attention_mask=False) if self.loss_on_llm else None
        v_labels = right_pad_sequences(v_labels, padding_value=IGNORE_TOKEN_ID, return_attention_mask=False)
        
        return dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            v_labels=v_labels,
        )


def right_pad_sequences(sequences: List[torch.LongTensor], padding_value: int, return_attention_mask: bool = False):
    padded_sequences = torch.nn.utils.rnn.pad_sequence(
        sequences,
        batch_first=True,
        padding_value=padding_value,
    )
    if return_attention_mask:
        attention_mask = padded_sequences.ne(padding_value)
        return padded_sequences, attention_mask
    return padded_sequences


def mask_labels(labels: List[int], masks: List[bool]):
    """Mask the corresponding label into IGNORE_INDEX"""
    assert len(labels) == len(masks)
    return [
        token if mask
        else IGNORE_TOKEN_ID
        for token, mask in zip(labels, masks) 
    ]

def sigmoid_normalize(reward, alpha=0.2, beta=0):
    return 1 / (1 + math.exp(-alpha * reward - beta))

def post_env_rewards(task, conv):
    """Post-process the rewards from the environment"""
    assert task in list(MATCH_TASK_PATTERN_DICT.keys())
    task_pattern = MATCH_TASK_PATTERN_DICT[task]
    task_borderline = BENCHMARK_BORDERLINE_DICT[task]

    states = []
    # TODO: MAY BE NEED TO CHANGE THE LOGIC WHEN THOUGHTS IN
    for step_id in range(0, len(conv.messages)-2, 2):
        if "Edit Script (AI)" in conv.messages[step_id+1][1] and step_id+4 < len(conv.messages) and "Execute Script" in conv.messages[step_id+3][1]:
            execute_result = conv.messages[step_id+4][1]
            task_strings = re.findall(task_pattern, execute_result)
            if len(task_strings) > 0: # executed successfully
                states.append(float(task_strings[-1]))
            else: # not executed successfully
                states.append(0)
        else:
            states.append(states[-1] if len(states) > 0 else task_borderline)
    
    rewards = [0] + [states[i] - states[i-1] for i in range(1, len(states))]

    return scale_rewards(task, rewards)


def scale_rewards(task, rewards):
    assert task in ["cifar10", "feedback", "house-price", "aerial-cactus-identification", "dogs-vs-cats-redux-kernels-edition", "nomad2018-predict-transparent-conductors", "ogbn-arxiv", "plant-pathology-2020-fgvc7", "spaceship-titanic"]
    for idx, reward in enumerate(rewards):
        rewards[idx] = BENCHMARK_SCALE_DICT[task]*reward
    rewards = [sigmoid_normalize(reward) for reward in rewards]
    return rewards




