import json
import os
import random

from copy import deepcopy
from tqdm import tqdm
import numpy as np
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

from dt.conversation import get_conv_template


IGNORE_INDEX = -100

task_to_keys = {
    "mnli": ("premise", "hypothesis"),
    # "mnli-mm": ("premise", "hypothesis"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
}

# task_to_keys = {
#     "mnli": ("premise", "hypothesis"),
#     # "mnli-mm": ("premise", "hypothesis"),
#     # "qnli": ("question", "sentence"),
#     "qqp": ("question1", "question2"),
#     # "rte": ("sentence1", "sentence2"),
#     "sst2": ("sentence", None),
# }

answer_mapping = {
    "sst2": {"negative": 0, "positive": 1},
    "mnli": {"yes": 0, "maybe": 1, "no": 2},
    # "mnli-mm": {"yes": 0, "maybe": 1, "no": 2},
    "qnli": {"yes": 0, "no": 1},
    "qqp": {"yes": 1, "no": 0},
    "rte": {"yes": 0, "no": 1},
}
answer_mapping = {k: {label: completion for completion, label in v.items()} for k, v in answer_mapping.items()}

TASK2SHORTPROMPT = {
    "sst2": "For the given sentence, label the sentiment of the sentence as positive or negative. The answer should be "
            "exactly 'positive' or 'negative'.",
    "mnli": "Please identify whether the premise entails the hypothesis. The answer should be exactly 'yes', 'maybe' or"
            " 'no'.",
    "qnli": "Please identify whether the sentence answers the question. The answer should be exactly 'yes' or 'no'.",
    "qqp": "Please identify whether question 1 has the same meaning as question 2. The answer should be exactly 'yes' "
           "or 'no'.",
    "rte": "Please identify whether the premise entails the hypothesis. The answer should be exactly 'yes' or 'no'."
}


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.


def get_prompt(perspective, conv_template, prompt, completion):
    if perspective == 'adv':
        if conv_template is None:
            return prompt + '\n' + completion
        else:
            conv = conv_template.copy()
            conv.append_message(conv.roles[0], prompt)
            conv.append_message(conv.roles[1], completion)
            return conv.get_prompt()
    elif perspective == 'fair':
        if conv_template is None:
            return prompt + ' ' + completion
        else:
            conv = conv_template.copy()
            conv.system = "You are a helpful assistant."
            conv.append_message(conv.roles[0], prompt)
            conv.append_message(conv.roles[1], completion)
            return conv.get_prompt()


def prune_adv_datasets(dataset, adapter_name, mode='ft', seed=42):
    if adapter_name == 'adv':
        task_to_count = {task: 0 for task in task_to_keys.keys()}
        if mode == 'ft':
            task_to_ratio = {task: 1 for task in task_to_keys.keys()}
            task_to_ratio["qnli"] = 0.2
        else:
            task_to_ratio = {task: 1 for task in task_to_keys.keys()}
            task_to_ratio["qnli"] = 0.2

        for i in dataset:
            task_to_count[i['task']] += 1
        task_to_data = {task: [i for i in dataset if i['task'] == task] for task in task_to_keys.keys()}
        for task in task_to_keys.keys():
            assert len(task_to_data[task]) == task_to_count[task], task
        print(task_to_count)
        task_to_count = {k: int(v * task_to_ratio[k]) for k, v in task_to_count.items()}
        print(task_to_count)
        for task, ratio in task_to_ratio.items():
            if ratio != 1:
                task_to_data[task] = random.sample(task_to_data[task], task_to_count[task])
        dataset = []
        for task, data in task_to_data.items():
            dataset += data
        assert len(dataset) == sum([cnt for cnt in task_to_count.values()])
    elif adapter_name in ['fair', 'toxicity', 'truth']:
        pass
    else:
        raise Exception("No standard dataset")
    return dataset


def merge_standard_datasets(dataset, adapter_name, mode='ft', seed=42):
    if adapter_name == 'adv':
        task_to_count = {task: 0 for task in task_to_keys.keys()}
        if mode == 'ft':
            task_to_ratio = {task: 0 for task in task_to_keys.keys()}
            task_to_ratio["mnli"] = 1
        else:
            task_to_ratio = {task: 1 for task in task_to_keys.keys()}

        for i in dataset:
            task_to_count[i['task']] += 1
        print(task_to_count)
        task_to_count = {k: int(v * task_to_ratio[k]) for k, v in task_to_count.items()}
        print(task_to_count)
        for task, count in task_to_count.items():
            if mode == 'ft' and task not in ['mnli', 'qqp']:
                assert count == 0, task
            standard_data = load_dataset('glue', task)['train'].shuffle(seed=seed)
            for i in range(min(count, len(standard_data))):
                data = standard_data[i]
                key1, key2 = task_to_keys[task]
                message = f"{key1.replace('original_', '')}: {data[key1]}"
                if key2:
                    message = f"{message}\n{key2.replace('original_', '')}: {data[key2]}"
                message = f"{message}".replace('sentence1', 'premise').replace('sentence2', 'Hypothesis')

                prompt = f"{TASK2SHORTPROMPT[task]}\n{message}"
                completion = answer_mapping[task][data['label']]
                item = deepcopy(data)
                item['task'] = task
                item['prompt'] = prompt
                item['completion'] = completion
                # if i == 0:
                #     print(task, item)
                dataset.append(item)
    elif adapter_name in ['fair', 'toxicity', 'truth']:
        pass
    else:
        raise Exception("No standard dataset")
    return dataset


def process_dpo_item(item, perspective='adv'):
    conv_template = get_conv_template('llama-2')
    if perspective == 'adv':
        task = item['task']
        prompt = item['prompt']
        chosen = item['completion']
        task_answer_mapping = answer_mapping[task]
        max_num_answers = max(task_answer_mapping.keys())
        rejected = ''
        for label, answer in task_answer_mapping.items():
            if chosen == answer:
                rejected_label = max_num_answers - label
                if rejected_label == label:
                    assert task == 'mnli'
                    rejected_label = random.choice([0, 2])
                assert rejected_label in task_answer_mapping, (label, rejected_label, task)
                rejected = task_answer_mapping[rejected_label]
                break
        assert rejected != ''
        dpo_item = {
            # 'prompt': get_prompt(perspective, conv_template, prompt, None),
            'prompt': prompt + '\n',
            'chosen': chosen,
            'rejected': rejected,
        }
    elif perspective == 'fair':
        prompt = item['prompt']
        chosen = item['completion']
        options = item['option']
        assert len(options) == 2
        if chosen.lower() == options[0]:
            rejected = options[1]
        elif chosen.lower() == options[1]:
            rejected = options[0]
        else:
            raise Exception(f'Response not in options. Response: {chosen}. Options: {options}')
        if chosen.lower() != chosen:
            rejected = rejected.capitalize()
        dpo_item = {
            # 'prompt': get_prompt(perspective, conv_template, prompt, None),
            'prompt': prompt + '\n',
            'chosen': chosen,
            'rejected': rejected,
        }
    elif perspective == 'truth':
        prompt = item['question']
        chosen = item['correct']
        rejected = item['incorrect']
        dpo_item = {
            # 'prompt': get_prompt(perspective, conv_template, prompt, None),
            'prompt': prompt + '\n',
            'chosen': chosen,
            'rejected': rejected,
        }
    else:
        raise Exception('Not Implemented')
    return dpo_item


def cal_dpo_dataset_loss(perspective='adv', seed=42, val_ratio=0.01):
    data_file = ft_data[perspective]
    print(f'reading data from {data_file}')
    with open(data_file) as f:
        dataset = f.readlines()
    dataset = [json.loads(line.strip()) for line in dataset]
    dataset = prune_adv_datasets(dataset, perspective, mode='ft', seed=seed)
    dataset = merge_standard_datasets(dataset, perspective, mode='ft', seed=seed)
    dpo_dataset = []
    for item in tqdm(dataset):
        dpo_item = process_dpo_item(item, perspective)
        dpo_dataset.append(dpo_item)

    model_path = ft_model_path[perspective]
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model.cuda()

    loss_list = []
    for item in tqdm(dpo_dataset):
        chosen = item['prompt'] + item['chosen']
        rejected = item['prompt'] + item['rejected']
        chosen_inputs = tokenizer(chosen, return_tensors="pt").to('cuda')
        chosen_loss = model(**chosen_inputs, labels=chosen_inputs.input_ids).loss
        rejected_inputs = tokenizer(rejected, return_tensors="pt").to('cuda')
        rejected_loss = model(**rejected_inputs, labels=rejected_inputs.input_ids).loss
        loss_list.append((chosen_loss - rejected_loss).item())

    print(f'saving loss scores to {ft_model_loss_path[perspective]}')
    with open(ft_model_loss_path[perspective], 'w') as f:
        json.dump(loss_list, f)


def linear_map(scores, new_min, new_max):
    old_min = min(scores)
    old_max = max(scores)

    if old_min == old_max:
        raise ValueError("All scores are the same. Cannot perform linear mapping.")

    # Transform each score
    mapped_scores = [
        ((score - old_min) / (old_max - old_min)) * (new_max - new_min) + new_min
        for score in scores
    ]
    return mapped_scores


def weighted_sample(data, scores, percent):
    if len(data) != len(scores):
        raise ValueError("The length of data and scores must be the same.")

    sample_size = int(len(data) * (percent / 100))  # integer division to get half the size
    sampled_items = random.choices(data, weights=scores, k=sample_size)
    return sampled_items


def filter_dpo_dataset(dpo_dataset, perspective, filter_dpo=None):
    loss_list_file = ft_model_loss_path[perspective]
    with open(loss_list_file) as f:
        loss_list = json.load(f)

    # filter_dpo --> 0: keep larger loss, 1: keep smaller loss, 2: keep both sides
    if filter_dpo == 0:
        print('keeping larger loss')
        mapped_losses = linear_map(loss_list, 0, 1)
    elif filter_dpo == 1:
        print('keeping smaller loss')
        mapped_losses = linear_map(loss_list, -1, 0)
        mapped_losses = [-loss for loss in mapped_losses]
    elif filter_dpo == 2:
        print('keeping both sides')
        mapped_losses = linear_map(loss_list, -1, 1)
        mapped_losses = [abs(loss) for loss in mapped_losses]
    else:
        raise Exception(f'filter_dpo {filter_dpo} not supported')
    filtered_dataset = weighted_sample(dpo_dataset, mapped_losses, 50)

    return filtered_dataset


def load_dt_dataset(perspective='adv', seed=42, val_ratio=0.01, filter_dpo=None):
    if perspective == 'all_3':
        perspectives = ['adv', 'fair', 'truth']
        dpo_dataset = []
        for perspective in perspectives:
            data_file = ft_data[perspective]
            print(f'reading data from {data_file}')
            with open(data_file) as f:
                dataset = f.readlines()
            dataset = [json.loads(line.strip()) for line in dataset]
            dataset = prune_adv_datasets(dataset, perspective, mode='ft', seed=seed)
            dataset = merge_standard_datasets(dataset, perspective, mode='ft', seed=seed)
            # dpo_dataset = []
            for item in tqdm(dataset):
                dpo_item = process_dpo_item(item, perspective)
                dpo_dataset.append(dpo_item)
            if filter_dpo is not None:
                dpo_dataset = filter_dpo_dataset(dpo_dataset, perspective, filter_dpo)
            else:
                print('keeping all dpo dataset')
            print(len(dpo_dataset))
    else:
        data_file = ft_data[perspective]
        print(f'reading data from {data_file}')
        with open(data_file) as f:
            dataset = f.readlines()
        dataset = [json.loads(line.strip()) for line in dataset]
        dataset = prune_adv_datasets(dataset, perspective, mode='ft', seed=seed)
        dataset = merge_standard_datasets(dataset, perspective, mode='ft', seed=seed)
        dpo_dataset = []
        for item in tqdm(dataset):
            dpo_item = process_dpo_item(item, perspective)
            dpo_dataset.append(dpo_item)
        if filter_dpo is not None:
            dpo_dataset = filter_dpo_dataset(dpo_dataset, perspective, filter_dpo)
        else:
            print('keeping all dpo dataset')

    # if perspective == 'all_3':
    #     dpo_dataset = all_3_dpo_dataset
    dpo_dataset = Dataset.from_list(dpo_dataset).train_test_split(test_size=val_ratio, shuffle=True, seed=seed)

    return dpo_dataset['train'], dpo_dataset['test']


def compare_dt_dataset(perspective='adv', seed=42, val_ratio=0.01):
    data_file = ft_data[perspective]
    print(f'reading data from {data_file}')
    with open(data_file) as f:
        dataset = f.readlines()
    dataset = [json.loads(line.strip()) for line in dataset]
    dataset = prune_adv_datasets(dataset, perspective, mode='ft', seed=seed)
    dataset = merge_standard_datasets(dataset, perspective, mode='ft', seed=seed)
    dpo_dataset = []
    for item in tqdm(dataset):
        dpo_item = process_dpo_item(item, perspective)
        dpo_dataset.append(dpo_item)

    first_dpo_dataset = dpo_dataset

    set_seed(seed)
    data_file = ft_data[perspective]
    print(f'reading data from {data_file}')
    with open(data_file) as f:
        dataset = f.readlines()
    dataset = [json.loads(line.strip()) for line in dataset]
    dataset = prune_adv_datasets(dataset, perspective, mode='ft', seed=seed)
    dataset = merge_standard_datasets(dataset, perspective, mode='ft', seed=seed)
    dpo_dataset = []
    for item in tqdm(dataset):
        dpo_item = process_dpo_item(item, perspective)
        dpo_dataset.append(dpo_item)

    print(first_dpo_dataset == dpo_dataset)


if __name__ == '__main__':
    seed = 42
    set_seed(seed)

    train, test = load_dt_dataset('all_3', seed=seed)
    print(train)
    print(test)
    print(train[0])
    # print(train[5000])

    # cal_dpo_dataset_loss('adv', seed=seed)
    # compare_dt_dataset('adv', seed=seed)
