import torch
import torch.nn as nn
import numpy as np
import time
from tqdm import tqdm

from scipy import signal
from scipy.ndimage import gaussian_filter
import pywt

from torchvision import transforms
#from pytorch_pretrained_vit import ViT

#from transformers import AutoTokenizer, AutoModelForMaskedLM

#from src.models.layers import *

# =========== Helper Layers ========================================================================
# ========== BASIC PREPROCESSING ===========================
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    sos = signal.butter(order, [lowcut, highcut], 'bandpass', fs=fs, output='sos')
    y = signal.sosfilt(sos, data)
    return y

def impute(ts, sr=1000, tr=500):
    idx = [i*tr for i in range(len(ts)) if not np.isnan(ts[i])]
    vals = [ts[i] for i in range(len(ts)) if not np.isnan(ts[i])]
    interp_vals = np.interp(np.arange(len(ts)*tr), idx, vals)
    
    r = sr
    interp_vals = np.array([interp_vals[i] for i in range(len(ts)*tr) if i % r == 0])
                  
    return interp_vals

def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def basic_preproc(
    ts, 
    sr=1000, # source sampling rate
    tr=125, # target sampling rate
    l_pass=0.1, # low frequency pass
    h_pass=5, # high frequency pass
    outlier_p=0.95, # threshold for outlier
    smooth_c=0.04,
    avg=0.5
):  
    # remove outlier
    diffs = np.abs(ts[1:] - ts[:-1])
    max_diff = np.quantile(diffs, outlier_p)
    ts_clear = np.array([_ for _ in ts])
    ts_outliers = diffs > max_diff
    ts_outliers = np.append(ts_outliers, False)
    ts_clear[ts_outliers] = np.nan
    
    # impute missing value
    ts_clear = impute(ts_clear, sr=sr, tr=tr)
    
    # # bandpass filter, remove frequency out of bound
    # ts_clear = butter_bandpass_filter(ts_clear, l_pass, h_pass, tr*2+2, order=4)
    
    # detrend, remove linear shift
    ts_clear = signal.detrend(ts_clear)
    
    # down sample to target sampling rate
    # ts_clear = impute(ts_clear, sr=tr*2+2, tr=tr)

    # smooth
    ts_clear = gaussian_filter(ts_clear, sigma=tr*smooth_c)

    # normalize
    ts_clear = ts_clear/(np.mean(np.abs(ts_clear))+0.000001)
    
    return ts_clear


# ====== CV Model =============
VIT_EMB_SIZE = {
    "B_16_imagenet1k": 768,
    "L_32_imagenet1k": 1024
}

def init_vit(m_name='B_16_imagenet1k'):
    model = ViT(m_name, pretrained=True) # construct and load pretrained weight
    model.eval() # eval mode
    
#     model.patch_embedding = nn.Sequential() # remove patching layer
    # model.positional_embedding = nn.Sequential() # remove position encoding layer

    # not [CLS] only
    model.norm = nn.Sequential(
        model.norm,
        CheckShape(None, lambda x: x.unsqueeze(1)), # keep all output
    )

    # overwrite last classification layer
    model.fc = nn.Sequential(
        CheckShape(None, lambda x: torch.mean(x, dim=1))  # average all output
    )
    return model

def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    return model

# ====== Signal Processing and Wavelet Transform =============
def wt(ts, lf=0.1, hf=65, wl='gaus1', method='fft'):
    # in: L
    # out: FxL
    cwtmatr, freqs = pywt.cwt(ts, np.arange(lf, hf), wl, method=method)
    return cwtmatr

def derive_2d(x, sr):
    return (x[:, 1:] - x[:, :-1])

def derive(x):
    return (x[1:]-x[:-1])


def norm(x):
#     return (x - np.mean(x)) / np.std(x)
    max_, min_ = np.max(x), np.min(x)
    return (x - min_) / (max_ - min_)

def proc_one_signal(x, ss, ts=64, lc=0.1, hc=128, ws=6, overlap=0.5):
    # basic process
    ts_clear = basic_preproc(
        x, 
        sr=ss, 
        tr=ts, 
        l_pass=lc,
        h_pass=hc
    )

    # wavelet transform and 1st 2nd derivatives
    cwt0 = wt(ts_clear[:-1], hf=65)
    cwt1 = derive_2d(cwt0, ts)
    cwt2 = derive_2d(cwt1, ts)
    cwt_im = torch.tensor([cwt0[:, 2:].tolist(), cwt1[:, 1:].tolist(), cwt2.tolist()]).permute(1, 2, 0)
    cwt_im = cwt_im.permute(1, 0, 2) # (L, F, 3)

    ts_clear = ts_clear[3:]

    # sliding window
    ws = round((ws * ts)*overlap)
    cwts = list()
    tss = list()
    c_i = cwt_im.shape[0]
    t_i = len(ts_clear)
    while c_i >= ws:
        cwts.append(cwt_im[c_i-ws:c_i, :, :])
        tss.append(ts_clear[t_i-ws:t_i])
        c_i -= ws
        t_i -= ws
    # TODO pad
    
        
    return cwts, tss

def preproc_all(all_tss, ss, ts=65, lc=0.1, hc=128, outlier_p=0.95, pad_to=-1): # N, C, L
    # basic preprocess
    N, C, L = all_tss.shape
    new_tss = None
    for i in range(N):
        for j in range(C):
            clean_tss = basic_preproc(
                all_tss[i, j], 
                sr=ss, 
                tr=ts, 
                l_pass=lc,
                h_pass=hc,
                outlier_p=outlier_p
            )
            if new_tss is None:
                new_tss = np.zeros((N, C, len(clean_tss)))
            new_tss[i, j] = clean_tss
    
    # # get tss
    N, C, L = new_tss.shape
    # new_tss = torch.from_numpy(new_tss).permute(1, 0, 2).view(C, -1)
    # mu, sigm = torch.mean(new_tss, dim=0, keepdims=True), torch.std(new_tss, dim=0, keepdims=True)
    # new_tss = (new_tss - mu) / sigm
    # new_tss = new_tss.view(C, N, L).permute(1, 0, 2)

    # get cwt
    final_tensor = torch.zeros(N, C, L-2, 65, 3)
    for i in range(N):
        for j in range(C):
            cwt0 = wt(new_tss[i, j], hf=65)
            cwt1 = derive_2d(cwt0, ts)
            cwt2 = derive_2d(cwt1, ts)
            cwt_im = torch.tensor([cwt0[:, 2:].tolist(), cwt1[:, 1:].tolist(), cwt2.tolist()]).permute(1, 2, 0)
            cwt_im = cwt_im.permute(1, 0, 2) # (L, F, 3)
            final_tensor[i, j] = cwt_im
    
    final_tensor, new_tss = final_tensor, torch.tensor(new_tss[:, :, 2:])
    N, C, L, F, ch = final_tensor.shape
    if pad_to > 0:
        if L < pad_to:
            tensor_pad = torch.zeros(N, C, pad_to-L, F, ch)
            tss_pad = torch.zeros(N, C, pad_to-L)

            final_tensor = torch.cat((tensor_pad, final_tensor), dim=2)
            new_tss = torch.cat((tss_pad, new_tss), dim=2)
            
    return final_tensor, new_tss

def get_cv_embedding(cwt_ims, cv_model):
    # Load image
    imgs = transforms.Compose([
        transforms.Resize((384, 384)), 
        # transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5),
    ])(cwt_ims)
    # print(imgs.shape) # torch.Size([N, 3, 384, 384])

    # forward
    outputs = cv_model(imgs.to(DEVICE)) # (N, 768)
    return outputs

class SignalTokenizer(nn.Module):
    def __init__(self, ts=64, lc=0.1, hc=31, ws=6, overlap=0.5, embed_size=768):
        super().__init__()
        
        self.ws = ws
        self.overlap = overlap
        self.ts=ts
        self.lc=lc
        self.hc=hc
        self.embed_size = embed_size

        self.cv_model=freeze_model(init_vit()).to(DEVICE)

        # text module
        self.nlp_tk = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
        self.nlp_embed = AutoModelForMaskedLM.from_pretrained("medicalai/ClinicalBERT").distilbert
        self.nlp_embed.transformer = EmptyLater()
        self.nlp_embed = freeze_model(self.nlp_embed).to(DEVICE)

        self.sensor_template = "The {sensor} signals within {duration} are <signal> "

    def forward(self, x, ss, names, question): # N, C, L
        N, C, L = x.shape
        duration = L / ss
        d_unit = "seconds"
        if duration >= 60: # if larger than 1 minute
            duration /= 60
            d_unit = "minutes"
        if duration >= 60: # if larger than 1 hour
            duration /= 60
            d_unit = "hours"
        if duration >= 24: # if larger than 1 day
            duration /= 24
            d_unit = "days"
        duration = "{} {}".format(round(duration, 2), d_unit)

        total_length = 0

        # tokenize text
        txts = list()
        for c in names:
            txt_curr = self.sensor_template.format(**{"sensor": c, "duration": duration})
            if len(txts) >= 1:
                txt_curr = " </signal>. " + txt_curr
            txts.append(self.nlp_embed(torch.tensor(self.nlp_tk.encode(txt_curr)).unsqueeze(0)))
            total_length += txts[-1].shape[1] 
        txts.append(self.nlp_embed(torch.tensor(self.nlp_tk.encode(" </signal>. " + question)).unsqueeze(0)))
        total_length += txts[-1].shape[1] 

        # tokenize signal
        cwt_all = list()
        num_chunk = 0
        for x_ in x:
            for sig in x_:
                proc_sig = proc_one_signal(sig, ss, ts=self.ts, lc=self.lc, hc=self.hc, ws=self.ws, overlap=self.overlap)
                cwt_all += proc_sig
                num_chunk = len(proc_sig)
        cwt_all = torch.stack(cwt_all).permute(0, 3, 1, 2) # N*C*num_chunk, 768
        cwt_embeds = get_cv_embedding(cwt_all, self.cv_model).view(N, C, num_chunk, self.embed_size)
        total_length += (num_chunk*(len(txts)-1))

        final_embeds = torch.zeros(N, total_length, self.embed_size).to(DEVICE) # N, LL, E
        curr_idx = 0
        for c_i in range(C):
            # integrate text embedding
            curr_txt_len = txts[c_i].shape[1]
            final_embeds[:, curr_idx:curr_idx+curr_txt_len, :] = txts[c_i]
            curr_idx += curr_txt_len

            # integrate signal embedding
            final_embeds[:, curr_idx:curr_idx+num_chunk, :] = cwt_embeds[:, c_i, :, :]
            curr_idx += num_chunk

        return final_embeds # N, LL, 768
    
if __name__ == "__main__":
    start = time.time()

    # initialize
    tokenizer = SignalTokenizer()

    # test random data
    x = torch.rand(2, 3, 64*12)
    sensor_names = ["PPG", "EEG", "ECG"]
    question = "The valence level is [MASK]%."

    # tokenize
    cwt_embeds = tokenizer(x, 64, sensor_names, question)
    print("Out shape:", cwt_embeds.shape) # (N, C, num_chunk, 768)

    # check time consumed
    end = time.time()
    print("Time used: {} s".format(round(end-start, 2)))
        