import matplotlib.pyplot as plt
import numpy as np

import pdb

def parse_file(file_name):
    task = {}
    try:
        with open(file_name,'r') as f:
            for i,line in enumerate(f):
                if i==7: # full task
                    line = line.strip().split(',')
                    line = [float(x) for x in line]
                    for j in range(len(line)):
                        task['{}%'.format(j*20)] = [line[j]]
                    # task['full_task_case_2'] = line[1]
                    # task['full_task_case_3'] = line[2]
                # elif i ==9:
                #     line = line.strip().split(',')
                #     line = [float(x) for x in line]
                #     #for j in range(len(line)):
                #     for j in range(len(line)):
                #         task['{}%'.format(j*20)].append(line[j])
                
                # elif i ==12:
                #     line = line.strip().split(',')
                #     line = [float(x) for x in line]
                #     #for j in range(len(line)):
                #     for j in range(len(line)):
                #         task['{}%'.format(j*20)].append(abs(line[j]))
                # elif i ==17:
                #     line = line.strip().split(',')
                #     line = [float(x) for x in line]
                #     #for j in range(len(line)):
                #     for j in range(len(line)):
                #         task['{}%'.format(j*20)].append(abs(line[j]))
                    # task['half_task_case_1'] = line[0]
                    # task['half_task_case_2'] = line[1]
                    # task['half_task_case_3'] = line[2]
                else:
                    continue
        return True,task



    except:
        return False,None




# get list

methods = ['0.0','1.0','1.5','5.0','15.0']

methods = ['0.0','0.5','1.0','1.5']



dataset = ['CUB_2_0'] # 25 150


epochs = [100,200,300,400,500]
num_views =[2]


pdf_data = {}
#method_data = {}
plt.style.use('seaborn-poster')

for data in dataset:
    for method in methods:
       
        for epoch in epochs:
            data_name,num_view = data.split('_')[0], data.split('_')[1]
            
            if data_name=='PolyMnist':
                file_name = "./Polymnist_output/{}_{}_{}_{}.csv".format(method,True,num_view,epoch)
            else:
                if method=='0.0':
                    file_name = "./Cpm_output_re/{}_{}_{}_{}_{}_a_{}.csv".format(data,'dcca_with_noise',True,num_view,epoch,method)
                else:
                    file_name = "./Cpm_output_re/{}_{}_{}_{}_{}_a_{}_zero.csv".format(data,'dcca_with_noise',True,num_view,epoch,method)
                
            flag,task = parse_file(file_name)
                #pdb.set_trace()
            # if method == 'dcca_dis':
            #     pdb.set_trace()
            
            if not  flag:
                print('no',file_name)
                    #pdb.set_trace()
                continue
                
            try:
                for task_name in task:
                        
                    print(data)
                    if data not in pdf_data:
                        pdf_data[data] = {}
                        #pdb.set_trace()
                    if epoch not in pdf_data[data]:
                        pdf_data[data][epoch] ={}
                        #pdb.set_trace()
                    if method not in pdf_data[data][epoch]:
                        pdf_data[data][epoch][method] = {}
                        #pdb.set_trace()
                    pdf_data[data][epoch][method][task_name] = task[task_name]
            except Exception as e:
                print(e)
                pdb.set_trace()
                continue
#pdb.set_trace()
#lists = {'list1': 'red', 'list2': 'blue', 'list3': 'green', 'list4': 'purple', 'list5': 'orange', 'list6': 'black'}
colors = ['red','blue','green','purple','orange','black','brown','gold','grey']
marker = ['o','s','^','v','<','>','x','+']



fig, axes = plt.subplots(1,1,figsize=(36, 10),sharex=True)
x = np.array(epochs)
for i,method in enumerate(methods):

    performance = []
    for j in epochs:
        p = pdf_data[dataset[0]][j][method]['0%'][0]
        performance.append(p)
        #pdb.set_trace()
    
    #x= (i+j-1)*10+x
    if method=='0.0':
        axes.plot(x, performance, label='DCCA'.format(method), color=colors[i],marker=marker[4])
    else:
        axes.plot(x, performance, label='NR-DCCA (zero) with α: {}'.format(method), color=colors[i],marker=marker[4])
    #axes[1].plot(x, cor[i], label=method, color=colors[i],marker=marker[3])
            
    
    
axes.legend(loc ='lower left')
 
#plt.tight_layout()

axes.set_ylabel('F1_performance',fontsize=24)

fig.text(0.5, -0.05, 'Epoch', fontsize=24)
plt.xticks(epochs)
axes.grid(linestyle= ':',linewidth=1)

plt.savefig('./draw_pictures/hyer_tune_in_cub (zero).png', dpi=300, bbox_inches='tight')

plt.close()