import os
import torch
import numpy as np
import random


def regulairize(x):
    return (x - x.min())/(x.max()-x.min())
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)

def log(save_path, txt, file_name='file.txt'):
    with open(os.path.join(save_path, file_name), 'a+') as f:
        f.write(txt)

def change_dim(pi, n_gaussian, t_dim):
    pi = pi.reshape(pi.shape[0], t_dim, n_gaussian)
    return pi
def test_adrf(model, args, device, test_data, adrf):
    test_data=test_data
    adrf = adrf
    t_dim = args.t_dim

    
    adrf_hat = np.zeros((adrf.shape[0]))
    for test_id in range(adrf.shape[0]):
        t_tmp = adrf[test_id, :t_dim].reshape((-1,t_dim)).repeat(test_data.shape[0],axis=0)
        x_tmp = test_data[:, args.t_dim:-1]

        t_tmp = torch.tensor(t_tmp, dtype=torch.float32).to(device)
        x_tmp = torch.tensor(x_tmp, dtype=torch.float32).to(device)
        tx_tmp = torch.concat([t_tmp, x_tmp], dim=1)

        if args.model == 'cr':
            _, _, t_tmp2, x_tmp2, out= model.b_forward(t_tmp,x_tmp)
        if args.model == 'nn':
            out = model(tx_tmp)
        y_hat = out.squeeze().mean()
        adrf_hat[test_id] = y_hat
        # print(adrf_hat[test_id], adrf[test_id, -1])
    adrf_mse = ((adrf[:,-1].squeeze()-adrf_hat.squeeze())**2).mean()
    # print(f'adrf_mse: {adrf_mse}')
    # log(args.save_dir, 'adrf_mse:'+str(adrf_mse)+'\n')
    return adrf_mse