import matplotlib.pyplot as plt
import nolds

from gtda.homology import VietorisRipsPersistence
from gtda.diagrams import PersistenceEntropy
from gtda.time_series import TakensEmbedding

from src.models.exist_models import *
physio_model = ViTAdjust()
physio_model.eval()

np.random.seed(42)
def get_samples_labels(ds_name="PPG_HTN", sample_size=100, channel_id=0, label='PPG', hidden_out=False):
    curr_samples, curr_labels = list(), list()
    if hidden_out:
        curr_hiddens = list()
    fns = [fn for fn in sorted(os.listdir("data/{}/samples".format(ds_name))) if fn[0] != '.']
    fns = np.random.choice(fns, sample_size)
    for fn in tqdm(fns):
        with open("data/{}/samples/{}".format(ds_name, fn), 'rb') as f:
            data = pickle.load(f)
        curr_labels.append(label)
        
        # get embed
        with torch.no_grad():
            out = physio_model(
                torch.from_numpy(data['cwt']).float().permute(0, 3, 1, 2).to(DEVICE),
                hidden_out=hidden_out
            ) # (C, L, 768)

            if hidden_out:
                physio_out, hiddens = out
                curr_hiddens.append([h[0] for h in hiddens])
            
            physio_out = physio_out.cpu().numpy().astype(np.float16)
            cls_embed = physio_out[channel_id, 0, :].tolist()
            curr_samples.append(cls_embed)
#     curr_samples = np.array(curr_samples).astype(np.float16)
    
    # return
    if hidden_out:
        return curr_samples, curr_labels, curr_hiddens
    return curr_samples, curr_labels

# extract samples
sample_size = 50

# get intermediate output
curr_samples, curr_labels, curr_hiddens = get_samples_labels(
    ds_name="PPG_HTN", 
    sample_size=sample_size,
    channel_id=0, 
    label='PPG',
    hidden_out=True
)
curr_hiddens_ppg = np.transpose(curr_hiddens, (0, 1, 3, 2))
# curr_hiddens_ppg.shape

# ECG
curr_samples, curr_labels, curr_hiddens = get_samples_labels(
    ds_name="ecg_heart_cat", 
    sample_size=sample_size,
    channel_id=0, 
    label='ECG',
    hidden_out=True
)
curr_hiddens_ecg = np.transpose(curr_hiddens, (0, 1, 3, 2))

# GSR
curr_samples, curr_labels, curr_hiddens = get_samples_labels(
    ds_name="wesad", 
    sample_size=sample_size,
    channel_id=4, 
    label='GSR',
    hidden_out=True
)
curr_hiddens_gsr = np.transpose(curr_hiddens, (0, 1, 3, 2))

# EEG_F
curr_samples, curr_labels, curr_hiddens = get_samples_labels(
    ds_name="gameemo", 
    sample_size=sample_size,
    channel_id=0, 
    label='EEG_F',
    hidden_out=True
)
curr_hiddens_eegf = np.transpose(curr_hiddens, (0, 1, 3, 2))
# curr_hiddens_ppg.shape

# ACC_X
curr_samples, curr_labels, curr_hiddens = get_samples_labels(
    ds_name="gameemo", 
    sample_size=sample_size,
    channel_id=0, 
    label='EEG_F',
    hidden_out=True
)
curr_hiddens_accx = np.transpose(curr_hiddens, (0, 1, 3, 2))

# nld calculation

def calc_pe(wf):
    TE = TakensEmbedding(time_delay=1, dimension=5)
    VR = VietorisRipsPersistence()
    point_clouds = TE.fit_transform(wf.reshape(1, -1))
    diagrams = VR.fit_transform(point_clouds)
    PE = PersistenceEntropy()
    features = PE.fit_transform(diagrams)
    return features[0]

def nld_f(wf):
    # DFA
    dfa_e = nolds.dfa(wf)
    
    # persistence entropy
    pe_e = calc_pe(wf)
    
    # Lyapunov exponent
    lya_e = nolds.lyap_r(
        wf,
        emb_dim=3,
        min_tsep=1,
        lag=1,
        tau=1,
        trajectory_len=64,
        fit="RANSAC"
    )
    
    # return
    return [dfa_e, pe_e[0], pe_e[1], lya_e]

def check_nld_pattern(hiddens_all):
    all_nlds = list()
    for hiddens in tqdm(hiddens_all):
        curr_nlds = list()
        for h in hiddens:
            chosen_idx = np.random.choice(np.arange(len(h)), 10)

            # calculate nld
            nlds = [nld_f(w) for w in h[chosen_idx]] # 768 or len(chosen_idx), 2
            
            # aggregate
            curr_nlds.append(np.mean(nlds, axis=0))
            
        all_nlds.append(curr_nlds)
    return np.array(all_nlds) # (N, 12, 4)

with open("processed_samples_for_nld.pkl", 'wb') as f:
    pickle.dump([
        curr_hiddens_ppg,
        curr_hiddens_ecg,
        curr_hiddens_gsr,
        curr_hiddens_eegf,
        curr_hiddens_accx
    ], f)

all_nlds = [
    check_nld_pattern(curr_hiddens_ppg), # N, 12, 4
    check_nld_pattern(curr_hiddens_ecg),
    check_nld_pattern(curr_hiddens_gsr),
    check_nld_pattern(curr_hiddens_eegf),
    check_nld_pattern(curr_hiddens_accx)
]
with open("processed_nld_for_nld.pkl", 'wb') as f:
    pickle.dump(all_nlds, f)

signal_names = ["PPG", "ECG", "GSR", "EEG_F", "ACC_X"]
nld_names = ["DFA", "Persistence Entropy H1", "Persistence Entropy H0", "Lyapunov Exponent"]