import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import optim
import numpy as np
import copy



def train_baseline(observational_data, test_data,baseline_cate_learner, num_epochs=1000,batch_size=256,device='cpu',lr=0.001):
    # Loss functions
    mse = nn.MSELoss()

    # Optimizers
    optimizer = optim.Adam(baseline_cate_learner.parameters(), lr=lr)

    losses_mse = []
    epehe_list = []
    test_mse = []   

    true_ite = test_data.mu1 - test_data.mu0
    X_batch = observational_data.X
    T_batch = observational_data.T
    Y_batch = observational_data.Y

    obs_dataset = TensorDataset(observational_data.X, observational_data.T, observational_data.Y)
    obs_loader = DataLoader(obs_dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epochs):

        #for X_batch, T_batch, Y_batch in obs_loader:
        #    X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)
            # Predict treatment effect
            # concatenate X and T
        XT = torch.cat((X_batch, T_batch), dim=1)
        Y_pred = baseline_cate_learner(XT)

        # Calculate loss
        loss = mse(Y_pred, Y_batch)

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses_mse.append(loss.item())

        # treatment ones
        T_one = torch.ones(test_data.X.shape[0], 1)
        T_one = T_one.to(device)
        input_te_one = torch.cat((test_data.X, T_one), dim=1)
        Y_pred_one = baseline_cate_learner(input_te_one)
        # treatment zeros
        T_zero = torch.zeros(test_data.X.shape[0], 1)
        T_zero = T_zero.to(device)
        input_te_zero = torch.cat((test_data.X, T_zero), dim=1)
        Y_pred_zero = baseline_cate_learner(input_te_zero)
        ite = Y_pred_one - Y_pred_zero
        epehe = torch.sqrt(mse(true_ite, ite))
        epehe_list.append(epehe.item())

        # test mse
        input_te_test = torch.cat((test_data.X, test_data.T), dim=1)
        Y_pred_test = baseline_cate_learner(input_te_test)
        test_mse.append(mse(Y_pred_test, test_data.Y).item())
        if epoch % 100 == 0:
            print(f'Epoch {epoch}, Loss MSE: {loss.item()}, EPEHE: {epehe}')

    return baseline_cate_learner, losses_mse, epehe_list, test_mse  


# marginal balancing model
def train_model_marginal_balancing(observational_data, test_data,rct_data, model_f,generator,cate_learner,alpha=1.0,num_epochs=500,balancing_iterations_start=1,balancing_iterations_end = 10,generator_input_dim=1,batch_size=64,device='cpu',lr_g=0.001,lr_te=0.001,lr_f=0.001):
    
    obs_dataset = TensorDataset(observational_data.X, observational_data.T, observational_data.Y)
    obs_loader = DataLoader(obs_dataset, batch_size=batch_size, shuffle=True)
    
    # add  z to dataloader
    T_rct = rct_data.T
    T_rct = T_rct.view(-1, 1)
    Y_rct = rct_data.Y

    X_test = test_data.X
    mu1 = test_data.mu1
    mu0 = test_data.mu0
    T_test = test_data.T

    # Loss functions
    mse = nn.MSELoss()

    # Optimizers
    optimizer_g = optim.Adam(generator.parameters(), lr=lr_g)
    optimizer_te = optim.Adam(cate_learner.parameters(), lr=lr_te)
    optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)

    losses_mse = []
    losses_f = []
    epehe_list = []
    test_mse = []

    true_ite = mu1 - mu0
    X_batch = observational_data.X
    T_batch = observational_data.T
    Y_batch = observational_data.Y



    Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)
    Z_test = torch.randn(test_data.X.shape[0], generator_input_dim, device=device)
    for epoch in range(num_epochs):

        if epoch < int(((num_epochs/3 * 2) - 100)):
            balancing_iterations = balancing_iterations_start
        elif epoch > int(((num_epochs/3 * 2) + 100)):
            balancing_iterations = balancing_iterations_end
            # decrease learning rate very slowly
            optimizer_f.param_groups[0]['lr'] = optimizer_f.param_groups[0]['lr'] * 0.999
            optimizer_te.param_groups[0]['lr'] = optimizer_te.param_groups[0]['lr'] * 0.99
            optimizer_g.param_groups[0]['lr'] = optimizer_g.param_groups[0]['lr'] * 0.999
        else:
            balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)

        #alpha = initial_alpha + (final_alpha - initial_alpha) * epoch / num_epochs
        for X_batch, T_batch, Y_batch in obs_loader:
            X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)
            #Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)
            U_hat = generator(Z)
            
            # Prepare inputs for CATE learner
            input_te = torch.cat((X_batch, U_hat, T_batch), dim=1)
            #input_te = torch.cat((X_batch, T_batch), dim=1)
            Y_pred = cate_learner(input_te)
            
            input_te_one = torch.cat((X_batch, U_hat, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
            #input_te_one = torch.cat((X_batch, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
            Y_pred_one = cate_learner(input_te_one)
            
            input_te_zero = torch.cat((X_batch, U_hat, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
            #input_te_zero = torch.cat((X_batch, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
            Y_pred_zero = cate_learner(input_te_zero)
            
            # Transform outcomes using f
            f_Y_rct_1 = model_f(Y_rct[T_rct == 1].view(-1, 1))
            f_Y_pred_1 = model_f(Y_pred_one.view(-1, 1))
            f_Y_rct_0 = model_f(Y_rct[T_rct == 0].view(-1, 1))
            f_Y_pred_0 = model_f(Y_pred_zero.view(-1, 1))

            # Update loss functions
            loss1 = mse(Y_pred, Y_batch)
            #loss_new = f_Y_pred_1.mean()**2 + f_Y_pred_0.mean()**2
            loss2 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())
            loss3 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())

            loss = alpha * loss1 + loss2 + loss3 #+ loss_new

            # Optimize generator and CATE learner
            optimizer_te.zero_grad()
            optimizer_g.zero_grad()
            loss.backward(retain_graph=True)
            optimizer_te.step()
            optimizer_g.step()

            # Train function f
            #print("balancing_iterations = ", balancing_iterations)
            for _ in range(balancing_iterations):
                f_Y_rct_1 = model_f(Y_rct[T_rct == 1].view(-1, 1).detach())
                f_Y_pred_1 = model_f(Y_pred_one.view(-1, 1).detach())
                f_Y_rct_0 = model_f(Y_rct[T_rct == 0].view(-1, 1).detach())
                f_Y_pred_0 = model_f(Y_pred_zero.view(-1, 1).detach())

                loss4 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())
                loss5 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())

                loss_f = -loss4 - loss5
                #print("loss_f = ", loss_f)
                optimizer_f.zero_grad()
                loss_f.backward()
                optimizer_f.step()

            # Log losses
            losses_mse.append(loss1.item())
            losses_f.append(loss_f.item())

        # Calculate test MSE and EPEHE
        U_hat = generator(Z_test)
        test_input_te = torch.cat((X_test, U_hat, T_test), dim=1)
        #test_input_te = torch.cat((X_test, T_test), dim=1)
        Y_pred_test = cate_learner(test_input_te)
        test_mse.append(mse(Y_pred_test, test_data.Y).item())

        # # Estimate ITE
        T_one = torch.ones(test_data.X.shape[0], 1, device=device)
        y_list = [cate_learner(torch.cat((test_data.X, generator(torch.randn(test_data.X.shape[0], generator_input_dim, device=device)), T_one), dim=1)) for _ in range(10)]
        Y_pred_one = torch.mean(torch.stack(y_list), dim=0)
        
        T_zero = torch.zeros(test_data.X.shape[0], 1, device=device)
        y_list_zero = [cate_learner(torch.cat((test_data.X, generator(torch.randn(test_data.X.shape[0], generator_input_dim, device=device)), T_zero), dim=1)) for _ in range(10)]
        Y_pred_zero = torch.mean(torch.stack(y_list_zero), dim=0)

        
        ite = Y_pred_one - Y_pred_zero
        epehe = torch.sqrt(mse(true_ite, ite))
        epehe_list.append(epehe.item())

        if epoch % 100 == 0:    
            print(f'Epoch {epoch}, Loss MSE: {loss1.item()}, Loss F: {loss_f.item()}, EPEHE: {epehe.item()}')

    return model_f, generator, cate_learner, losses_mse, losses_f, epehe_list, test_mse


# train projections balancing model
def train_model_projections_balancing(observational_data, test_data,rct_data, model_f,cate_learner,alpha=1.0,num_epochs=500,device='cpu',lr_te=0.001,lr_f=0.001):
    
    # add  z to dataloader
    T_rct = rct_data.T
    T_rct = T_rct.view(-1, 1)
    Y_rct = rct_data.Y

    X_test = test_data.X
    mu1 = test_data.mu1
    mu0 = test_data.mu0
    T_test = test_data.T

    # Loss functions
    mse = nn.MSELoss()

    # Optimizers
    optimizer_te = optim.Adam(cate_learner.parameters(), lr=lr_te)
    optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)


    losses_mse = []
    losses_f = []
    epehe_list = []
    test_mse = []

    true_ite = mu1 - mu0
    X_batch = observational_data.X
    T_batch = observational_data.T
    Y_batch = observational_data.Y

    
    balancing_iterations_start = 1
    balancing_iterations_end = 10


    for epoch in range(num_epochs):
        # make it an int
        if epoch < int(((num_epochs/3 * 2) - 100)):
            balancing_iterations = balancing_iterations_start
        elif epoch > int(((num_epochs/3 * 2) + 100)):
            balancing_iterations = balancing_iterations_end
        else:
            balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)


        X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)
        #Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)
        
        # Prepare inputs for CATE learner
        input_te = torch.cat((X_batch, T_batch), dim=1)
        #input_te = torch.cat((X_batch, T_batch), dim=1)
        Y_pred = cate_learner(input_te)
        
        input_te_one = torch.cat((X_batch, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
        #input_te_one = torch.cat((X_batch, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
        Y_pred_one = cate_learner(input_te_one)
        
        input_te_zero = torch.cat((X_batch, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
        #input_te_zero = torch.cat((X_batch, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
        Y_pred_zero = cate_learner(input_te_zero)
        
        # Transform outcomes using f
        # if X batch seize is more that Y_rct[T_rct == 1].view(-1, 1) then we need to sample same size from X_batch
        if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:
            idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)
            X_batch_small = X_batch[idx]
        f_Y_rct_1 = model_f(X_batch_small)*Y_rct[T_rct == 1].view(-1, 1)
        f_Y_pred_1 = model_f(X_batch)*Y_pred_one.view(-1, 1)
        if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:
            idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)
            X_batch_small = X_batch[idx]
        f_Y_rct_0 = model_f(X_batch_small)*Y_rct[T_rct == 0].view(-1, 1)
        f_Y_pred_0 = model_f(X_batch)*Y_pred_zero.view(-1, 1)

        # Update loss functions
        loss1 = mse(Y_pred, Y_batch)
        loss2 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())
        loss3 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())

        loss = alpha * loss1 + loss2 + loss3

        # Optimize generator and CATE learner
        optimizer_te.zero_grad()
        loss.backward(retain_graph=True)
        optimizer_te.step()

        # Train function f
        #print("balancing_iterations = ", balancing_iterations)
        for _ in range(balancing_iterations):
            if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:
                idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)
                X_batch_small = X_batch[idx]
            f_Y_rct_1 = model_f(X_batch_small)*Y_rct[T_rct == 1].view(-1, 1).detach()
            f_Y_pred_1 = model_f(X_batch)*Y_pred_one.view(-1, 1).detach()
            if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:
                idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)
                X_batch_small = X_batch[idx]
            f_Y_rct_0 = model_f(X_batch_small)*Y_rct[T_rct == 0].view(-1, 1).detach()
            f_Y_pred_0 = model_f(X_batch)*Y_pred_zero.view(-1, 1).detach()

            loss4 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())
            loss5 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())

            loss_f = -loss4 - loss5
            #print("loss_f = ", loss_f)
            optimizer_f.zero_grad()
            loss_f.backward()
            optimizer_f.step()

        # Log losses
        losses_mse.append(loss1.item())
        losses_f.append(loss_f.item())

        # Calculate test MSE and EPEHE
        test_input_te = torch.cat((X_test, T_test), dim=1)
        #test_input_te = torch.cat((X_test, T_test), dim=1)
        Y_pred_test = cate_learner(test_input_te)
        test_mse.append(mse(Y_pred_test, test_data.Y).item())

        # Estimate ITE
        T_one = torch.ones(test_data.X.shape[0], 1, device=device)
        y_list = [cate_learner(torch.cat((test_data.X, T_one), dim=1)) for _ in range(10)]
        Y_pred_one = torch.mean(torch.stack(y_list), dim=0)
        
        T_zero = torch.zeros(test_data.X.shape[0], 1, device=device)
        y_list_zero = [cate_learner(torch.cat((test_data.X, T_zero), dim=1)) for _ in range(10)]
        Y_pred_zero = torch.mean(torch.stack(y_list_zero), dim=0)
        
        ite = Y_pred_one - Y_pred_zero
        epehe = torch.sqrt(mse(true_ite, ite))
        epehe_list.append(epehe.item())
 
        if epoch % 100 == 0:    
            print(f'Epoch {epoch}, Loss MSE: {loss1.item()}, Loss F: {loss_f.item()}, EPEHE: {epehe.item()}')

    return model_f, cate_learner, losses_mse, losses_f, epehe_list, test_mse

# train the combined marginal balancing and projections balancing model
def train_model_mb_plus_pb(observational_data, test_data,rct_data, model_g,model_f,generator,cate_learner,alpha_start=100,alpha_end=0.01,num_epochs=500,balancing_iterations=2,generator_input_dim=1,device='cpu',lr_g=0.001,lr_te=0.001,lr_f=0.001):
      
    # add  z to dataloader
    T_rct = rct_data.T
    T_rct = T_rct.view(-1, 1)
    Y_rct = rct_data.Y

    X_test = test_data.X
    mu1 = test_data.mu1
    mu0 = test_data.mu0
    T_test = test_data.T

    # Loss functions
    mse = nn.MSELoss()

    # Optimizers
    optimizer_g = optim.Adam(generator.parameters(), lr=lr_g)
    optimizer_te = optim.Adam(cate_learner.parameters(), lr=lr_te)
    optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)
    optimizer_g_model = optim.Adam(model_g.parameters(), lr=lr_g)

    losses_mse = []
    losses_f = []
    epehe_list = []
    test_mse = []

    true_ite = mu1 - mu0
    X_batch = observational_data.X
    T_batch = observational_data.T
    Y_batch = observational_data.Y

    
    balancing_iterations_start = 5
    balancing_iterations_end = 100

    alpha = alpha_start

    Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)
    for epoch in range(num_epochs):
        #balancing_iterations = balancing_iterations_start - (balancing_iterations_start - balancing_iterations_end) * epoch / num_epochs
        # make it an int
        if epoch < int(((num_epochs/3 * 2) - 100)):
            alpha = alpha_start
            balancing_iterations = balancing_iterations_start
        elif epoch > int(((num_epochs/3 * 2) + 100)):
            alpha = alpha_end
            balancing_iterations = balancing_iterations_end
        else:
            balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)
            alpha = alpha_start - (alpha_start - alpha_end) * (epoch - int(num_epochs/3 * 2 - 100)) / (200)

        X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)
        U_hat = generator(Z)
        
        # Prepare inputs for CATE learner
        input_te = torch.cat((X_batch, U_hat, T_batch), dim=1)
        Y_pred = cate_learner(input_te)
        
        input_te_one = torch.cat((X_batch, U_hat, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
        Y_pred_one = cate_learner(input_te_one)
        
        input_te_zero = torch.cat((X_batch, U_hat, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
        Y_pred_zero = cate_learner(input_te_zero)
        
        # Transform outcomes using f
        # if X batch seize is more that Y_rct[T_rct == 1].view(-1, 1) then we need to sample same size from X_batch
        if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:
            idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)
            X_batch_small = X_batch[idx]

        f_Y_rct_1 = model_f(X_batch_small)*Y_rct[T_rct == 1].view(-1, 1)
        f_Y_pred_1 = model_f(X_batch)*Y_pred_one.view(-1, 1)
        if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:
            idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)
            X_batch_small = X_batch[idx]

        f_Y_rct_0 = model_f(X_batch_small)*Y_rct[T_rct == 0].view(-1, 1)
        f_Y_pred_0 = model_f(X_batch)*Y_pred_zero.view(-1, 1)

        g_Y_pred_1 = model_g(Y_pred_one.view(-1, 1))
        g_Y_rct_1 = model_g(Y_rct[T_rct == 1].view(-1, 1))
        g_Y_pred_0 = model_g(Y_pred_zero.view(-1, 1))
        g_Y_rct_0 = model_g(Y_rct[T_rct == 0].view(-1, 1))
  

        # Update loss functions
        loss1 = mse(Y_pred, Y_batch)
        loss2 = mse(f_Y_pred_1.mean(), f_Y_rct_1.mean()) 
        loss3 = mse(f_Y_pred_0.mean(), f_Y_rct_0.mean())
        loss2p = mse(g_Y_pred_1.mean(), g_Y_rct_1.mean())
        loss3p = mse(g_Y_pred_0.mean(), g_Y_rct_0.mean())

        loss = alpha * loss1 + loss2 + loss3 +  loss2p + loss3p
 
        # Optimize generator and CATE learner
        optimizer_te.zero_grad()
        optimizer_g.zero_grad()
        loss.backward(retain_graph=True)
        optimizer_te.step()
        optimizer_g.step()

        for _ in range(balancing_iterations):
    
            if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:
                idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)
                X_batch_small = X_batch[idx]
            f_Y_rct_1 = model_f(X_batch_small)*Y_rct[T_rct == 1].view(-1, 1).detach()
            f_Y_pred_1 = model_f(X_batch)*Y_pred_one.view(-1, 1).detach()
            if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:
                idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)
                X_batch_small = X_batch[idx]
            f_Y_rct_0 = model_f(X_batch_small)*Y_rct[T_rct == 0].view(-1, 1).detach()
            f_Y_pred_0 = model_f(X_batch)*Y_pred_zero.view(-1, 1).detach()

            g_Y_pred_1 = model_g(Y_pred_one.view(-1, 1)).detach()
            g_Y_rct_1 = model_g(Y_rct[T_rct == 1].view(-1, 1)).detach()
            g_Y_pred_0 = model_g(Y_pred_zero.view(-1, 1)).detach()
            g_Y_rct_0 = model_g(Y_rct[T_rct == 0].view(-1, 1)).detach()

            loss4 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())
            loss5 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())
            loss4p = mse(g_Y_pred_1.mean(), g_Y_rct_1.mean())
            loss5p = mse(g_Y_pred_0.mean(), g_Y_rct_0.mean())

            loss_f = -loss4 - loss5 -  loss4p - loss5p
            #print("loss_f = ", loss_f)
            optimizer_f.zero_grad()
            optimizer_g_model.zero_grad()
            loss_f.backward()
            optimizer_f.step()
            optimizer_g_model.step()

        # Log losses
        losses_mse.append(loss1.item())
        losses_f.append(loss_f.item())

        # Calculate test MSE and EPEHE
        Z_test = torch.randn(test_data.X.shape[0], generator_input_dim, device=device)
        U_hat = generator(Z_test)
        test_input_te = torch.cat((X_test, U_hat, T_test), dim=1)
        #test_input_te = torch.cat((X_test, T_test), dim=1)
        Y_pred_test = cate_learner(test_input_te)
        test_mse.append(mse(Y_pred_test, test_data.Y).item())

        # Estimate ITE
        T_one = torch.ones(test_data.X.shape[0], 1, device=device)
        y_list = [cate_learner(torch.cat((test_data.X, generator(torch.randn(test_data.X.shape[0], generator_input_dim, device=device)), T_one), dim=1)) for _ in range(10)]
        Y_pred_one = torch.mean(torch.stack(y_list), dim=0)
        
        T_zero = torch.zeros(test_data.X.shape[0], 1, device=device)
        y_list_zero = [cate_learner(torch.cat((test_data.X, generator(torch.randn(test_data.X.shape[0], generator_input_dim, device=device)), T_zero), dim=1)) for _ in range(10)]
        Y_pred_zero = torch.mean(torch.stack(y_list_zero), dim=0)
        
        ite = Y_pred_one - Y_pred_zero
        epehe = torch.sqrt(mse(true_ite, ite))
        epehe_list.append(epehe.item())

        if epoch % 100 == 0:    
            print(f'Epoch {epoch}, Loss MSE: {loss1.item()}, Loss F: {loss_f.item()}, EPEHE: {epehe.item()}')

    return model_f, generator, cate_learner, losses_mse, losses_f, epehe_list, test_mse




# train ideal model with no missing confounders
def train_ideal(observational_data, test_data, baseline_cate_learner, num_epochs=1000, batch_size=256, device='cpu',lr=0.001):
    # Loss functions
    mse = nn.MSELoss()

    # Optimizers
    optimizer = optim.Adam(baseline_cate_learner.parameters(), lr=lr)

    losses_mse = []
    epehe_list = []
    test_mse = []   

    true_ite = test_data.mu1 - test_data.mu0
    X_batch = observational_data.X
    T_batch = observational_data.T
    Y_batch = observational_data.Y
    U_batch = observational_data.U

    dim_U = U_batch.shape[1]

    obs_dataset = TensorDataset(X_batch, T_batch, Y_batch, U_batch)
    obs_loader = DataLoader(obs_dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epochs):
        epoch_losses_mse = []
        for X_batch, T_batch, Y_batch, U_batch in obs_loader:
            X_batch, T_batch, Y_batch, U_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device), U_batch.to(device)
            
            # Predict treatment effect
            XT = torch.cat((X_batch, T_batch, U_batch), dim=1)
            Y_pred = baseline_cate_learner(XT)

            # Calculate loss
            loss = mse(Y_pred, Y_batch)

            # Update model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_losses_mse.append(loss.item())

        losses_mse.append(np.mean(epoch_losses_mse))

        num_samples = 10

        Y_pred_one_samples = []
        Y_pred_zero_samples = []
        
        for _ in range(num_samples):
            # Generate random U each time
            # take the mean of U_batch (its just a float)
            U_batch = U_batch.to(device)
            u_mean = torch.mean(U_batch)
            U = torch.bernoulli(u_mean * torch.ones(test_data.X.shape[0], dim_U,device=device))

            # Treatment ones
            T_one = torch.ones(test_data.X.shape[0], 1).to(device)
            input_te_one = torch.cat((test_data.X, T_one, U), dim=1)
            Y_pred_one = baseline_cate_learner(input_te_one)
            Y_pred_one_samples.append(Y_pred_one)

            # Treatment zeros
            T_zero = torch.zeros(test_data.X.shape[0], 1).to(device)
            input_te_zero = torch.cat((test_data.X, T_zero, U), dim=1)
            Y_pred_zero = baseline_cate_learner(input_te_zero)
            Y_pred_zero_samples.append(Y_pred_zero)

        Y_pred_one_avg = torch.mean(torch.stack(Y_pred_one_samples), dim=0)
        Y_pred_zero_avg = torch.mean(torch.stack(Y_pred_zero_samples), dim=0)

        ite = Y_pred_one_avg - Y_pred_zero_avg
        epehe = torch.sqrt(mse(true_ite, ite))
        epehe_list.append(epehe.item())

        # Test MSE
        input_te_test = torch.cat((test_data.X, test_data.T, test_data.U), dim=1)
        Y_pred_test = baseline_cate_learner(input_te_test)
        test_mse.append(mse(Y_pred_test, test_data.Y).item())

        if epoch % 100 == 0:
            print(f'Epoch {epoch}, Avg Loss MSE: {losses_mse[-1]}, EPEHE: {epehe.item()}')

    return baseline_cate_learner, losses_mse, epehe_list, test_mse





##### Not needed

# def train_model_unconfounded_general(observational_data, test_data,rct_data, model_f,generator,cate_learner,alpha_start=100,alpha_end=0.01,num_epochs=500,balancing_iterations=2,generator_input_dim=1,batch_size=64,device='cpu',lr_g=0.001,lr_te=0.001,lr_f=0.1):

    
#     # add  z to dataloader
#     T_rct = rct_data.T
#     T_rct = T_rct.view(-1, 1)
#     Y_rct = rct_data.Y

#     X_test = test_data.X
#     mu1 = test_data.mu1
#     mu0 = test_data.mu0
#     T_test = test_data.T

#     # Loss functions
#     mse = nn.MSELoss()
#     mae = nn.L1Loss()

#     # Optimizers
#     optimizer_g = optim.Adam(generator.parameters(), lr=lr_g)
#     optimizer_te = optim.Adam(cate_learner.parameters(), lr=lr_te)
#     optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)

#     # create a copy of the model_f
#     model_f_tilde = copy.deepcopy(model_f)
#     model_f_tilde = model_f_tilde.to(device)

#     optimizer_f_tilde = optim.Adam(model_f_tilde.parameters(), lr=lr_f)


#     losses_mse = []
#     losses_f = []
#     epehe_list = []
#     test_mse = []

#     true_ite = mu1 - mu0
#     X_batch = observational_data.X
#     T_batch = observational_data.T
#     Y_batch = observational_data.Y

    
#     balancing_iterations_start = 10
#     balancing_iterations_end = 50

#     alpha = alpha_start


#     Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)
#     for epoch in range(num_epochs):

#         if epoch < int(((num_epochs/3 * 2) - 100)):
#             alpha = alpha_start
#             balancing_iterations = balancing_iterations_start
#         elif epoch > int(((num_epochs/3 * 2) + 100)):
#             alpha = alpha_end
#             balancing_iterations = balancing_iterations_end
#             # decrease learning rate very slowly
#             optimizer_f.param_groups[0]['lr'] = optimizer_f.param_groups[0]['lr'] * 0.999
#             optimizer_te.param_groups[0]['lr'] = optimizer_te.param_groups[0]['lr'] * 0.99
#             optimizer_g.param_groups[0]['lr'] = optimizer_g.param_groups[0]['lr'] * 0.999
#         else:
#             balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)
#             alpha = alpha_start - (alpha_start - alpha_end) * (epoch - int(num_epochs/3 * 2 - 100)) / (200)

#         X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)
#         #Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)
#         U_hat = generator(Z)
        
#         # Prepare inputs for CATE learner
#         input_te = torch.cat((X_batch, U_hat, T_batch), dim=1)
#         #input_te = torch.cat((X_batch, T_batch), dim=1)
#         Y_pred = cate_learner(input_te)
        
#         input_te_one = torch.cat((X_batch, U_hat, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
#         #input_te_one = torch.cat((X_batch, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
#         Y_pred_one = cate_learner(input_te_one)
        
#         input_te_zero = torch.cat((X_batch, U_hat, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
#         #input_te_zero = torch.cat((X_batch, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
#         Y_pred_zero = cate_learner(input_te_zero)


#         if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:
#             idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].shape[0], replace=False)
#             X_batch_small = X_batch[idx]
#         #print("X_batch_small = ", X_batch_small.shape)
#         #print("Y_rct[T_rct == 1] = ", Y_rct[T_rct == 1].shape)
#         f_Y_rct_1 = torch.mul(model_f(X_batch_small).view(-1, 1),model_f_tilde(Y_rct[T_rct == 1].view(-1, 1)))
#         #print("f_Y_rct_1 = ", f_Y_rct_1.shape)
#         f_Y_pred_1 = torch.mul(model_f(X_batch).view(-1, 1),model_f_tilde(Y_pred_one.view(-1, 1)))
#         if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:
#             idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].shape[0], replace=False)
#             X_batch_small = X_batch[idx]
        
#         f_Y_rct_0 = torch.mul(model_f(X_batch_small).view(-1, 1),model_f_tilde(Y_rct[T_rct == 0].view(-1, 1)))
#         f_Y_pred_0 = torch.mul(model_f(X_batch).view(-1, 1),model_f_tilde(Y_pred_zero.view(-1, 1)))

        

#         # Update loss functions
#         loss1 = mse(Y_pred, Y_batch)
#         loss2 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())
#         loss3 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())

#         loss = alpha * loss1 + loss2 + loss3

#         # Optimize generator and CATE learner
#         optimizer_te.zero_grad()
#         optimizer_g.zero_grad()
#         loss.backward(retain_graph=True)
#         optimizer_te.step()
#         optimizer_g.step()

#         # Train function f
#         #print("balancing_iterations = ", balancing_iterations)
#         for _ in range(balancing_iterations):

#             if X_batch.shape[0] > Y_rct[T_rct == 1].view(-1, 1).shape[0]:
#                 idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 1].view(-1, 1).shape[0], replace=False)
#                 X_batch_small = X_batch[idx]
#             f_Y_rct_1 = torch.mul(model_f(X_batch_small).view(-1, 1).detach(),model_f_tilde(Y_rct[T_rct == 1].view(-1, 1).detach()))
#             f_Y_pred_1 = torch.mul(model_f(X_batch).view(-1, 1).detach(),model_f_tilde(Y_pred_one.view(-1, 1).detach()))
#             if X_batch.shape[0] > Y_rct[T_rct == 0].view(-1, 1).shape[0]:
#                 idx = np.random.choice(X_batch.shape[0], Y_rct[T_rct == 0].view(-1, 1).shape[0], replace=False)
#                 X_batch_small = X_batch[idx]
#             f_Y_rct_0 = torch.mul(model_f(X_batch_small).view(-1, 1).detach(),model_f_tilde(Y_rct[T_rct == 0].view(-1, 1).detach()))
#             f_Y_pred_0 = torch.mul(model_f(X_batch).view(-1,1).detach(),model_f_tilde(Y_pred_zero.view(-1, 1).detach()))

#             loss4 = mse(f_Y_rct_1.mean(), f_Y_pred_1.mean())
#             loss5 = mse(f_Y_rct_0.mean(), f_Y_pred_0.mean())

#             loss_f = -loss4 - loss5
#             #print("loss_f = ", loss_f)
#             optimizer_f.zero_grad()
#             optimizer_f_tilde.zero_grad()
#             loss_f.backward()
#             optimizer_f.step()
#             optimizer_f_tilde.step()

#         # Log losses
#         losses_mse.append(loss1.item())
#         losses_f.append(loss_f.item())

#         # Calculate test MSE and EPEHE
#         Z_test = torch.randn(test_data.X.shape[0], generator_input_dim, device=device)
#         U_hat = generator(Z_test)
#         test_input_te = torch.cat((X_test, U_hat, T_test), dim=1)
#         #test_input_te = torch.cat((X_test, T_test), dim=1)
#         Y_pred_test = cate_learner(test_input_te)
#         test_mse.append(mse(Y_pred_test, test_data.Y).item())

#         # Estimate ITE
#         T_one = torch.ones(test_data.X.shape[0], 1, device=device)
#         y_list = [cate_learner(torch.cat((test_data.X, generator(torch.randn(test_data.X.shape[0], generator_input_dim, device=device)), T_one), dim=1)) for _ in range(10)]
#         Y_pred_one = torch.mean(torch.stack(y_list), dim=0)
        
#         T_zero = torch.zeros(test_data.X.shape[0], 1, device=device)
#         y_list_zero = [cate_learner(torch.cat((test_data.X, generator(torch.randn(test_data.X.shape[0], generator_input_dim, device=device)), T_zero), dim=1)) for _ in range(10)]
#         Y_pred_zero = torch.mean(torch.stack(y_list_zero), dim=0)
        
#         ite = Y_pred_one - Y_pred_zero
#         epehe = torch.sqrt(mse(true_ite, ite))
#         epehe_list.append(epehe.item())
  
#         if epoch % 100 == 0:    
#             print(f'Epoch {epoch}, Loss MSE: {loss1.item()}, Loss F: {loss_f.item()}, EPEHE: {epehe.item()}')


#     return model_f, generator, cate_learner, losses_mse, losses_f, epehe_list, test_mse





# def train_model_unconfounded_multiple(observational_data1, test_data1,observational_data2, test_data2,model_f,generator,cate_learner1,cate_learner2,alpha_start=100,alpha_end=0.01,num_epochs=500,balancing_iterations=2,generator_input_dim=1,batch_size=64,device='cpu',lr_g=0.001,lr_te=0.001,lr_f=0.001):
    
#     obs_dataset1 = TensorDataset(observational_data1.X, observational_data1.T, observational_data1.Y)
#     obs_loader1 = DataLoader(obs_dataset1, batch_size=batch_size, shuffle=True)

#     obs_dataset2 = TensorDataset(observational_data2.X, observational_data2.T, observational_data2.Y)
#     obs_loader2 = DataLoader(obs_dataset2, batch_size=batch_size, shuffle=True)

#     # Create a DataLoader
#     size = observational_data1.X.shape[0]


#     X_test1 = test_data1.X
#     mu11 = test_data1.mu1
#     mu01 = test_data1.mu0
#     T_test1 = test_data1.T

#     X_test2 = test_data2.X
#     mu12 = test_data2.mu1
#     mu02 = test_data2.mu0
#     T_test2 = test_data2.T

#     # Loss functions
#     mse = nn.MSELoss()
#     mae = nn.L1Loss()

#     # Optimizers
#     optimizer_g = optim.Adam(generator.parameters(), lr=lr_g)
#     optimizer_te1 = optim.Adam(cate_learner1.parameters(), lr=lr_te)
#     optimizer_te2 = optim.Adam(cate_learner2.parameters(), lr=lr_te)
#     optimizer_f = optim.Adam(model_f.parameters(), lr=lr_f)


    # scheduler_g = optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999)
    # scheduler_te = optim.lr_scheduler.ExponentialLR(optimizer_te, gamma=0.99)
    # scheduler_f = optim.lr_scheduler.ExponentialLR(optimizer_f, gamma=0.999)

#     losses_mse = []
#     losses_f = []
#     epehe_list1 = []
#     epehe_list2 = []
#     test_mse = []

#     true_ite1 = mu11 - mu01
#     true_ite2 = mu12 - mu02
#     X_batch1 = observational_data1.X
#     T_batch1 = observational_data1.T
#     Y_batch1 = observational_data1.Y
#     X_batch2 = observational_data2.X    
#     T_batch2 = observational_data2.T
#     Y_batch2 = observational_data2.Y

    
#     balancing_iterations_start = 5
#     balancing_iterations_end = 100


#     alpha = alpha_start

#     # print("balancing_iterations_start = ", balancing_iterations_start)
#     # print("balancing_iterations_end = ", balancing_iterations_end)
#     # print("alpha_start = ", alpha_start)

#     Z = torch.randn(X_batch1.shape[0], generator_input_dim, device=device)
#     Z_test = torch.randn(test_data1.X.shape[0], generator_input_dim, device=device)
#     for epoch in range(num_epochs):
#         #balancing_iterations = balancing_iterations_start - (balancing_iterations_start - balancing_iterations_end) * epoch / num_epochs
#         # make it an int
#         if epoch < int(((num_epochs/3 * 2) - 100)):
#             alpha = alpha_start
#             balancing_iterations = balancing_iterations_start
#         elif epoch > int(((num_epochs/3 * 2) + 100)):
#             alpha = alpha_end
#             balancing_iterations = balancing_iterations_end
#             # decrease learning rate very slowly
#             optimizer_f.param_groups[0]['lr'] = optimizer_f.param_groups[0]['lr'] * 0.999
#             optimizer_te1.param_groups[0]['lr'] = optimizer_te1.param_groups[0]['lr'] * 0.99
#             optimizer_te2.param_groups[0]['lr'] = optimizer_te2.param_groups[0]['lr'] * 0.99
#             optimizer_g.param_groups[0]['lr'] = optimizer_g.param_groups[0]['lr'] * 0.999
#         else:
#             balancing_iterations = int((balancing_iterations_start + balancing_iterations_end) / 2)
#             alpha = alpha_start - (alpha_start - alpha_end) * (epoch - int(num_epochs/3 * 2 - 100)) / (200)

#         X_batch1, T_batch1, Y_batch1 = X_batch1.to(device), T_batch1.to(device), Y_batch1.to(device)
#         X_batch2, T_batch2, Y_batch2 = X_batch2.to(device), T_batch2.to(device), Y_batch2.to(device)
#         #Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)
#         U_hat = generator(Z)
        
#         # Prepare inputs for CATE learner
#         input_te1 = torch.cat((X_batch1, U_hat, T_batch1), dim=1)
#         input_te2 = torch.cat((X_batch2, U_hat, T_batch2), dim=1)
#         #input_te = torch.cat((X_batch, T_batch), dim=1)
#         Y_pred1 = cate_learner1(input_te1)
#         Y_pred2 = cate_learner2(input_te2)
        
#         input_te_one1 = torch.cat((X_batch1, U_hat, torch.ones(X_batch1.shape[0], 1, device=device)), dim=1)
#         input_te_one2 = torch.cat((X_batch2, U_hat, torch.ones(X_batch2.shape[0], 1, device=device)), dim=1)
#         #input_te_one = torch.cat((X_batch, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
#         Y_pred_one1 = cate_learner1(input_te_one1)
#         Y_pred_one2 = cate_learner2(input_te_one2)
        
#         input_te_zero1 = torch.cat((X_batch1, U_hat, torch.zeros(X_batch1.shape[0], 1, device=device)), dim=1)
#         input_te_zero2 = torch.cat((X_batch2, U_hat, torch.zeros(X_batch2.shape[0], 1, device=device)), dim=1)
#         #input_te_zero = torch.cat((X_batch, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
#         Y_pred_zero1 = cate_learner1(input_te_zero1)
#         Y_pred_zero2 = cate_learner2(input_te_zero2)
        
#         # Transform outcomes using f
#         f_Y_pred_11 = model_f(Y_pred_one1.view(-1, 1))
#         f_Y_pred_12 = model_f(Y_pred_one2.view(-1, 1))
#         f_Y_pred_01 = model_f(Y_pred_zero1.view(-1, 1))
#         f_Y_pred_02 = model_f(Y_pred_zero2.view(-1, 1))

#         # Update loss functions
#         loss11 = mse(Y_pred1, Y_batch1)
#         loss12 = mse(Y_pred2, Y_batch2)
#         loss2 = mse(f_Y_pred_11.mean(), f_Y_pred_12.mean())
#         loss3 = mse(f_Y_pred_01.mean(), f_Y_pred_02.mean())

#         loss = alpha * (loss11+loss12) + loss2 + loss3

#         # Optimize generator and CATE learner
#         optimizer_te1.zero_grad()
#         optimizer_te2.zero_grad()
#         optimizer_g.zero_grad()
#         loss.backward(retain_graph=True)
#         optimizer_te1.step()
#         optimizer_te2.step()
#         optimizer_g.step()

#         # Train function f
#         #print("balancing_iterations = ", balancing_iterations)
#         for _ in range(balancing_iterations):
#             f_Y_pred_11 = model_f(Y_pred_one1.view(-1, 1).detach())
#             f_Y_pred_12 = model_f(Y_pred_one2.view(-1, 1).detach())
#             f_Y_pred_01 = model_f(Y_pred_zero1.view(-1, 1).detach())
#             f_Y_pred_02 = model_f(Y_pred_zero2.view(-1, 1).detach())

#             loss4 = mse(f_Y_pred_11.mean(), f_Y_pred_12.mean())
#             loss5 = mse(f_Y_pred_01.mean(), f_Y_pred_02.mean())

#             loss_f = -loss4 - loss5
#             #print("loss_f = ", loss_f)
#             optimizer_f.zero_grad()
#             loss_f.backward()
#             optimizer_f.step()

#         # Log losses
#         losses_mse.append(loss11+loss12.item())
#         losses_f.append(loss_f.item())

#         # Calculate test MSE and EPEHE
#         U_hat = generator(Z_test)
#         test_input_te1 = torch.cat((X_test1, U_hat, T_test1), dim=1)
#         test_input_te2 = torch.cat((X_test2, U_hat, T_test2), dim=1)
#         #test_input_te = torch.cat((X_test, T_test), dim=1)
#         Y_pred_test1 = cate_learner1(test_input_te1)
#         Y_pred_test2 = cate_learner2(test_input_te2)
#         test_mse.append(mse(Y_pred_test1, test_data1.Y).item()+mse(Y_pred_test2, test_data2.Y).item())

#         # # Estimate ITE
#         T_one1 = torch.ones(test_data1.X.shape[0], 1, device=device)
#         y_list1 = [cate_learner1(torch.cat((test_data1.X, generator(torch.randn(test_data1.X.shape[0], generator_input_dim, device=device)), T_one1), dim=1)) for _ in range(10)]
#         Y_pred_one1 = torch.mean(torch.stack(y_list1), dim=0)

#         T_zero1 = torch.zeros(test_data1.X.shape[0], 1, device=device)
#         y_list_zero1 = [cate_learner1(torch.cat((test_data1.X, generator(torch.randn(test_data1.X.shape[0], generator_input_dim, device=device)), T_zero1), dim=1)) for _ in range(10)]
#         Y_pred_zero1 = torch.mean(torch.stack(y_list_zero1), dim=0)
        
#         T_one2 = torch.ones(test_data2.X.shape[0], 1, device=device)
#         y_list2 = [cate_learner2(torch.cat((test_data2.X, generator(torch.randn(test_data2.X.shape[0], generator_input_dim, device=device)), T_one2), dim=1)) for _ in range(10)]
#         Y_pred_one2 = torch.mean(torch.stack(y_list2), dim=0)

#         T_zero2 = torch.zeros(test_data2.X.shape[0], 1, device=device)
#         y_list_zero2 = [cate_learner2(torch.cat((test_data2.X, generator(torch.randn(test_data2.X.shape[0], generator_input_dim, device=device)), T_zero2), dim=1)) for _ in range(10)]
#         Y_pred_zero2 = torch.mean(torch.stack(y_list_zero2), dim=0)

#         # estimate ITE
#         # T_one = torch.ones(test_data.X.shape[0], 1, device=device)
#         # input_te_one = torch.cat((test_data.X, T_one), dim=1)
#         # Y_pred_one = cate_learner(input_te_one)
#         # T_zero = torch.zeros(test_data.X.shape[0], 1, device=device)
#         # input_te_zero = torch.cat((test_data.X, T_zero), dim=1)
#         # Y_pred_zero = cate_learner(input_te_zero)
        
#         ite1 = Y_pred_one1 - Y_pred_zero1
#         epehe1 = torch.sqrt(mse(true_ite1, ite1))
#         epehe_list1.append(epehe1.item())

#         ite2 = Y_pred_one2 - Y_pred_zero2
#         epehe2 = torch.sqrt(mse(true_ite2, ite2))
#         epehe_list2.append(epehe2.item())

#         # # Update alpha if the change in MSE is less than epsilon
#         # if abs(previous_mse - loss1.item()) < epsilon:
#         #     alpha = max(alpha * 0.1, alpha_end)
        
#         # if epoch % 10 == 0:
#         #     previous_mse = loss1.item()
#         #print("alpha = ", alpha)    
#         if epoch % 100 == 0:    
#             print(f'Epoch {epoch}, Loss MSE: {loss11+loss12.item()}, Loss F: {loss_f.item()}, EPEHE1: {epehe1.item()}, EPEHE2: {epehe2.item()}')

#         # scheduler_g.step()
#         # scheduler_te.step()
#         # scheduler_f.step()

#     return model_f, generator, cate_learner1,cate_learner2, losses_mse, losses_f, epehe_list1, epehe_list2, test_mse





# ##### old functions

# class Wassertein_Loss(nn.Module):
#   def __init__(self,p=1,blur=0.05):
#     super(Wassertein_Loss, self).__init__()
#     self.p = p
#     self.blur = blur
#   def forward(self,phi1,phi0):
#     samples_loss = SamplesLoss(loss="sinkhorn", p=self.p, blur=self.blur, backend="tensorized")
#     imbalance_loss = samples_loss(phi1, phi0)
#     return imbalance_loss
# def train_model_unconfounded_sinkhorn(observational_data, test_data,rct_data,generator,cate_learner,alpha_start=100,alpha_end=0.01,num_epochs=500,balancing_iterations=2,generator_input_dim=1,batch_size=64,device='cpu'):
#     obs_dataset = TensorDataset(observational_data.X, observational_data.T, observational_data.Y)
#     obs_loader = DataLoader(obs_dataset, batch_size=batch_size, shuffle=True)

#     # Create a DataLoader
#     size = observational_data.X.shape[0]
    
#     # add  z to dataloader
#     T_rct = rct_data.T
#     T_rct = T_rct.view(-1, 1)
#     Y_rct = rct_data.Y

#     X_test = test_data.X
#     mu1 = test_data.mu1
#     mu0 = test_data.mu0
#     T_test = test_data.T

#     # Loss functions
#     mse = nn.MSELoss()
#     wass = Wassertein_Loss()


#     # Optimizers
#     optimizer_g = optim.Adam(generator.parameters(), lr=0.001)
#     optimizer_te = optim.Adam(cate_learner.parameters(), lr=0.001)

#     losses_mse = []
#     epehe_list = []
#     test_mse = []
#     losses_sinkhorn = []

#     true_ite = mu1 - mu0
#     X_batch = observational_data.X
#     T_batch = observational_data.T
#     Y_batch = observational_data.Y




#     alpha = alpha_start



#     for epoch in range(num_epochs):
#          #balancing_iterations = balancing_iterations_start - (balancing_iterations_start - balancing_iterations_end) * epoch / num_epochs
#         # make it an int
#         if epoch < int(((num_epochs/3 * 2) - 100)):
#             alpha = alpha_start
#         elif epoch > int(((num_epochs/3 * 2) + 100)):
#             alpha = alpha_end
#         else:
#             alpha = alpha_start - (alpha_start - alpha_end) * (epoch - int(num_epochs/3 * 2 - 100)) / (200)


#         for X_batch, T_batch, Y_batch in obs_loader:
#             X_batch, T_batch, Y_batch = X_batch.to(device), T_batch.to(device), Y_batch.to(device)
#             Z = torch.randn(X_batch.shape[0], generator_input_dim, device=device)
#             U_hat = generator(Z)
            
#             # Prepare inputs for CATE learner
#             input_te = torch.cat((X_batch, U_hat, T_batch), dim=1)
#             Y_pred = cate_learner(input_te)
            
#             input_te_one = torch.cat((X_batch, U_hat, torch.ones(X_batch.shape[0], 1, device=device)), dim=1)
#             Y_pred_one = cate_learner(input_te_one)
            
#             input_te_zero = torch.cat((X_batch, U_hat, torch.zeros(X_batch.shape[0], 1, device=device)), dim=1)
#             Y_pred_zero = cate_learner(input_te_zero)
            
#             Y_rct_1 = Y_rct[T_rct == 1].view(-1, 1)
#             Y_rct_0 = Y_rct[T_rct == 0].view(-1, 1)

#             # Update loss functions
#             loss1 = mse(Y_pred, Y_batch)
#             loss2 = wass(Y_rct_1, Y_pred_one)
#             loss3 = wass(Y_rct_0, Y_pred_zero)

#             loss_sinkhorn = loss2 + loss3

#             loss = alpha * loss1 + loss_sinkhorn

#             # Optimize generator and CATE learner
#             optimizer_te.zero_grad()
#             optimizer_g.zero_grad()
#             loss.backward(retain_graph=True)
#             optimizer_te.step()
#             optimizer_g.step()

            

#         # Log losses
#         losses_mse.append(loss1.item())
#         losses_sinkhorn.append(loss_sinkhorn.item())

#         # Calculate test MSE and EPEHE
#         Z_test = torch.randn(test_data.X.shape[0], generator_input_dim, device=device)
#         U_hat = generator(Z_test)
#         test_input_te = torch.cat((X_test, U_hat, T_test), dim=1)
#         Y_pred_test = cate_learner(test_input_te)
#         test_mse.append(mse(Y_pred_test, test_data.Y).item())

#         # Estimate ITE
#         T_one = torch.ones(test_data.X.shape[0], 1, device=device)
#         y_list = [cate_learner(torch.cat((test_data.X, generator(torch.randn(test_data.X.shape[0], generator_input_dim, device=device)), T_one), dim=1)) for _ in range(10)]
#         Y_pred_one = torch.mean(torch.stack(y_list), dim=0)
        
#         T_zero = torch.zeros(test_data.X.shape[0], 1, device=device)
#         y_list_zero = [cate_learner(torch.cat((test_data.X, generator(torch.randn(test_data.X.shape[0], generator_input_dim, device=device)), T_zero), dim=1)) for _ in range(10)]
#         Y_pred_zero = torch.mean(torch.stack(y_list_zero), dim=0)
        
#         ite = Y_pred_one - Y_pred_zero
#         epehe = torch.sqrt(mse(true_ite, ite))
#         epehe_list.append(epehe.item())  
            
#         print(f'Epoch {epoch}, Loss MSE: {loss1.item()}, Loss F: {loss_sinkhorn.item()}, EPEHE: {epehe.item()}')

#     return generator, cate_learner, losses_mse, losses_sinkhorn, epehe_list, test_mse