from klreg import BatchRunLLM, ChatEnv, init_prompt, get_model_and_tokenizer, get_sentiment_pipeline, PPO, run_episodes, RolloutBuffer

import asyncio
import torch
from datetime import datetime

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import math
import os
import sys
import argparse
import pickle
from collections import defaultdict
plt.rcParams['figure.titlesize'] = 18
plt.rcParams['axes.labelsize'] = 15

import matplotlib.scale as mscale
import matplotlib.transforms as mtransforms
import matplotlib.ticker as ticker
class SquareRootScale(mscale.ScaleBase):
    """
    ScaleBase class for generating square root scale.
    """
 
    name = 'squareroot'
 
    def __init__(self, axis, **kwargs):
        # note in older versions of matplotlib (<3.1), this worked fine.
        # mscale.ScaleBase.__init__(self)

        # In newer versions (>=3.1), you also need to pass in `axis` as an arg
        mscale.ScaleBase.__init__(self, axis)
 
    def set_default_locators_and_formatters(self, axis):
        # axis.set_major_locator(ticker.AutoLocator())
        axis.set_major_locator(ticker.MultipleLocator(base=1.))
        axis.set_major_formatter(ticker.ScalarFormatter())
        # axis.set_minor_locator(ticker.NullLocator())
        axis.set_minor_locator(ticker.MultipleLocator(base=0.1))
        axis.set_minor_formatter(ticker.NullFormatter())
 
    def limit_range_for_scale(self, vmin, vmax, minpos):
        return  max(0., vmin-1e-2), vmax
 
    class SquareRootTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True
 
        def transform_non_affine(self, a): 
            return np.array(a)**0.5
 
        def inverted(self):
            return SquareRootScale.InvertedSquareRootTransform()
 
    class InvertedSquareRootTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True
 
        def transform(self, a):
            return np.array(a)**2
 
        def inverted(self):
            return SquareRootScale.SquareRootTransform()
 
    def get_transform(self):
        return self.SquareRootTransform()
 
mscale.register_scale(SquareRootScale)

import transformers
from transformers import AutoTokenizer
mixtral_cache_dir = './mixtral_cache/'
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", cache_dir=mixtral_cache_dir)
def detokenize(tokens):
    """Takes a 1d tensor of integer tokens; returns a string"""
    return tokenizer.batch_decode(tokens.view(1, -1))[0]

def match(whole_chat, fragments):
    """whole_chat is composed by joining fragments together with spaces at
    various places (magically generated by tokenizer). This adds the spaces into
    fragments, and keeps track of which fragments are original."""
    originals = [True for _ in fragments]
    for i, frag in enumerate(fragments):
        if len(frag) == 0:
            fragments[i] = " "
    frag_i = 0
    frag_j = 0
    insertion_spots = []
    badness_count = 0
    for chat_i, char in enumerate(whole_chat):
        if frag_i >= len(fragments):
            break
            print("frag_i", frag_i)
            print("len(fragments)", len(fragments))
            print(fragments)
        if frag_j >= len(fragments[frag_i]):
            print("frag_i", frag_i)
            print("frag_j", frag_j)
            print("len(fragments[frag_i])", len(fragments[frag_i]))
            print(fragments)
        # if char == fragments[frag_i][frag_j]:
        #     frag_j += 1
        #     if frag_j == len(fragments[frag_i]):
        #         frag_j = 0
        #         frag_i += 1
        # else:
        #     message = f"char is not ' ' but rather {char} \n\n previously we had {whole_chat[:chat_i]} \n\n fragments[frag_i][frag_j] is {fragments[frag_i][frag_j]} \n\n previous fragments are {fragments[:frag_i+1]}"
        #     assert char == " ", message
        #     assert frag_j == 0, message
        #     insertion_spots.append(frag_i)
        if char != ' ' or char == fragments[frag_i][frag_j]:
            if char != fragments[frag_i][frag_j]:
                badness_count += 1
                new_frag_i, new_frag_j = frag_i, frag_j
                back_on_track = True
                while char != fragments[new_frag_i][new_frag_j]:
                    new_frag_j += 1
                    if new_frag_j == len(fragments[new_frag_i]):
                        new_frag_j = 0
                        new_frag_i += 1
                    if new_frag_i == len(fragments):
                        back_on_track = False
                        break
                if back_on_track:
                    frag_i, frag_j = new_frag_i, new_frag_j
                frag_j += 1
                if frag_j == len(fragments[frag_i]):
                    frag_j = 0
                    frag_i += 1
            else:
                frag_j += 1
            if frag_j == len(fragments[frag_i]):
                frag_j = 0
                frag_i += 1
        else:
            message = f"char is {char} \n\n previously we had {whole_chat[:chat_i]} \n\n fragments[frag_i] is {fragments[frag_i]} \n\n fragments[frag_i][frag_j] is {fragments[frag_i][frag_j]} \n\n previous fragments are {fragments[:frag_i+1]}"
            assert frag_j == 0, message
            insertion_spots.append(frag_i)
    out_fragments = fragments.copy()
    for i in insertion_spots[::-1]:
        out_fragments.insert(i, " ")
        originals.insert(i, False)
    if badness_count > 10:
        print("badness_count > 10, see below:")
        print(f"{whole_chat} \n\n fragments: {fragments} \n\n out_fragments: {out_fragments}")
        print("*"*256)
    return out_fragments, originals

letters = tuple('qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM')

def group_by_word(fragments_, originals, kls_by_token):
    new_frags = []
    new_origs = []
    new_kls_by_token = []
    is_last_letter_based = False
    j = 0
    for i, frag in enumerate(fragments_):
        if frag[0] not in letters or not is_last_letter_based:
            new_frags.append(frag)
            new_origs.append(originals[i])
            new_kls_by_token.append(kls_by_token[j])
        else:
            new_frags[-1] += frag
            assert(new_origs[-1] == originals[i])
            new_kls_by_token[-1] += kls_by_token[j]
        is_last_letter_based = frag[0] in letters
        if originals[i]:
            j += 1
    return new_frags, new_origs, new_kls_by_token

teacher_start = ['\n', 'Teacher', ':']
student_start = ['\n', 'Student', ':']

def str_to_wait_for(boolean):
    if boolean:
        return student_start
    return teacher_start

def make_is_speaking(fragments):
    is_speaking = []
    speaking_now = False
    for i, frag in enumerate(fragments):
        is_speaking.append(speaking_now)
        if i > 2:
            if str_to_wait_for(speaking_now) == fragments[i-2:i+1]:
                speaking_now = not speaking_now
    return is_speaking

def just_words(fragments, is_speaking, kls):
    words, is_speaking_, new_kls = [], [], []
    for frag, is_speak, kl in zip(fragments, is_speaking, kls):
        if frag[0] in letters:
            words.append(frag)
            is_speaking_.append(is_speak)
            new_kls.append(kl)
    return words, is_speaking_, new_kls

def find_n_grams(words, is_speaking, n):
    nth = [1 for _ in words]
    pointers = [None for _ in words]
    for i, _ in enumerate(words):
        if i+n > len(words):
            break
        if all(is_speaking[i:i+n]):
            j = i - 1
            while j >= 0:
                if words[i:i+n] == words[j:j+n] and all(is_speaking[j:j+n]):
                    nth[i] += nth[j]
                    pointers[j] = i
                    break
                j -= 1
    return pointers, nth

def cut_to_end_token(chat_tokens):
    fragments = tokenizer.batch_decode(chat_tokens.view(-1, 1))
    end = len(fragments) - 1
    for i in range(len(fragments)-2, -1, -1):
        if fragments[i] == '</s>':
            end = i
    return end+1

def find_n_grams_from_tokens(chat_tokens, max_n=10, kls=None):
    if kls is None:
        kls = [0. for _ in chat_tokens]
    fragments = tokenizer.batch_decode(chat_tokens.view(-1, 1))
    chat = detokenize(chat_tokens)
    fragments_, originals = match(chat, fragments)
    new_frags, new_origs, new_kls_by_token = group_by_word(fragments_, originals, kls)
    is_speaking_ = make_is_speaking(new_frags)
    words, is_speaking, new_kls = just_words(new_frags, is_speaking_, new_kls_by_token)
    pointers, nths = [], []
    for i in range(1, max_n+1):
        pointer, nth = find_n_grams(words, is_speaking, i)
        pointers.append(pointer)
        nths.append(nth)
    return pointers, nths, words, is_speaking, new_kls

def colors(n):
    palette = sns.color_palette("rocket_r", as_cmap=True).colors
    indices = [255 - int(255 * 12 / (12 + i)) for i in range(n)]
    return [palette[index] for index in indices][::-1]

def plot_counts(n_for_ngrams, nths, savefile):
    max_reps = max([max(nth) for nth in nths])
    color_meaning = [f"repeated at least {i} times" for i in range(max_reps, 1, -1)]
    counts = [[len([n for n in nth if n == i]) for nth in nths] for i in range(max_reps, 1, -1)]
    df = pd.DataFrame({c : count for c, count in zip(color_meaning, counts)}, index=[f"{n}-grams" for n in n_for_ngrams])
    df.plot(kind='bar', stacked=True, color=colors(len(color_meaning)), rot=1)
    plt.ylabel('Count')
    plt.savefig(savefile, bbox_inches='tight', dpi=500)

def plot_avgcounts(n_for_ngrams, many_nths, savefile):
    max_reps = max([max([max(nth) for nth in nths]) for nths in many_nths])
    color_meaning = [f"appear at least {i} times" for i in range(max_reps, 1, -1)]
    color_meaning[0] = "            that\n" + color_meaning[0] + "\n              ..."
    for i in range(1, len(color_meaning)-3):
        color_meaning[i] = '_' + color_meaning[i]
    many_counts = [[[len([n for n in nth if n >= i]) for nth in nths] for nths in many_nths] for i in range(max_reps, 1, -1)]
    counts = [[sum([many_counts[i][k][j] for k in range(len(many_nths))])/len(many_nths) for j in range(len(many_nths[0]))] for i in range(len(many_counts))]
    counts.insert(0, [0 for _ in many_nths[0]])
    diffs = [[counts[i+1][j] - counts[i][j] for j in range(len(many_nths[0]))] for i in range(len(many_counts))]
    many_counts_np = np.array(many_counts)
    counts_std = np.std(many_counts_np, axis=1) / np.sqrt(len(nths))
    df = pd.DataFrame({c : count for c, count in zip(color_meaning, diffs)}, index=[f"{n}-grams" + (" (i.e. words)" if n == 1 else "") for n in n_for_ngrams])
    try:
        df.plot(kind='bar', stacked=True, yerr=counts_std, color=colors(len(color_meaning)), rot=1)
    except:
        print(df)
        print(locals())
        assert(False)
    plt.ylabel('How many different')
    plt.savefig(savefile, bbox_inches='tight', dpi=500)

def ngram_kls(pointers, nths, words, is_speaking, new_kls):
    max_n = len(pointers)
    kl_for_lineplots = []
    appear_num = []
    ngrams = []
    ns = []
    for i in range(max_n):
        n = i + 1
        pointer, nth = pointers[i], nths[i]
        for j, point in enumerate(pointer):
            count = nth[j]
            new_pointer = point
            if count == 1:
                a = 1
                if new_pointer is not None:
                    kl = sum(new_kls[j:j+n])
                    kl_for_lineplots.append([kl])
                    appear_num.append([a])
                    ngrams.append(' '.join(words[j:j+n]))
                    ns.append(n)
                while new_pointer is not None:
                    a += 1
                    kl = sum(new_kls[new_pointer:new_pointer+n])
                    kl_for_lineplots[-1].append(kl)
                    appear_num[-1].append(a)
                    new_pointer = pointer[new_pointer]
    return ns, ngrams, appear_num, kl_for_lineplots

def plot_ngram_kls(ns, ngrams, appear_num, kl_for_lineplots, savefile):
    fig, ax = plt.subplots()
    for n, xs, ys in zip(ns, appear_num, kl_for_lineplots):
        if len(xs) > 2:
            sns.lineplot(x=xs, y=ys, ax=ax, color=colors(len(ns))[::-1][n], linewidth=0.5)
    ax.set_xlabel('ith appearance')
    ax.set_ylabel('total KL cost')
    plt.yscale('log')
    plt.savefig(savefile, bbox_inches='tight', dpi=500)

def plot_kl_per_response(kls_by_tokens, speaking_records, savefile):
    fig, ax = plt.subplots()
    all_xs = []
    all_ys = []
    for kls, speaking_record in zip(kls_by_tokens, speaking_records):
        ys = []
        prev_speaking = False
        response_kl = 0
        for kl, speaking in zip(kls, speaking_record):
            if speaking:
                response_kl += kl
            if not speaking and prev_speaking:
                ys.append(response_kl)
                response_kl = 0
            prev_speaking = speaking
        # if response_kl > 0:
        #     ys.append(response_kl)
        xs = list(range(len(ys)))
        sns.lineplot(x=xs, y=ys, ax=ax, linewidth=0.5)
        all_xs = all_xs + xs
        all_ys = all_ys + ys
    sns.lineplot(x=all_xs, y=all_ys, ci='sd', ax=ax)
    # plt.yscale('log')
    plt.savefig(savefile, bbox_inches='tight', dpi=500)

def ecdf_errors(ys, n):
    # gives 1 std error in the x-direction
    # y_l + np.sqrt(y_l * (1-y_l) / n) = y # derived from error bars for quantiles
    # y_h - np.sqrt(y_h * (1-y_h) / n) = y
    # quadratic formula...
    disc = np.sqrt((2*n * ys + 1)**2 - 4 * ys**2 * n*(n+1))
    y_h = ((2*n*ys + 1) + disc) / (2 * (n+1))
    y_l = ((2*n*ys + 1) - disc) / (2 * (n+1))
    return y_l, y_h

def plot_kl_per_empty_response(kls_by_tokens, speaking_records, id, budg, savefile):
    fig, ax = plt.subplots(figsize=(4, 8))
    all_xs = []
    all_ys = []
    n_plotted = 0
    frac_empty_responses = []
    n_total = 0
    for kls, speaking_record in zip(kls_by_tokens, speaking_records):
        ys = []
        prev_speaking = False
        response_kl = 0
        length = 0
        n_responses = 0
        for kl, speaking in zip(kls, speaking_record):
            if speaking:
                response_kl += kl
                length += 1
            if not speaking and prev_speaking:
                if length <= 4:
                    ys.append(response_kl)
                n_responses += 1
                response_kl = 0
                length = 0
            prev_speaking = speaking
        if response_kl > 0:
            ys.append(response_kl)
        # total not per repsone
        for i in range(len(ys) - 1):
            ys[i+1] += ys[i]
        # # normalized to end at 1
        # for i in range(len(ys)):
        #     ys[i] /= ys[-1]
        xs = list(range(len(ys)))
        sns.lineplot(x=xs, y=ys, ax=ax, linewidth=0.5)
        all_xs = all_xs + xs
        all_ys = all_ys + ys
        n_total += 1
        if len(ys) > 0:
            n_plotted += 1
        if n_responses == 0:
            n_responses = 1
        frac_empty_responses.append(len(ys) / n_responses)
    # sns.lineplot(x=all_xs, y=all_ys, ci='sd', ax=ax) # only good for instantaneous kl cost, not total
    # plt.yscale('log')
    ax.set_xlabel('ith empty response')
    ax.set_ylabel('total KL cost of empty responses so far')
    ax.set_title(f"KL costs of empty responses -- budget {budg} \n ({n_plotted}/{n_total} episodes have ≥1 empty response)")
    plt.savefig(savefile, bbox_inches='tight', dpi=500)
    plt.close()
    dir = 'pickled_outputs/empty_response_counts/'
    with open(dir+'{}_{}_{}.pkl'.format(budg, id, args.filestr), 'wb') as file:
        pickle.dump(frac_empty_responses, file)
    d_10 = defaultdict(lambda : [])
    d_20 = defaultdict(lambda : [])
    for filename in os.listdir(dir):
        with open(dir+filename, 'rb') as file:
            fracs = pickle.load(file)
            budg = filename.split("_")[0]
            id_ = filename.split("_")[1]
            if id_ not in [str(i) for i in range(20)]:
                continue
            else:
                ep_len = 512 if filename.split("_")[-2] == 'cont' else 256
            if budg == '10':
                d_10[f"{ep_len}"] = d_10[f"{ep_len}"] + fracs
            else:
                d_20[f"{ep_len}"] = d_20[f"{ep_len}"] + fracs
    fig, ax = plt.subplots(figsize=(4, 5))
    palette = sns.color_palette("muted", 4)
    x10, x20 = {256: None, 512: None}, {256: None, 512: None}
    y10, y20 = {256: None, 512: None}, {256: None, 512: None}
    y_l10, y_l20 = {256: None, 512: None}, {256: None, 512: None}
    y_h10, y_h20 = {256: None, 512: None}, {256: None, 512: None}
    for ep_len in [256, 512]:
        # use quantile error formula for ecdf error bars q +/- math.sqrt(q * (1-q) / len(kls_i))
        line = plt.ecdf(d_10[str(ep_len)], label=f"Budget 10; length {ep_len}", color=palette[ep_len//128 - 2])
        x10[ep_len], y10[ep_len] = line.get_data()
        y_l10[ep_len], y_h10[ep_len] = ecdf_errors(y10[ep_len], len(y10[ep_len]))
        plt.fill_between(x10[ep_len], y_l10[ep_len], y_h10[ep_len], alpha=0.5, color=palette[ep_len//128 - 2], edgecolor='none')
    for ep_len in [256, 512]:
        line = plt.ecdf(d_20[str(ep_len)], label=f"Budget 20; length {ep_len}", color=palette[ep_len//128 - 1])
        x20[ep_len], y20[ep_len] = line.get_data()
        y_l20[ep_len], y_h20[ep_len] = ecdf_errors(y20[ep_len], len(y20[ep_len]))
        plt.fill_between(x20[ep_len], y_l20[ep_len], y_h20[ep_len], alpha=0.5, color=palette[ep_len//128 - 1], edgecolor='none')
    plt.legend()
    plt.xlabel('Fraction of responses that are empty')
    plt.ylabel('Cumulative probability (steep = high prob. density)')
    plt.savefig('empty_frac_ecdf.png', bbox_inches='tight', dpi=500)
    plt.close()
    fig, ax = plt.subplots(figsize=(4, 5))
    for ep_len in [256, 512]:
        plt.plot(y10[ep_len]*100, x10[ep_len], label=f"Budget 10; length {ep_len}", color=palette[ep_len//128 - 2])
        plt.fill_betweenx(x10[ep_len], y_l10[ep_len]*100, y_h10[ep_len]*100, alpha=0.5, color=palette[ep_len//128 - 2], edgecolor='none')
    for ep_len in [256, 512]:
        plt.plot(y20[ep_len]*100, x20[ep_len], label=f"Budget 20; length {ep_len}", color=palette[ep_len//128 - 1])
        plt.fill_betweenx(x20[ep_len], y_l20[ep_len]*100, y_h20[ep_len]*100, alpha=0.5, color=palette[ep_len//128 - 1], edgecolor='none')
    plt.legend()
    plt.xlabel('Percentile')
    plt.ylabel('Fraction of responses that are empty')
    plt.savefig('empty_frac_percentiles.png', bbox_inches='tight', dpi=500)
    plt.close()
    df_10, df_20 = pd.DataFrame(d_10), pd.DataFrame(d_20)
    df_melted_10, df_melted_20 = pd.melt(df_10), pd.melt(df_20)
    df_melted_10 = df_melted_10.rename(columns={'variable': 'Transcript length'}).sort_values(['Transcript length'])
    df_melted_20 = df_melted_20.rename(columns={'variable': 'Transcript length'}).sort_values(['Transcript length'])
    # df_melted_10['Transcript length'] = pd.Categorical(df_melted_10['Transcript length'], categories=[256, 512], ordered=True)
    # df_melted_20['Transcript length'] = pd.Categorical(df_melted_20['Transcript length'], categories=[256, 512], ordered=True)
    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(4, 5))
    sns.histplot(data=df_melted_10, x='value', hue='Transcript length', kde=False, kde_kws={'bw_adjust': 0.5}, ax=ax1)
    sns.histplot(data=df_melted_20, x='value', hue='Transcript length', kde=False, kde_kws={'bw_adjust': 0.5}, ax=ax2)
    # plt.title('Multiple Histograms')
    plt.xlabel('Fraction of responses empty')
    plt.savefig('empty_frac_hist.png', bbox_inches='tight', dpi=500)
    plt.close()

def frac_idx(i_frac, lst):
    if len(lst) == 1 or i_frac <= 0:
        return lst[0]
    if i_frac >= 1:
        return lst[-1]
    i_float = i_frac * (len(lst) - 1)
    i_floor = int(i_float)
    i_ceil = i_floor + 1
    dec = i_float - i_floor
    return lst[i_floor] + dec * (lst[i_ceil] - lst[i_floor])

def plot_quantile_kl_per_empty_response(kls_by_tokens, speaking_records, id, budg, savefile, qs=[0.25, 0.5, 0.75]):
    dir = 'pickled_outputs/'
    with open(dir+'klsbytokens_{}_{}_{}.pkl'.format(budg, id, len(savefile)), 'wb') as file:
        pickle.dump((kls_by_tokens, speaking_records), file)
    for filename in os.listdir(dir):
        if filename.split("_")[-3] == str(budg) and filename.split("_")[-2] != str(id):
            with open(dir+filename, 'rb') as file:
                loaded_kls_by_tokens, loaded_speaking_records = pickle.load(file)
            if len(kls_by_tokens[0]) == len(loaded_kls_by_tokens[0]):
                kls_by_tokens = kls_by_tokens + loaded_kls_by_tokens
                speaking_records = speaking_records + loaded_speaking_records
    fig, ax = plt.subplots(figsize=(3, 4))
    kls_by_index = []
    for kls, speaking_record in zip(kls_by_tokens, speaking_records):
        ys = []
        prev_speaking = False
        response_kl = 0
        length = 0
        for kl, speaking in zip(kls, speaking_record):
            if speaking:
                response_kl += kl
                length += 1
            if not speaking and prev_speaking:
                if length <= 4:
                    ys.append(response_kl)
                response_kl = 0
                length = 0
            prev_speaking = speaking
        for i, y in enumerate(ys):
            while len(kls_by_index) < i + 1:
                kls_by_index.append([])
            kls_by_index[i].append(y)
    kls_by_index_sorted = [sorted(kls_i) for kls_i in kls_by_index]
    q_dict = {}
    for q in qs:
        q_dict[str(q)] = np.array([frac_idx(q, kls_i) for kls_i in kls_by_index_sorted])
        q_dict[str(q)+'up'] = np.array([frac_idx(q + math.sqrt(q * (1-q) / len(kls_i)), kls_i) for kls_i in kls_by_index_sorted])
        q_dict[str(q)+'down'] = np.array([frac_idx(q - math.sqrt(q * (1-q) / len(kls_i)), kls_i) for kls_i in kls_by_index_sorted])
    palette = sns.color_palette("muted", 3)
    x = range(len(kls_by_index_sorted))
    for i, q in enumerate(qs):
        plt.plot(x, q_dict[str(q)], label=f'Quantile {q}', color=palette[i])
        plt.fill_between(x, q_dict[str(q)+'down'], q_dict[str(q)+'up'], alpha=0.5, color=palette[i], edgecolor='none')
    ax.set_xlabel('nth empty response')
    ax.set_ylabel('KL cost of that response (quartiles)')
    ax.set_title(f"KL cost to remain silent \n budget {budg}")
    # plt.yscale('log')
    ax.set_yscale('squareroot')
    plt.savefig(savefile, bbox_inches='tight', dpi=500)
    plt.close()

def plot_base_pol_ents(base_pol_entss, savefile):
    fig, ax = plt.subplots()
    for base_pol_ents in base_pol_entss:
        xs = list(range(len(base_pol_ents)))
        sns.lineplot(x=xs, y=base_pol_ents, ax=ax, linewidth=0.5)
    plt.savefig(savefile, bbox_inches='tight', dpi=500)

def latex_colored_text(chat_tokens, kls):
    fragments_ = tokenizer.batch_decode(chat_tokens.view(-1, 1))
    chat = detokenize(chat_tokens)
    fragments, originals = match(chat, fragments_)

custom_colors = ['#805a19', '#865a13', '#8d5a0d', '#945a06', '#9b5900', '#a25900', '#a95700', '#b15600', '#b95300', '#c15100', '#ca4e00', '#d24900', '#db4500', '#e43e00', '#ed3700', '#f62c00', '#fb1e02', '#f30c00', '#eb0000', '#e20000', '#da0000', '#d20000', '#ca0000', '#c20000', '#ba0000', '#b30000', '#ab0000', '#a30000', '#9c0000', '#940000', '#8d0000', '#860000', '#830000', '#850000', '#860000', '#860000', '#870000', '#860000', '#850000', '#84000c', '#820016', '#80001e', '#7d0025', '#79002d', '#750034', '#6f003c', '#690044', '#61004c', '#5b004d', '#56004d', '#52004b', '#4d004a', '#490048', '#440047', '#400044', '#3c0042', '#380040', '#35003d', '#31003a', '#2e0037', '#2b0533', '#280a30', '#250e2c', '#221128']
def kl_color_mapping(kl, max_kl=10):
    x = kl / max_kl * (1-1e-6)
    x = np.sqrt(x)
    color = custom_colors[int(x * len(custom_colors))]
    rgb_tuple = tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
    return rgb_tuple

def latex_color_legend():
    out_str = ""
    for i in range(10):
        color = kl_color_mapping(i)
        out_str += "\\colorbox[RGB]{" + ",".join([str(int(val)) for val in color]) + "}{\\textcolor{white}{"+ str(i) +"}}"
        for j in range(1, 4):
            color = kl_color_mapping(i+j/4)
            out_str += "\\colorbox[RGB]{" + ",".join([str(int(val)) for val in color]) + "}{\\phantom{0}}"
    out_str += """

"""
    return out_str

def latex_colored_text(chat_tokens, kls, kl_color_mapping):
    sorted_indices = np.argsort(np.array(kls))[::-1]
    top_3_indices = sorted_indices[:3]
    in_tops = np.isin(np.arange(len(kls)), top_3_indices).tolist()
    fragments_ = tokenizer.batch_decode(chat_tokens.view(-1, 1))
    chat = detokenize(chat_tokens)
    fragments, originals = match(chat, fragments_)
    original_i = 0
    out_str = """
\\noindent
  \\begin{minipage}{\\dimexpr\\linewidth-2\\fboxsep-2\\fboxrule}
    \\setlength{\\columnsep}{5mm}
    \\begin{multicols}{3}
    \\raggedright"""
    for frag, orig in zip(fragments, originals):
        if orig:
            kl = kls[original_i]
            in_top = in_tops[original_i]
            original_i += 1
        else:
            kl = 0
        if original_i > 60:
            color = kl_color_mapping(kl)
        else:
            color = (0, 0, 0)
        leading_space = frag[0] == ' '
        trailing_space = frag[-1] == ' '
        newline = False
        if frag == '\n':
            frag = " [\\textbackslash n]"
            newline = True
        if leading_space and len(frag) > 1:
            out_str += ' '
        if frag != ' ':
            out_str += "\\textcolor[RGB]{" + ",".join([str(int(val)) for val in color]) + "}{" + frag
            if in_top:
                out_str += "\\footnote{\\textcolor[RGB]{" + ",".join([str(int(val)) for val in color]) + "}{KL="+ "{:.2f}".format(kl) +"}}}"
            else:
                out_str += "}"

        if trailing_space:
            out_str += ' '
        if newline:
            out_str += """

"""
    out_str += """
\end{multicols}
\end{minipage}"""
    return out_str

def save_or_append_string_list(filename, string_list):
    """
    Save a list of strings to a file using pickle, or append to the existing list if the file exists.
    
    :param filename: The name of the file to save to or append to
    :param string_list: The list of strings to save or append
    """
    if os.path.exists(filename):
        # File exists, so we'll append to it
        with open(filename, 'rb') as file:
            existing_list = pickle.load(file)
        existing_list.extend(string_list)
        with open(filename, 'wb') as file:
            pickle.dump(existing_list, file)
    else:
        # File doesn't exist, so we'll create it
        with open(filename, 'wb') as file:
            pickle.dump(string_list, file)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--filestr', type=str, help='identifier_string')
    parser.add_argument('--id', type=int, help='Run id')
    args = parser.parse_args()

    transcripts_already_generated = False
    use_base_pol = False

    fname = 'ngrams_{}_{}.png'.format(args.id, args.filestr)
    if use_base_pol:
        fname = 'ngrams_basepol_{}.png'.format(args.filestr)

    if not os.path.exists('stdout/'):
        os.makedirs('stdout/')
    file = open("stdout/gentext_{}_{}.txt".format(args.id, args.filestr), 'w')
    sys.stdout = file
    sys.stderr = file

    plain_texts = []
    max_chat_length = 512 if args.filestr.split("_")[-2] == 'cont' else 256

    if not transcripts_already_generated:
        model_path = "PPO_checkpoints/Chat/PPO_Chat_{}_0_{}.pth".format(args.id, args.filestr)

        n_envs = 8
        base_pol_temp = 1.0
        env_temp = 0.05
        model, tokenizer = get_model_and_tokenizer()
        sentiment_pipeline = get_sentiment_pipeline()
        batchrunllm = BatchRunLLM(model, n_envs)

        print("id", args.id)
        print("max_chat_length", max_chat_length)
        envs = [ChatEnv(init_prompt, "Teacher", ["Student"], model, tokenizer, sentiment_pipeline, batchrunllm, base_pol_temp=base_pol_temp, env_temp=env_temp, max_chat_length=max_chat_length) for i in range(n_envs)]
        env = envs[0]
        # state space dimension
        state_dim = env.observation_space.shape[0]

        # action space dimension
        action_dim = env.action_space.n

        ppo_agent = PPO(state_dim, action_dim, lr_actor=0.0, lr_critic=0.0, gamma=None, K_epochs=None, update_batch_size=None, eps_clip=None, ent_coef=None, max_grad_norm=None, gae_lam=None, dtype=torch.float32)
        ppo_agent.load(model_path)

        kl_budget = ppo_agent.last_used_episode_kl_budget
        if use_base_pol:
            kl_budget = 0

        many_nths = []
        max_n = 10
        ns, ngrams, appear_num, kl_for_lineplots, kls_by_tokens, kls_by_tokens_whole, speaking_records, speaking_records_whole, base_pol_entss  = [], [], [], [], [], [], [], [], []
        all_outputs = []
        for i in range(1):
            outputs = asyncio.run(run_episodes(ppo_agent, envs, kl_budget, verbose=True))
            all_outputs = all_outputs + outputs
            for output in outputs:
                t, chat, rewards_by_token, kls_by_token, kls, log_ratios_by_token, log_ratios, speaking_record, base_pol_ents = output
                end_index = cut_to_end_token(torch.tensor(chat))
                kls_by_tokens_whole, speaking_records_whole = kls_by_tokens_whole + [kls_by_token], speaking_records_whole + [speaking_record]
                chat, rewards_by_token, kls_by_token, log_ratios_by_token, speaking_record = chat[:end_index], rewards_by_token[:end_index], kls_by_token[:end_index], log_ratios_by_token[:end_index], speaking_record[:end_index]
                kls_by_tokens, speaking_records = kls_by_tokens + [kls_by_token], speaking_records + [speaking_record]
                pointers, nths, words, is_speaking, new_kls = find_n_grams_from_tokens(torch.tensor(chat), max_n=max_n, kls=kls_by_token)
                many_nths.append(nths)
                ns_, ngrams_, appear_num_, kl_for_lineplots_ = ngram_kls(pointers, nths, words, is_speaking, new_kls)
                ns, ngrams, appear_num, kl_for_lineplots = ns + ns_, ngrams + ngrams_, appear_num + appear_num_, kl_for_lineplots + kl_for_lineplots_
                base_pol_entss.append(base_pol_ents)
                plain_texts.append(tokenizer.decode(torch.tensor(chat), skip_special_tokens=True))
        with open('pickled_outputs/{}_{}.pkl'.format(args.id, args.filestr), 'wb') as file:
            pickle.dump(all_outputs, file)
    if transcripts_already_generated:
        with open('pickled_outputs/{}_{}.pkl'.format(args.id, args.filestr), 'rb') as file:
            outputs = pickle.load(file)
        many_nths = []
        max_n = 10
        ns, ngrams, appear_num, kl_for_lineplots, kls_by_tokens, kls_by_tokens_whole, speaking_records, speaking_records_whole, base_pol_entss  = [], [], [], [], [], [], [], [], []
        for output in outputs:
            t, chat, rewards_by_token, kls_by_token, kls, log_ratios_by_token, log_ratios, speaking_record, base_pol_ents = output
            end_index = cut_to_end_token(torch.tensor(chat))
            kls_by_tokens_whole, speaking_records_whole = kls_by_tokens_whole + [kls_by_token], speaking_records_whole + [speaking_record]
            chat, rewards_by_token, kls_by_token, log_ratios_by_token, speaking_record = chat[:end_index], rewards_by_token[:end_index], kls_by_token[:end_index], log_ratios_by_token[:end_index], speaking_record[:end_index]
            kls_by_tokens, speaking_records = kls_by_tokens + [kls_by_token], speaking_records + [speaking_record]
            pointers, nths, words, is_speaking, new_kls = find_n_grams_from_tokens(torch.tensor(chat), max_n=max_n, kls=kls_by_token)
            many_nths.append(nths)
            ns_, ngrams_, appear_num_, kl_for_lineplots_ = ngram_kls(pointers, nths, words, is_speaking, new_kls)
            ns, ngrams, appear_num, kl_for_lineplots = ns + ns_, ngrams + ngrams_, appear_num + appear_num_, kl_for_lineplots + kl_for_lineplots_
            base_pol_entss.append(base_pol_ents)
            plain_texts.append(tokenizer.decode(torch.tensor(chat), skip_special_tokens=True))
        print("Chat length = ", len(chat))
    budg = "20" if args.id in [3, 4, 5] else "10" if args.id in [0, 1, 2] else None
    save_or_append_string_list(f"transcripts/budg_{budg}_len_{max_chat_length}_id_{args.id}.pkl", plain_texts)
    print(latex_color_legend())
    print(latex_colored_text(torch.tensor(chat), kls_by_token, kl_color_mapping))
    plot_quantile_kl_per_empty_response(kls_by_tokens_whole, speaking_records_whole, args.id, budg, 'kl_quantiles_empty_replies_{}_{}.png'.format(args.id, args.filestr))
    plot_kl_per_empty_response(kls_by_tokens_whole, speaking_records_whole, args.id, budg, 'kl_per_empty_reply_{}_{}.png'.format(args.id, args.filestr))

