import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as ticker
import pandas as pd

def plot_SPC(index, poisonlabel, SPC, pathname):
    #import pdb;pdb.set_trace()
    _, indices = torch.sort(index)
    spc_values = SPC
    poisonlabels = poisonlabel
    
    for epoch in range(spc_values.shape[0]):
        spc_values[epoch] = spc_values[epoch][indices[epoch]]
        
    poisonlab = poisonlabels[0][indices[0]]
    
    spc_values_poison = spc_values[:,poisonlab == 1]
    spc_values_clean = spc_values[:,poisonlab == 0]
    
    #import pdb;pdb.set_trace()
    # ===== PLOT 1 =====
    fig, ax = plt.subplots(1,2, figsize=(20,5), constrained_layout=True)
    ax[0].plot(torch.mean(spc_values_clean,dim=1), color='blue', linewidth=3)
    ax[0].fill_between(torch.arange(len(spc_values_clean))+1, y1=torch.max(spc_values_clean,dim=1).values, y2=torch.min(spc_values_clean,dim=1).values, color='blue', alpha=0.21)
    ax[0].set_title('Clean')
    ax[1].plot(torch.mean(spc_values_poison,dim=1), color='red', linewidth=3)
    ax[1].fill_between(torch.arange(len(spc_values_poison))+1, y1=torch.max(spc_values_poison,dim=1).values, y2=torch.min(spc_values_poison,dim=1).values, color='red', alpha=0.21)
    ax[1].set_yticklabels([])
    ax[1].set_title('Poison')
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylabel('SPC Loss')
    ax[1].set_xlabel('Epochs')
    fig.savefig(pathname + "/SPCPlot_line.png")
    
    # ===== PLOT 2 =====
    per_row = 36
    divs = [slice(0,36), slice(36,72), slice(72,108), slice(108,144), slice(144,182)]
    fig, ax = plt.subplots(nrows=len(divs),figsize=(36,15), constrained_layout=True)

    for row in range(len(divs)):
     
        spc_values_clean_slice = spc_values_clean[divs[row]]
        spc_values_poison_slice = spc_values_poison[divs[row]]
        
        spc_clean_loss = []
        spc_clean_epoch = []
        
        for i in range(spc_values_clean_slice.shape[0]):
            epoch = (i+per_row*row)+1
            vals = list(spc_values_clean_slice[i].numpy())
            spc_clean_loss = spc_clean_loss + vals
            spc_clean_epoch = spc_clean_epoch + [epoch]*len(vals)
        
        spc_poison_loss = []
        spc_poison_epoch = []
        
        for i in range(spc_values_poison_slice.shape[0]):
            epoch = (i+per_row*row)+1
            vals = list(spc_values_poison_slice[i].numpy())
            spc_poison_loss = spc_poison_loss + vals
            spc_poison_epoch = spc_poison_epoch + [epoch]*len(vals)
        
        data = {'Loss Value': spc_clean_loss + spc_poison_loss ,
                'Epoch' : spc_clean_epoch + spc_poison_epoch,
                'Poisoning' : ['Clean']*len(spc_clean_loss) + ['Poison']*len(spc_poison_loss) 
                }
            
        df = pd.DataFrame(data)
        sns.set(style="darkgrid")
        sns.violinplot(x="Epoch", y="Loss Value", hue="Poisoning", data=df, palette="Pastel1", ax = ax[row])
    fig.savefig(pathname + '/SPCPlot_violin.png')

'''
        
# plot_SPC("/home/soumyadeep/Poison_Influence_Sift/Exp_Models_train_SPC/cifar10/Badnet/Poisonratio_0.01/res18/Trial 3")


pathname = "/home/soumyadeep/Poison_Influence_Sift/Exp_Models_train_SPC/cifar10/Badnet/Poisonratio_0.01/res18/Trial 3"
index = torch.load(pathname + '/Index.pt')
_, indices = torch.sort(index)

spc_values = torch.load(pathname + '/SPC.pt')
poisonlabels = torch.load(pathname + '/Poisonlabel.pt')

for epoch in range(spc_values.shape[0]):
    spc_values[epoch] = spc_values[epoch][indices[epoch]]
    
poisonlab = poisonlabels[0][indices[0]]

spc_values_poison = spc_values[:,poisonlab == 1]
spc_values_clean = spc_values[:,poisonlab == 0]

# ==========     Plot 1   =============================
# fig, ax = plt.subplots(figsize=(20,10))
# ax.plot(spc_values_clean, color='plum')
# ax.plot(spc_values_poison, color='mediumaquamarine')

# ==========     Plot 2  =============================
fig, ax = plt.subplots(1,2, figsize=(20,5), constrained_layout=True)
ax[0].plot(torch.mean(spc_values_clean,dim=1), color='blue', linewidth=3)
ax[0].fill_between(torch.arange(len(spc_values_clean))+1, y1=torch.max(spc_values_clean,dim=1).values, y2=torch.min(spc_values_clean,dim=1).values, color='blue', alpha=0.21)
ax[0].set_title('Clean')
ax[1].plot(torch.mean(spc_values_poison,dim=1), color='red', linewidth=3)
ax[1].fill_between(torch.arange(len(spc_values_poison))+1, y1=torch.max(spc_values_poison,dim=1).values, y2=torch.min(spc_values_poison,dim=1).values, color='red', alpha=0.21)
ax[1].set_yticklabels([])
ax[1].set_title('Poison')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('SPC Loss')
ax[1].set_xlabel('Epochs')
fig.savefig(pathname + "/SPCPlot_line.png")


# ==========     Plot 3   =============================
# spc_values_clean = spc_values_clean[:, 0:20]
# n = spc_values_clean.shape[1]
# colors = plt.cm.Wistia(np.linspace(0,1,n))
# fig, ax = plt.subplots(figsize=(40,10))
# for i in range(n):
#     ax.plot(spc_values_clean[:,i], color=colors[i])
# fig.savefig(pathname + "/spcplot3.png")

# ==========     Plot 3   =============================
# per_row = 36
# divs = [slice(0,36), slice(36,72), slice(72,108), slice(108,144), slice(144,182)]
# fig, ax = plt.subplots(nrows=len(divs),figsize=(36,15), constrained_layout=True)

# for row in range(len(divs)):
 
#     spc_values_clean_slice = spc_values_clean[divs[row]]
#     spc_values_poison_slice = spc_values_poison[divs[row]]
    
#     spc_clean_loss = []
#     spc_clean_epoch = []
    
#     for i in range(spc_values_clean_slice.shape[0]):
#         epoch = (i+per_row*row)+1
#         vals = list(spc_values_clean_slice[i].numpy())
#         spc_clean_loss = spc_clean_loss + vals
#         spc_clean_epoch = spc_clean_epoch + [epoch]*len(vals)
    
#     spc_poison_loss = []
#     spc_poison_epoch = []
    
#     for i in range(spc_values_poison_slice.shape[0]):
#         epoch = (i+per_row*row)+1
#         vals = list(spc_values_poison_slice[i].numpy())
#         spc_poison_loss = spc_poison_loss + vals
#         spc_poison_epoch = spc_poison_epoch + [epoch]*len(vals)
    
#     data = {'Loss Value': spc_clean_loss + spc_poison_loss ,
#             'Epoch' : spc_clean_epoch + spc_poison_epoch,
#             'Poisoning' : ['Clean']*len(spc_clean_loss) + ['Poison']*len(spc_poison_loss) 
#             }
        
#     df = pd.DataFrame(data)
#     # import pdb;pdb.set_trace()
    
#     # fig, ax = plt.subplots(figsize=(30,30))
#     sns.set(style="darkgrid")
#     sns.violinplot(x="Epoch", y="Loss Value", hue="Poisoning", data=df, palette="Pastel1", ax = ax[row])
# fig.savefig(pathname + '/Violin2.png')
    


# # spc_values_clean_slice = spc_values_clean[181]
# # spc_values_poison_slice = spc_values_poison[181]


# # data = {'Loss Value': list(spc_values_clean_slice.numpy()) + list(spc_values_poison_slice.numpy()) ,
# #         'Epoch' : [182]*(len(spc_values_clean_slice) + len(spc_values_poison_slice)),
# #         'Poisoning' : ['Clean']*len(spc_values_clean_slice) + ['Poison']*len(spc_values_poison_slice) 
# #         }
    
# # df = pd.DataFrame(data)
# # sns.set(style="darkgrid")
# # sns.violinplot(x="Epoch", y="Loss Value", hue="Poisoning", data=df, palette="Pastel1")
# # plt.savefig(pathname + '/Violin3.png')



'''











