import warnings
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
import gc
import PIL.Image as Image
# import cv2
import os
import yaml
warnings.filterwarnings("ignore")
matplotlib.use('Agg')

Leftp = 0.18
Bottomp = 0.18
Widthp = 0.88 - Leftp
Heightp = 0.9 - Bottomp
pos = [Leftp, Bottomp, Widthp, Heightp]


def save_fig(pltm, fntmp, fp=0, ax=0, isax=0, iseps=0, isShowPic=0):  # Save the figure
    if isax == 1:
        pltm.rc('xtick', labelsize=18)
        pltm.rc('ytick', labelsize=10)
        ax.set_position(pos, which='both')
    fnm = '%s.png' % (fntmp)
    pltm.savefig(fnm)
    if iseps:
        fnm = '%s.eps' % (fntmp)
        pltm.savefig(fnm, format='eps', dpi=600)
    if fp != 0:
        fp.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
    if isShowPic == 1:
        pltm.show()
    elif isShowPic == -1:
        return
    else:
        pltm.close()


def plot_loss(path, R, x_log=False):
    plt.figure()
    ax = plt.gca()
    y2 = np.asarray(R['loss_train'])
    plt.plot(y2, 'k-', label='Train')
    if x_log:
        ax.set_xscale('log')
    ax.set_yscale('log')
    # plt.legend(fontsize=18)
    plt.title('loss', fontsize=15)
    if x_log == False:
        fntmp = '%sloss_log' % (path)
    else:
        fntmp = '%sloss' % (path)
    save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)


def plot_model_output(path, args, argsy,epoch):
    if args.input_dim == 1:

        plt.figure()
        ax = plt.gca()

        plt.plot(args.train_inputs.detach().cpu().numpy(),
                args.train_targets.detach().cpu().numpy(), 'b*', label='True')
        plt.plot(args.test_inputs.detach().cpu().numpy(),
                argsy['test_outputs'][-1], 'r-', label='Test')
        plt.title('g2u', fontsize=15)
        plt.legend(fontsize=18)
        fntmp = '%soutput/%s' % (path,epoch)
        save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)
    # else:
    #     print('input')


def plot_eig_vs_var(path, var, eig, epoch, y_log=True, x_log=True):
    plt.figure()
    ax = plt.gca()
    plt.scatter(abs(eig), abs(np.array(var)))
    if y_log:
        ax.set_yscale('log')
    ax.set_ylim((1e-30, 10))
    if x_log:
        ax.set_xscale('log')
    ax.set_xlim((1e-5, 10))
    plt.xlabel('eigenvalue for hessian')
    plt.ylabel('variance')
    plt.title('eigenvalue v.s. variance', fontsize=15)
    plt.legend(fontsize=18)
    fntmp = '%seig_vs_varlog%s' % (path, epoch)
    save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)

    plt.figure()
    ax = plt.gca()
    plt.scatter(abs(eig), abs(np.array(var)))
    eig_log = np.log10(abs(eig))
    var_log = np.log10(var)
    index = np.argsort(eig_log)[::-1][:4]
    coe = np.polyfit(eig_log[index], var_log[index], 1)
    plt.xlabel('eigenvalue for hessian')
    plt.ylabel('variance')
    plt.title('eigenvalue v.s. variance %.3f' % (coe[0]), fontsize=15)
    plt.legend(fontsize=18)
    fntmp = '%seig_vs_var%s' % (path, epoch)
    save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)


def plot_eig_vs_mean(path, mean, eig, epoch, y_log=True, x_log=True):
    plt.figure()
    ax = plt.gca()
    plt.scatter(abs(eig), mean)
    if y_log:
        ax.set_yscale('log')
    ax.set_ylim((1e-30, 10))
    if x_log:
        ax.set_xscale('log')
    ax.set_xlim((1e-5, 10))
    plt.xlabel('eigenvalue for hessian')
    plt.ylabel('mean')
    plt.title('eigenvalue v.s. mean', fontsize=15)
    plt.legend(fontsize=18)
    fntmp = '%seig_vs_meanlog%s' % (path, epoch)
    save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)

    plt.figure()
    ax = plt.gca()
    plt.scatter(eig, mean)
    plt.xlabel('eigenvalue for hessian')
    plt.ylabel('mean')
    plt.title('eigenvalue v.s. mean', fontsize=15)
    plt.legend(fontsize=18)
    fntmp = '%seig_vs_mean%s' % (path, epoch)
    save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)


def plot_ori_A_trajectory(path, m, k, ori, A):
    fp = plt.figure()
    ax1 = plt.subplot(111, projection='polar')
    ax1.set_ylim(0, 1.1)
    for i in range(m):
        # if i % 10==0:
        #     print(i)
        line = ax1.plot(ori[i, :k+1], A[i, :k+1]**(0.1),
                        '-', lw=0.5, color='cyan', zorder=1)

        # ax1.scatter(ori[i,-1], A[i,-1],s=10)
    sca = ax1.scatter(ori[:, k], A[:, k]**(0.1), color='r', zorder=2, s=10)
    plt.savefig(
        '/home/xxx/data/saddle_points/test96_retrain/2.0/200/101237/pic/%s.png' % (k))
    fp.clf()
    plt.close()
    gc.collect()


def plot_loss_one(path, loss, k, x_log=False):
    plt.figure()
    ax = plt.gca()
    y2 = np.asarray(loss)
    plt.plot(y2, 'k-', label='Train')
    plt.plot(k, loss[k], 'bo')
    if x_log:
        ax.set_xscale('log')
    ax.set_yscale('log')
    # plt.legend(fontsize=18)
    plt.title('loss', fontsize=15)
    if x_log == False:
        fntmp = '%s%s.png' % (path, k)
    else:
        fntmp = '%s%s.png' % (path, k)
    plt.savefig(fntmp)
    plt.clf()
    plt.close()
    gc.collect()


def concen_pic(save_path, image_column, image_row, path1, path2, i, weigh=640, height=480):
    to_image = Image.new('RGB', (image_column * weigh,
                         image_row * height))  # 创建一个新图
    # 循环遍历，把每张图片按顺序粘贴到对应位置上
    # for y in range(1, IMAGE_ROW + 1):
    #     for x in range(1, IMAGE_COLUMN + 1):
    from_image_1 = Image.open('%s%s.png' % (path1, i))
    from_image_2 = Image.open('%s%s.png' % (path2, i))
    # from_image_3 = Image.open('%s%s.png' % (path2, i))
    # from_image_4 = Image.open('%s%s.png' % (path4, i))
    to_image.paste(from_image_1, (0, 0))
    to_image.paste(from_image_2, (0, height))
    # to_image.paste(from_image_3, (weigh, 0))
    # to_image.paste(from_image_4, (weigh, height))
    return to_image.save("%s%s.png" % (save_path, i))  # 保存新图


# def images_to_video(save_path, video_folder, rep=5, result_filename=None):

    # if result_filename is None:
    #     result_filename = "{}.avi".format(save_path)
    # images_name = {int(os.path.splitext(f)[0]): os.path.join(
    #     video_folder, f) for f in os.listdir(video_folder)}
    # img = cv2.imread(images_name[0])
    # height, width, layers = img.shape
    # print(height)
    # print(width)
    # four_cc = cv2.VideoWriter_fourcc(*"XVID")  # avi
    # video = cv2.VideoWriter(result_filename, four_cc, 25, (width, height))
    # print(len(images_name))
    # for i in range( len(images_name)):
    #     for j in range(rep):
    #         img = cv2.imread(images_name[int(5*i)])
    #         video.write(img)
    #     if i % 100 == 0:
    #         print("Done {}%".format((i*100)/len(images_name)))
    # cv2.destroyAllWindows()
    # video.release()
    # print("Done!")
    # return None


def plot_sigma_F(path):
    lst=os.listdir(path)
    plt.figure()
    ax = plt.gca()
    for i in lst:
        if i.startswith('p'):
            continue
        yamlPath = '%s%s/code/config/config.yaml'%(path,i)
        with open(yamlPath, 'r', encoding='utf-8') as f:
            config = f.read()
        d = yaml.load(config,Loader=yaml.FullLoader)
        sigma=np.loadtxt('%s%s/sigma.txt'%(path,i))
        theta_nage=np.loadtxt('%s%s/theta_nage.txt'%(path,i))
        theta_posi=np.loadtxt('%s%s/theta_posi.txt'%(path,i))
        if int(d['training_size'])>int(d['training_batch_size']):
            kind='SGD'
            plt.scatter(10**(theta_nage)+10**(theta_posi),abs(sigma/10000),label='batch size:%s, lr:%s'%(d['training_batch_size'],d['lr']))
        elif d['dropout']:
            kind='dropout'
            plt.scatter(10**(theta_nage)+10**(theta_posi),abs(sigma/10000),label='dropout proportion:%s, lr:%s'%(d['dropout_pro'],d['lr']))
        elif d['add_tru_on_weight']:
            kind='add tru on weight'
            plt.scatter(10**(theta_nage)+10**(theta_posi),abs(sigma/10000),label='turblence:%s, lr:%s'%(d['turblence'],d['lr']))
        elif d['add_tru_on_grad']:
            kind='add tru on grad'
            plt.scatter(10**(theta_nage)+10**(theta_posi),abs(sigma/10000),label='turblence:%s, lr:%s'%(d['turblence'],d['lr']))
        else:
            kind='GD'
            plt.scatter(10**(theta_nage)+10**(theta_posi),abs(sigma/10000),label='training size:%s, lr:%s'%(d['training_size'],d['lr']))
    # plt.plot([10**(-1),10**(0.4)],[10**(-8),10**(-13.6)],'--')
    plt.title(kind+', bias:%s, no training after the selected point:%s'%(d['bias'], d['pca_with_no_training']))
    plt.legend()
    # x_vals = np.array(ax.get_xlim()) 
    # print(x_vals)
    # y_vals = np.array(ax.get_ylim()) 
    # print(y_vals)
    # y_vals_1 = 10**(-14+ -4* np.log10(x_vals) )
    # print(y_vals_1)
    # plt.plot([10**(-0.9),10**(0.1)],[1e-6,1e-10] , '--') 
    # plt.xlim(x_vals)
    # plt.ylim(y_vals)
    plt.xscale('log')
    plt.yscale('log')
    plt.savefig('%spic.png'%(path))


def plot_loss_landscape(path,theta, loss_all, index):
    plt.figure()
    ax = plt.gca()
    plt.plot(theta,loss_all,label='%s'%(index))
    plt.yscale('log')
    plt.savefig('%s%s.png'%(path,index))

def plot_several_loss_landscape(path,alpha, loss_all):
    plt.figure()
    ax = plt.gca()
    for ind,loss_lst in enumerate (loss_all):
        plt.plot(alpha+ind,loss_lst)
    plt.yscale('log')
    plt.savefig('%sloss_landscape.png'%(path))


def plot_cov_hessian(path):
    lst=os.listdir(path)
    plt.figure()
    ax = plt.gca()
    for i in lst:
        if i.startswith('p'):
            continue
        yamlPath = '%s%s/code/config/config.yaml'%(path,i)
        with open(yamlPath, 'r', encoding='utf-8') as f:
            config = f.read()
        d = yaml.load(config,Loader=yaml.FullLoader)
        iden_trace=np.loadtxt('%s%s/iden_trace.txt'%(path,i))
        ini_trace=np.loadtxt('%s%s/ini_trace.txt'%(path,i))
        # index=np.arange(0,len(ini_trace))*5
        plt.plot(ini_trace[:30],'-')
        plt.plot(iden_trace[:30],'--')
    plt.yscale('log')
    plt.savefig('%spic.png'%(path))
    

