from copy import deepcopy
import json
import yaml
import random
import argparse
import numpy as np
from collections import Counter

import os
import torch
from transformers import BitsAndBytesConfig
from datasets import Dataset
from prompts import prompt_factory

train_tag = 'train'
val_tag = 'valid'
test_tag = 'test'
msk_test_tag = 'msk_test'


def read_cli():
    parser = argparse.ArgumentParser(description='Train model.')
    parser.add_argument(
        "-cfg",
        "--config",
        help="Path to config file",
        required=False,
        type=str,
    )
    parser.add_argument(
        "-model_path",
        "--model_path",
        help="Path to checkpoint",
        required=False,
        type=str,
        default=None
    )
    # OVERRIDES
    parser.add_argument(
        "-datapath",
        "--datapath",
        help="Data path to use instead of training",
        required=False,
        type=str,
        default=None
    )
    parser.add_argument(
        "-lr",
        "--learning_rate",
        help="Learning rate",
        required=False,
        type=float,
        default=None
    )
    parser.add_argument(
        "-fp16",
        "--fp16",
        help="Enable FP16",
        required=False,
        type=str,
        default=None,
        choices=['True', 'False']
    )
    parser.add_argument(
        "-warmup_ratio",
        "--warmup_ratio",
        help="Warm up ratio",
        required=False,
        type=float,
        default=None
    )
    parser.add_argument(
        "-bsz",
        "--batch_size",
        help="Batch size",
        required=False,
        type=int,
        default=None
    )
    parser.add_argument(
        "-seed",
        "--seed",
        help="Seed",
        required=False,
        type=int,
        default=None
    )
    parser.add_argument(
        "-rs",
        "--result_path",
        help="Result path",
        required=False,
        type=str,
        default=None
    )
    parser.add_argument(
        "-it",
        "--infer_tag",
        help="Infer Path",
        required=False,
        type=str,
        default=None
    )
    parser.add_argument(
        "-slot_preds",
        "--slot_preds",
        help="Slot Predictions path",
        required=False,
        type=str,
        default=None
    )
    args = vars(parser.parse_args())

    return args


def override_config(cfg, ocfg):
    tags = []

    if ocfg['datapath'] is not None:
        cfg['datapath'] = ocfg['datapath']
        tags.append(('datapath', ocfg['datapath']))

    if ocfg['learning_rate'] is not None:
        cfg['train']['learning_rate'] = ocfg['learning_rate']
        tags.append(('learning_rate', ocfg['learning_rate']))

    if ocfg['fp16'] is not None:
        cfg['train']['fp16'] = ocfg['fp16']
        tags.append(('fp16', ocfg['fp16']))

    if ocfg['warmup_ratio'] is not None:
        cfg['train']['warmup_ratio'] = ocfg['warmup_ratio']
        tags.append(('warmup_ratio', ocfg['warmup_ratio']))

    if ocfg['batch_size'] is not None:
        cfg['train']['per_device_train_batch_size'] = ocfg['batch_size']
        tags.append(('batch_size', ocfg['batch_size']))

    if ocfg['seed'] is not None:
        cfg['train']['seed'] = ocfg['seed']
        tags.append(('seed', ocfg['seed']))

    if ocfg['model_path'] is not None:
        cfg['model_path'] = ocfg['model_path']

    if ocfg['result_path'] is not None:
        cfg['result_path'] = ocfg['result_path']

    if ocfg['infer_tag'] is not None:
        cfg['infer_tag'] = ocfg['infer_tag']

    if ocfg['slot_preds'] is not None:
        cfg['slot_preds'] = ocfg['slot_preds']

    if len(tags) > 0:
        print(tags)
        tag = '_'.join([
            f"{k}:{v}" for (k, v) in tags
        ])
        cfg['experiment_name'] = cfg['experiment_name'] + '_' + tag

    return cfg


def get_config(config_file):
    print(f'Reading config from', config_file)
    with open(config_file, 'r') as fp:
        cfg = yaml.safe_load(fp)
    
    return cfg


def get_joint_config(wandb_cfg):
    base_cfg = get_config(wandb_cfg['base_config'])
    base_cfg['base_name'] = base_cfg['experiment_name']

    ov_keys = [base_cfg['experiment_name']]
    for key1 in wandb_cfg.keys():
        if key1 not in base_cfg:
            continue
        for key2 in wandb_cfg[key1].keys():
            base_cfg[key1][key2] = wandb_cfg[key1][key2]
            ov_keys.append(f"{key1}-{key2}:{wandb_cfg[key1][key2]}")
    base_cfg['experiment_name'] = '_'.join(ov_keys)

    return base_cfg


def load_json(fname):
    with open(fname, 'r') as fp:
        obj = json.load(fp)
    return obj


def get_session_samples(session, history_size, prompt_type, bayesnet=None, note_type='text', desc_type='text', use_rag=False):
    data = []
    context = []

    USER = prompt_factory[prompt_type]['user']
    SYSTEM = prompt_factory[prompt_type]['system']

    assert not (bayesnet is not None and use_rag), f"BayesNet and RAG cannot be None at the same time. BayesNet: {bayesnet}, RAG: {use_rag}"

    for kk, entry in enumerate(session):
        context.append(f"Patient: {entry['user'].strip('<sos_u>').strip('<eos_u>').strip()}")

        clen = len(context)
        if history_size == -1:
            st = 0
        else:
            st = max(clen - history_size, 0)
        en = len(context)

        history = '\n'.join(context[st:en])
        if note_type == 'json':
            clinical_note = '```json\n' + json.dumps(entry['dst']) + '\n```'
        elif note_type == 'text':
            clinical_note = create_clinical_note(entry['dst'])
        else:
            raise NotImplementedError

        user = deepcopy(USER)
        user = user.replace('{{dialogue}}', history)
        user = user.replace('{{dialogue_state}}', clinical_note)
        system = deepcopy(SYSTEM)
        system = system.replace("{{answer}}", json.dumps(entry['pol']))

        if bayesnet is not None:
            posterior = bayesnet.get_posterior(entry['dst'])
            hints = create_posterior_note(posterior, desc_type)
            user = user.replace('{{hints}}', hints)
        elif use_rag:
            posterior = entry['rag_results']
            hints = create_posterior_note(posterior, desc_type="rag")
            user = user.replace('{{hints}}', hints)

        sample = deepcopy(entry)
        sample['dial_id'] = entry['dial_id']
        sample['turn_num'] = entry['turn_num']
        sample['prompt_input'] = [{'role': 'user', 'content': user}]
        sample['prompt_output'] = [{'role': 'assistant', 'content': system}]
        sample['task'] = 'pol'
        sample['context'] = deepcopy(context)
        sample['task_output'] = json.dumps(entry['pol'])
        data.append(sample)

        context.append(f"Doctor: {entry['resp'].strip('<sos_r>').strip('<eos_r>').strip()}")

    return data


def load_data(datapath, tag, history_size, prompt_type, note_type, desc_type):
    """Loads dataset from the datapath"""
    ttag = tag
    if tag == 'valid':
        ttag = 'dev'
    with open(os.path.join(datapath, f"fine-processed-{ttag}.json"), "r") as fp:
        raw_data = json.load(fp)

    print(f"Loaded {len(raw_data)} dialogs for {tag}...")
    print("History Size", history_size)

    data = []
    for jj, session in enumerate(raw_data):
        samples = get_session_samples(session, history_size, prompt_type, bayesnet=None, note_type=note_type, desc_type=desc_type)
        data.extend(samples)

    print(f'Total samples: {len(data)}')
    dataset = Dataset.from_list(data)

    return dataset


def load_data_bayesnet(datapath, tag, history_size, prompt_type, bayesnet, note_type, desc_type):
    """Loads dataset from the datapath"""
    ttag = tag
    if tag == 'valid':
        ttag = 'dev'
    with open(os.path.join(datapath, f"fine-processed-{ttag}.json"), "r") as fp:
        raw_data = json.load(fp)

    print(f"Loaded {len(raw_data)} dialogs for {tag}...")
    print("History Size", history_size)
    assert prompt_type in ['bayesnet']
    print('Using BayesNet...')


    data = []
    for ss, session in enumerate(raw_data):
        samples = get_session_samples(session, history_size, prompt_type, bayesnet, note_type, desc_type)
        data.extend(samples)

    print(f'Total samples: {len(data)}')
    dataset = Dataset.from_list(data)

    return dataset


def load_data_rag(datapath, tag, history_size, prompt_type, note_type, desc_type):
    """Loads dataset from the datapath"""
    ttag = tag
    if tag == 'valid':
        ttag = 'dev'
    with open(os.path.join(datapath, f"fine-processed-{ttag}.json"), "r") as fp:
        raw_data = json.load(fp)

    print(f"Loaded {len(raw_data)} dialogs for {tag}...")
    print("History Size", history_size)

    data = []
    for session in raw_data:
        samples = get_session_samples(
            session, history_size, prompt_type, bayesnet=None, note_type=note_type, desc_type=desc_type,
            use_rag=True
        )
        data.extend(samples)

    print(f'Total samples: {len(data)}')
    dataset = Dataset.from_list(data)

    return dataset


# -------------------------------------------------- NLU
def get_session_samples_nlu(session, history_size, prompt_type):
    data = []
    context = []

    USER = prompt_factory[prompt_type]['user']
    SYSTEM = prompt_factory[prompt_type]['system']
    last_action = []

    for kk, entry in enumerate(session):
        context.append(f"Patient: {entry['user'].strip('<sos_u>').strip('<eos_u>').strip()}")

        clen = len(context)
        if history_size == -1:
            st = 0
        else:
            st = max(clen - history_size, 0)
        en = len(context)

        history = '\n'.join(context[st:en])
        user = deepcopy(USER)
        user = user.replace('{{dialogue}}', history)
        tmp = json.dumps(last_action)
        user = user.replace('{{last_doctor_action}}', tmp)
        system = deepcopy(SYSTEM)
        system = system.replace("{{answer}}", json.dumps(entry['nlu']))

        sample = deepcopy(entry)
        sample['dial_id'] = entry['dial_id']
        sample['turn_num'] = entry['turn_num']
        sample['prompt_input'] = [{'role': 'user', 'content': user}]
        sample['prompt_output'] = [{'role': 'assistant', 'content': system}]
        sample['task'] = 'nlu'
        sample['context'] = deepcopy(context)
        sample['task_output'] = json.dumps(entry['nlu'])
        data.append(sample)

        context.append(f"Doctor: {entry['resp'].strip('<sos_r>').strip('<eos_r>').strip()}")
        last_action = deepcopy(entry['pol'])

    return data


def load_data_nlu(datapath, tag, history_size, prompt_type):
    """Loads dataset from the datapath"""
    ttag = tag
    if tag == 'valid':
        ttag = 'dev'
    with open(os.path.join(datapath, f"fine-processed-{ttag}.json"), "r") as fp:
        raw_data = json.load(fp)

    print(f"Loaded {len(raw_data)} dialogs for {tag}...")
    print("History Size", history_size)

    data = []
    for session in raw_data:
        samples = get_session_samples_nlu(session, history_size, prompt_type)
        data.extend(samples)

    print(f'Total samples: {len(data)}')
    dataset = Dataset.from_list(data)

    return dataset


# -------------------------------------------------- NLG
def get_session_samples_nlg(session, history_size, prompt_type):
    data = []
    context = []

    USER = prompt_factory[prompt_type]['user']
    SYSTEM = prompt_factory[prompt_type]['system']

    for kk, entry in enumerate(session):
        context.append(f"Patient: {entry['user'].strip('<sos_u>').strip('<eos_u>').strip()}")

        clen = len(context)
        if history_size == -1:
            st = 0
        else:
            st = max(clen - history_size, 0)
        en = len(context)

        history = '\n'.join(context[st:en])
        user = deepcopy(USER)
        user = user.replace('{{dialogue}}', history)
        tmp = json.dumps(entry['pol'])
        user = user.replace('{{doctor_action}}', tmp)
        system = deepcopy(SYSTEM)
        system = system.replace("{{answer}}", entry['nlg'])

        sample = deepcopy(entry)
        sample['dial_id'] = entry['dial_id']
        sample['turn_num'] = entry['turn_num']
        sample['prompt_input'] = [{'role': 'user', 'content': user}]
        sample['prompt_output'] = [{'role': 'assistant', 'content': system}]
        sample['task'] = 'nlg'
        sample['context'] = deepcopy(context)
        sample['task_output'] = entry['nlg']
        data.append(sample)

        context.append(f"Doctor: {entry['resp'].strip('<sos_r>').strip('<eos_r>').strip()}")

    return data


def load_data_nlg(datapath, tag, history_size, prompt_type):
    """Loads dataset from the datapath"""
    ttag = tag
    if tag == 'valid':
        ttag = 'dev'
    with open(os.path.join(datapath, f"fine-processed-{ttag}.json"), "r") as fp:
        raw_data = json.load(fp)

    print(f"Loaded {len(raw_data)} dialogs for {tag}...")
    print("History Size", history_size)

    data = []
    for session in raw_data:
        samples = get_session_samples_nlg(session, history_size, prompt_type)
        data.extend(samples)

    print(f'Total samples: {len(data)}')
    dataset = Dataset.from_list(data)

    return dataset


def get_bitsandbytes_config(quantization=None):
    if torch.cuda.get_device_capability()[0] >= 8 or quantization == 4:
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=False, # We don't know if this is okay.
        )

    return BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=False, # We don't know if this is okay.
    )


def create_clinical_note(dialog_state):
    slots_with_canon_heads = [
        'symptom',
        'habit',
        'medical_history',
        'family_history',
        'medication',
        'medical_test',
        'exposure',
        'disease'
    ]
    slots_without_canon_heads = [
        'occupation',
        'travel',
        'basic_information',
        'residence'
    ]
    slot_descriptions = {
        'positive_symptom': 'Positive Symptoms (current symptoms experienced by the patient)',
        'negative_symptom': 'Negative Symptoms (symptoms not experienced by the patient)',
        'unknown_symptom': 'Unknown Symptoms (symptoms whose status is unknown)',
        'positive_habit': 'Positive Habits (habits the patient has)',
        'negative_habit': 'Negative Habits (habits the patient does not have)',
        'unknown_habit': 'Unknown Habits (habits whose status is unknown)',
        'positive_medical_history': 'Positive Past Medical History (previous medical conditions experienced by the patient)',
        'negative_medical_history': 'Negative Past Medical History (medical conditions the patient has not experienced)',
        'unknown_medical_history': 'Unknown Past Medical History (medical conditions whose status is unknown)',
        'positive_family_history': 'Positive Family History (medical conditions present in the patient\'s family)',
        'negative_family_history': 'Negative Family History (medical conditions not present in the patient\'s family)',
        'unknown_family_history': 'Unknown Family History (medical conditions whose status is unknown)',
        'positive_medication': 'Positive Medications (medications the patient is taking)',
        'negative_medication': 'Negative Medications (medications the patient is not taking)',
        'unknown_medication': 'Unknown Medications (medications whose status is unknown)',
        'avail_medical_test': 'Available Medical Tests (medical tests the patient has undergone)',
        'unavail_medical_test': 'Unavailable Medical Tests (medical tests the patient has not undergone)',
        'positive_exposure': 'Positive Exposures (exposures the patient has had)',
        'negative_exposure': 'Negative Exposures (exposures the patient has not had)',
        'unknown_exposure': 'Unknown Exposures (exposures whose status is unknown)',
        'occupation': 'Patient Occupation Details',
        'travel': 'Patient Travel History',
        'basic_information': 'Patient General Information',
        'residence': 'Patient Accommodation Details'
    }

    canon_slots = [x for x in slots_with_canon_heads if x != 'disease']
    func = lambda x: ' '.join([y.capitalize() for y in x.split('_')])
    func2 = lambda x: x.lower().replace('_', ' ')
    lines = []
    for ii, slot in enumerate(dialog_state.keys()):
        # lines.append(f"## {func(slot)}")
        if slot not in slot_descriptions:
            continue 
        lines.append(f"## {slot_descriptions[slot]}")
        flag = any(x in slot for x in canon_slots)

        for jj, entry in enumerate(dialog_state[slot]):
            if entry.get('value') in ['other', 'general', 'all']:
                continue

            if flag:
                tmp = f'{jj + 1}. {func2(entry["value"])}'
                lines.append(tmp)

            for k, v in entry.items():
                if flag and k == 'value':
                        continue

                if type(v) == list:
                    lines.append(f'  - {func2(k)}: {", ".join([func2(z) for z in v])}')
                else:
                    lines.append(f'  - {func2(k)}: {func2(v)}')

        lines.append('')

    return '\n'.join(lines).strip()


def create_posterior_note(posterior, desc_type):
    content = []
    if desc_type == 'text':
        template = """## {{disease}} (Likelihood - {{score}})
{{text}}
"""
    elif desc_type == 'rag':
        template = """## {{disease}}
{{text}}
"""        
    else:
        template = """## {{disease}} (Likelihood - {{score}})"""

    for entry in posterior:
        disease = entry['disease']
        score = round(100.0 * entry['score'], 5)
        text = entry['text']
        tmp = deepcopy(template)
        tmp = tmp.replace('{{disease}}', disease.capitalize())
        tmp = tmp.replace('{{score}}', str(score))
        if desc_type in ['text', 'rag']:
            tmp = tmp.replace('{{text}}', text)
        content.append(tmp)

    return '\n'.join(content)
