import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import pandas as pd
from sklearn.metrics import mean_squared_error
from tslearn.metrics import dtw
from fastdtw import fastdtw




def coord_gen(length):
    coord_seqs = []
    v0, v1 = -1, 1
    r = (v1 - v0) / (length - 1)
    seq = v0 + r * torch.arange(length).float()
    coord_seqs.append(seq)
    coord_seqs.append(torch.from_numpy(np.array([0.0 for i in range(1)])).float())
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    return ret




def lr_linear_impu(lr, target_len, device):
    lr = lr.unsqueeze(1)
    B, C, T, D = lr.shape
    lr = torch.concat([lr[ :, :, :, i] for i in range(lr.shape[-1])]).unsqueeze(-1)
    grid = coord_gen(target_len).unsqueeze(0).expand(B, target_len, D, 2)[ :, :, 0, :].repeat(D, 1, 1).unsqueeze(2).flip(-1)
    lr = lr.to(device)
    grid = grid.to(device)
    lr_impu = F.grid_sample(
        lr, 
        grid, 
        mode= "bilinear", 
        padding_mode= "border",
        align_corners= True
    ).permute(0, 3, 1, 2)
    lr_impu = lr_impu.view(D, B, lr_impu.shape[2], lr_impu.shape[3]).permute(1, 0, 2, 3).permute(0, 2, 3, 1) 
    return lr_impu.squeeze(1)




def get_target_broadcast(lr, m):
    B, T, D = lr.shape
    m = torch.as_tensor(m, dtype=torch.long, device=lr.device)
    L = m.numel()
    seg = torch.cumsum(m, dim=0)
    seg = torch.clamp(seg, min=0, max=T-1)  # [L]
    out = lr[:, seg, :]                     # [B, L, D]
    return out




def slice_dataset(data_array:np.array, window_len):
    assert np.ndim(data_array) == 2     # [T, D]
    cut_rows = int(np.floor(data_array.shape[0] / window_len)) * window_len
    data_array = data_array[-cut_rows : , : ]
    win_num = int(data_array.shape[0] // window_len)
    split_array = np.apply_along_axis(lambda x : np.split(x, win_num, axis= 0), 0, data_array).astype(np.float32)
    return split_array                  # [win_num, window_len, D]  




def slice_dataset_dual(hr_array:np.array, lr_array:np.array, target_mask:list, seq_hr, seq_lr):
    assert np.ndim(hr_array) == np.ndim(lr_array) == 2     # [T, D]
    cut_rows = int(np.floor(hr_array.shape[0] / seq_hr)) * seq_hr
    hr_array = hr_array[-cut_rows : , : ]
    win_num = int(hr_array.shape[0] // seq_hr)
    split_hr = np.apply_along_axis(lambda x : np.split(x, win_num, axis= 0), 0, hr_array).astype(np.float32)

    target_mask = np.array(target_mask)
    target_mask = target_mask[-cut_rows : ]
    split_mask = np.apply_along_axis(lambda x : np.split(x, win_num, axis= 0), 0, target_mask).astype(np.float32)

    cut_rows = int(np.floor(lr_array.shape[0] / seq_lr)) * seq_lr
    lr_array = lr_array[-cut_rows : , : ]
    win_num = int(lr_array.shape[0] // seq_lr)
    split_lr = np.apply_along_axis(lambda x : np.split(x, win_num, axis= 0), 0, lr_array).astype(np.float32)

    return split_hr, split_lr, split_mask                  # [win_num, window_len, D]  




def normalisation(inputArray:np.array, meanlist= None, stdlist= None):
    if meanlist is None and stdlist is None :
        meanlist, stdlist = [], []
        normalised_array = []
        for j in range(inputArray.shape[1]):
            if len(np.unique(inputArray[:, j])) == 1 :
                mean_j = np.mean(inputArray[:, j])
                std_j = 1
            else :
                mean_j = np.mean(inputArray[:, j])
                std_j = np.std(inputArray[:, j])
            normalised_array.append((inputArray[:, j] - mean_j) / std_j)
            meanlist.append(mean_j)
            stdlist.append(std_j)
        return np.array(normalised_array).T, meanlist, stdlist
    else :
        normalised_array = []
        for j in range(inputArray.shape[1]):
            normalised_array.append((inputArray[:, j] - meanlist[j]) / stdlist[j])
        return np.array(normalised_array).T, meanlist, stdlist




def normalisation_UEA(inputArray:np.array):
    normalised_array = []
    for b in range(inputArray.shape[0]):
        inputArray_b = inputArray[b, :, :]
        normalised_array_b = []
        for j in range(inputArray.shape[2]):
            mean_j = np.mean(inputArray_b[:, j])
            std_j = np.std(inputArray_b[:, j])
            normalized_array_bj = (inputArray_b[:, j] - mean_j) / std_j + 1e-6
            normalised_array_b.append(np.array(normalized_array_bj).T)
        normalised_array.append(np.array(normalised_array_b))
    return np.array(normalised_array)




def get_diff(sr_s, sr_t, lr_data, sr_ratio, mask, refine= False):
    B, H, D = sr_s.shape
    L = lr_data.shape[1]
    diff_raw = sr_s + sr_t
    if not refine : return diff_raw

    for b in range(B):
        for d in range(D):
            for t in range(1, L - 1):
                _left, _right = np.floor(t / sr_ratio), np.ceil(t / sr_ratio)
                if _left == _right : continue
                lr_diff = abs(lr_data[b, int(_right), d] - lr_data[b, int(_left), d])
                if abs(diff_raw[b, t - 1, d] - diff_raw[b, t, d]) > lr_diff :
                    dir = -1 if diff_raw[b, t - 1, d] > diff_raw[b, t, d] else 1
                    diff_raw[b, t, d] = diff_raw[b, t - 1, d] + dir * lr_diff
    for t in range(L):
        if mask[t] == 1 : diff_raw[:, t, :] = torch.from_numpy(np.zeros_like(diff_raw[:, t, :]))
    return diff_raw




class TIME_ENCODING(nn.Module):

    def __init__(self, input_dim= 32, hidden_dim= 32, output_dim= 32):
        super(TIME_ENCODING, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.embed1 = nn.Linear(input_dim, hidden_dim)
        self.embed2 = nn.Linear(hidden_dim, output_dim)
        self.activation = nn.ReLU()
    

    def forward(self, t):
        t = t.unsqueeze(1)
        div_term = torch.exp(torch.arange(0.0, self.input_dim, 2.0) * -(math.log(10000.0) / self.input_dim)).to(t.device)
        pe = torch.zeros(t.size(0), self.input_dim).to(t.device)
        pe[ : , 0 :: 2] = torch.sin(t * div_term)
        pe[ : , 1 :: 2] = torch.cos(t * div_term)
        t_emb = self.embed1(pe)
        t_emb = self.activation(t_emb)
        t_emb = self.embed2(t_emb)
        return t_emb




def vpred_loss(pred, target, vol_reg= False):
    mse_loss = F.mse_loss(pred, target)
    if vol_reg :
        var_loss = ((pred.var(dim=1) - target.var(dim=1)) ** 2).mean()
        mse_loss += var_loss
    return mse_loss




def match_variance(A, B, eps=1e-8, sr_ratio= 1):
    mu_A = A.mean(dim=1, keepdim=True)
    mu_B = B.mean(dim=1, keepdim=True)
    
    std_A = A.std(dim=1, keepdim=True, unbiased=False)
    std_B = B.std(dim=1, keepdim=True, unbiased=False)

    B_adjusted = mu_B + (B - mu_B) * np.sqrt(sr_ratio) * (std_A / (std_B + eps))
    B_adjusted = (B - mu_B) + mu_A
    return B_adjusted





def cal_loss_1(preds:np.array, labels:np.array):
    assert np.ndim(preds) == np.ndim(labels) == 1
    assert preds.shape[0] == labels.shape[0]
    mse = mean_squared_error(labels, preds)
    mae = abs(preds - labels).mean()
    mape = abs((preds - labels) / labels).mean()
    print(preds.shape, labels.shape)
    dtw_dist = dtw(preds, labels)
    return mse, mae, mape, dtw_dist




def cal_loss(preds:np.array, labels:np.array):
    assert preds.shape[0] == labels.shape[0]
    mse = mean_squared_error(labels, preds)
    dtw_dist = dtw(preds, labels)
    return mse, dtw_dist



def evaluate_metrics(pred, label):
    """
    Args:
        pred: ndarray, shape [Batch, Length, Channel]
        label: ndarray, shape [Batch, Length, Channel]
    Returns:
        mse: float
        dtw: float
    """
    mse = np.mean((pred - label) ** 2)

    batch, length, channel = pred.shape
    dtw_distances = []
    for b in range(batch):
        for c in range(channel):
            x = pred[b, :, c]
            y = label[b, :, c]
            dist, path = fastdtw(x, y)
            dtw_distances.append(dist / len(path))
    dtw_dist = np.mean(dtw_distances)
    return {'mse': mse, 'dtw': dtw_dist}


