import matplotlib.pyplot as plt
import numpy as np

import pdb

def parse_file(file_name):
    task = {}
    try:
        with open(file_name,'r') as f:
            for i,line in enumerate(f):
                if i==7: # full task
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)] = [line[j]]
                  
                elif i ==9:  # std_f1
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(line[j])
                
                elif i ==12: # view_score
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                  
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(abs(line[j]))
                elif i ==17: # nesum in featre
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(abs(line[j]))
                elif i ==27: # reconstruction
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(abs(line[j]))
                elif i == 32: # nesum in dnn
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(abs(line[j]))
                elif i == 47: # noise score
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(abs(line[j]))
                elif i == 42: # denoise
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(abs(line[j]))
                  
                else:
                    continue
        return True,task
    except:
        return False,None




# get list
methods = ['linear_cca','linear_gcca','dcca','dgcca']


dataset = ['True'] # 25 150

epochs = [10,25,50,100,150,200,250,300]
epochs += [100*4,500,600,700,800]
epochs=[i for i in range(500,3000,400)]
epochs=[i for i in range(500,3100,800)]
epochs = [100,200,400,600,800,1000,1200]
num_views =[2]

plt.style.use('seaborn-poster')
pdf_data = {}
#method_data = {}

for num_view in num_views:
    for data in dataset:
        for method in methods:
            t_methods = [method]
            if 'cca' in method and 'd' in method:
                t_methods.append(method+'e')
                
                t_methods.append(method+'_private')
                t_methods.append(method+'_with_noise')
                # t_methods.append(method+'e_decov')
                # t_methods.append(method+'e_mma')
                #t_methods.append(method+'e_decov_fea_1')
            t_epochs = epochs
            for method in t_methods:
                for epoch in t_epochs:

              
                    file_name = "./Syn_output/{}_{}_{}_{}.csv".format(method,data,num_view,epoch)
              
                    flag,task = parse_file(file_name)
             
                
                    if not  flag:
                        print('no',file_name)
                        #pdb.set_trace()
                        continue
                    
                    try:
                        for task_name in task:
                            pdf_name = "{}_{}_{}".format(data,num_view,task_name)
                    
                            if epoch not in pdf_data:
                                pdf_data[epoch] ={}
                    
                            if method not in pdf_data[epoch]:
                                pdf_data[epoch][method] = {}
                            #pdb.set_trace()
                            pdf_data[epoch][method][task_name] = task[task_name]
                    except:
                        pdb.set_trace()
                        continue

colors = ['red','blue','green','purple','orange','black','brown','gold']
marker = ['o','s','^','v','<','>','x','+']

epochs = [100,200,400,1200]
# fig, axes = plt.subplots(2, len(epochs),figsize=(25, 10),sharex=True,sharey=True)

fig, axes = plt.subplots(1, len(epochs),figsize=(35, 6),sharex=True, sharey=True)

fig_1, axes_1 = plt.subplots(1,len(epochs),figsize=(35, 6),sharex=True, sharey=True)

fig_s = [fig,fig_1]

axes_s = [axes,axes_1]


front = 36

methods_s = [['linear_cca','dcca'],['linear_gcca','dgcca']]
for f,methods in enumerate(methods_s):
    for k in range(len(epochs)):
        epoch = epochs[k]
        for i,method in enumerate(methods):
            t_methods = [method]
            if 'cca' in method and 'd' in method:
                t_methods.append(method+'e')
                
                t_methods.append(method+'_private')
                t_methods.append(method+'_with_noise')
                # t_methods.append(method+'e_decov')
                # t_methods.append(method+'e_mma')
                #t_methods.append(method+'e_decov_fea_1')
            #pdb.set_trace()
            for j,method in enumerate(t_methods):
            #pdb.set_trace()
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[0] for d in y])
                y_std = np.array( [d[1] for d in y])
                #pdb.set_trace()

                method = method.upper()
                if ('CCA' in method and 'D' in method)  or method=='LINEAR_CCA' or method=='LINEAR_GCCA':
                    if method=='dcca_with_noise'.upper():
                        method='NR-dcca'.upper()+'(ours)'
                    if method=='dgcca_with_noise'.upper():
                        method='NR-dgcca'.upper()  +'(ours)'
                    axes_s[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[f+1],linewidth=5)
                  
                else:
                    # if f==1:
                    #     pdb.set_trace()
                    axes_s[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                
        
        axes_s[f][k].set_title('Epoch: {}'.format(epoch))
        if k==0:
            axes_s[f][k].legend()
        axes_s[f][k].grid(linestyle= ':',linewidth=2)


fig_s[0].tight_layout()
fig_s[1].tight_layout()

axes_s[0][0].set_ylabel('R2_performance',fontsize=front)
axes_s[1][0].set_ylabel('R2_performance',fontsize=front)

fig_s[0].text(0.5, -0.05, 'Common_rate', fontsize=front)
fig_s[1].text(0.5, -0.05,  'Common_rate', fontsize=front)
#ylabel.set_position((0.5, -0.1))
#plt.xticks(epochs,epochs)
# axes_s[0].tick_params(axis='both', labelsize=14)
# axes_s[1].tick_params(axis='both', labelsize=14)

fig_s[0].savefig('./draw_pictures/Syn_rate_r2_CCA.png', dpi=300, bbox_inches='tight')
fig_s[1].savefig('./draw_pictures/Syn_rate_r2_GCCA.png', dpi=300, bbox_inches='tight')
    
   
plt.close()

# axes[0][0].set_ylabel('R2_performance',fontsize=24)

# fig.text(0.5, -0.05, 'Common_rate', fontsize=24)




# plt.savefig('./draw_pictures/Syn_rate_r2.png', dpi=300, bbox_inches='tight')

# plt.close()





#epochs = [50,100,200]
# fig, axes = plt.subplots(1, 1,figsize=(20,10),sharex=True,sharey=True)
# #pdb.set_trace()
# methods_s = [['concat','linear_cca','kcca','dcca']]
# epochs = [50,100,200,400,600,800,1000,1200]
# for f,methods in enumerate(methods_s):
    
#     for i,method in enumerate(methods):
#         t_methods = [method]
#         if 'cca' in method and 'd' in method:
#             t_methods.append(method+'e')
            
#             t_methods.append(method+'_private')
#             t_methods.append(method+'_with_noise')
#             #pdb.set_trace()
        
        
#         for j,method in enumerate(t_methods):
#             if method not in ['linear_cca','concat','dcca','dcca_with_noise']:
#                 continue

#             y_mean_abs = []
#             y_mean_std= []
#             y_mean_mean= []
            
#             for k in range(len(epochs)):
#                 epoch = epochs[k]
#                 if method not in pdf_data[epoch]:
#                     x = list(pdf_data[epochs[0]][method].keys())
#                     y = list(pdf_data[epochs[0]][method].values())
#                 else:
#                     x = list(pdf_data[epoch][method].keys())
#                     y = list(pdf_data[epoch][method].values())
#                 y_mean = np.array( [d[0] for d in y])
#                 y_std = np.array( [d[1] for d in y])
#                 y_mean_std.append(y_mean.std(0))
#                 y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
#                 y_mean_mean.append(y_mean.mean(0))
#             y_mean_mean = np.array(y_mean_mean)
#             y_mean_std = np.array(y_mean_std)

#             #pdb.set_trace()
#             x = np.array(epochs)
#             method = method.upper()
#             if 'CCA' in method and 'D' in method:
#                 if method=='dcca_with_noise'.upper():
#                     method='NR-dcca'.upper()+'(our)'
#                 if method=='dgcca_with_noise'.upper():
#                     method='NR-dgcca'.upper()  +'(our)'
#                 axes.plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1])
              
#                 axes.fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
#             else:
#                 axes.plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
#                 axes.fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
#     axes.legend(loc ='lower left')
 
# plt.tight_layout()

# axes.set_ylabel('R2_performance',fontsize=24)

# fig.text(0.5, -0.05, 'Epoch', fontsize=24)

# plt.grid(linestyle= ':',linewidth=1)
# plt.savefig('./draw_pictures/Syn_rate_mean_std_select.png', dpi=300, bbox_inches='tight')

# plt.close()


# fig, axes = plt.subplots(2, 1,figsize=(20, 15),sharex=True,sharey=True)

fig, axes = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_1, axes_1 = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_s = [fig,fig_1]

axes_s = [axes,axes_1]

#pdb.set_trace()
methods_s = [['linear_cca','dcca'],['linear_gcca','dgcca']]
epochs = [100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    
    for i,method in enumerate(methods):
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
                
            t_methods.append(method+'_private')
            t_methods.append(method+'_with_noise')
            # t_methods.append(method+'e_decov')
            # t_methods.append(method+'e_mma')
            #t_methods.append(method+'e_decov_fea_1')
            #pdb.set_trace()
        
        
        for j,method in enumerate(t_methods):
            # if method not in ['cca','concat','dcca','dcca_with_noise']:
            #     continue

            y_mean_abs = []
            y_mean_std= []
            y_mean_mean= []
            
            for k in range(len(epochs)):
                epoch = epochs[k]
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[0] for d in y])
                #y_std = np.array( [d[1] for d in y])
                y_mean_std.append(y_mean.std(0))
                #y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
                y_mean_mean.append(y_mean.mean(0))
            y_mean_mean = np.array(y_mean_mean)
            y_mean_std = np.array(y_mean_std)

            #pdb.set_trace()
            x = np.array(epochs)
            #x= (i+j-1)*10+x
            method = method.upper()
            if ('CCA' in method and 'D' in method)  or method=='LINEAR_CCA' or method=='LINEAR_GCCA':
                if method=='dcca_with_noise'.upper():
                    method='NR-dcca'.upper()+'(ours)'
                if method=='dgcca_with_noise'.upper():
                    method='NR-dgcca'.upper()  +'(ours)'
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1],linewidth=5)
              
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
            else:
                print(method)
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
    axes_s[f].legend(loc ='lower left',prop = {'size':24})
    axes_s[f].grid(linestyle= ':',linewidth=2)
 
fig_s[0].tight_layout()
fig_s[1].tight_layout()

axes_s[0].set_ylabel('R2_performance',fontsize=front)
axes_s[1].set_ylabel('R2_performance',fontsize=front)

fig_s[0].text(0.5, -0.05, 'Epoch', fontsize=front)
fig_s[1].text(0.5, -0.05,  'Epoch', fontsize=front)
#ylabel.set_position((0.5, -0.1))
#plt.xticks(epochs,epochs)
axes_s[0].tick_params(axis='both', labelsize=36)
axes_s[1].tick_params(axis='both', labelsize=36)


fig_s[0].savefig('./draw_pictures/Syn_rate_mean_std_CCA.png', dpi=300, bbox_inches='tight')
fig_s[1].savefig('./draw_pictures/Syn_rate_mean_std_GCCA.png', dpi=300, bbox_inches='tight')



plt.close()

fig, axes = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_1, axes_1 = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_s = [fig,fig_1]

axes_s = [axes,axes_1]


#pdb.set_trace()
methods_s = [['linear_cca','dcca'],['linear_gcca','dgcca']]
epochs = [100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    
    for i,method in enumerate(methods):
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
                
            t_methods.append(method+'_private')
            t_methods.append(method+'_with_noise')
            # t_methods.append(method+'e_decov')
            # t_methods.append(method+'e_mma')
            #t_methods.append(method+'e_decov_fea_1')
            #pdb.set_trace()
        
        
        for j,method in enumerate(t_methods):
            # if method not in ['cca','concat','dcca','dcca_with_noise']:
            #     continue

            y_mean_abs = []
            y_mean_std= []
            y_mean_mean= []
            
            for k in range(len(epochs)):
                epoch = epochs[k]
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[3] for d in y])
                #y_std = np.array( [d[1] for d in y])
                y_mean_std.append(0)
                #y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
                y_mean_mean.append(y_mean.mean(0))
            y_mean_mean = np.array(y_mean_mean)
            y_mean_std = np.array(y_mean_std)

            #pdb.set_trace()
            x = np.array(epochs)
            #x= (i+j-1)*10+x
            method = method.upper()
            if ('CCA' in method and 'D' in method)  or method=='LINEAR_CCA'  or method=='LINEAR_GCCA':
                if method=='dcca_with_noise'.upper():
                    method='NR-dcca'.upper()+'(ours)'
                if method=='dgcca_with_noise'.upper():
                    method='NR-dgcca'.upper()  +'(ours)'
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1],linewidth=5)
              
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
            else:
                print(method)
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
    axes_s[f].legend(loc ='lower left',prop = {'size':24})
    axes_s[f].grid(linestyle= ':',linewidth=2)
 
fig_s[0].tight_layout()
fig_s[1].tight_layout()

axes_s[0].set_ylabel('NESum_of_feature',fontsize=front)
axes_s[1].set_ylabel('NESum_of_feature',fontsize=front)

fig_s[0].text(0.5, -0.05, 'Epoch', fontsize=front)
fig_s[1].text(0.5, -0.05,  'Epoch', fontsize=front)
#ylabel.set_position((0.5, -0.1))
#plt.xticks(epochs,epochs)

axes_s[0].tick_params(axis='both', labelsize=36)
axes_s[1].tick_params(axis='both', labelsize=36)

fig_s[0].savefig('./draw_pictures/Syn_rate_feature_nesum_mean_std_CCA.png', dpi=300, bbox_inches='tight')
fig_s[1].savefig('./draw_pictures/Syn_rate_feature_nesum_mean_std_GCCA.png', dpi=300, bbox_inches='tight')



plt.close()

fig, axes = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_1, axes_1 = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_s = [fig,fig_1]

axes_s = [axes,axes_1]


#pdb.set_trace()
methods_s = [['linear_cca','dcca'],['linear_gcca','dgcca']]
epochs = [100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    
    for i,method in enumerate(methods):
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
                
            t_methods.append(method+'_private')
            t_methods.append(method+'_with_noise')
            # t_methods.append(method+'e_decov')
            # t_methods.append(method+'e_mma')
            #t_methods.append(method+'e_decov_fea_1')
            #pdb.set_trace()
        
        
        for j,method in enumerate(t_methods):
            # if method not in ['cca','concat','dcca','dcca_with_noise']:
            #     continue

            y_mean_abs = []
            y_mean_std= []
            y_mean_mean= []
            
            for k in range(len(epochs)):
                epoch = epochs[k]
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[4] for d in y])
                #y_std = np.array( [d[1] for d in y])
                y_mean_std.append(0)
                #y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
                y_mean_mean.append(y_mean.mean(0))
            y_mean_mean = np.array(y_mean_mean)
            y_mean_std = np.array(y_mean_std)

            #pdb.set_trace()
            x = np.array(epochs)
            #x= (i+j-1)*10+x
            method = method.upper()
            if ('CCA' in method and 'D' in method)  or method=='LINEAR_CCA'  or method=='LINEAR_GCCA':
                if method=='dcca_with_noise'.upper():
                    method='NR-dcca'.upper()+'(ours)'
                if method=='dgcca_with_noise'.upper():
                    method='NR-dgcca'.upper()  +'(ours)'
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1],linewidth=5)
              
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
            else:
                print(method)
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
    axes_s[f].legend(loc ='lower left',prop = {'size':24})
    axes_s[f].grid(linestyle= ':',linewidth=2)
 
fig_s[0].tight_layout()
fig_s[1].tight_layout()

axes_s[0].set_ylabel('Reconstruction Loss',fontsize=front)
axes_s[1].set_ylabel('Reconstruction Loss',fontsize=front)

fig_s[0].text(0.5, -0.05, 'Epoch', fontsize=front)
fig_s[1].text(0.5, -0.05,  'Epoch', fontsize=front)
#ylabel.set_position((0.5, -0.1))
#plt.xticks(epochs,epochs)
axes_s[0].tick_params(axis='both', labelsize=36)
axes_s[1].tick_params(axis='both', labelsize=36)


fig_s[0].savefig('./draw_pictures/Syn_rate_reconstruction_mean_std_CCA.png', dpi=300, bbox_inches='tight')
fig_s[1].savefig('./draw_pictures/Syn_rate_reconstruction_mean_std_GCCA.png', dpi=300, bbox_inches='tight')



plt.close()

fig, axes = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_1, axes_1 = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_s = [fig,fig_1]

axes_s = [axes,axes_1]


#pdb.set_trace()
methods_s = [['linear_cca','dcca'],['linear_gcca','dgcca']]
epochs = [100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    
    for i,method in enumerate(methods):
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
                
            t_methods.append(method+'_private')
            t_methods.append(method+'_with_noise')
            # t_methods.append(method+'e_decov')
            # t_methods.append(method+'e_mma')
           # t_methods.append(method+'e_decov_fea_1')
            #pdb.set_trace()
        
        
        for j,method in enumerate(t_methods):
            # if method not in ['cca','concat','dcca','dcca_with_noise']:
            #     continue

            y_mean_abs = []
            y_mean_std= []
            y_mean_mean= []
            
            for k in range(len(epochs)):
                epoch = epochs[k]
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[5] for d in y])
                #y_std = np.array( [d[1] for d in y])
                y_mean_std.append(0)
                #y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
                y_mean_mean.append(y_mean.mean(0))
            y_mean_mean = np.array(y_mean_mean)
            y_mean_std = np.array(y_mean_std)

            #pdb.set_trace()
            x = np.array(epochs)
            #x= (i+j-1)*10+x
            method = method.upper()
            if ('CCA' in method and 'D' in method)  or method=='LINEAR_CCA'  or method=='LINEAR_GCCA':
                if method=='dcca_with_noise'.upper():
                    method='NR-dcca'.upper()+'(ours)'
                if method=='dgcca_with_noise'.upper():
                    method='NR-dgcca'.upper()  +'(ours)'
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1],linewidth=5)
              
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
            else:
                print(method)
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
    axes_s[f].legend(loc ='lower left',prop = {'size':24})
    axes_s[f].grid(linestyle= ':',linewidth=2)
 
fig_s[0].tight_layout()
fig_s[1].tight_layout()

axes_s[0].set_ylabel('NESum_of_dnns',fontsize=front)
axes_s[1].set_ylabel('NESum_of_dnns',fontsize=front)

fig_s[0].text(0.5, -0.05, 'Epoch', fontsize=front)
fig_s[1].text(0.5, -0.05,  'Epoch', fontsize=front)
#ylabel.set_position((0.5, -0.1))
#plt.xticks(epochs,epochs)
axes_s[0].tick_params(axis='both', labelsize=36)
axes_s[1].tick_params(axis='both', labelsize=36)


fig_s[0].savefig('./draw_pictures/Syn_rate_dnn_nesum_mean_std_CCA.png', dpi=300, bbox_inches='tight')
fig_s[1].savefig('./draw_pictures/Syn_rate_dnn_nesum_mean_std_GCCA.png', dpi=300, bbox_inches='tight')



plt.close()



fig, axes = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_1, axes_1 = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_s = [fig,fig_1]

axes_s = [axes,axes_1]


#pdb.set_trace()
methods_s = [['linear_cca','dcca'],['linear_gcca','dgcca']]
epochs = [100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    
    for i,method in enumerate(methods):
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
                
            t_methods.append(method+'_private')
            t_methods.append(method+'_with_noise')
            # t_methods.append(method+'e_decov')
            # t_methods.append(method+'e_mma')
            #t_methods.append(method+'e_decov_fea_1')
            #pdb.set_trace()
        
        
        for j,method in enumerate(t_methods):
            if method not in ['dccae','dcca','dcca_with_noise','dgccae','dgcca','dgcca_with_noise','dcca_private','dgcca_private']:
                continue

            y_mean_abs = []
            y_mean_std= []
            y_mean_mean= []
            
            for k in range(len(epochs)):
                epoch = epochs[k]
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[7] for d in y])
                #y_std = np.array( [d[1] for d in y])
                y_mean_std.append(0)
                #y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
                y_mean_mean.append(y_mean.mean(0))
            y_mean_mean = np.array(y_mean_mean)
            y_mean_std = np.array(y_mean_std)

            #pdb.set_trace()
            x = np.array(epochs)
            #x= (i+j-1)*10+x
            method = method.upper()
            if ('CCA' in method and 'D' in method) or method=='LINEAR_CCA'  or method=='LINEAR_GCCA':
                if method=='dcca_with_noise'.upper():
                    method='NR-dcca'.upper()+'(ours)'
                if method=='dgcca_with_noise'.upper():
                    method='NR-dgcca'.upper()  +'(ours)'
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1],linewidth=5)
              
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
            else:
                print(method)
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
    axes_s[f].legend(loc ='lower left',prop = {'size':24})
    axes_s[f].grid(linestyle= ':',linewidth=2)
 
fig_s[0].tight_layout()
fig_s[1].tight_layout()

axes_s[0].set_ylabel('Noise_cor',fontsize=front)
axes_s[1].set_ylabel('Noise_cor',fontsize=front)

fig_s[0].text(0.5, -0.05, 'Epoch', fontsize=front)
fig_s[1].text(0.5, -0.05,  'Epoch', fontsize=front)
#ylabel.set_position((0.5, -0.1))
#plt.xticks(epochs,epochs)
axes_s[0].tick_params(axis='both', labelsize=36)
axes_s[1].tick_params(axis='both', labelsize=36)


fig_s[0].savefig('./draw_pictures/Syn_rate_noise_cor_mean_std_CCA.png', dpi=300, bbox_inches='tight')
fig_s[1].savefig('./draw_pictures/Syn_rate_noise_cor_mean_std_GCCA.png', dpi=300, bbox_inches='tight')

plt.close()



fig, axes = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_1, axes_1 = plt.subplots(1,1,figsize=(20, 10),sharex=True, sharey=True)

fig_s = [fig,fig_1]

axes_s = [axes,axes_1]


#pdb.set_trace()
methods_s = [['linear_cca','dcca'],['linear_gcca','dgcca']]
epochs = [100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    
    for i,method in enumerate(methods):
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
                
            t_methods.append(method+'_private')
            t_methods.append(method+'_with_noise')
            # t_methods.append(method+'e_decov')
            # t_methods.append(method+'e_mma')
            #t_methods.append(method+'e_decov_fea_1')
            #pdb.set_trace()
        
        
        for j,method in enumerate(t_methods):
            # if method not in ['cca','concat','dcca','dcca_with_noise']:
            #     continue

            y_mean_abs = []
            y_mean_std= []
            y_mean_mean= []
            
            for k in range(len(epochs)):
                epoch = epochs[k]
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                #pdb.set_trace()
                y_mean = np.array( [d[6] for d in y])
                #y_std = np.array( [d[1] for d in y])
                y_mean_std.append(0)
                #y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
                y_mean_mean.append(y_mean.mean(0))
            y_mean_mean = np.array(y_mean_mean)
            y_mean_std = np.array(y_mean_std)

            #pdb.set_trace()
            x = np.array(epochs)
            #x= (i+j-1)*10+x
            method = method.upper()
            if ('CCA' in method and 'D' in method) or method=='LINEAR_CCA'  or method=='LINEAR_GCCA':
                if method=='dcca_with_noise'.upper():
                    method='NR-dcca'.upper()+'(ours)'
                if method=='dgcca_with_noise'.upper():
                    method='NR-dgcca'.upper()  +'(ours)' 
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1],linewidth=5)
              
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
            else:
                print(method)
                axes_s[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                axes_s[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
    axes_s[f].legend(loc ='lower left',prop = {'size':24})
    axes_s[f].grid(linestyle= ':',linewidth=2)
 
fig_s[0].tight_layout()
fig_s[1].tight_layout()

axes_s[0].set_ylabel('Denoising Loss',fontsize=front)
axes_s[1].set_ylabel('Denoising Loss',fontsize=front)

fig_s[0].text(0.5, -0.05, 'Epoch', fontsize=front)
fig_s[1].text(0.5, -0.05,  'Epoch', fontsize=front)
#ylabel.set_position((0.5, -0.1))
#plt.xticks(epochs,epochs)
axes_s[0].tick_params(axis='both', labelsize=36)
axes_s[1].tick_params(axis='both', labelsize=36)


fig_s[0].savefig('./draw_pictures/Syn_rate_denoising_loss_mean_std_CCA.png', dpi=300, bbox_inches='tight')
fig_s[1].savefig('./draw_pictures/Syn_rate_denoising_loss_mean_std_GCCA.png', dpi=300, bbox_inches='tight')

plt.close()

commons = ['0%','60%','100%']



fig, axes = plt.subplots(1,len(commons),figsize=(20, 8.5),sharex=True,sharey='row')

fig_1, axes_1 = plt.subplots(1,len(commons),figsize=(20, 8.5),sharex=True,sharey='row')

fig_s = [fig,fig_1]

axes_s = [axes,axes_1]


methods_s = [['linear_cca','dcca'],['linear_gcca','dgcca']]
epochs = [100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    for k in range(len(commons)):
        common = commons[k]
        for i,method in enumerate(methods):

            t_methods = [method]
            if 'cca' in method and 'd' in method:
                t_methods.append(method+'e')
                
                t_methods.append(method+'_private')
                t_methods.append(method+'_with_noise')
                # t_methods.append(method+'e_decov')
                # t_methods.append(method+'e_mma')
                #t_methods.append(method+'e_decov_fea_1')
            #pdb.set_trace()
            for j,method in enumerate(t_methods):
            #pdb.set_trace()
                # if ('cca' in method and 'd' in method):
                #     pass
                # else:
                #     continue
                y_mean = []
                
                for epoch in epochs:
                    if method not in pdf_data[epoch]:
                        y_mean.append(pdf_data[epochs[0]][method][common][0])
                    else:
                        y_mean.append(pdf_data[epoch][method][common][0])
                    
                    #y_mean.append(pdf_data[epochs[0]][method][common][3])
                x = np.array(epochs)
                y_mean = np.array(y_mean)    
                #pdb.set_trace()

                method = method.upper()
                if ('CCA' in method and 'D' in method) or method=='LINEAR_CCA'  or method=='LINEAR_GCCA':
                    if method=='dcca_with_noise'.upper():
                        method='NR-dcca'.upper()+'(ours)'
                    if method=='dgcca_with_noise'.upper():
                        method='NR-dgcca'.upper()  +'(ours)'
                    axes_s[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[f+1],linewidth=5)
                  
                else:
                    axes_s[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                    
        
        axes_s[f][k].set_title('Common_rate: {}'.format(common))
        if k==0:
            axes_s[f][k].legend(prop = {'size':24})
        axes_s[f][k].grid(linestyle= ':',linewidth=2)
        axes_s[f][k].set_xticks([0,200,400,600,800,1000,1200])

plt.tight_layout()

fig_s[0].text(0.5, -0.05, 'Epoch', fontsize=front)
fig_s[1].text(0.5, -0.05,  'Epoch', fontsize=front)

axes_s[0][0].set_ylabel('R2_performance',fontsize=front)
axes_s[1][0].set_ylabel('R2_performance',fontsize=front)
#ylabel.set_position((0.5, -0.1))
#plt.xticks(epochs,epochs)
# axes_s[0].tick_params(axis='both', labelsize=14)
# axes_s[1].tick_params(axis='both', labelsize=14)

fig_s[0].savefig('./draw_pictures/Syn_rate_cor_CCA.png', dpi=300, bbox_inches='tight')
fig_s[1].savefig('./draw_pictures/Syn_rate_cor_GCCA.png', dpi=300, bbox_inches='tight')


plt.close()