import numpy as np
import os
import random

import matplotlib.pyplot as plt
from sklearn import manifold
import seaborn as sns


datapath="/home/zll/code/offMARL/buffer/room5_from_start/buffer_12008/"
def test():
    messages=np.load(datapath+"messages.npy")
    obs=np.load(datapath+"obs.npy")
    states=np.load(datapath+"state.npy")
    print(messages.shape,obs.shape,states.shape)

    message=messages[3,1,1,:]
    ob=obs[3,1,1,:]
    state=states[3,1,:]

    featurepath="/home/zll/code/offMARL/src/features.npy"
    feature=np.load(featurepath)
    print("observations",ob)

    print(feature.shape)






def visualization():
    message=np.load(datapath+'messages.npy')[0,0,0,:]
    message=message.reshape((8,3))
    print(message.shape)

    actions=np.load(datapath+'actions.npy', allow_pickle=True)[0,0,0,:]

    goal=actions[0]

    print(message.shape,actions.shape,goal)
    

    
    
    
    # print(imitation_generated[0].shape)
    # print(imitation_oracle[0].shape)

    # imitation_generated=np.random.choice(imitation_generated,150,replace=False)
    # imitation_oracle=np.random.choice(imitation_oracle,150,replace=False)




    # no_imitation_generated=np.load("./radar2/expert/no_imitation/generated_message.npy", allow_pickle=True)
    # no_imitation_oracle=np.load("./radar2/expert/no_imitation/oracle_message.npy",allow_pickle=True)
    # # print(no_imitation_generated.shape,no_imitation_generated[0].shape)
    # # print(no_imitation_oracle.shape,no_imitation_oracle[-1].shape)

    # # no_imitation_generated=np.random.choice(no_imitation_generated,150,replace=False)
    # # no_imitation_oracle=np.random.choice(no_imitation_oracle,150,replace=False)

    # imitation_generated=np.concatenate(imitation_generated,axis=0)[:,4,:]
    # # print(imitation_generated.shape,"imitation_generated.shape")
    # imitation_oracle=np.concatenate(imitation_oracle,axis=0)[:,4,:]
    # no_imitation_generated=np.concatenate(no_imitation_generated,axis=0)[:,4,:]
    # no_imitation_oracle=np.concatenate(no_imitation_oracle,axis=0)[:,4,:]

    # imitation_generated=np.concatenate(imitation_generated,axis=0)
    # imitation_oracle=np.concatenate(imitation_oracle,axis=0)
    # no_imitation_generated=np.concatenate(no_imitation_generated,axis=0)
    # no_imitation_oracle=np.concatenate(no_imitation_oracle,axis=0)

   




    plt.figure(figsize=(10, 5)) 


    tsne = manifold.TSNE(n_components=2, init='pca', random_state=501)
    message_tsne = tsne.fit_transform(message)
  



    x_axis = message_tsne[:, 0]
    y_axis = message_tsne[:, 1]

    plt.scatter(x_axis, y_axis, c="r")
    plt.savefig('./0510_expert_tsne-scratch.png', dpi=120)
    plt.show()

    # plt.figure(figsize=(10, 5)) 

    # test=no_imitation_oracle[10000:20000,:]
    # tsne = manifold.TSNE(n_components=2, init='pca', random_state=501)
    # imitation_generated_tsne = tsne.fit_transform(test)
  



    # x_axis = imitation_generated_tsne[:, 0]
    # y_axis = imitation_generated_tsne[:, 1]

    # plt.scatter(x_axis, y_axis, c="b")
    # plt.savefig('./pictures/0427_expert_tsne-no_imitation_oracle_agent5.png', dpi=120)
    # plt.show()




def visualize_loss():
    
    loss=np.load("./radar2/expert/message_loss/loss.npy", allow_pickle=True)
    message=np.load("./radar2/expert/message_loss/message.npy", allow_pickle=True)
    
   
    loss=np.concatenate(loss,axis=0)
    message=np.concatenate(message,axis=0)

    loss=np.concatenate(loss,axis=0).sum(axis=1)
    message=np.concatenate(message,axis=0)

    
    print(message.shape,message[0].shape)

    
    print(loss.shape)
    
    Z=zip(loss,message)
    Z=sorted(Z,reverse=True)
    loss_new, message_new=zip(*Z)
    color=[]
    print(message_new[0].shape)
    # colordict=[1,20,50,100,256]
    # for i in range(len(message_new)):
    #     tmp=[colordict[i]]*message_new[i].shape[0]
    #     print(tmp)
    #     color+=tmp


    # message_new=np.concatenate(message_new,axis=0)
    # print(len(loss))
    # print(message_new.shape,len(color))
    
    loss_distribution=np.concatenate(loss,axis=0)
    loss=sorted(loss,reverse=True)
    print("loss",loss)
    sns.set()                                  
    f = plt.figure()   
    f.add_subplot(1,2,1)
    sns.distplot(loss, kde=False)                 #绘制频数直方图
    plt.ylabel("frequency", fontsize=16)
    plt.xticks(fontsize=16)                    #设置x轴刻度值的字体大小
    plt.yticks(fontsize=16)                   #设置y轴刻度值的字体大小
    plt.title("(a)", fontsize=20)             #设置子图标题

    f.add_subplot(1,2,2)
    sns.distplot(loss)                           #绘制密度直方图
    plt.ylabel("probaility", fontsize=16)
    plt.xticks(fontsize=16)                  #设置x轴刻度值的字体大小
    plt.yticks(fontsize=16)                  #设置y轴刻度值的字体大小
    plt.title("(b)", fontsize=20)            #设置子图标题

    plt.subplots_adjust(wspace=1)         #调整两幅子图的间距
    plt.savefig('./pictures/0427_loss_distribution_all.png', dpi=120)
    plt.show()


    tsne(message_new[:10000],color[:10000])

def tsne(test,color):
    fig, ax = plt.subplots()
    tsne = manifold.TSNE(n_components=2, init='pca', random_state=501)
    test_data = tsne.fit_transform(test)
    # x_min, x_max = imitation_generated_tsne.min(0), imitation_generated_tsne.max(0)
    # X_norm = (imitation_generated_tsne - x_min) / (x_max - x_min)  # 归一化


    # Plotting 2d t-Sne
    x_axis = test_data[:, 0]
    y_axis = test_data[:, 1]

    scatter=ax.scatter(x_axis, y_axis, c=color)
    legend1 = ax.legend(*scatter.legend_elements(),loc="lower left", title="Classes")
    ax.add_artist(legend1)
    plt.savefig('./pictures/0427tsne-top10000.png', dpi=120)
    plt.show()

visualization()
# visualization()