import matplotlib.pyplot as plt
import numpy as np

import pdb

def parse_file(file_name):
    task = {}
    try:
        with open(file_name,'r') as f:
            for i,line in enumerate(f):
                if i==7: # full task
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)] = [line[j]]
                    # task['full_task_case_2'] = line[1]
                    # task['full_task_case_3'] = line[2]
                # elif i ==9:
                #     line = line.strip().split(',')
                #     line = [float(x) for x in line]
                #     #for j in range(len(line)):
                #     for j in range(len(line)):
                #         task['{}%'.format(j*20)].append(line[j])
                
                # elif i ==12:
                #     line = line.strip().split(',')
                #     line = [float(x) for x in line]
                #     #for j in range(len(line)):
                #     for j in range(len(line)):
                #         task['{}%'.format(j*20)].append(abs(line[j]))
                # elif i ==17:
                #     line = line.strip().split(',')
                #     line = [float(x) for x in line]
                #     #for j in range(len(line)):
                #     for j in range(len(line)):
                #         task['{}%'.format(j*20)].append(abs(line[j]))
                    # task['half_task_case_1'] = line[0]
                    # task['half_task_case_2'] = line[1]
                    # task['half_task_case_3'] = line[2]
                else:
                    continue
        return True,task



    except:
        return False,None




# get list

methods = ['concat','cca','dcca','dgcca','mvtcae']


dataset = ['PolyMnist_2','PolyMnist_3','PolyMnist_4','PolyMnist_5','CUB_2_0','Caltech101_3_0'] # 25 150


epochs = [100,200,300,400,500]
num_views =[2]


pdf_data = {}
#method_data = {}
plt.style.use('seaborn-poster')

for data in dataset:
    for method in methods:
        t_methods = [method]
        if 'cca' in method and 'd' in method:
            t_methods.append(method+'e')
            t_methods.append(method+'_with_noise')
            t_methods.append(method+'_private')
           
        t_epochs = epochs
        for method in t_methods:
            for epoch in t_epochs:
                data_name,num_view = data.split('_')[0], data.split('_')[1]
               
                if data_name=='PolyMnist':
                    file_name = "./Polymnist_output/{}_{}_{}_{}.csv".format(method,True,num_view,epoch)
                else:
                    file_name = "./Cpm_output/{}_{}_{}_{}_{}.csv".format(data,method,True,num_view,epoch)
               
                flag,task = parse_file(file_name)
                    #pdb.set_trace()
                # if method == 'dcca_dis':
                #     pdb.set_trace()
                
                if not  flag:
                    print('no',file_name)
                        #pdb.set_trace()
                    continue
                    
                try:
                    for task_name in task:
                            
                        print(data)
                        if data not in pdf_data:
                            pdf_data[data] = {}
                            #pdb.set_trace()
                        if epoch not in pdf_data[data]:
                            pdf_data[data][epoch] ={}
                            #pdb.set_trace()
                        if method not in pdf_data[data][epoch]:
                            pdf_data[data][epoch][method] = {}
                            #pdb.set_trace()
                        pdf_data[data][epoch][method][task_name] = task[task_name]
                except Exception as e:
                    print(e)
                    pdb.set_trace()
                    continue
#pdb.set_trace()
#lists = {'list1': 'red', 'list2': 'blue', 'list3': 'green', 'list4': 'purple', 'list5': 'orange', 'list6': 'black'}
colors = ['red','blue','green','purple','orange','black','brown','gold','grey']
marker = ['o','s','^','v','<','>','x','+']

fig, axes = plt.subplots(2, len(dataset),figsize=(25, 10),sharex=True, sharey='col')
methods_s = [['concat','cca','kcca','dcca','mvtcae'],['concat','cca','kcca','dgcca','mvtcae']]
for f,methods in enumerate(methods_s):
    for k in range(len(dataset)):
        data = dataset[k]
        data_name,num_view = data.split('_')[0], data.split('_')[1]
        for i,method in enumerate(methods):
            t_methods = [method]
            if 'cca' in method and 'd' in method:
                t_methods.append(method+'e')
                
                t_methods.append(method+'_private')
                t_methods.append(method+'_with_noise')
            #pdb.set_trace()
            for j,method in enumerate(t_methods):
                if method in ['kcca']:
                    continue
            #pdb.set_trace()
                y_mean = []
                
                for epoch in epochs:
                    #pdb.set_trace()
                    if method not in pdf_data[data][epoch]:

                        try:
                            y_mean.append(pdf_data[data][epochs[0]][method]['0%'][0])
                        except:
                            y_mean.append(pdf_data[data][epochs[-1]][method]['0%'][0])
                    else:
                        y_mean.append(pdf_data[data][epoch][method]['0%'][0])
                    
                    #y_mean.append(pdf_data[epochs[0]][method][common][3])
                x = np.array(epochs)
                y_mean = np.array(y_mean) 
                #pdb.set_trace()

                if 'cca' in method and 'd' in method:
                    if method=='dcca_with_noise':
                        method='NR-dcca (ours)'
                    if method=='dgcca_with_noise':
                        method='NR-dgcca (ours)'  
                    axes[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[f+1])
                    #print(method,i,j)
                    #axes[f][k].errorbar(x, y_mean, yerr = y_std , label=method, color=colors[i+j],marker=marker[f+1])
                elif method=='mvtcae':
                    #pdb.set_trace()
                    axes[f][k].plot(x, y_mean, label=method, color=colors[i+4],marker=marker[f+1])
                else:
                    axes[f][k].plot(x, y_mean, label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                    #axes[f][k].errorbar(x, y_mean, yerr = y_std,label=method, color=colors[i+j],marker=marker[0],linestyle='--')
                #axes[f][k].fill_between(x, y_mean-y_std, y_mean+y_std, color=colors[i+j], alpha=0.1)
        if f==0:
            axes[f][k].set_title('{}'.format(data_name+'({})'.format(num_view)))
        if k==0:
            axes[f][k].legend()
        axes[f][k].grid(linestyle= ':',linewidth=1)
       

axes[0][0].set_ylabel('R2_performance',fontsize=24)

fig.text(0.5, -0.05, 'Epoch', fontsize=24)
#ylabel.set_position((0.5, -0.1))
plt.xticks(epochs,epochs)

plt.tight_layout()

plt.savefig('./draw_pictures/Real_F1.png', dpi=300, bbox_inches='tight')
    
   
plt.close()