import torch
import torch.nn as nn
import numpy as np
import time
from tqdm import tqdm

from scipy import signal
from scipy.ndimage import gaussian_filter
# import pywt

from torchvision import transforms
from pytorch_pretrained_vit import ViT

# from transformers import AutoTokenizer, AutoModelForMaskedLM

from src.models.layers import *

# =========== Helper Layers ========================================================================
# ========== BASIC PREPROCESSING ===========================
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    # sos = signal.butter(order, [lowcut, highcut], 'bandpass', fs=fs, output='sos')
    # y = signal.sosfilt(sos, data)

    b, a =  signal.butter(order, highcut, fs=fs, btype='low', analog=False)
    y = signal.lfilter(b, a, data)
    return y

def impute(ts, sr=1000, tr=500):
    idx = [i*tr for i in range(len(ts)) if not np.isnan(ts[i])]
    vals = [ts[i] for i in range(len(ts)) if not np.isnan(ts[i])]
    interp_vals = np.interp(np.arange(len(ts)*tr), idx, vals)
    
    r = sr
    interp_vals = np.array([interp_vals[i] for i in range(len(ts)*tr) if i % r == 0])
                  
    return interp_vals

def rolling_window(a, window):
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def basic_preproc(
    ts, 
    sr=1000, # source sampling rate
    tr=125, # target sampling rate
    l_pass=0.1, # low frequency pass
    h_pass=5, # high frequency pass
    outlier_p=0.95, # threshold for outlier
    smooth_c=0.02,
    avg=0.5
):  
    # remove outlier
    diffs = np.abs(ts[1:] - ts[:-1])
    max_diff = np.quantile(diffs, outlier_p)
    ts_clear = np.array([_ for _ in ts])
    ts_outliers = diffs > max_diff
    ts_outliers = np.append(ts_outliers, False)
    ts_clear[ts_outliers] = np.nan

    # if h_pass < sr // 2:
    #     # impute missing value
    #     ts_clear = impute(ts_clear, sr=sr, tr=sr)
    
    #     # bandpass filter, remove frequency out of bound
    #     ts_clear = butter_bandpass_filter(ts_clear, l_pass, h_pass, sr, order=4)
    
    # down sample to target sampling rate
    ts_clear = impute(ts_clear, sr=sr, tr=tr)

    # detrend, remove linear shift
    ts_clear = signal.detrend(ts_clear)

    # smooth
    ts_clear = gaussian_filter(ts_clear, sigma=tr*smooth_c)

    # normalize
    ts_clear /= np.mean(np.abs(ts_clear))
    
    return ts_clear


# ====== CV Model =============
VIT_EMB_SIZE = {
    "B_16_imagenet1k": 768,
    "L_32_imagenet1k": 1024
}

def init_vit(m_name='B_16_imagenet1k'):
    model = ViT(m_name, pretrained=True) # construct and load pretrained weight
    model.eval() # eval mode
    
#     model.patch_embedding = nn.Sequential() # remove patching layer
    # model.positional_embedding = nn.Sequential() # remove position encoding layer

    # not [CLS] only
    model.norm = nn.Sequential(
        model.norm,
        CheckShape(None, lambda x: x.unsqueeze(1)), # keep all output
    )

    # overwrite last classification layer
    model.fc = nn.Sequential(
        CheckShape(None, lambda x: torch.mean(x, dim=1))  # average all output
    )
    return model

def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    return model

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True
    return model

# ====== Signal Processing and Wavelet Transform =============
def wt(ts, lf=0.1, hf=65, wl='gaus1', method='fft'):
    # in: L
    # out: FxL
    # cwtmatr, freqs = pywt.cwt(ts, np.arange(lf, hf), wl, method=method)
    cwtmatr = signal.cwt(ts, signal.ricker, np.arange(lf, hf))
    return cwtmatr

def derive_2d(x, sr):
    return (x[:, 1:] - x[:, :-1])

def derive(x, sr):
    return x[1:] - x[:-1]

def norm(x):
#     return (x - np.mean(x)) / np.std(x)
    max_, min_ = np.max(x), np.min(x)
    return (x - min_) / (max_ - min_)

def proc_one_signal(x, ss, ts=64, lc=0.1, hc=128, ws=6, overlap=0.5):
    # basic process
    ts_clear = basic_preproc(
        x, 
        sr=ss, 
        tr=ts, 
        l_pass=lc,
        h_pass=hc
    )

    # wavelet transform and 1st 2nd derivatives
    t0 = ts_clear[:-1] # left out the ar point
    t1 = derive(t0) # 1st derivative
    t2 = derive(t1) # 2nd derivative
    cwt0 = wt(t0, hf=65)
    cwt1 = wt(t1, hf=65)
    cwt2 = wt(t2, hf=65)
    # cwt1 = derive_2d(cwt0, ts)
    # cwt2 = derive_2d(cwt1, ts)
    cwt_im = torch.tensor([cwt0[:, 2:].tolist(), cwt1[:, 1:].tolist(), cwt2.tolist()]).permute(1, 2, 0)
    cwt_im = cwt_im.permute(1, 0, 2) # (L, F, 3)

    ts_clear = ts_clear[3:]

    # sliding window
    ws = round((ws * ts)*overlap)
    cwts = list()
    tss = list()
    c_i = cwt_im.shape[0]
    t_i = len(ts_clear)
    while c_i >= ws:
        cwts.append(cwt_im[c_i-ws:c_i, :, :])
        tss.append(ts_clear[t_i-ws:t_i])
        c_i -= ws
        t_i -= ws
    # TODO pad
    return cwts, tss

def preproc_all(all_tss, ss, ts=65, lc=0.1, hc=128, outlier_p=0.95, pad_to=-1, left_ar=False, preproc=True): # N, C, L
    # basic preprocess
    N, C, L = all_tss.shape
    if preproc:
        new_tss = None
        for i in range(N):
            for j in range(C):
                clean_tss = basic_preproc(
                    all_tss[i, j], 
                    sr=ss, 
                    tr=ts, 
                    l_pass=lc,
                    h_pass=hc,
                    outlier_p=outlier_p
                )
                if new_tss is None:
                    new_tss = np.zeros((N, C, len(clean_tss)))
                new_tss[i, j] = clean_tss
    else:
        new_tss = all_tss
    
    # # get tss
    N, C, L = new_tss.shape

    # get cwt
    new_L = L-3 if left_ar else L-2
    final_tensor = torch.zeros(N, C, new_L, 65, 3)
    for i in range(N):
        for j in range(C):
            t0 = new_tss[i, j, :-1] if left_ar else new_tss[i, j, :]
            t1 = derive(t0, ts)
            t2 = derive(t1, ts)
            cwt0 = wt(t0, hf=65)
            cwt1 = wt(t1, hf=65)
            cwt2 = wt(t2, hf=65)

            # cwt0 = wt(new_tss[i, j, :-1], hf=65)
            # cwt1 = derive_2d(cwt0, ts)
            # cwt2 = derive_2d(cwt1, ts)
            cwt_im = torch.tensor([cwt0[:, 2:].tolist(), cwt1[:, 1:].tolist(), cwt2.tolist()]).permute(1, 2, 0)
            cwt_im = cwt_im.permute(1, 0, 2) # (L, F, 3)
            final_tensor[i, j] = cwt_im
    
    final_tensor, new_tss = final_tensor, torch.tensor(new_tss[:, :, 2:])
    N, C, L, F, ch = final_tensor.shape
    if pad_to > 0:
        if L < pad_to:
            tensor_pad = torch.zeros(N, C, pad_to-L, F, ch)
            tss_pad = torch.zeros(N, C, pad_to-L)

            final_tensor = torch.cat((tensor_pad, final_tensor), dim=2)
            new_tss = torch.cat((tss_pad, new_tss), dim=2)
            
    return final_tensor, new_tss

# def get_cv_embedding(cwt_ims, cv_model):
#     # Load image
#     imgs = transforms.Compose([
#         transforms.Resize((384, 384)), 
#         # transforms.ToTensor(),
#         transforms.Normalize(0.5, 0.5),
#     ])(cwt_ims)
#     # print(imgs.shape) # torch.Size([N, 3, 384, 384])

#     # forward
#     outputs = cv_model(imgs.to(DEVICE)) # (N, 768)
#     return outputs

# class SignalTokenizer(nn.Module):
#     def __init__(self, ts=64, lc=0.1, hc=31, ws=6, overlap=0.5, embed_size=768):
#         super().__init__()
        
#         self.ws = ws
#         self.overlap = overlap
#         self.ts=ts
#         self.lc=lc
#         self.hc=hc
#         self.embed_size = embed_size

#         self.cv_model=freeze_model(init_vit()).to(DEVICE)

#         # text module
#         self.nlp_tk = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
#         self.nlp_embed = AutoModelForMaskedLM.from_pretrained("medicalai/ClinicalBERT").distilbert
#         self.nlp_embed.transformer = EmptyLater()
#         self.nlp_embed = freeze_model(self.nlp_embed).to(DEVICE)

#         self.sensor_template = "The {sensor} signals within {duration} are <signal> "

#     def forward(self, x, ss, names, question): # N, C, L
#         N, C, L = x.shape
#         duration = L / ss
#         d_unit = "seconds"
#         if duration >= 60: # if larger than 1 minute
#             duration /= 60
#             d_unit = "minutes"
#         if duration >= 60: # if larger than 1 hour
#             duration /= 60
#             d_unit = "hours"
#         if duration >= 24: # if larger than 1 day
#             duration /= 24
#             d_unit = "days"
#         duration = "{} {}".format(round(duration, 2), d_unit)

#         total_length = 0

#         # tokenize text
#         txts = list()
#         for c in names:
#             txt_curr = self.sensor_template.format(**{"sensor": c, "duration": duration})
#             if len(txts) >= 1:
#                 txt_curr = " </signal>. " + txt_curr
#             txts.append(self.nlp_embed(torch.tensor(self.nlp_tk.encode(txt_curr)).unsqueeze(0)))
#             total_length += txts[-1].shape[1] 
#         txts.append(self.nlp_embed(torch.tensor(self.nlp_tk.encode(" </signal>. " + question)).unsqueeze(0)))
#         total_length += txts[-1].shape[1] 

#         # tokenize signal
#         cwt_all = list()
#         num_chunk = 0
#         for x_ in x:
#             for sig in x_:
#                 proc_sig = proc_one_signal(sig, ss, ts=self.ts, lc=self.lc, hc=self.hc, ws=self.ws, overlap=self.overlap)
#                 cwt_all += proc_sig
#                 num_chunk = len(proc_sig)
#         cwt_all = torch.stack(cwt_all).permute(0, 3, 1, 2) # N*C*num_chunk, 768
#         cwt_embeds = get_cv_embedding(cwt_all, self.cv_model).view(N, C, num_chunk, self.embed_size)
#         total_length += (num_chunk*(len(txts)-1))

#         final_embeds = torch.zeros(N, total_length, self.embed_size).to(DEVICE) # N, LL, E
#         curr_idx = 0
#         for c_i in range(C):
#             # integrate text embedding
#             curr_txt_len = txts[c_i].shape[1]
#             final_embeds[:, curr_idx:curr_idx+curr_txt_len, :] = txts[c_i]
#             curr_idx += curr_txt_len

#             # integrate signal embedding
#             final_embeds[:, curr_idx:curr_idx+num_chunk, :] = cwt_embeds[:, c_i, :, :]
#             curr_idx += num_chunk

#         return final_embeds # N, LL, 768
    
# if __name__ == "__main__":
#     start = time.time()

#     # initialize
#     tokenizer = SignalTokenizer()

#     # test random data
#     x = torch.rand(2, 3, 64*12)
#     sensor_names = ["PPG", "EEG", "ECG"]
#     question = "The valence level is [MASK]%."

#     # tokenize
#     cwt_embeds = tokenizer(x, 64, sensor_names, question)
#     print("Out shape:", cwt_embeds.shape) # (N, C, num_chunk, 768)

#     # check time consumed
#     end = time.time()
#     print("Time used: {} s".format(round(end-start, 2)))
        