from matplotlib import pyplot as plt
import torch
import os
from utils.loadcolor import loadcolor
from matplotlib.ticker import ScalarFormatter

def getfige(plott, plotf, ploterr, f0, config_yaml, logpath, yname, errbound = 5e-4):
    savename = config_yaml["log"]["savename"]

    histf1 = torch.round((plotf[0]-f0).to('cpu') * 100) / 100; histt1 = plott[:len(histf1)].to('cpu')
    histf2 = torch.round((plotf[1]-f0).to('cpu') * 100) / 100; histt2 = plott[:len(histf2)].to('cpu')
    histf3 = torch.round((plotf[2]-f0).to('cpu') * 100) / 100; histt3 = plott[:len(histf3)].to('cpu')
    histf4 = torch.round((plotf[3]-f0).to('cpu') * 100) / 100; histt4 = plott[:len(histf4)].to('cpu')
    histf5 = torch.round((plotf[4]-f0).to('cpu') * 100) / 100; histt5 = plott[:len(histf5)].to('cpu')

    histerr1 = ploterr[0].squeeze(); histerr2 = ploterr[1].squeeze();
    histerr3 = ploterr[2].squeeze(); histerr4 = ploterr[3].squeeze();
    histerr5 = ploterr[4].squeeze()
    histf1[histerr1 > errbound] = histf1[0].clone()
    histf2[histerr2 > errbound] = histf2[0].clone()
    histf3[histerr3 > errbound] = histf3[0].clone()
    histf4[histerr4 > errbound] = histf4[0].clone()
    histf5[histerr5 > errbound] = histf5[0].clone()


    pcolor = loadcolor()
    plt.figure(figsize=(8, 6), dpi=80, facecolor='w')
    plt.subplots_adjust(left=0.15, right=0.9, bottom=0.15, top=0.9)

    plt.plot(histt1.detach().numpy(), histf1.detach().numpy(), '-', linewidth=5, markersize=3, color=pcolor['purple'].numpy(), label='J-JOBCD')
    plt.plot(histt2.detach().numpy(), histf2.detach().numpy(), '-', linewidth=5, markersize=3, color=pcolor['green'].numpy(), label='VR-J-JOBCD')
    plt.plot(histt3.detach().numpy(), histf3.detach().numpy(), '-', linewidth=5, markersize=3, color=pcolor['blue'].numpy(), label='CSDM')
    plt.plot(histt4.detach().numpy(), histf4.detach().numpy(), '-.', linewidth=5, markersize=3, color=pcolor['gray'].numpy(), label='ADMM')
    plt.plot(histt5.detach().numpy(), histf5.detach().numpy(), '--', linewidth=5, markersize=3, color=pcolor['red'].numpy(), label='UMCM')

    plt.gca().yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
    plt.gca().ticklabel_format(style='sci', axis='y', scilimits=(0, 0), useOffset=False, useLocale=False,
                               useMathText=True)
    plt.gca().yaxis.get_offset_text().set_size('x-small')
    plt.gca().yaxis.get_major_formatter().set_powerlimits((0, 1))
    plt.gca().yaxis.get_major_formatter().set_scientific(True)
    plt.gca().yaxis.get_major_formatter().set_useOffset(False)
    plt.tick_params(axis='y', which='major', labelsize=10)

    hleg = plt.legend(fontsize=15,frameon=True, framealpha=1)
    hleg.set_title('', prop={'size': 15, 'weight': 'normal', 'family': 'times new Roman'})
    hleg.get_title().set_fontsize('15')

    plt.xlabel('Epoch', fontsize=20)
    plt.ylabel(yname, fontsize=20)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)

    plt.grid(True)
    # plt.axis([min(tss), max(tss), 0, torch.max(new_fs).detach().numpy()])

    plt.savefig(os.path.join(logpath, yname+'e{}.png'.format(savename)), format='png', dpi=300)
    plt.savefig(os.path.join(logpath, yname+'e{}.eps'.format(savename)), format='eps', dpi=300)

    plt.show()
