import argparse
import glob
import json
import os
import random
from typing import List
from concurrent.futures import ProcessPoolExecutor
from functools import partial

import numpy as np
import requests
import sentencepiece as spm
import torch
import torch.distributed as dist
from tqdm import tqdm

from tokenizer import Tokenizer
from random import shuffle
from collections import Counter

import math
import os
import time
from contextlib import nullcontext
from datetime import datetime
from functools import partial

import torch
from model import Transformer, ModelArgs
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP

from tinystories import Task
from export import model_export

import torch._dynamo
torch._dynamo.config.suppress_errors = True

import time 

import seaborn as sns
import matplotlib.pylab as plt
from sklearn.preprocessing import normalize


# -----------------------------------------------------------------------------
n_heads = 6
n_kv_heads = 6
multiple_of = 32
max_seq_len = 6
dropout = 0.0
pos_enc = "off"
context_length = 6

DATA_PROCESS_DIR = "./TinyStories/TinyStories_processing_files"
GRAPH_DIR = "./graphs"
LOG_DIR = "./"


from tqdm import tqdm
from tinystories import PretokDataset

vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained


from tinystories import get_tokenizer_model_path

def graph_full(epoch, version, sort_by = "freq"):
    
    ds = PretokDataset("train", max_seq_len, vocab_size, vocab_source)
    

    tokenizer_model = get_tokenizer_model_path(vocab_size)
    enc = Tokenizer(tokenizer_model)

    context_to_support_sets = {}
    for X,Y in tqdm(ds):
        context_tuple = tuple(X.tolist())

        if context_tuple in context_to_support_sets.keys():
            context_to_support_sets[context_tuple]["count"] += 1
            context_to_support_sets[context_tuple]["support"][Y] = 1
        else:
            context_to_support_sets[context_tuple] = {}
            context_to_support_sets[context_tuple]["count"] = 1
            context_to_support_sets[context_tuple]["support"] = np.zeros(vocab_size)
            context_to_support_sets[context_tuple]["support"][Y.item()] = 1
            context_to_support_sets[context_tuple]["tokens"] = enc.sp_model.decode(list(context_tuple))

    support_counts = {}
    support_set_to_context = {}
    for context in tqdm(list(context_to_support_sets.keys())):
        if tuple(context_to_support_sets[context]["support"]) in set(support_counts.keys()):
            support_counts[tuple(context_to_support_sets[context]["support"])] += 1
            support_set_to_context[tuple(context_to_support_sets[context]["support"])].append(context)
        else:
            support_counts[tuple(context_to_support_sets[context]["support"])] = 1
            support_set_to_context[tuple(context_to_support_sets[context]["support"])] = [context]
    
    if sort_by == "sup_freq":
        sorted_support_set_to_context = {k: v for k, v in sorted(support_set_to_context.items(), key=lambda item: len(item[1]))}
    elif sort_by == "sup_size":
        sorted_support_set_to_context = {k: v for k, v in sorted(support_set_to_context.items(), key=lambda item: sum(item[0]))}
    elif sort_by == "both":
        sorted_support_set_to_context = {k: v for k, v in sorted(support_set_to_context.items(), key=lambda item: len(item[1]) + sum(item[0]))}



    list_target_contexts = []
    for sup_set in list(sorted_support_set_to_context.keys())[::-1][0:1000]:
        if len(list_target_contexts) >= 100:
            break
        
        if sum(list(sup_set)) <= 2:
            continue 

        sup_set_contexts_list = sorted_support_set_to_context[sup_set]

        if len(sup_set_contexts_list) < 10:
            continue

        sorted_sup_set_contexts_list = sorted(sup_set_contexts_list, key=lambda x: context_to_support_sets[x]['count'])
        for context in sorted_sup_set_contexts_list[::-1][0:10]:
            list_target_contexts.append(context)
    
    
    X_target = torch.zeros((len(list_target_contexts), 6))
    S_target = torch.zeros((len(list_target_contexts), vocab_size))
    H_target = torch.zeros((len(list_target_contexts), dim))
    Z_target = torch.zeros((len(list_target_contexts), vocab_size))
    Context_Target = []
    for i in range(0, len(list_target_contexts)):
        i_th_context = list_target_contexts[i]
        Context_Target.append(enc.sp_model.decode(list(i_th_context)))
        X_target[i,:] = torch.tensor(i_th_context)
        S_target[i,:] = torch.tensor(context_to_support_sets[i_th_context]["support"])
    X_target = X_target.to("cpu").long()

    gram_S = S_target @ S_target.T
    

    gptconf = ModelArgs(**model_args)
    model = Transformer(gptconf)
    dict = torch.load(LOG_DIR + "logs_V" + str(version) + "/model_it" + str(epoch) + ".pth", map_location=torch.device('cpu'))
    new_dict = {k[10:]:dict[k] for k in list(dict.keys())}
    model.load_state_dict(new_dict)
    model.eval()
    logits, h = model(X_target, None)
    Z_target = torch.reshape(logits, (len(list_target_contexts), vocab_size)).detach().to("cpu")
    H_target = torch.reshape(h[:,-1,:], (len(list_target_contexts), dim)).detach().to("cpu")
    
    gram_Z = Z_target @ Z_target.T
    gram_H = H_target @ H_target.T


    #######################################################################################################################################################################

    # Create subplots
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # Plot heatmap for matrix1
    sns.heatmap(gram_S, ax=axs[0], cbar=True)

    # Plot heatmap for matrix2
    sns.heatmap(gram_H, ax=axs[1], cbar=True)

    font = { 'color': 'black', 'weight': 'normal', 'size': 12}
    axs[0].set_xticks([i for i in range(0, 100) if i % 10 == 0], [str(i) for i in range(0, 100) if i % 10 == 0], fontdict=font)
    axs[0].set_yticks([i for i in range(0, 100) if i % 10 == 0], [str(i) for i in range(0, 100) if i % 10 == 0], fontdict=font)

    font = { 'color': 'black', 'weight': 'normal', 'size': 12}
    axs[1].set_xticks([i for i in range(0, 100) if i % 10 == 0], [str(i) for i in range(0, 100) if i % 10 == 0], fontdict=font)
    axs[1].set_yticks([i for i in range(0, 100) if i % 10 == 0], [str(i) for i in range(0, 100) if i % 10 == 0], fontdict=font)

    plt.tight_layout()
    plt.savefig(GRAPH_DIR + "/Paper_Gram_plots_V" + str(version) + "_Epoch" + str(epoch)  + "_voc" + str(vocab_size) + "_dim" + str(dim) + ".jpg")
    plt.show()
    plt.clf()



def plot_loss(epoch, version):
    #######################################################################################################################################################################

    entropy = 0.31153

    # Create subplots
    fig, axs = plt.subplots(1, 1, figsize=(9, 6))

    list_of_files = os.listdir(LOG_DIR + "logs_V" + str(version))
    epochs_list = []
    for file_name in list_of_files:
        if "model" in file_name:
            epoch = int(file_name.split(".")[0][8:])
            epochs_list.append(epoch)
    epochs_list.sort()

    with open(LOG_DIR + "logs_V" + str(version) + "/loss_list.npy", 'rb') as f:
        total_loss_per_batch_ave_list = np.load(f)
    
    loss_samples = [total_loss_per_batch_ave_list[epoch] - entropy for epoch in epochs_list]
        
    axs.plot(epochs_list, loss_samples, linewidth=5)
    axs.set_yscale("log")

    axs.tick_params(axis='both', which='major', labelsize=20)

    plt.tight_layout()
    plt.savefig(GRAPH_DIR + "Paper_Loss_plots_V" + str(version) + "_Epoch" + str(epoch)  + "_voc" + str(vocab_size) + "_dim" + str(dim) + ".jpg")
    plt.show()
    plt.clf()

    #######################################################################################################################################################################



dim = 256
n_layers = 12
epoch = 598
version = 22
vocab_size = 64 

model_args = dict(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_kv_heads,
    vocab_size=vocab_size,
    multiple_of=multiple_of,
    max_seq_len=max_seq_len,
    dropout=dropout,
    pos_enc=pos_enc,
)  



graph_full(epoch, version, "sup_size")
plot_loss(epoch, version)
