import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
import numpy as np
from sklearn.manifold import TSNE

np.random.seed(42)
from src.models.exist_models import *
from src.models.mae import *

# ===== Preprocess =============================================
def load_model():
    # mae
    physio_model = MaskedAutoencoderViT(img_size=(387,65), patch_size=(9,5),mask_scheme='random',mask_prob=0.8,use_cwt=True,nvar=4, comb_freq=True)
    stat_dict = torch.load('../data/results/model_mae_checkpoint-140.pth', map_location=torch.device('cpu'))['model']
    physio_model.load_state_dict(stat_dict)
    print("Model load successfull.")

    # # raw vit
    # physio_model = ViTAdjust()

    physio_model = physio_model.to(torch.bfloat16).to(DEVICE)
    physio_model.eval()
    return physio_model

def get_samples_labels(physio_model, ds_name="PPG_HTN", sample_size=100, channel_id=0, label='PPG'):
    curr_samples, curr_labels, curr_hiddens = list(), list(), list()
    fns = [fn for fn in sorted(os.listdir("../data/{}/samples".format(ds_name))) if fn[0] != '.']
    fns = np.random.choice(fns, sample_size, replace=False)
    for fn in tqdm(fns):
        with open("../data/{}/samples/{}".format(ds_name, fn), 'rb') as f:
            data = pickle.load(f)
        curr_labels.append(label)
        
        # get embed
        with torch.no_grad():
            # out = physio_model(
            #     torch.from_numpy(data['cwt']).float().permute(0, 3, 1, 2).to(DEVICE),
            #     hidden_out=hidden_out
            # ) # (C, L, 768)

            # forward
            out = physio_model.forward_all(
                torch.from_numpy(data['cwt']).to(torch.bfloat16).permute(0, 3, 1, 2).unsqueeze(0).to(DEVICE),
                hidden_out=True
            )

            # unpack
            physio_out, hiddens = out # hidden: (12, N(1), C, L, 768)
            physio_out = physio_out[0] # (C, L, 768)
            curr_hiddens.append([h[0, channel_id, 1:, :].cpu().float().numpy().astype(np.float16).tolist() for h in hiddens]) # (12, L, 768)
            
            physio_out = physio_out.cpu().float().numpy().astype(np.float16)
            cls_embed = physio_out[channel_id, 0, :].tolist() # 768
            curr_samples.append(cls_embed)
#     curr_samples = np.array(curr_samples).astype(np.float16)
    
    return curr_samples, curr_labels, curr_hiddens

def process_and_save():
    physio_model = load_model()

    sample_size = 500

    settings = [ #(ds_name, channel_id, label)
        ("PPG_HTN", 0, 'PPG'), 
        ("ecg_heart_cat", 0, "ECG"),
        ("wesad", 4, "GSR"),
        ("gameemo", 0, "EEG_F"),
        ("uci_har", 0, "ACC_X"),
        # newly added for cls visual
        ("uci_har", 1, "ACC_Y"),
        ("uci_har", 2, "ACC_Z"),
        ("gameemo", 1, "EEG_O"),
        ("gameemo", 2, "EEG_L"),
        ("gameemo", 3, "EEG_R"),
    ]

    # forward for all settings
    all_outputs = list()
    for s in tqdm(settings):
        # get intermediate output
        curr_samples, curr_labels, curr_hiddens = get_samples_labels(
            physio_model, 
            ds_name=s[0], 
            sample_size=sample_size,
            channel_id=s[1], 
            label=s[2]
        ) # curr_samples are cls

        #
        curr_hiddens = np.transpose(curr_hiddens, (0, 1, 3, 2)).astype(np.float16) # N, 12, 768, L
        curr_samples = np.array(curr_samples).astype(np.float16)

        all_outputs.append((curr_hiddens, curr_samples))
    
    # save
    with open("../data/processed_samples_for_nld_mae.pkl", 'wb') as f:
        pickle.dump(all_outputs, f)


# ===== NLD =============================================
import nolds

from gtda.homology import VietorisRipsPersistence
from gtda.diagrams import PersistenceEntropy
from gtda.time_series import TakensEmbedding

def calc_pe(wf):
    TE = TakensEmbedding(time_delay=1, dimension=5)
    VR = VietorisRipsPersistence()
    point_clouds = TE.fit_transform(wf.reshape(1, -1))
    diagrams = VR.fit_transform(point_clouds)
    PE = PersistenceEntropy()
    features = PE.fit_transform(diagrams)
    return features[0]

def nld_f(wf):
    # DFA
    dfa_e = nolds.dfa(wf)
    
    # persistence entropy
    pe_e = calc_pe(wf)
    
    # Lyapunov exponent
    lya_e = nolds.lyap_r(
        wf,
        emb_dim=2,
        min_tsep=1,
        lag=1,
        tau=1,
        trajectory_len=8, # 64
        # fit="RANSAC"
    )
    
    # return
    return [dfa_e, pe_e[0], pe_e[1], lya_e]
    # return [dfa_e, pe_e[0], pe_e[1]]

def check_nld_pattern(hiddens_all, first_n=20, sample_n=10):
    all_nlds = list()
    hiddens_all = [h for h in hiddens_all if not np.isnan(np.sum(h))]
    for hiddens in tqdm(hiddens_all[:first_n]):
        curr_nlds = list()
        for h in hiddens:
            chosen_idx = np.random.choice(np.arange(len(h)), sample_n, replace=False) # sample n of waveforms

            # calculate nld
            nlds = [nld_f(w) for w in h[chosen_idx]] # 768 or len(chosen_idx), 2
            
            # aggregate
            curr_nlds.append(np.mean(nlds, axis=0))
            
        all_nlds.append(curr_nlds)
        # break # for test
    return np.array(all_nlds) # (N, 12, 2)

def check_all_nld_and_plot():
    # # load preprocessed files
    # with open("data/processed_files/processed_samples_for_nld_mae.pkl", 'rb') as f:
    #     all_outputs = pickle.load(f)[:5]
    
    # # main nld
    # all_nlds = list()
    # for output in all_outputs:
    #     all_nlds.append(check_nld_pattern(
    #         np.array(output[0]).astype(float),
    #         first_n=100, 
    #         sample_n=50
    #     ))
    # with open("data/processed_files/processed_nld_for_nld_mae.pkl", 'wb') as f:
    #     pickle.dump(all_nlds, f)
    
    # read the already calculated nld
    with open("data/processed_files/processed_nld_for_nld_mae.pkl", 'rb') as f:
        all_nlds = pickle.load(f)

    # plot
    signal_names = ["PPG", "ECG", "GSR", "EEG_F", "ACC_X"]
    nld_names = ["DFA", "Persistence Entropy H1", "Persistence Entropy H0", "Lyapunov Exponent"]
    # nld_names = ["DFA", "Persistence\nEntropy H1", "Persistence\nEntropy H0"]

    plt.clf()
    plt.style.use('seaborn-darkgrid')
    fig, axs = plt.subplots(nrows=1, ncols=len(nld_names), figsize=(20, 4))

    for i in range(len(nld_names)):
        for j in range(len(all_nlds)):
            curr_x = all_nlds[j][:, :, i]

            curr_x = (curr_x - np.min(curr_x)) / (np.max(curr_x) - np.min(curr_x))
            # curr_x = (curr_x - np.mean(curr_x)) / np.std(curr_x)

            axs[i].errorbar(np.arange(12), np.mean(curr_x, axis=0), yerr=np.std(curr_x, axis=0)**2, capsize=3, fmt="-o", label=signal_names[j])
            # print(np.mean(curr_x, axis=0))
            # axs[i].plot(np.mean(curr_x, axis=0), label=signal_names[j])
            # if i == len(nld_names) - 1:
            # if i == 2:
            #     axs[i].legend(fontsize=22,loc='upper center', bbox_to_anchor=(1.0, -0.2), fancybox=True, shadow=True, ncol=5)
            axs[i].set_xlabel("Layers", fontsize=25)
            # axs[i].set_ylabel(nld_names[i], fontsize=22)
            axs[i].set_title(nld_names[i], fontsize=25)
    fig.tight_layout()
    plt.legend(
        fontsize=20, 
        loc='upper center', 
        bbox_to_anchor=(-1.2, -0.11), 
        fancybox=True, 
        shadow=True, 
        ncol=5,
        frameon=True
    )
    # plt.savefig("figures/nld_analysis_small_mae.pdf", format="pdf", bbox_inches="tight")
    plt.show()

# ===== Main =============================================
if __name__ == '__main__':
    # # preprocess to get all required tensors
    # process_and_save()
    
    # plot
    check_all_nld_and_plot()