import matplotlib.pyplot as plt
import numpy as np

import pdb

def parse_file(file_name):
    task = {}
    try:
        with open(file_name,'r') as f:
            for i,line in enumerate(f):
                if i==7: # full task
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)] = [line[j]]
                  
                elif i ==9:
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(line[j])
                
                elif i ==12:
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                  
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(abs(line[j]))
                elif i ==17:
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)].append(abs(line[j]))
                  
                else:
                    continue
        return True,task



    except:
        return False,None




# get list
methods = ['concat','cca','kcca','dcca','dgcca']


dataset = ['True'] # 25 150

epochs = [10,25,50,100,150,200,250,300]
epochs += [100*4,500,600,700,800]
epochs=[i for i in range(500,3000,400)]
epochs=[i for i in range(500,3100,800)]
epochs = [50,100,200,400,600,800,1000,1200]
num_views =[2]

plt.style.use('seaborn-poster')
pdf_data = {}
#method_data = {}

for num_view in num_views:
    for data in dataset:
        for method in methods:
            t_methods = [method]
            if 'cca' in method and 'd' in method:
                t_methods.append(method+'e')
                t_methods.append(method+'_with_noise')
                t_methods.append(method+'_private')
          
            t_epochs = epochs
            for method in t_methods:
                for epoch in t_epochs:

              
                    file_name = "./Syn_output/{}_{}_{}_{}.csv".format(method,data,num_view,epoch)
              
                    flag,task = parse_file(file_name)
             
                
                    if not  flag:
                        print('no',file_name)
                        #pdb.set_trace()
                        continue
                    
                    try:
                        for task_name in task:
                            pdf_name = "{}_{}_{}".format(data,num_view,task_name)
                    
                            if epoch not in pdf_data:
                                pdf_data[epoch] ={}
                    
                            if method not in pdf_data[epoch]:
                                pdf_data[epoch][method] = {}
                            #pdb.set_trace()
                            pdf_data[epoch][method][task_name] = task[task_name]
                    except:
                        pdb.set_trace()
                        continue

colors = ['red','blue','green','purple','orange','black','brown','gold']
marker = ['o','s','^','v','<','>','x','+']

epochs = [50,100,200,400,1200]
fig, axes = plt.subplots(2, len(epochs),figsize=(25, 10),sharex=True,sharey=True)

methods_s = [['concat','cca','kcca','dcca'],['concat','cca','kcca','dgcca']]
for f,methods in enumerate(methods_s):
    for k in range(len(epochs)):
        epoch = epochs[k]
        for i,method in enumerate(methods):
            t_methods = [method]
            if 'cca' in method and 'd' in method:
                t_methods.append(method+'e')
                
                t_methods.append(method+'_private')
                t_methods.append(method+'_with_noise')
            #pdb.set_trace()
            for j,method in enumerate(t_methods):
            #pdb.set_trace()
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[0] for d in y])
                y_std = np.array( [d[1] for d in y])
                #pdb.set_trace()

                if 'cca' in method and 'd' in method:
                    if method=='dcca_with_noise':
                        method='NR-dcca (ours)'
                    if method=='dgcca_with_noise':
                        method='NR-dgcca (ours)'  
                    axes[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[f+1])
                  
                else:
                    axes[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                
        if f==0:
            axes[f][k].set_title('Epoch: {}'.format(epoch))
        if k==0:
            axes[f][k].legend()
        axes[f][k].grid(linestyle= ':',linewidth=1)
   
plt.tight_layout()

axes[0][0].set_ylabel('R2_performance',fontsize=24)

fig.text(0.5, -0.05, 'Common_rate', fontsize=24)




plt.savefig('./draw_pictures/Syn_rate_r2.png', dpi=300, bbox_inches='tight')

plt.close()

#epochs = [50,100,200]
fig, axes = plt.subplots(1, 1,figsize=(20, 10),sharex=True,sharey=True)
#pdb.set_trace()
methods_s = [['concat','cca','kcca','dcca']]
epochs = [50,100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    
    for i,method in enumerate(methods):
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
            
            t_methods.append(method+'_private')
            t_methods.append(method+'_with_noise')
            #pdb.set_trace()
        
        
        for j,method in enumerate(t_methods):
            if method not in ['cca','concat','dcca','dcca_with_noise']:
                continue

            y_mean_abs = []
            y_mean_std= []
            y_mean_mean= []
            
            for k in range(len(epochs)):
                epoch = epochs[k]
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[0] for d in y])
                y_std = np.array( [d[1] for d in y])
                y_mean_std.append(y_mean.std(0))
                y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
                y_mean_mean.append(y_mean.mean(0))
            y_mean_mean = np.array(y_mean_mean)
            y_mean_std = np.array(y_mean_std)

            #pdb.set_trace()
            x = np.array(epochs)
            #x= (i+j-1)*10+x
            if 'cca' in method and 'd' in method: 
                if method=='dcca_with_noise':
                    method='NR-dcca (ours)'
                if method=='dgcca_with_noise':
                    method='NR-dgcca (ours)'
                axes.plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1])
              
                axes.fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
            else:
                axes.plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                axes.fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
    axes.legend(loc ='lower left')
 
plt.tight_layout()

axes.set_ylabel('R2_performance',fontsize=24)

fig.text(0.5, -0.05, 'Epoch', fontsize=24)

plt.grid(linestyle= ':',linewidth=1)
plt.savefig('./draw_pictures/Syn_rate_mean_std_select.png', dpi=300, bbox_inches='tight')

plt.close()


fig, axes = plt.subplots(2, 1,figsize=(20, 15),sharex=True,sharey=True)
#pdb.set_trace()
methods_s = [['concat','cca','kcca','dcca'],['concat','cca','kcca','dgcca']]
epochs = [50,100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    
    for i,method in enumerate(methods):
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
            
            t_methods.append(method+'_private')
            t_methods.append(method+'_with_noise')
            #pdb.set_trace()
        
        
        for j,method in enumerate(t_methods):
            # if method not in ['cca','concat','dcca','dcca_with_noise']:
            #     continue

            y_mean_abs = []
            y_mean_std= []
            y_mean_mean= []
            
            for k in range(len(epochs)):
                epoch = epochs[k]
                if method not in pdf_data[epoch]:
                    x = list(pdf_data[epochs[0]][method].keys())
                    y = list(pdf_data[epochs[0]][method].values())
                else:
                    x = list(pdf_data[epoch][method].keys())
                    y = list(pdf_data[epoch][method].values())
                y_mean = np.array( [d[0] for d in y])
                y_std = np.array( [d[1] for d in y])
                y_mean_std.append(y_mean.std(0))
                y_mean_abs.append(abs(max(y_mean)-min(y_mean)))
                y_mean_mean.append(y_mean.mean(0))
            y_mean_mean = np.array(y_mean_mean)
            y_mean_std = np.array(y_mean_std)

            #pdb.set_trace()
            x = np.array(epochs)
            #x= (i+j-1)*10+x
            if 'cca' in method and 'd' in method: 
                if method=='dcca_with_noise':
                    method='NR-dcca (ours)'
                if method=='dgcca_with_noise':
                    method='NR-dgcca (ours)'
                axes[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[f+1])
              
                axes[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
            else:
                axes[f].plot(x, y_mean_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                axes[f].fill_between(x, y_mean_mean-y_mean_std, y_mean_mean+y_mean_std,  color=colors[i+j], alpha=0.2)
               
        
        
    axes[f].legend(loc ='lower left')
 
plt.tight_layout()

axes[0].set_ylabel('R2_performance',fontsize=24)

fig.text(0.5, -0.05, 'Epoch', fontsize=24)

plt.grid(linestyle= ':',linewidth=1)
plt.savefig('./draw_pictures/Syn_rate_mean_std.png', dpi=300, bbox_inches='tight')

plt.close()



commons = ['0%','60%','100%']
fig, axes = plt.subplots(2, len(commons),figsize=(12, 6),sharex=True,sharey='row')
methods_s = [['concat','cca','kcca','dcca'],['concat','cca','kcca','dgcca']]
epochs = [50,100,200,400,600,800,1000,1200]
for f,methods in enumerate(methods_s):
    for k in range(len(commons)):
        common = commons[k]
        for i,method in enumerate(methods):

            t_methods = [method]
            if 'cca' in method and 'd' in method:
                t_methods.append(method+'e')
                
                t_methods.append(method+'_private')
                t_methods.append(method+'_with_noise')
            #pdb.set_trace()
            for j,method in enumerate(t_methods):
            #pdb.set_trace()
                if 'cca' in method and 'd' in method:
                    pass
                else:
                    continue
                y_mean = []
                
                for epoch in epochs:
                    if method not in pdf_data[epoch]:
                        y_mean.append(pdf_data[epochs[0]][method][common][3])
                    else:
                        y_mean.append(pdf_data[epoch][method][common][3])
                    
                    #y_mean.append(pdf_data[epochs[0]][method][common][3])
                x = np.array(epochs)
                y_mean = np.array(y_mean)    
                #pdb.set_trace()

                if 'cca' in method and 'd' in method:
                    if method=='dcca_with_noise':
                        method='NR-dcca (ours)'
                    if method=='dgcca_with_noise':
                        method='NR-dgcca (ours)'  
                    axes[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[f+1])
                  
                else:
                    axes[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                    
        if f==0:
            axes[f][k].set_title('Common_rate: {}'.format(common))
        if k==0:
            axes[f][k].legend()
        axes[f][k].grid(linestyle= ':',linewidth=1)
      

plt.tight_layout()

#plt.grid(linestyle= ':',linewidth=1)
fig.text(0.5, -0.05, 'Epoch', fontsize=20)
#xlabel.set_position((0.5, -0.1))

ylabel = axes[0][0].set_ylabel('Corr(cca)', fontsize=20)
ylabel = axes[1][0].set_ylabel('Corr(gcca)', fontsize=20)

plt.savefig('./draw_pictures/Syn_rate_cor.png', dpi=300, bbox_inches='tight')


