import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import pandas as pd
import os
import h5py
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



#################################### Function defination ##########################################
def set_random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)



def make_data_dict(data_path, fold=0, args=None):
    """to be polish"""
    full_data = np.load(data_path, allow_pickle=True).item()

    # here should add one more data-loader class?

    data_dict = full_data["data"][fold]
    data_dict["CONTI_2_DISCT_dicts"] = full_data["CONTI_2_DISCT_dicts"]
    data_dict["DISCT_2_CONTI_dicts"] = full_data["DISCT_2_CONTI_dicts"]

    data_dict["ndims"] = full_data["ndims"]

    return data_dict





def Write_excel(list0,list1, list2, output_path):
    data = {
        "Epoch":list0,
        'RMSE': list1,
        'MAE': list2
    }

    # 创建一个 DataFrame
    df = pd.DataFrame(data)

    # 将 DataFrame 写入 Excel 文件

    df.to_excel(output_path, index=True)
def get_ind_time(data_dict, key1, key2):
    x1 = data_dict[key1]
    ind1 = x1[:, :2]
    t1 = x1[:, -1]
    x2 = data_dict[key2]
    ind2 = x2[:, :2]
    t2 = x2[:, -1]
    return ind1, ind2, t1, t2



def normalize_data(tr_y, te_y):
    data_mean = tr_y.mean()
    data_std = tr_y.std()

    # data_mean = tr_y.min()
    # data_std = tr_y.max() - tr_y.min()

    tr = (tr_y - data_mean) / data_std
    te = (te_y - data_mean) / data_std

    return tr, te, data_std, data_mean



def normalize_data2(d):
    min_val = d.min()
    max_val = d.max()


    return 2 * (d - min_val) / (max_val - min_val) - 1

def get_sample_data(tr_time_ind):
    values = np.array([4*x for x in range(67)])
    # 使用torch.isin找到在列表内的值
    mask = np.isin(tr_time_ind, values)

    # 使用torch.where返回索引
    indices = np.where(mask)
    return indices[0]



def create_dict(list):
    my_dict = {x: list[x] for x in range(len(list))}
    return my_dict


def create_dict2(list):
    my_dict = {list[x]:x for x in range(len(list))}
    return my_dict




def load_data_ssf(data_path, flag="train"):
    if flag == "train":
        d = np.load(data_path, allow_pickle=True).item() # time, lat, lon, depth
        dims = d["ndims"]
        d = d["data"]
        tr_ind_conti = d["tr_ind_conti"]
        tr_ind = d["tr_ind"]
        tr_time_ind = tr_ind[:, 0]
        tr_dep_ind = tr_ind[:, -1]
        tr_ind = tr_ind[:, 1:3]
        tr_y = d["tr_y"]
        u_ind_uni = d["u_ind_uni"]
        v_ind_uni = d["v_ind_uni"]
        return tr_ind_conti, tr_ind, tr_dep_ind, tr_time_ind, tr_y,   u_ind_uni, v_ind_uni, dims
    elif flag == "test":
        d = np.load(data_path, allow_pickle=True).item()   
        d = d["data"]
        te_ind_conti = d["te_ind_conti"]
        te_ind = d["te_ind"]
        te_time_ind = te_ind[:, 0]
        te_dep_ind = te_ind[:, -1]
        te_ind = te_ind[:, 1:3]
        te_y = d["te_y"]
        return te_ind_conti, te_ind, te_dep_ind, te_time_ind, te_y



def load_data_se(data_path):
    d = np.load(data_path, allow_pickle=True).item() # time, lat, lon, depth
    d = d["data"]
    ind_uni = (d["u_ind_uni"], d["v_ind_uni"], d["w_ind_uni"], d["t_ind_uni"])
    data_extract = d["data"]
    mask_tr = d["mask_tr"]
    data_mean = d["data_mean"]
    data_std = d["data_std"]  
    return ind_uni, data_extract, mask_tr, data_mean, data_std


def load_large_data(data_path, metadata_path):
    with h5py.File(data_path, 'r') as f:
        data_extract = f['data'][:]
    d = np.load(metadata_path, allow_pickle=True).item()
    d = d["data"]
    ind_uni = (d["u_ind_uni"], d["v_ind_uni"], d["w_ind_uni"], d["t_ind_uni"])
    mask_tr = d["mask_tr"]
    return ind_uni, data_extract, mask_tr


###########################################################################




###################################################################
def tv_regularization(tensor, weight=1.0, flag = True):
    diff = tensor[:-1] - tensor[1:]
    if flag:
        tv_loss = torch.sum(torch.norm(diff, p='fro', dim=(1, 2)))
    else:
        tv_loss = torch.sum(torch.norm(diff, p='fro', dim=(1)))
    return weight * tv_loss



def total_variation_loss(X, weight):
    diff = X[:, 1:, :, :, :] - X[:, :-1, :, :, :]
    tv_loss = torch.sum(torch.norm(diff, p='fro', dim=(1)))
    return weight * tv_loss

def get_gp_covariance(t, gp_gamma=1e4): # 200 * 100 * 1
    s = t - t.transpose(-1, -2) # Pairwise time differences, shape [B, S, S]
    diag = torch.eye(t.shape[-2]).to(t) * 1e-5 # for numerical stability 100 * 100
    return torch.exp(-torch.square(s)*gp_gamma) + diag




def matern_kernel(r, sigma_f=1.0, l=1.0, nu=1.5):
    """
    Matern kernel function for Gaussian Process (PyTorch version).
    
    Parameters:
    - x, x_prime: Input tensors, the points for which to compute the kernel.
    - sigma_f: Signal variance (default is 1.0).
    - l: Length scale (default is 1.0).
    - nu: Smoothness parameter (default is 1.5).
    
    Returns:
    - The computed Matern kernel value between x and x_prime.
    """
    if nu == 0.5:
        # Exponential kernel (Matern with nu=0.5)
        return sigma_f**2 * torch.exp(-r / l)
    
    elif nu == 1.5:
        # Matern kernel with nu=1.5
        return sigma_f**2 * (1 + torch.sqrt(torch.tensor(3.0)) * r / l) * torch.exp(-torch.sqrt(torch.tensor(3.0)) * r / l)
    
    elif nu == 2.5:
        # Matern kernel with nu=2.5
        return sigma_f**2 * (1 + torch.sqrt(torch.tensor(5.0)) * r / l + 5 * r**2 / (3 * l**2)) * torch.exp(-torch.sqrt(torch.tensor(5.0)) * r / l)
    
    else:
        raise ValueError("Only nu values of 0.5, 1.5, and 2.5 are supported in this implementation.")



def get_ktT(y_tt, core_t, gp_gamma=1e3, gp_sigma=2):
    r = torch.sqrt(torch.square(y_tt - core_t)) #  time differences, shape [B, S]

    return gp_sigma*torch.exp(-torch.square(r)*gp_gamma) # [B, S]
    #return matern_kernel(r, sigma_f=1.0, l=1.0, nu=2.5)


def get_kTT_inv(t, gp_gamma=1e3, gp_sigma=2):
    r = torch.sqrt(torch.square(t - t.transpose(-1, -2))) # Pairwise time differences, shape [B, S, S]
    diag = torch.eye(t.shape[-2]).to(t) * 1e-4 # for numerical stability 100 * 100

    K = gp_sigma*torch.exp(-torch.square(r)*gp_gamma) + diag
    #L = torch.linalg.cholesky(K.squeeze(0))

    #return (L.T@L).unsqueeze(0)
    return torch.inverse(K)
    #return torch.inverse(matern_kernel(r, sigma_f=1.0, l=1.0, nu=2.5) + diag)

def add_noise(x, t, i, alphas):
    """
    x: Clean data sample, shape [B, S, D]
    t: Times of observations, shape [B, S, 1]
    i: Diffusion step, shape [B, S, 1]
    """
    noise_gaussian = torch.randn_like(x)#200*100*1
    
    cov = get_gp_covariance(t) # Covariance matrix 200*100*100
    L = torch.linalg.cholesky(cov) # 200*100*100
    noise = L @ noise_gaussian #200*100*1
    
    alpha = alphas[i.long()].to(x) #200*100*1  
    x_noisy = torch.sqrt(alpha) * x + torch.sqrt(1 - alpha) * noise
    
    return x_noisy, noise



def get_betas(steps):
    beta_start, beta_end = 1e-4, 0.04
    diffusion_ind = torch.linspace(0, 1, steps).to(device)
    return beta_start * (1 - diffusion_ind) + beta_end * diffusion_ind






def plot_generate_git(basis, core, core_mean, core_std, data_std, data_mean, u_ind_uni, v_ind_uni, file_name = r"generate_animation.gif"):
    core = core * core_std + core_mean
    core = np.squeeze(core[0,:,:])

    output_batch = np.einsum("ik,lk->il", core, basis)

    output_batch = output_batch.reshape(output_batch.shape[0], u_ind_uni.shape[0], v_ind_uni.shape[0])
    data = output_batch * data_std + data_mean
    fig, ax = plt.subplots(figsize=(3, 2))
    n_frames = output_batch.shape[0]  # 
    vmin = np.min(data[:n_frames,:,:])
    vmax = np.max(data[:n_frames,:,:])
    im = ax.imshow(data[0], cmap='viridis', vmin=vmin, vmax=vmax)
    def update(frame):
        im.set_array(data[frame]) 
        return [im]

    ani = FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)


    save_path = os.path.join("output", file_name)
    ani.save(save_path, writer='imagemagick', fps=20, dpi=100)




# compute rmse
def compute_rmse(gt, basis, core, core_mean, core_std, data_std, data_mean, u_ind_uni, v_ind_uni):
    core = core * core_std + core_mean
    core = np.squeeze(core[0,:,:])
    output_batch = np.einsum("ik,lk->il", core, basis)

    output_batch = output_batch.reshape(output_batch.shape[0], u_ind_uni.shape[0], v_ind_uni.shape[0])
    data = output_batch * data_std + data_mean
    return np.sqrt(np.mean((gt - data)**2)), np.mean(np.abs((gt - data)))





def extract_observations(data_path, data_mean, data_std):
    #load partial observations  group 
    tr_ind_conti, tr_ind, tr_dep_ind, tr_time_ind, tr_y,   u_ind_uni, v_ind_uni, dims = load_data_ssf(data_path)
    
    tr_ind_conti = tr_ind_conti[:,1:3]
    
    
    
    t_grid = torch.linspace(0, 1, 50)
    

    timestep = 50
    y_group = []
    ind_conti_group = []
    y_tt_group = []
    tr_y = (tr_y- data_mean)/data_std
    be_ind = 600
    for i in range(timestep):
        if i % 1 ==0 :
            y_temp = tr_y[tr_time_ind==(i+be_ind)]
            y_group.append(y_temp+0.3*np.random.randn(*y_temp.shape))
            ind_conti_group.append(tr_ind_conti[tr_time_ind==(i+be_ind)])
            y_tt_group.append(t_grid[i])
        # else:
        #     y_group.append([])
        #     ind_conti_group.append([])
        #     y_tt_group.append([])

    return y_group, ind_conti_group, y_tt_group, u_ind_uni, v_ind_uni



def extract_observations2(data_path, data_mean, data_std):
    #load partial observations  group 
    tr_ind_conti, tr_ind, tr_dep_ind, tr_time_ind, tr_y,   u_ind_uni, v_ind_uni, dims = load_data_ssf(data_path)
    tr_ind_conti = tr_ind_conti[:,1:3]

    timestep = 20
    y_group = []
    ind_conti_group = []
    tr_y = (tr_y- data_mean)/data_std
    for i in range(timestep):
        y_group.append(tr_y[tr_time_ind==(i+40)])
        ind_conti_group.append(tr_ind_conti[tr_time_ind==(i+40)])

    return y_group, ind_conti_group, u_ind_uni, v_ind_uni



def extract_observations3(data_path, data_mean, data_std):
    #load partial observations  group 
    tr_ind_conti, tr_ind, tr_dep_ind, tr_time_ind, tr_y,   u_ind_uni, v_ind_uni, dims = load_data_ssf(data_path)
    tr_ind_conti = tr_ind_conti[:,1:3]

    timestep = 50
    y_group = []
    ind_conti_group = []
    tr_y = (tr_y- data_mean)/data_std
    for i in range(timestep):
        y_group.append(tr_y[tr_time_ind==(i+180)])
        ind_conti_group.append(tr_ind_conti[tr_time_ind==(i+180)])

    return y_group, ind_conti_group, u_ind_uni, v_ind_uni




def extract_observations4(data_path, data_mean, data_std):
    #load partial observations  group 
    tr_ind_conti, tr_ind, tr_dep_ind, tr_time_ind, tr_y,   u_ind_uni, v_ind_uni, dims = load_data_ssf(data_path)
    tr_ind_conti = tr_ind_conti[:,1:3]

    timestep = 20
    y_group = []
    ind_conti_group = []
    tr_y = (tr_y- data_mean)/data_std
    for i in range(timestep):
        y_group.append(tr_y[tr_time_ind==(i+140)])
        ind_conti_group.append(tr_ind_conti[tr_time_ind==(i+140)])

    return y_group, ind_conti_group, u_ind_uni, v_ind_uni






def load_meta_data(data_path):
    d = np.load(data_path, allow_pickle=True).item() # time, lat, lon, depth
    d = d["data"]
    u_uni = d["u_ind_uni"]
    v_uni = d["v_ind_uni"]
    w_uni = d["w_ind_uni"]
    t_uni = d["t_ind_uni"]

    return (u_uni, v_uni, w_uni, t_uni)


