# External Modules
import os
import argparse

argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("uuid", type=str, help="UUID of the model to self train")
argument_parser.add_argument("--size", type=str, default="small", help="Size of the model to self train")
argument_parser.add_argument("--batch_size", type=int, default=100, help="Batch size to use for self training")
argument_parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to evaluate")
argument_parser.add_argument("--silent", default=False, action="store_true", help="Be silent")
argument_parser.add_argument("--self_train_start", type=int, default=-1, help="Number of digits to start self training on")
argument_parser.add_argument("--buffer", type=int, default=8, help="Number of digits to buffer")
argument_parser.add_argument("--rank", type=int, default=0, help="Rank of this process")
argument_parser.add_argument("--generate_data", default=False, action="store_true", help="Generate data. Otherwise read it from file")

args = argument_parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.rank)

import ast

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colorbar import ColorbarBase
from matplotlib import cm


from datasets import load_dataset
from torch.utils.data import DataLoader
from IPython import embed
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import wandb
import torch
import logging
import os
import socket
import argparse
import sys
import seaborn as sns

# Internal Modules
import generate_data
from models import AdditionFlanT5
import utils

# wandb.login()
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
host_name = socket.gethostname()



CHECKPOINT_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/".format(host_name) + args.uuid
MODEL = "google/byt5-{}".format(args.size)

logging.basicConfig(
        stream=sys.stdout, 
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')

def clamp(val, minimum=0, maximum=1):
    return max(min(val, maximum), minimum)

def main():

    assert args.self_train_start != -1, "Must specify self_train_start"
    if args.generate_data:
        flash_accs, think_accs = generate_data()
    else:
        flash_accs, think_accs = read_data()
    
    plot_data(flash_accs, "flash")
    plot_data(think_accs, "think")

def plot_data(data_lists, name):
    sns.set_style("whitegrid")
    
    size = args.size
    uuid = args.uuid
    self_train_start_idx = args.self_train_start - 3

    plt.clf()
    
    fig, ax = plt.subplots(figsize=(9, 6))
    
    # Define start and end points within the jet colormap
    start_point = 0.2
    end_point = 0.8

    # Define gradient start and end colors
    # end_blue = (0.8, 0.8, 1.0)  # Light blue
    end_blue = (1.0, 1.0, 1.0)  # Light blue
    start_blue = (0.0, 0.0, 0.8)    # Dark blue

    start_red = end_blue
    end_red = (0.8, 0.0, 0.0)     # Dark red
   
    num_lines = len(data_lists)
    
    # Plot from 0 to self_train_start_idx
    for i, data in enumerate(data_lists[:self_train_start_idx]):
        fraction = i / self_train_start_idx
        line_color = [start_blue[j] + fraction * (end_blue[j] - start_blue[j]) for j in range(3)]
        ax.plot(range(args.buffer + 1), data[i + 2:], color=line_color, linewidth=2)

    # Plot from self_train_start_idx to end
    for i, data in enumerate(data_lists[self_train_start_idx:], start=self_train_start_idx):
        fraction = (i - self_train_start_idx) / (num_lines - self_train_start_idx)
        line_color = [start_red[j] + fraction * (end_red[j] - start_red[j]) for j in range(3)]
        
        ax.plot(range(args.buffer + 1), data[i + 2:], color=line_color, linewidth=2)

 
    ax.set_xlim(0, args.buffer)
    ax.set_ylim(0, 1.03)  # Assuming accuracy is between 0 and 1
    ax.set_xlabel('Generalization beyond training', fontsize=18)
    ax.set_ylabel('Accuracy', fontsize=18)

    ax.set_xticks(range(args.buffer + 1))
    ax.set_xticklabels(["+" + str(i) for i in range(args.buffer + 1)], fontsize=12)
    ax.tick_params(axis='y', labelsize=12)

    params = "582M" if size == "base" else "300M"
    if name == "flash":
        title = "Addition Accuracy w/o CoT ({})".format(params)
    else:
        title = "Addition Accuracy w/ CoT ({})".format(params)

    ax.set_title(title, size=18)
    
    # Make the grid visible and slightly transparent
    ax.grid(True, zorder=3, alpha=0.6)

    # Compute the proportion for self_train_start_idx
    proportion_for_self_train = self_train_start_idx / num_lines

    # Update the colormap
    cdict = {
        "red": [(0, start_blue[0], start_blue[0]),
                (self_train_start_idx/num_lines, end_blue[0], end_blue[0]),
                (self_train_start_idx/num_lines, start_red[0], start_red[0]),
                (1, end_red[0], end_red[0])],
        
        "green": [(0, start_blue[1], start_blue[1]),
                (self_train_start_idx/num_lines, end_blue[1], end_blue[1]),
                (self_train_start_idx/num_lines, start_red[1], start_red[1]),
                (1, end_red[1], end_red[1])],
        
        "blue": [(0, start_blue[2], start_blue[2]),
                (self_train_start_idx/num_lines, end_blue[2], end_blue[2]),
                (self_train_start_idx/num_lines, start_red[2], start_red[2]),
                (1, end_red[2], end_red[2])]
    } 

    # Create the colormap using the color dictionary
    cmap = LinearSegmentedColormap("custom", cdict)

    cax = fig.add_axes([0.88, 0.1, 0.02, 0.8])
    cb = ColorbarBase(cax, cmap=cmap, orientation='vertical')
    cb.set_label('Digits', size=18)

    # Setting the ticks
    cb.set_ticks([0, proportion_for_self_train, 1])
    cb.set_ticklabels(['3', str(self_train_start_idx + 3), str(num_lines + 3)], fontsize=12)

    # Adjust layout
    plt.tight_layout(rect=[0, 0, 0.87, 1])

    plt.savefig("plots/final_plots/{}_{}.pdf".format(uuid, name), dpi=500)
    plt.savefig("plots/final_plots/{}_{}.png".format(uuid, name))

def generate_data():
    device = "gpu"


    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    tokenizer.add_special_tokens({"additional_special_tokens": ["<scratchpad>", "</scratchpad>"] + tokenizer.additional_special_tokens})
    trainer = Trainer(accelerator=device, default_root_dir="logs/self_train", enable_progress_bar=False)
    print("Checkpoint dir: {}".format(CHECKPOINT_DIR))

    num_digits = 3
    lightning_model = AdditionFlanT5(type="decomp", silent=args.silent)
    lightning_model.initialize(lightning_model.model, tokenizer, None)

    think_accs, flash_accs = [], []

    # Create files
    with open("plots/{}_flash.txt".format(args.uuid), "w") as f:
        pass

    with open("plots/{}_think.txt".format(args.uuid), "w") as f:
        pass

    while True:

        traintype = "supervised" if num_digits < args.self_train_start else "selftrain"
        # model_location = CHECKPOINT_DIR + "/model-supervised-{}-digits.ckpt".format(num_digits)
        model_location = CHECKPOINT_DIR + "/model-{}-{}-digits.ckpt".format(traintype, num_digits)
        # model_location = CHECKPOINT_DIR + "/model-failsafe.ckpt"

        if not os.path.exists(model_location):
            print("No {} digit model found, stopping".format(num_digits))
            break

        lightning_model.set_cur_num_digits(num_digits)
        lightning_model.model = AutoModelForSeq2SeqLM.from_pretrained(model_location)

        perfect_acc, accs = utils.verify_accuracy(trainer, lightning_model, num_digits+args.buffer, "decomp", batch_size=args.batch_size, flash=True, silent=args.silent, num_samples=args.num_samples)

        think_acc, flash_acc = accs[:num_digits+args.buffer], accs[num_digits+args.buffer:]
        think_accs.append(think_acc)
        flash_accs.append(flash_acc)
        print("Num digits: {}".format(num_digits))
        print("Think accs: {}".format(think_accs[-1]))
        print("Flash accs: {}".format(flash_accs[-1]))

        # Write data to file
        with open("plots/{}_flash.txt".format(args.uuid), "a") as f:
            f.write("{}\n".format(str(flash_acc.tolist())))

        with open("plots/{}_think.txt".format(args.uuid), "a") as f:
            f.write("{}\n".format(str(think_acc.tolist())))
        
        num_digits += 1

    return flash_accs, think_accs

def read_data():

    with open("plots/{}_flash.txt".format(args.uuid), "r") as f:
        flash_accs = f.readlines()
        flash_accs = [ast.literal_eval(accs) for accs in flash_accs]
    
    with open("plots/{}_think.txt".format(args.uuid), "r") as f:
        think_accs = f.readlines()
        think_accs = [ast.literal_eval(accs) for accs in think_accs]
    
    return flash_accs, think_accs





if __name__ == "__main__":
    main()