# python3 plot_supervised_generalization.py --all
# This generates the plot in fig:supervised_generalization

# External Modules
import os
import argparse

argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--uuid", default=None, type=str, help="UUID of the model to self train")
argument_parser.add_argument("--all", default=False, action="store_true", help="Plot all models")
argument_parser.add_argument("--size", type=str, default="small", help="Size of the model to self train")
argument_parser.add_argument("--batch_size", type=int, default=100, help="Batch size to use for self training")
argument_parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to evaluate")
argument_parser.add_argument("--silent", default=False, action="store_true", help="Be silent")
argument_parser.add_argument("--rank", type=int, default=0, help="Rank of this process")
argument_parser.add_argument("--generate_data", default=False, action="store_true", help="Generate data. Otherwise read it from file")

args = argument_parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.rank)

import ast

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colorbar import ColorbarBase
import seaborn as sns

from datasets import load_dataset
from torch.utils.data import DataLoader
from IPython import embed
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import wandb
import torch
import logging
import os
import socket
import argparse
import sys

# Internal Modules
import generate_data
from models import AdditionFlanT5
import utils

# wandb.login()
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
host_name = socket.gethostname()



if args.uuid:
    CHECKPOINT_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/".format(host_name) + args.uuid
MODEL = "google/byt5-{}".format(args.size)

logging.basicConfig(
        stream=sys.stdout, 
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')

uuids = ["ca00221d-07e4-402c-a78a-bf5c740a5535", "9aa66935-c664-4dbb-96f6-c1a0ac464325"]

def main():

    assert args.all or args.uuid is not None, "Must specify either --all or --uuid"

    if not args.all:
        if args.generate_data:
            flash_accs, think_accs = generate_data()
        else:
            flash_accs, think_accs = read_data(args.uuid)
    
        setup()

        plot_data(flash_accs, think_accs)
        plt.savefig("plots/final_plots/supervised_{}.tiff".format(args.uuid), dpi=400)
    
    else:

        assert not args.generate_data, "Cannot generate data for all models"

        setup()

        for uuid in uuids:
            flash_accs, think_accs = read_data(uuid)
            plot_data(flash_accs, think_accs, uuid)

        plt.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=14)

        # Adjust the layout to make room for the legend
        plt.tight_layout()
        plt.savefig("plots/final_plots/supervised_all.pdf", dpi=1200)
        plt.savefig("plots/final_plots/supervised_all.png")

def setup():
    
    # Set Seaborn theme and context for a professional look
    sns.set_theme(style="whitegrid")
    sns.set_context("talk")

    # Create the plot
    plt.figure(figsize=(10,6))

    # Customizing the plot
    plt.title('Length Generalization', fontsize=18)
    plt.xlabel('Max Length Seen In Supervised Training', fontsize=16)
    plt.ylabel('Length (N+1) Generalization Accuracy', fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.ylim(-0.03, 1.03)
    plt.xlim(3, 10)
     # Move the legend outside the plot on the right
   
    
def plot_data(flash_accs, think_accs, uuid):

    x = np.arange(3, len(flash_accs)+3)

    params = "300M" if uuid == "9aa66935-c664-4dbb-96f6-c1a0ac464325" else "582M"
    color = "purple" if uuid == "9aa66935-c664-4dbb-96f6-c1a0ac464325" else "orange"

    # Plotting both lines
    sns.lineplot(x=x, y=think_accs, label="{} w/ CoT".format(params), linewidth=2.5, color=color)
    sns.lineplot(x=x, y=flash_accs, label="{} w/o CoT".format(params), linewidth=2.5, linestyle="--", color=color)


    # sns.despine(left=True, bottom=True)

    # Display the plot


def generate_data():
    device = "gpu"

    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    tokenizer.add_special_tokens({"additional_special_tokens": ["<scratchpad>", "</scratchpad>"] + tokenizer.additional_special_tokens})
    trainer = Trainer(accelerator=device, default_root_dir="logs/self_train", enable_progress_bar=False)
    print("Checkpoint dir: {}".format(CHECKPOINT_DIR))

    num_digits = 3
    lightning_model = AdditionFlanT5(type="decomp", silent=args.silent, accuracy_check_threshold=np.inf)
    lightning_model.initialize(lightning_model.model, tokenizer, None)

    think_accs, flash_accs = [], []

    # Create files
    with open("plots/supervised_{}_flash.txt".format(args.uuid), "w") as f:
        pass

    with open("plots/supervised_{}_think.txt".format(args.uuid), "w") as f:
        pass

    while True:

        traintype = "supervised" 
        # model_location = CHECKPOINT_DIR + "/model-supervised-{}-digits.ckpt".format(num_digits)
        model_location = CHECKPOINT_DIR + "/model-{}-{}-digits.ckpt".format(traintype, num_digits)
        # model_location = CHECKPOINT_DIR + "/model-failsafe.ckpt"

        if not os.path.exists(model_location):
            print("No {} digit model found, stopping".format(num_digits))
            break

        lightning_model.set_cur_num_digits(num_digits)
        lightning_model.model = AutoModelForSeq2SeqLM.from_pretrained(model_location)

        perfect_acc, accs = utils.verify_accuracy(trainer, lightning_model, num_digits+1, "decomp", batch_size=args.batch_size, digit_only=True, flash=True, silent=args.silent)

        think_acc = accs[1]
        flash_acc = accs[2]

        think_accs.append(think_acc)
        flash_accs.append(flash_acc)
        print("Num digits: {}".format(num_digits))
        print("Think accs: {}".format(think_accs[-1]))
        print("Flash accs: {}".format(flash_accs[-1]))
      
        num_digits += 1
        
    # Write data to file
    with open("plots/supervised_{}_flash.txt".format(args.uuid), "a") as f:
        f.write("{}\n".format(str(flash_accs)))

    with open("plots/supervised_{}_think.txt".format(args.uuid), "a") as f:
        f.write("{}\n".format(str(think_accs)))

    return flash_accs, think_accs

def read_data(uuid):

    with open("plots/supervised_{}_flash.txt".format(uuid), "r") as f:
        flash_accs = ast.literal_eval(f.readline())
    
    with open("plots/supervised_{}_think.txt".format(uuid), "r") as f:
        think_accs = ast.literal_eval(f.readline())
    
    return flash_accs, think_accs





if __name__ == "__main__":
    main()