import argparse
import os
from sklearn.metrics.pairwise import cosine_similarity
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import concurrent.futures
import faiss 
import matplotlib.pyplot as plt
import numpy as np
import datasets
from positional_embeddings import PositionalEmbedding
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

L2_DISTANCE_THRESHOLD=1
RATIO_THRESHOLD=1.5
""" def get_fid(real_features,generated_features):
    fid_score = FIDScore()
    fid_value = fid_score.calculate_fid_score(real_features, generated_features)
    return fid_value """

class MinL2Distances:
    def __init__(self, epoch):
        self.data = []  # List to store L2 distances
        self.epoch = epoch  # Epoch information
        self.ratio = []
class MemorizedPoints:
    def __init__(self, epoch):
        self.ratio = []
        self.dist = []
        self.epoch= epoch
def process_frame(frame, max_distance,fidx,index,index_test):
    l2_inliers= MinL2Distances(fidx*config.save_images_step)
    l2_outliers= MinL2Distances(fidx*config.save_images_step)
    l2_inliers_test= MinL2Distances(fidx*config.save_images_step)
    l2_outliers_test= MinL2Distances(fidx*config.save_images_step)
    l2_memo_inliers = MemorizedPoints(fidx*config.save_images_step)
    l2_memo_outliers= MemorizedPoints(fidx*config.save_images_step)
    l2_not_memo_outliers= MemorizedPoints(fidx*config.save_images_step)
    l2_not_memo_inliers = MemorizedPoints(fidx*config.save_images_step)
    l2_dup= MinL2Distances(fidx*config.save_images_step)
    l2_dup_test=MinL2Distances(fidx*config.save_images_step)
    l2_memo_dup=MemorizedPoints(fidx*config.save_images_step)
    l2_not_memo_dup= MemorizedPoints(fidx*config.save_images_step)
    progress_bar = tqdm(total=len(frame))
    progress_bar.set_description(f"Frame {fidx}")
    min_index=0
    for idy,point in enumerate(frame):
        min= 1
        min_test=1
        min_point=None
        progress_bar.update(1)
        #Can be changed to FAISS
        query_point = np.array(point, dtype=np.float32)  # Take the first point from frames
        query_point = query_point.reshape(1, -1) 
        min, min_index = index.search(query_point, 1)
        min = min[0][0]
        min_index=min_index[0][0]
        query_point_test = np.array(point, dtype=np.float32)  # Take the first point from frames
        query_point_test = query_point_test.reshape(1, -1) 
        min_test, min_index_test = index_test.search(query_point_test, 1)
        min_test = min_test[0][0]
        min_index_test=min_index_test[0][0]
        data_label=fullset[2][min_index]
        data_label_test=fullset[4][min_index_test]
        """ for t_idx,t_point in enumerate(fullset[1]):
            l2=np.linalg.norm(t_point - point)/max_distance
            if min>=l2:
                min=l2
                min_index=t_idx
                min_point=t_point
                #Getting Appropriate Label For This Min Data Point
                data_label=fullset[2][t_idx]
                #Comparisons For Test Dataset    
            l2_test=np.linalg.norm(fullset[3][t_idx] - point)/max_distance
            if min_test>=l2_test:
                min_test=l2_test
                #Getting Appropriate Label For This Min Test Data Point
                data_label_test=fullset[4][t_idx] """
        mem_ratio= min_test/min
        """ print(f"==>> min: {min}")
        print(f"==>> min_index: {min_index}")
        print(f"==>> point: {point}")
        print(f"==>> closest point: {min_point}")
        print(f"==>> point_index: {idy}") """
        #if point is memorized
        if min<L2_DISTANCE_THRESHOLD and mem_ratio>RATIO_THRESHOLD:
            if data_label == "Outlier":
                l2_memo_outliers.dist.append(min)
                l2_memo_outliers.ratio.append(mem_ratio)
            elif data_label == "Duplicates":
                l2_memo_dup.dist.append(min)
                l2_memo_dup.ratio.append(mem_ratio)
            else:
                l2_memo_inliers.dist.append(min)                        
                l2_memo_inliers.ratio.append(mem_ratio)
        else:
            if data_label == "Outlier":
                l2_not_memo_outliers.dist.append(min)
                l2_not_memo_outliers.ratio.append(mem_ratio)
            elif data_label == "Duplicates":
                l2_not_memo_dup.dist.append(min)
                l2_not_memo_dup.ratio.append(mem_ratio)
            else:
                l2_not_memo_inliers.dist.append(min)
                l2_not_memo_inliers.ratio.append(mem_ratio)
        #is the trainig point closest an outlier?
        if data_label == "Outlier":
            l2_outliers.data.append(min)
            #print(f"==>> point: {point}")
            #print(f"==>> min is_outlier: {min}")
            l2_outliers.ratio.append(mem_ratio)
        elif data_label == "Duplicates":
            l2_dup.data.append(min)
            l2_dup.ratio.append(mem_ratio)
        else:
            l2_inliers.data.append(min)    
            l2_inliers.ratio.append(mem_ratio)    
        #is the test point closest an outlier?            
        if data_label_test == "Outlier":
            l2_outliers_test.data.append(min_test)
            #print(f"==>> point: {point}")
            #print(f"==>> min_test is_outlier_test: {min_test}")
            l2_outliers_test.ratio.append(mem_ratio)
        elif data_label_test == "Duplicates":
            l2_dup_test.data.append(min)
            l2_dup_test.ratio.append(mem_ratio)                
        else:
            l2_inliers_test.data.append(min_test)
            l2_inliers_test.ratio.append(mem_ratio)    
    """   test_fid=get_fid(fullset[3],frame)
    train_fid=get_fid(fullset[1],frame) """
    progress_bar.close()
    return l2_inliers,l2_outliers,l2_inliers_test,l2_outliers_test,l2_memo_inliers,l2_memo_outliers,l2_not_memo_inliers,l2_not_memo_outliers,l2_dup,l2_dup_test,l2_memo_dup,l2_not_memo_dup

#Plotter
def my_plotter(sim_dist,epoch,title):
    bin_num= 20
    plt.close()
    frequencies, bins = np.histogram(sim_dist, bins=bin_num)
# Plot the histogram
    plt.hist(sim_dist,bins=bin_num)
    plt.xticks(bins, fontsize=5)
    #plt.scatter(bins[:-1], frequencies, width=(bins[1]-bins[0]))
    plt.xlabel('L2 Distance')
    plt.ylabel('Frequency')
    plt.title("Frequecy Of Each L2 Distance")
    plt.grid(True)
    plt.savefig(f"exps/{config.experiment_name}/images/{title}-L2@Epoch{epoch}.png")


#END OF ARYAN CODE
class Block(nn.Module):
    def __init__(self, size: int):
        super().__init__()

        self.ff = nn.Linear(size, size)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor):
        return x + self.act(self.ff(x))


class MLP(nn.Module):
    def __init__(self, hidden_size: int = 128, hidden_layers: int = 3, emb_size: int = 128,
                 time_emb: str = "sinusoidal", input_emb: str = "sinusoidal"):
        super().__init__()

        self.time_mlp = PositionalEmbedding(emb_size, time_emb)
        self.input_mlp1 = PositionalEmbedding(emb_size, input_emb, scale=25.0)
        self.input_mlp2 = PositionalEmbedding(emb_size, input_emb, scale=25.0)

        concat_size = len(self.time_mlp.layer) + \
            len(self.input_mlp1.layer) + len(self.input_mlp2.layer)
        layers = [nn.Linear(concat_size, hidden_size), nn.GELU()]
        for _ in range(hidden_layers):
            layers.append(Block(hidden_size))
        layers.append(nn.Linear(hidden_size, 2))
        self.joint_mlp = nn.Sequential(*layers)

    def forward(self, x, t):
        x1_emb = self.input_mlp1(x[:, 0])
        x2_emb = self.input_mlp2(x[:, 1])
        t_emb = self.time_mlp(t)
        x = torch.cat((x1_emb, x2_emb, t_emb), dim=-1)
        x = self.joint_mlp(x)
        return x


class NoiseScheduler():
    def __init__(self,
                 num_timesteps=1000,
                 beta_start=0.0001,
                 beta_end=0.02,
                 beta_schedule="linear"):

        self.num_timesteps = num_timesteps
        if beta_schedule == "linear":
            self.betas = torch.linspace(
                beta_start, beta_end, num_timesteps, dtype=torch.float32)
        elif beta_schedule == "quadratic":
            self.betas = torch.linspace(
                beta_start ** 0.5, beta_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(
            self.alphas_cumprod[:-1], (1, 0), value=1.)

        # required for self.add_noise
        self.sqrt_alphas_cumprod = self.alphas_cumprod ** 0.5
        self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5

        # required for reconstruct_x0
        self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod)
        self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt(
            1 / self.alphas_cumprod - 1)

        # required for q_posterior
        self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1. - self.alphas_cumprod)

    def reconstruct_x0(self, x_t, t, noise):
        s1 = self.sqrt_inv_alphas_cumprod[t]
        s2 = self.sqrt_inv_alphas_cumprod_minus_one[t]
        s1 = s1.reshape(-1, 1)
        s2 = s2.reshape(-1, 1)
        return s1 * x_t - s2 * noise

    def q_posterior(self, x_0, x_t, t):
        s1 = self.posterior_mean_coef1[t]
        s2 = self.posterior_mean_coef2[t]
        s1 = s1.reshape(-1, 1)
        s2 = s2.reshape(-1, 1)
        mu = s1 * x_0 + s2 * x_t
        return mu

    def get_variance(self, t):
        if t == 0:
            return 0

        variance = self.betas[t] * (1. - self.alphas_cumprod_prev[t]) / (1. - self.alphas_cumprod[t])
        variance = variance.clip(1e-20)
        return variance

    def step(self, model_output, timestep, sample):
        t = timestep
        pred_original_sample = self.reconstruct_x0(sample, t, model_output)
        pred_prev_sample = self.q_posterior(pred_original_sample, sample, t)

        variance = 0
        if t > 0:
            noise = torch.randn_like(model_output)
            variance = (self.get_variance(t) ** 0.5) * noise

        pred_prev_sample = pred_prev_sample + variance

        return pred_prev_sample

    def add_noise(self, x_start, x_noise, timesteps):
        s1 = self.sqrt_alphas_cumprod[timesteps]
        s2 = self.sqrt_one_minus_alphas_cumprod[timesteps]

        s1 = s1.reshape(-1, 1)
        s2 = s2.reshape(-1, 1)

        return s1 * x_start + s2 * x_noise

    def __len__(self):
        return self.num_timesteps


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--experiment_name", type=str, default="base")
    parser.add_argument("--dataset", type=str, default="dino", choices=["circle", "dino", "line", "moons"])
    parser.add_argument("--train_batch_size", type=int, default=32)
    parser.add_argument("--eval_batch_size", type=int, default=1000)
    parser.add_argument("--num_epochs", type=int, default=200)
    parser.add_argument("--learning_rate", type=float, default=1e-3)
    parser.add_argument("--num_timesteps", type=int, default=50)
    parser.add_argument("--beta_schedule", type=str, default="linear", choices=["linear", "quadratic"])
    parser.add_argument("--embedding_size", type=int, default=128)
    parser.add_argument("--hidden_size", type=int, default=128)
    parser.add_argument("--hidden_layers", type=int, default=3)
    parser.add_argument("--time_embedding", type=str, default="sinusoidal", choices=["sinusoidal", "learnable", "linear", "zero"])
    parser.add_argument("--input_embedding", type=str, default="sinusoidal", choices=["sinusoidal", "learnable", "linear", "identity"])
    parser.add_argument("--save_images_step", type=int, default=1)
    parser.add_argument("--outliers", type=int, default=0)
    parser.add_argument("--dups", type=int, default=0)
    parser.add_argument("--duplicate_dist",type=int, default=0)
    parser.add_argument("--test_outlier",type=bool, default=False)
    config = parser.parse_args()

    #Some Aryan Taking The Dataset For Loading into Processsing
    frames = []

    fullset = datasets.get_dataset(config.dataset,config.outliers,config.dups,config.duplicate_dist,config.test_outlier)
    print(f"==>> config.test_outlier: {config.test_outlier}")
    print(f"==>> fullset: {len(fullset[2])}")
    dataset = fullset[0]
    dataloader = DataLoader(
        dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=True)

    model = MLP(
        hidden_size=config.hidden_size,
        hidden_layers=config.hidden_layers,
        emb_size=config.embedding_size,
        time_emb=config.time_embedding,
        input_emb=config.input_embedding)

    noise_scheduler = NoiseScheduler(
        num_timesteps=config.num_timesteps,
        beta_schedule=config.beta_schedule)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
    )

    global_step = 0
    losses = []
    print("Training model...")
    model.eval()
    sample = torch.randn(config.eval_batch_size, 2)
    timesteps = list(range(len(noise_scheduler)))[::-1]
    for i, t in enumerate(tqdm(timesteps)):
        t = torch.from_numpy(np.repeat(t, config.eval_batch_size)).long()
        with torch.no_grad():
            residual = model(sample, t)
        sample = noise_scheduler.step(residual, t[0], sample)
    frames.append(sample.numpy())
    for epoch in range(config.num_epochs):
        model.train()
        progress_bar = tqdm(total=len(dataloader))
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(dataloader):
            batch = batch[0]
            noise = torch.randn(batch.shape)
            timesteps = torch.randint(
                0, noise_scheduler.num_timesteps, (batch.shape[0],)
            ).long()

            noisy = noise_scheduler.add_noise(batch, noise, timesteps)
            noise_pred = model(noisy, timesteps)
            loss = F.mse_loss(noise_pred, noise)
            loss.backward(loss)

            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            losses.append(loss.detach().item())
            progress_bar.set_postfix(**logs)
            global_step += 1
        progress_bar.close()

        if epoch % config.save_images_step == 0 or epoch == config.num_epochs - 1:
            # generate data with the model to later visualize the learning process
            model.eval()
            sample = torch.randn(config.eval_batch_size, 2)
            timesteps = list(range(len(noise_scheduler)))[::-1]
            for i, t in enumerate(tqdm(timesteps)):
                t = torch.from_numpy(np.repeat(t, config.eval_batch_size)).long()
                with torch.no_grad():
                    residual = model(sample, t)
                sample = noise_scheduler.step(residual, t[0], sample)
            frames.append(sample.numpy())

    print("Saving model...")
    config.experiment_name=f"l2_{config.dataset}_{config.num_epochs}Epochs_{config.outliers}Out_{config.dups}Dups_{L2_DISTANCE_THRESHOLD}MemDist__{config.duplicate_dist}DupGenDist_{RATIO_THRESHOLD}Ratio_{config.eval_batch_size}DatasetSize_{config.test_outlier}Test_outlier_{config.experiment_name}"
    outdir = f"exps/{config.experiment_name}"
    os.makedirs(outdir, exist_ok=True)
    torch.save(model.state_dict(), f"{outdir}/model.pth")

    print("Saving images...")
    imgdir = f"{outdir}/images"
    os.makedirs(imgdir, exist_ok=True)
    frames = np.stack(frames)
    print(f"==>> frames: {frames}")
    xmin, xmax = -6, 6
    ymin, ymax = -6, 6
    for i, frame in enumerate(frames):
        plt.figure(figsize=(10, 10))
        plt.scatter(frame[:, 0], frame[:, 1])
        plt.xlim(xmin, xmax)
        plt.ylim(ymin, ymax)
        plt.savefig(f"{imgdir}/{i:04}")
        plt.close()

    print("Saving loss as numpy array...")
    np.save(f"{outdir}/loss.npy", np.array(losses))

    print("Saving frames...")
    np.save(f"{outdir}/frames.npy", frames)

    #ARYAN ADDED CODE

    #Add Image With The Training Dataset Points
    plt.figure(figsize=(10, 10))
    plt.scatter(fullset[1][:, 0], fullset[1][:, 1])
    plt.xlim(-6, 6)
    plt.ylim(-6, 6)
    plt.savefig(f"exps/{config.experiment_name}/images/train_dataset.png")
    plt.close()
    #Add Image With The Test Dataset Points
    plt.figure(figsize=(10, 10))
    plt.scatter(fullset[3][:, 0], fullset[3][:, 1])
    plt.xlim(-6, 6)
    plt.ylim(-6, 6)
    plt.savefig(f"exps/{config.experiment_name}/images/test_dataset.png")
    plt.close()
    #Calculate Min L2 Dist Per Point as compared to training points, test points
    max_distance =np.linalg.norm(xmax - xmin)
    l2_mins=[]
    l2_ave_inliers,l2_ave_outliers,l2_ave_inliers_test,l2_ave_outliers_test,l2_ave_dups,l2_ave_dups_test = [],[],[],[],[],[]
    l2_var_inliers,l2_var_outliers,l2_var_dups,l2_var_inliers_test,l2_var_outliers_test,l2_var_dups_test=[],[],[],[],[],[]
    l2_epochs = []
    l2_memo_inliers,l2_memo_outliers,l2_not_memo_inliers,l2_not_memo_outliers= [],[],[],[]
    l2_memo_inlier_counts=[]
    l2_memo_outlier_counts=[]
    l2_memo_dups_counts=[]
    l2_not_memo_dups_counts=[]

    l2_inlier_ratios=[]
    l2_outlier_ratios=[]
    l2_inlier_ratios_memo=[]
    l2_outlier_ratios_memo=[]
    l2_dup_ratios=[]
    l2_dup_ratios_memo=[]
    train_fids=[]
    test_fids=[]

    reference_data = np.array(fullset[1], dtype=np.float32)
    reference_data_test = np.array(fullset[3], dtype=np.float32)
    index = faiss.IndexFlatL2(2)
    index_test= faiss.IndexFlatL2(2)
    index.add(reference_data)
    index_test.add(reference_data_test)
    k_nearest = 1  
    print(f"==>> k_nearest: {k_nearest}")
    print(f"==>> frames[1][0]: {frames[1][0]}")
    query_point = np.array(frames[1][0], dtype=np.float32)  # Take the first point from frames
    query_point = query_point.reshape(1, -1) 
    t_point, t_idx = index.search(query_point, k_nearest)
    print(f"==>> t_idx: {t_idx}")
    t_point=t_point/max_distance
    print(f"==>> t_point: {t_point}")
    
    with  concurrent.futures.ProcessPoolExecutor(max_workers=len(frames)) as executor:
        for f_idx,frame in enumerate(frames):        
            l2_mins.append(executor.submit(process_frame, frame, max_distance,f_idx,index,index_test))

    for l2_future in l2_mins:
        l2_cur = l2_future.result()
        l2_inlier,l2_outlier,l2_inlier_test,l2_outlier_test,l2_memo_inliers,l2_memo_outliers,l2_not_memo_inliers,l2_not_memo_outliers,l2_dup,l2_dup_test,l2_memo_dup,l2_not_memo_dup= l2_cur

        l2_epochs.append(l2_inlier.epoch)
        """ train_fids.append(train_fid)
        test_fids.append(test_fid) """
        #AVE L2 Distance/Variance Over Epochs
        if len(l2_inlier.data)!=0:
            l2_ave_inliers.append(sum(l2_inlier.data) / len(l2_inlier.data))
            l2_var_inliers.append(np.var(l2_inlier.data))
        else:
            l2_ave_inliers.append(0)
            l2_var_inliers.append(0)
        if len(l2_outlier.data)!=0:
            #my_plotter(l2_inlier.data,l2_inlier.epoch,"Inliers")
            l2_ave_outliers.append(sum(l2_outlier.data) / len(l2_outlier.data))
            l2_var_outliers.append(np.var(l2_outlier.data))
        else:
            l2_ave_outliers.append(0)
            l2_var_outliers.append(0)
        if len(l2_dup.data)!=0:
            #my_plotter(l2_outlier.data,l2_outlier.epoch,"Outliers")
            l2_ave_dups.append(sum(l2_dup.data) / len(l2_dup.data))
            l2_var_dups.append(np.var(l2_dup.data))
        else:
            l2_ave_dups.append(0)
            l2_var_dups.append(0)

        #Ave L2 Distance Test Over Epochs
        if len(l2_inlier_test.data)!=0:
            l2_ave_inliers_test.append(sum(l2_inlier_test.data) / len(l2_inlier_test.data))
            l2_var_inliers_test.append(np.var(l2_inlier_test.data))              
        else:
            l2_ave_inliers_test.append(0)
            l2_var_inliers_test.append(0)
        if len(l2_outlier_test.data)!=0:
            #my_plotter(l2_inlier_test.data,l2_inlier_test.epoch,"Inlier Test Points")
            l2_ave_outliers_test.append(sum(l2_outlier_test.data) / len(l2_outlier_test.data))
            l2_var_outliers_test.append(np.var(l2_outlier_test.data))
        else:
            l2_ave_outliers_test.append(0)
            l2_var_outliers_test.append(0)
            #my_plotter(l2_outlier_test.data,l2_outlier_test.epoch,"Outlier Test Points")
        if len(l2_dup_test.data)!=0:
            l2_ave_dups_test.append(sum(l2_dup_test.data) / len(l2_dup_test.data))
            l2_var_dups_test.append(np.var(l2_dup_test.data))
        else:
            l2_ave_dups_test.append(0)
            l2_var_dups_test.append(0)
        #Memorized Count
        if len(l2_inlier.data)!=0:
            l2_memo_inlier_counts.append((len(l2_memo_inliers.dist)/len(l2_inlier.data))*100)
        else:
            l2_memo_inlier_counts.append(0)
        if len(l2_outlier.data)!=0:
            l2_memo_outlier_counts.append((len(l2_memo_outliers.dist)/len(l2_outlier.data))*100)
        else:
            l2_memo_outlier_counts.append(0)
        if len(l2_dup.data)!=0:
            l2_memo_dups_counts.append((len(l2_memo_dup.dist)/len(l2_dup.data))*100)
        else:
            l2_memo_dups_counts.append(0)

        #Ratio of L2 Distance(Memorized)
        if len(l2_memo_inliers.ratio)==0:
            l2_inlier_ratios_memo.append(0)
        else:
            l2_inlier_ratios_memo.append(sum(l2_memo_inliers.ratio)/len(l2_memo_inliers.ratio))

        if len(l2_memo_outliers.ratio)==0:
            print(f"==>> l2_memo_outliers.ratio: {l2_memo_outliers.ratio}")
            l2_outlier_ratios_memo.append(0)
        else:
            l2_outlier_ratios_memo.append(sum(l2_memo_outliers.ratio)/len(l2_memo_outliers.ratio))

        if len(l2_memo_dup.ratio)==0:
            l2_dup_ratios_memo.append(0)
        else:
            l2_dup_ratios_memo.append(sum(l2_memo_dup.ratio)/len(l2_memo_dup.ratio))
        #Ratio of L2 Distance(All)
        if len(l2_inlier.ratio)==0:   
            l2_inlier_ratios.append(0)
        else:
            l2_inlier_ratios.append(sum(l2_inlier.ratio)/len(l2_inlier.ratio))

        if len(l2_outlier.ratio)==0:
            l2_outlier_ratios.append(0)
        else:
            l2_outlier_ratios.append(sum(l2_outlier.ratio)/len(l2_outlier.ratio))
            
        if len(l2_dup.ratio)==0:
            l2_dup_ratios.append(0)
        else:
            l2_dup_ratios.append(sum(l2_dup.ratio)/len(l2_dup.ratio))

    #Plotting Line Graphs  
    if len(l2_outlier.data)!=0:
        """         #Plot FID Over Epochs
        fig, ax = plt.subplots()
        ax.plot(l2_epochs, (train_fids), label='Train FID', color='blue', marker='o')
        ax.plot(l2_epochs, (test_fids), label='Test FID', color='green', marker='o')
        ax.set_xlabel('Epochs')
        ax.set_ylabel('FID Over Epochs')
        ax.set_title(config.experiment_name)
        ax.legend()
        plt.grid(True)
        plt.savefig(f"exps/{config.experiment_name}/images/FID_Over_Epochs.png")
        plt.close() """
        #Plot Average L2 Over Epochs for All Points
        fig, ax = plt.subplots()
        ax.plot(l2_epochs, (l2_ave_inliers), label='Inliers', color='blue', marker='o')
        ax.plot(l2_epochs, (l2_ave_inliers_test), label='Inliers Test Set', color='green', marker='o')
        ax.plot(l2_epochs, (l2_ave_outliers), label='Outliers', color='red', marker='o')
        ax.plot(l2_epochs, (l2_ave_outliers_test), label='Outliers Test Set', color='purple', marker='o')
        ax.plot(l2_epochs, (l2_ave_dups), label='Near Duplicates', color='black', marker='o')
        ax.plot(l2_epochs, (l2_ave_dups_test), label='Near Duplicates Test', color='brown', marker='o')
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Average L2-Distance')
        ax.set_title(config.experiment_name)
        ax.legend()
        plt.grid(True)
        plt.savefig(f"exps/{config.experiment_name}/images/Ave_L2_Over_Epochs.png")
        plt.close()
        #Plot Variance of L2 Over Epochs for All Points
        fig, ax = plt.subplots()
        ax.plot(l2_epochs, (l2_var_inliers), label='Inliers', color='blue', marker='o')
        ax.plot(l2_epochs, (l2_var_inliers_test), label='Inliers Test Set', color='green', marker='o')
        ax.plot(l2_epochs, (l2_var_outliers), label='Outliers', color='red', marker='o')
        ax.plot(l2_epochs, (l2_var_outliers_test), label='Outliers Test Set', color='purple', marker='o')
        ax.plot(l2_epochs, (l2_var_dups), label='Near Duplicates', color='black', marker='o')
        ax.plot(l2_epochs, (l2_var_dups_test), label='Near Duplicates Test', color='brown', marker='o')
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Variance of L2-Distance')
        ax.set_title(config.experiment_name)
        ax.legend()
        plt.grid(True)
        plt.savefig(f"exps/{config.experiment_name}/images/Var_L2_Over_Epochs.png")
        plt.close()
        #Plot Memorized Outliers vs Memorized Inliers Count Over Time
        fig, ax = plt.subplots()
        ax.plot(l2_epochs, l2_memo_inlier_counts, label='inliers', color='blue', marker='o')
        ax.plot(l2_epochs, l2_memo_outlier_counts, label='outliers', color='red', marker='o')
        ax.plot(l2_epochs, l2_memo_dups_counts, label='near duplicates', color='black', marker='o')
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Percentage Of Memorized Points Based on Threshold')
        ax.set_title(config.experiment_name)
        plt.grid(True)
        ax.legend()
        plt.savefig(f"exps/{config.experiment_name}/images/MemorizedPointsOverEpochs.png")
        plt.close()
        #Plot Memorized Outliers vs Memorized Inliers Count Over Time Normalized By Control Epoch
        fig, ax = plt.subplots()
        l2_memo_inlier_counts=[count - l2_memo_inlier_counts[0] for count in l2_memo_inlier_counts]
        l2_memo_outlier_counts=[count - l2_memo_outlier_counts[0] for count in l2_memo_outlier_counts]
        l2_memo_dups_counts=[count - l2_memo_dups_counts[0] for count in l2_memo_dups_counts]
        ax.plot(l2_epochs, l2_memo_inlier_counts, label='inliers', color='blue', marker='o')
        ax.plot(l2_epochs, l2_memo_outlier_counts, label='outliers', color='red', marker='o')
        ax.plot(l2_epochs, l2_memo_dups_counts, label='near duplicates', color='black', marker='o')
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Percentage Of Memorized Points Based on Threshold(Normalized by Control Epoch)')
        ax.set_title(config.experiment_name)
        plt.grid(True)
        ax.legend()
        plt.savefig(f"exps/{config.experiment_name}/images/MemorizedPointsOverEpochs(Normalized By Control).png")
        plt.close()
        
        #Ratio between Test and Training L2 distance over L2 Distance(All)
        fig, ax = plt.subplots()
        print(f"==>> l2_epochs: {l2_epochs}")
        print(f"==>> l2_inlier_ratios: {l2_inlier_ratios}")
        ax.plot(l2_epochs, l2_inlier_ratios, label='ave inlier ratios', color='blue', marker='o')
        ax.plot(l2_epochs, l2_outlier_ratios, label='ave outlier ratios', color='red', marker='o')
        ax.plot(l2_epochs, l2_dup_ratios, label='ave near duplicate ratios', color='black', marker='o')
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Test/Train Distance L2 Dist Ratio')
        ax.set_title("Ratios Over Epochs(All)")
        plt.grid(True)
        ax.legend()
        plt.savefig(f"exps/{config.experiment_name}/images/Ratios Over Epochs(All).png")
        plt.close()
        #Ratio between Test and Training L2 distance over L2 Distance(Memorized)
        fig, ax = plt.subplots()
        ax.plot(l2_epochs, l2_inlier_ratios_memo, label='ave inlier ratios', color='blue', marker='o')
        ax.plot(l2_epochs, l2_outlier_ratios_memo, label='ave outlier ratios', color='red', marker='o')
        ax.plot(l2_epochs, l2_dup_ratios_memo, label='ave near duplicate ratios', color='black', marker='o')
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Test/Train Distance L2 Dist Ratio')
        ax.set_title("Ratios Over Epochs(Memorized)")
        plt.grid(True)
        ax.legend()
        plt.savefig(f"exps/{config.experiment_name}/images/Ratios Over Epochs(Memorized).png")
        plt.close()