import argparse
from data_utils import custom_load_dataset
from utils import *
import math
import gym
from gym import spaces
import copy

import string
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize, sent_tokenize
from supar import Parser
from nltk.tokenize.treebank import TreebankWordDetokenizer

# Editing the instruction
def delete_phrase(candidate, phrase):
    if candidate.find(' ' + phrase) > 0:
        answer = candidate.replace(' ' + phrase, ' ')
    elif candidate.find(phrase + ' ') > 0:
        answer = candidate.replace(phrase + ' ', ' ')
    else: 
        answer = candidate.replace(phrase, '')
    return answer

def add_phrase(candidate, phrase, after):
    if after == '': answer = phrase + ' ' + candidate
    else: 
        if candidate.find(' ' + after) > 0:
            answer = candidate.replace(' ' + after, ' ' + after + ' ' + phrase)
        elif candidate.find(after + ' ') > 0:
            answer = candidate.replace(after + ' ', after + ' ' + phrase + ' ')
        else: 
            answer = candidate.replace(after, after + phrase )
    return answer

def swap_phrases(candidate, phrase_1, phrase_2):
    if candidate.find(' ' + phrase_1 + ' ') >= 0 : 
        candidate = candidate.replace(' ' + phrase_1 + ' ', ' <1> ')
    else: candidate = candidate.replace(phrase_1, '<1>')
    if candidate.find(' ' + phrase_2 + ' ') >= 0 : 
        candidate = candidate.replace(' ' + phrase_2 + ' ', ' <2> ')
    else: candidate = candidate.replace(phrase_2, '<2>')
    candidate = candidate.replace('<1>', phrase_2)
    candidate = candidate.replace('<2>', phrase_1)
    return candidate

def substitute_phrase(candidate, phrase):
    num_beams = 10
    num_return_sequences = 10
    paraphrases = get_response(phrase, num_return_sequences, num_beams)
    paraphrase = np.random.choice(paraphrases, 1)[0] 
    paraphrase = paraphrase.strip('.')
    if candidate.find(' ' + phrase) > 0:
        answer = candidate.replace(' ' + phrase, ' ' + paraphrase)
    elif candidate.find(phrase + ' ') > 0:
        answer = candidate.replace(phrase + ' ', paraphrase + ' ')
    else: 
        answer = candidate.replace(phrase, paraphrase)
    return answer

def perform_edit(edit, base, phrase_lookup, delete_tracker):
    if edit == 'del':
        [i] = np.random.choice(list(phrase_lookup.keys()), 1) 
        return delete_phrase(base, phrase_lookup[i]), [i]
    elif edit == 'swap':
        try: [i, j] = np.random.choice(list(phrase_lookup.keys()), 2, replace=False) 
        except: [i, j] = np.random.choice(list(phrase_lookup.keys()), 2, replace=True) 
        return swap_phrases(base, phrase_lookup[i], phrase_lookup[j]), [i, j]
    elif edit == 'sub':
        [i] = np.random.choice(list(phrase_lookup.keys()), 1) 
        return substitute_phrase(base, phrase_lookup[i]), [i]
    elif edit == 'add':
        keys = list(phrase_lookup.keys())
        keys.append(-1)
        [i] = np.random.choice(keys, 1) 
        if i >= 0: after = phrase_lookup[i]
        else: after = ''
        if len(delete_tracker) == 0: return base, []
        phrase = np.random.choice(delete_tracker, 1)[0]
        return add_phrase(base, phrase, after), [phrase]

# Tokenize the sentence
def traverse_tree(parsed_tree):
    phrases = []
    for tree in parsed_tree:
        if tree.label() == '_': continue
        phrases.append(detokenize(tree.leaves()))
        for subtree in tree:
            if type(subtree) == nltk.tree.Tree:
                if subtree.label() == '_': continue
                phrases.append(detokenize(subtree.leaves()))
                phrases.extend(traverse_tree(subtree))
    return phrases

def check_child(tree):
    check = False
    count = 0
    total_count = 0
    for subtree in tree:
        total_count += 1
        if type(subtree) == nltk.tree.Tree:
            if subtree.label() == '_':
                count += 1
    if count >= total_count - count: check = True

    return check

def collect_leaves(parsed_tree):
    leaves = []
    for tree in parsed_tree:
        if type(parsed_tree) != nltk.tree.Tree: continue
        if tree.label() == '_': 
            leaves.append(detokenize(tree.leaves()))
            continue
        if check_child(tree): leaves.append(detokenize(tree.leaves()))
        else:
            leaves.extend(collect_leaves(tree))
    return leaves

def get_phrases(instruction, parser): # one possible way of obtaining disjoint phrases
    phrases = []
    for sentence in sent_tokenize(instruction):
        parsed_tree = parser.predict(word_tokenize(sentence), verbose=False).sentences[0].trees[0]
        leaves = collect_leaves(parsed_tree)
        phrases.extend(leaves)
    phrases = [detokenize(word_tokenize(phrase)) for phrase in phrases if phrase not in string.punctuation or phrase == '']
    return phrases

def eval_accuracy(all_label_probs, test_labels, mode=None, p_cf=None):
    # evaluate the accuracy with and without contextual calibration
    num_classes = all_label_probs.shape[1]
    if p_cf is None:
        # do not calibrate
        W = np.identity(num_classes)
        b = np.zeros([num_classes, 1])
    else:
        # calibrate
        if mode == "diagonal_W":
            W = np.linalg.inv(np.identity(num_classes) * p_cf)
            b = np.zeros([num_classes, 1])
        elif mode == "identity_W":
            W = np.identity(num_classes)
            b = -1 * np.expand_dims(p_cf, axis=-1)
        else:
            assert False

    correctness_list = []
    total_list = []
    assert len(all_label_probs) == len(test_labels)
    for label_probs, true_label in zip(all_label_probs, test_labels):
        label_probs = label_probs / np.sum(label_probs) # normalize to 1

        calibrate_label_probs = np.matmul(W, np.expand_dims(label_probs, axis=-1)) + b

        ans_label = np.argmax(calibrate_label_probs)
        if ans_label == true_label:
            correctness_list.append(1)
        else:
            correctness_list.append(0)
        total_list.append(1)
    return np.sum(correctness_list), np.sum(total_list)

def get_label_probs(params, raw_resp, train_sentences, train_labels, test_sentences):
    """Obtain model's label probability for each of the test examples. The returned prob is NOT normalized"""
    num_classes = len(params['label_dict'])
    approx = params['approx']
    assert len(raw_resp) == len(test_sentences)

    # Fill in the labels that is in the top k prob
    all_label_probs = []
    all_missing_positions = []
    for i, ans in enumerate(raw_resp):
        top_logprobs = ans['logprobs']['top_logprobs'][0]  # [0] since we only ask for complete one more token
        label_probs = [0] * len(params['label_dict'].keys())
        #TODO: changes here
        # label_probs = [1e-12] * len(params['label_dict'].keys())
        for j, label_list in params['label_dict'].items():
            all_found = True
            for label in label_list:  # each possible label correspond to the same class
                label = " " + label  # notice prompt does not have space after 'A:'
                if label in top_logprobs:
                    label_probs[j] += np.exp(top_logprobs[label])
                else:
                    all_found = False
            if not all_found:
                position = (i, j) # (which test example, which label)
                all_missing_positions.append(position)
                # TODO: change this
                label_probs = [1/len(params['label_dict'].keys())] * len(params['label_dict'].keys())
        all_label_probs.append(label_probs)
    all_label_probs = np.array(all_label_probs) # prob not normalized

    # Fill in the label probs that are NOT in top k probs, by asking the model to rate perplexity
    # This helps a lot in zero shot as most labels wil not be in Top 100 tokens returned by LM
    if (not approx) and (len(all_missing_positions) > 0):
        print(f"Missing probs: {len(all_missing_positions)}/{len(raw_resp) * num_classes}")
        all_additional_prompts = []
        num_prompts_each = []
        for position in all_missing_positions:
            which_sentence, which_label = position
            test_sentence = test_sentences[which_sentence]
            label_list = params['label_dict'][which_label]
            for label in label_list:
                prompt = construct_prompt(params, train_sentences, train_labels, test_sentence)
                prompt += " " + label
                all_additional_prompts.append(prompt)
            num_prompts_each.append(len(label_list))

        # chunk the prompts and feed into model
        chunked_prompts = list(chunks(all_additional_prompts, chunk_size_helper(params)))
        all_probs = []
        for chunk_id, chunk in enumerate(chunked_prompts):
            resp = complete(chunk, 0, params['model'], echo=True, num_log_probs=1)
            for ans in resp['choices']:
                prob = np.exp(ans['logprobs']['token_logprobs'][-1])
                all_probs.append(prob)

        assert sum(num_prompts_each) == len(all_probs)
        assert len(num_prompts_each) == len(all_missing_positions)

        # fill in corresponding entries in all_label_probs
        for index, num in enumerate(num_prompts_each):
            probs = []
            while num > 0:
                probs.append(all_probs.pop(0))
                num -= 1
            prob = np.sum(probs)
            i, j = all_missing_positions[index]
            all_label_probs[i][j] = prob

        assert len(all_probs) == 0, "all should be popped"
        assert (all_label_probs > 0).all(), "all should be populated with non-zero value"

    return all_label_probs # NOT NORMALIZED


def get_phrase_lookup(base_candidate, parser):
    return {p:phrase for p, phrase in enumerate(get_phrases(base_candidate, parser))}
    # Not used for now
    if args.level == 'phrase': phrase_lookup = {p:phrase for p, phrase in enumerate(get_phrases(base_candidate))}
    elif args.level == 'word': 
        words = word_tokenize(base_candidate)
        words = [w for w in words if w not in string.punctuation or w != '']
        phrase_lookup = {p:phrase for p, phrase in enumerate(words)}
    elif args.level == 'sentence':
        sentences = sent_tokenize(base_candidate)
        phrase_lookup = {p:phrase for p, phrase in enumerate(sentences)}
    elif args.level == 'span':
        phrases = []
        for sentence in sent_tokenize(base_candidate):
            spans_per_sentence = np.random.choice(range(2,5)) # split sentence into 2, 3, 4, 5 chunks
            spans = np.array_split(word_tokenize(sentence), spans_per_sentence)
            spans = [detokenize(s) for s in spans]
            phrases.extend(spans)
        phrase_lookup = {p:phrase for p, phrase in enumerate(phrases)}
    else: raise ValueError()
    return phrase_lookup

def detokenize(tokens):
    return TreebankWordDetokenizer().detokenize(tokens)

# This environment supports parallel
class LMForwardEnvNoPrefix(gym.Env):
  """Custom Environment that follows gym interface"""

  def __init__(self, params, prompt_sentence_pool, prompt_label_pool, train_sentences, train_labels, max_steps, num_processes, obs_size, entropy_coef=0, loss_type='ce', verbalizer=False, evaluate=False):
    super(LMForwardEnvNoPrefix, self).__init__()
    self.params = params
    self.prompt_sentence_pool = prompt_sentence_pool
    self.prompt_label_pool = prompt_label_pool
    self.train_sentences = train_sentences
    self.train_labels = train_labels
    self.current_prompt = self.prompt_sentence_pool
    self.current_prompt_labels = self.prompt_label_pool
    self.deleted_prompt = []
    self.deleted_prompt_labels = []
    self.latent_type = 'embedding'
    self.loss_type = loss_type
    self.max_steps = max_steps
    self.subset_size = num_processes
    self.evaluate = evaluate
    self.rew_scale = 10
    self.entropy_coef = entropy_coef
    self.verbalizer = verbalizer

    # Prefix editing
    '''
    self.prefix_candidate = detokenize(word_tokenize(params['prompt_prefix']))
    self.parser = Parser.load('crf-con-en')
    self.phrase_lookup = get_phrase_lookup(self.prefix_candidate, self.parser)
    self.phrase_length = len(self.phrase_lookup)
    self.phrase_total_length = int(self.phrase_length*(self.phrase_length-1)/2)
    self.prefix_phrase_total_length = self.phrase_total_length + int(params['num_shots']*(params['num_shots']-1)/2)+1+2
    '''
    self.prefix_phrase_total_length = params['num_shots']-1+1+2

    self.current_sentence = None
    self.current_label = None
    self.previous_loss = None
    self.idxs = None
    self.steps = 0

    # 2 for adding and deleting
    self.swap_idxs1 = []
    self.swap_idxs2 = []
    self.swap_idxs1.append(0)
    self.swap_idxs2.append(0)
    for i in range(1):
        for j in range(i+1, params['num_shots']):
            self.swap_idxs1.append(i)
            self.swap_idxs2.append(j)
    # 2 for adding and deleting
    self.swap_idxs1.append(-1)
    self.swap_idxs2.append(-1)
    self.swap_idxs1.append(-2)
    self.swap_idxs2.append(-2)
    
    # Prefix indexs
    '''
    self.swap_prefix_idxs1 = []
    self.swap_prefix_idxs2 = []
    self.swap_prefix_idxs1.append(0)
    self.swap_prefix_idxs2.append(0)
    for i in range(self.phrase_length):
        for j in range(i+1, self.phrase_length):
            self.swap_prefix_idxs1.append(i)
            self.swap_prefix_idxs2.append(j)
    '''

    # 3 Verbalizer Dataset formating
    from promptsource.templates import DatasetTemplates
    self.prompt_templates = DatasetTemplates(params['dataset'])
    self.prompt_template_keys = self.prompt_templates.all_template_names
    self.current_verbalizer = []
    self.deleted_verbalizer = []
    self.subset_verbalizer = []
    self.prefix_phrase_verbalizer_total_length = self.prefix_phrase_total_length + 2*params['num_shots']+1

    from datasets import load_dataset
    # dataset = load_dataset('super_glue', 'rte', split='train')
    # dataset = load_dataset('trec', split='train')
    '''
    print(dataset[0])
    print(self.prompt_templates[self.prompt_template_keys[0]].apply(dataset[0]))
    exit()
    '''

    if self.verbalizer:
        self.action_space = spaces.Discrete(self.prefix_phrase_verbalizer_total_length + len(self.prompt_template_keys)+1)
    else:
        self.action_space = spaces.Discrete(self.prefix_phrase_total_length)
    print(self.action_space)
    self.observation_space = spaces.Box(-np.inf, np.inf, (obs_size,))

  def verbalize(self, current_sentences, current_verbalizer, subset=False):
    if subset: 
        return_sentences = []
        for sentences, verbalizer in zip(current_sentences, current_verbalizer):
            prompt = self.prompt_templates[self.prompt_template_keys[verbalizer]]
            return_sentences.append(prompt.apply(sentences)[0])
        return return_sentences
    else:
        return_sentences = []
        for sentences, verbalizer in zip(current_sentences, current_verbalizer):
            return_sentences.append([])
            for i, sentence in enumerate(sentences):
                prompt = self.prompt_templates[self.prompt_template_keys[verbalizer[i]]]
                # print(sentence)
                # data = {'sentence': sentence, 'text': sentence, 'label': label, 'label-coarse': label, 'label-fine': label}
                # print('new ', prompt.apply(sentence)[0])
                # exit()
                return_sentences[-1].append(prompt.apply(sentence)[0])
        return return_sentences

  def step(self, action):
    # Execute one time step within the environment
    action = action.squeeze(-1)
    # swap_idx1 = [self.swap_idxs1[act] for act in action]
    # swap_idx2 = [self.swap_idxs2[act] for act in action]
    # print(self.current_prompt)
    idx = 0 
    for act, sentence, label, delete_sentence, delete_label, verbalizer, delete_verbalizer, subset_verbalizer in zip(action.tolist(), self.current_prompt, self.current_prompt_labels, self.deleted_prompt, self.deleted_prompt_labels, self.current_verbalizer, self.deleted_verbalizer, self.subset_verbalizer):
        # print(idx1, idx2, len(sentence), len(label), len(delete_sentence), len(delete_label))
        if act < self.prefix_phrase_total_length:
            #TODO: maybe we need to swap verbalizer as we swap example
            idx1 = self.swap_idxs1[act]
            idx2 = self.swap_idxs2[act]
            if idx1 == -1 and idx2 == -1 and len(sentence) > 0 and len(label) > 0:
                delete_sentence.append(copy.deepcopy(sentence[0]))
                delete_label.append(copy.deepcopy(label[0]))
                delete_verbalizer.append(copy.deepcopy(verbalizer[0]))
                sentence.pop(0)
                label.pop(0)
                verbalizer.pop(0)
            elif idx1 == -2 and idx2 == -2 and len(delete_sentence) > 0 and len(delete_label) > 0:
                sentence.append(copy.deepcopy(delete_sentence[0]))
                label.append(copy.deepcopy(delete_label[0]))
                verbalizer.append(copy.deepcopy(delete_verbalizer[0]))
                delete_sentence.pop(0)
                delete_label.pop(0)
                delete_verbalizer.pop(0)
            elif idx1 >= 0 and idx2 >= 0:
                if idx1 < len(sentence) and idx2 < len(sentence):
                    sentence[idx1], sentence[idx2] = sentence[idx2], sentence[idx1]
                    label[idx1], label[idx2] = label[idx2], label[idx1]
                    verbalizer[idx1], verbalizer[idx2] = verbalizer[idx2], verbalizer[idx1]
        elif self.verbalizer and act < self.prefix_phrase_verbalizer_total_length:
            act = act - self.prefix_phrase_total_length
            verbalize_idx = act % self.params['num_shots']
            if act == 2*self.params['num_shots']:
                pass
            elif act < self.params['num_shots'] and verbalize_idx < len(verbalizer):
                if verbalizer[verbalize_idx] + 1 < len(self.prompt_template_keys):
                    verbalizer[verbalize_idx] += 1
            elif act >= self.params['num_shots'] and verbalize_idx < len(verbalizer):
                if verbalizer[verbalize_idx] - 1 >= 0:
                    verbalizer[verbalize_idx] -= 1
        elif self.verbalizer:
            act = act - self.prefix_phrase_verbalizer_total_length
            if act < len(self.prompt_template_keys):
                subset_verbalizer = act

        idx += 1
    # print(self.current_prompt)

    if self.verbalizer:
        verbalized_prompt = self.verbalize(self.current_prompt, self.current_verbalizer)
        subset_sentences = self.verbalize(self.subset_sentences, self.subset_verbalizer, subset=True)
    else:
        verbalized_prompt = self.current_prompt
        subset_sentences = self.subset_sentences
    raw_resp, obs = get_model_response_parallel(self.params, verbalized_prompt, self.current_prompt_labels, subset_sentences)
    all_label_probs = get_label_probs(self.params, raw_resp, verbalized_prompt, self.current_prompt_labels, subset_sentences)
    # print('step ', self.steps)

    assert len(all_label_probs) == len(self.subset_labels)
    # label_probs = all_label_probs / (np.sum(all_label_probs, axis=-1, keepdims=True) + (np.sum(all_label_probs, axis=-1, keepdims=True) == 0))
    label_probs = all_label_probs / np.sum(all_label_probs, axis=-1, keepdims=True)
    if self.loss_type == 'ce':
        onehot = np.zeros((all_label_probs.shape))
        onehot[np.arange(all_label_probs.shape[0]), np.array(self.subset_labels)] = 1
        loss = -np.sum(onehot*np.log(label_probs+1e-6), axis=-1)
        entropy = -np.sum(label_probs*np.log(label_probs+1e-6), axis=-1)
        reward = self.previous_loss - self.entropy_coef * entropy - loss
    elif self.loss_type == 'step':
        predicts = torch.max(label_probs, -1).cpu().numpy()
        correct = (predicts == np.array(self.subset_labels)).astype(float)
        correct_probs = label_probs[np.arange(all_label_probs.shape[0]), np.array(self.subset_labels)]
        predict_probs = label_probs[np.arange(all_label_probs.shape[0]), predicts.cpu().numpy()]
        loss = predict_probs - correct_probs
        loss = correct * self.correct_bonus + (1 - correct) * loss * self.incorrect_bonus
    # Reward Scaling
    reward = reward * self.rew_scale
    self.previous_loss = copy.deepcopy(loss)

    self.steps += 1
    if self.steps >= self.max_steps:
        done = np.ones(self.subset_size)
        if self.evaluate:
            correct, total = eval_accuracy(all_label_probs, self.subset_labels)
            info = {'episode_r': loss, 'correct': correct, 'total': total, 'orig_correct': self.orig_correct, 'orig_total': self.orig_total}
            # print('final ', all_label_probs)
        else:
            info = {'episode_r': loss}
    else:
        done = np.zeros(self.subset_size)
        info = {'episode_r': loss}

    return obs, reward, done, info

  def reset(self, idx=None):
    if self.idxs is not None and self.evaluate:
        subset_idxs = self.idxs
    else:
        subset_idxs = np.random.permutation(len(self.train_sentences))[:self.subset_size]
    self.subset_sentences = [self.train_sentences[i] for i in subset_idxs]
    # self.subset_sentences = self.train_sentences.select(subset_idxs)
    self.subset_labels = [self.train_labels[i] for i in subset_idxs]
    self.steps = 0
    # First sample a batch of sentences
    self.current_prompt = [copy.deepcopy(self.prompt_sentence_pool) for _ in range(len(self.subset_sentences))]
    self.current_prompt_labels = [copy.deepcopy(self.prompt_label_pool) for _ in range(len(self.subset_sentences))]
    self.deleted_prompt = [[] for _ in range(len(self.subset_sentences))]
    self.deleted_prompt_labels = [[] for _ in range(len(self.subset_sentences))]
    # Reset the verbalizer
    self.current_verbalizer = [[0 for _ in range(self.params['num_shots'])] for _ in range(len(self.subset_sentences))]
    self.deleted_verbalizer = [[] for _ in range(len(self.subset_sentences))]
    self.subset_verbalizer = [0 for _ in range(len(self.subset_sentences))]

    if self.verbalizer:
        verbalized_prompt = self.verbalize(self.current_prompt, self.current_verbalizer)
        subset_sentences = self.verbalize(self.subset_sentences, self.subset_verbalizer, subset=True)
    else:
        verbalized_prompt = self.current_prompt
        subset_sentences = self.subset_sentences
    raw_resp, obs = get_model_response_parallel(self.params, verbalized_prompt, self.current_prompt_labels, subset_sentences)
    all_label_probs = get_label_probs(self.params, raw_resp, verbalized_prompt, self.current_prompt_labels, subset_sentences)
    # print('reset ', self.idxs, ' ', all_label_probs)

    if self.evaluate:
        self.orig_correct, self.orig_total = eval_accuracy(all_label_probs, self.subset_labels)

    assert len(all_label_probs) == len(self.subset_labels)
    # label_probs = all_label_probs / (np.sum(all_label_probs, axis=-1, keepdims=True) + (np.sum(all_label_probs, axis=-1, keepdims=True) == 0))
    label_probs = all_label_probs / np.sum(all_label_probs, axis=-1, keepdims=True)
    if self.loss_type == 'ce':
        onehot = np.zeros((all_label_probs.shape))
        onehot[np.arange(all_label_probs.shape[0]), np.array(self.subset_labels)] = 1
        loss = -np.sum(onehot*np.log(label_probs+1e-6), axis=-1)
        entropy = -np.sum(label_probs*np.log(label_probs+1e-6), axis=-1)
        self.previous_loss = copy.deepcopy(loss) - self.entropy_coef * entropy
    elif self.loss_type == 'step':
        predicts = torch.max(label_probs, -1).cpu().numpy()
        correct = (predicts == np.array(self.subset_labels)).astype(float)
        correct_probs = label_probs[np.arange(all_label_probs.shape[0]), np.array(self.subset_labels)]
        predict_probs = label_probs[np.arange(all_label_probs.shape[0]), predicts.cpu().numpy()]
        loss = predict_probs - correct_probs
        loss = correct * self.correct_bonus + (1 - correct) * loss * self.incorrect_bonus

    return obs

