import os
import pickle
from typing import List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torchvision import datasets
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.transforms import BaseTransform
import torch_geometric.datasets as datasets
from torch_geometric.utils import to_scipy_sparse_matrix, negative_sampling
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import Data
import numpy as np
from temperature import tune_temp
from conformal import * 
from gnn import GCN, SAGE
from torch_geometric.utils import subgraph
from tdigest import TDigest


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def get_save_name(args, experiment, neigh_gen = False):
    dset_name = args["dataset"]
    num_clients = args["num_client"]
    model  =  args["architecture"]
    partition = args['partition']

    if partition == "dirichlet":
        part = f"dirichlet_{args['dirichlet_beta']}"
    else:
        part = partition 

    if args["use_federated_cfgnn"]:
        exp_fed_cf = "fl_corr:yes"
    else:
        exp_fed_cf = "fl_corr:no"

    if args["use_fed_sage_plus"]:
        generator = "gen:yes"
    else:
        generator = "gen:no"

    if args["use_vae_gen"]:
        vae = "vae:yes"
    else:
        vae = "vae:no"

    # save_name = f"{dset_name}_{experiment}_{part}_{model}_{num_clients}_{exp_cf}_{exp_fed_cf}"    
    save_name = f"{dset_name}_{experiment}_{part}_{model}_{num_clients}_{exp_fed_cf}_{generator}_{vae}"    

    return save_name

def split_train(data, dataset=None, data_path=None, ratio_train=0.2):
    n_data = data.num_nodes
    ratio_test = (1-ratio_train)/2
    n_train = round(n_data * ratio_train)
    n_test = round(n_data * ratio_test)
    
    permuted_indices = torch.randperm(n_data)
    train_indices = permuted_indices[:n_train]
    test_indices = permuted_indices[n_train:n_train+n_test]
    val_indices = permuted_indices[n_train+n_test:]

    data.train_mask.fill_(False)
    data.test_mask.fill_(False)
    data.val_mask.fill_(False)

    data.train_mask[train_indices] = True
    data.test_mask[test_indices] = True
    data.val_mask[val_indices] = True

    return data

def get_gnn_datasets(dataset, data_path):
    if dataset in ['Cora', 'CiteSeer', 'PubMed']:
        data = datasets.Planetoid(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))[0]
    elif dataset in ['Cora_ML', 'DBLP']:
        data = datasets.CitationFull(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))[0]
        data.train_mask, data.val_mask, data.test_mask \
            = torch.zeros(data.num_nodes, dtype=torch.bool), torch.zeros(data.num_nodes, dtype=torch.bool), torch.zeros(data.num_nodes, dtype=torch.bool)
    elif dataset in ['Cornell', 'Texas', 'Wisconsion']:
        data = datasets.WebKB(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))[0]
        data.train_mask, data.val_mask, data.test_mask \
            = data.train_mask[:,0], data.val_mask[:,0], data.test_mask[:,0]
    elif dataset in ['Roman-empire', 'Amazon-ratings', 'Tolokers']:
        data = datasets.HeterophilousGraphDataset(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))[0]
        data.train_mask, data.val_mask, data.test_mask \
            = data.train_mask[:,0], data.val_mask[:,0], data.test_mask[:,0]
    elif dataset in ['chameleon', 'squirrel']:
        data = datasets.WikipediaNetwork(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))[0]
        data.train_mask, data.val_mask, data.test_mask \
            = data.train_mask[:,0], data.val_mask[:,0], data.test_mask[:,0]
    elif dataset in['CS', 'Physics']:
        data = datasets.Coauthor(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))[0]
        data.train_mask, data.val_mask, data.test_mask \
            = torch.zeros(data.num_nodes, dtype=torch.bool), torch.zeros(data.num_nodes, dtype=torch.bool), torch.zeros(data.num_nodes, dtype=torch.bool)
    elif dataset in ['Computers', 'Photo']:
        data = datasets.Amazon(data_path, dataset, transform=T.Compose([LargestConnectedComponents(), T.NormalizeFeatures()]))[0]
        data.train_mask, data.val_mask, data.test_mask \
            = torch.zeros(data.num_nodes, dtype=torch.bool), torch.zeros(data.num_nodes, dtype=torch.bool), torch.zeros(data.num_nodes, dtype=torch.bool)
    elif dataset in ['ogbn-arxiv' , 'ogbn-products']:
        data = PygNodePropPredDataset(dataset, root=data_path, transform=T.Compose([T.ToUndirected(), LargestConnectedComponents()]))[0]
        data.train_mask, data.val_mask, data.test_mask \
            = torch.zeros(data.num_nodes, dtype=torch.bool), torch.zeros(data.num_nodes, dtype=torch.bool), torch.zeros(data.num_nodes, dtype=torch.bool)
        data.y = data.y.view(-1)
    return data

def prepareData_oneDS(datapath, data, num_client, batchSize, partition= "dirichlet", seed=None):
    if data  == "Cora":
        num_classes, num_features = 7, 1433
    elif data  == "Cora_ML":
        num_classes, num_features = 7, 2879
    elif data  == "DBLP":
        num_classes, num_features = 4, 1639
    elif data  == "CiteSeer":
        num_classes, num_features = 6, 3703
    elif data  == "PubMed":
        num_classes, num_features = 3, 500
    elif data  == "chameleon":
        num_classes, num_features = 5, 2325
    elif data  == "squirrel":
        num_classes, num_features = 5, 2089
    elif data == "Roman-empire":
        num_classes, num_features = 18, 300
    elif data == "Amazon-ratings":
        num_classes, num_features = 5, 300
    elif data == "Tolokers":
        num_classes, num_features = 2, 10
    elif data  == "Cornell":
        num_classes, num_features = 5, 1703
    elif data  == "Texas":
        num_classes, num_features = 5, 1703
    elif data  == "Wisconsin":
        num_classes, num_features = 5, 1703
    elif data  == 'Computers':
        num_classes, num_features = 10, 767
    elif data  == 'Photo':
        num_classes, num_features = 8, 745
    elif data == 'CS':
        num_classes, num_features = 15, 6805
    elif data == 'Physics':
        num_classes, num_features = 5, 8415
    else:
        raise Exception("Given dataset not implemented yet!")

    splitedData = {}
    pkl_str = f'{data}_disjoint_{partition}/{num_client}/client_label_map.pkl'
   
    #Init global data
    global_d = split_train(get_gnn_datasets(data, datapath), data, datapath, 0.2)
    if num_client == 1:
        loader = DataLoader(dataset= [global_d], batch_size=batchSize, pin_memory=False)
        client_class_map = {'client_0' : [i for i in range(num_classes)]}
        splitedData[0] = {'loader' : loader , 'glob_loader' : loader, 'client_data' : global_d, 'global_data' : global_d, 'tr_mask' : global_d.train_mask, 'val_mask' : global_d.val_mask, 'test_mask' : global_d.test_mask}

        return splitedData, global_d.val_mask,global_d.test_mask ,client_class_map, num_features,  num_classes
    else:
         client_class_map = pickle.load(open(os.path.join(datapath, pkl_str), 'rb'))

    for client_id in range(num_client):
        part = torch_load(datapath, f'{data}_disjoint_{partition}/{num_client}/partition_{client_id}.pt')

        cli_data = part['client_data']
        print(f"For client {client_id}", torch.unique(cli_data.y))
        print(cli_data)
        loader = DataLoader(dataset= [cli_data], batch_size=batchSize, pin_memory=False)
        glob_loader = DataLoader(dataset= [global_d], batch_size=batchSize, pin_memory=False)

        splitedData[client_id] = {'loader' : loader , 'glob_loader' : glob_loader, 'client_data' : cli_data, 'global_data' : global_d, 'tr_mask' : cli_data.train_mask, 'val_mask' : cli_data.val_mask, 'test_mask' : cli_data.test_mask}
    
    return splitedData, global_d.val_mask, global_d.test_mask, client_class_map, num_features, num_classes

def torch_save(base_dir, filename, data):
    os.makedirs(base_dir, exist_ok=True)
    fpath = os.path.join(base_dir, filename)    
    torch.save(data, fpath)

def torch_load(base_dir, filename):
    fpath = os.path.join(base_dir, filename)    
    return torch.load(fpath, map_location=torch.device('cpu'))

def init_params(net: torch.nn.Module):
    """Init layer parameters."""
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            torch.nn.init.kaiming_normal(m.weight, mode="fan_out")
            if m.bias:
                torch.nn.init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            torch.nn.init.constant(m.weight, 1)
            torch.nn.init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            torch.nn.init.normal(m.weight, std=1e-3)
            if m.bias:
                torch.nn.init.constant(m.bias, 0)

def make_gnn_model(
    architecture: str,
    num_layers: int = 2,
    in_channels: int = 1,
    num_hidden : int = 128,
    num_classes: int = 10,
    drop_rate: float = 0.5
    ) -> torch.nn.Module:
    if architecture == "gcn":
        model = GCN(nlayer = num_layers, nfeat = in_channels, nhid = num_hidden,
                     ncls= num_classes, drop_rate = drop_rate)
    elif architecture == "sage":
        model = SAGE(nlayer = num_layers, nfeat = in_channels, nhid = num_hidden,
                     ncls= num_classes, drop_rate = drop_rate)
    else:
        raise ValueError(f'Architecture "{architecture}" not supported.')
    return model

def replace_last_layer(model: torch.nn.Module, architecture: str, num_classes: int = 1):
    if architecture in ["sage" , "gcn" ]:
        model.classifier = nn.Linear(128, num_classes).to(device)
    else:
        raise ValueError(f'Architecture "{architecture}" not supported.')
    return model

class Net(nn.Module):
    def __init__(self, in_channels: int = 1, out_channels: int = 10):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(2048, 128)
        self.fc2 = nn.Linear(128, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

def client_update(
    client_model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_loader: DataLoader,
    train_mask: torch.Tensor,
    epoch: int = 5,
) -> float:

    """Train a client_model on the train_loader data."""
    client_model.train()
    client_model.to(device)
    for e in range(epoch):
        total_loss = 0
        for batch_idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            optimizer.zero_grad()
            _, out, _ = client_model(batch.x, batch.edge_index)
            loss = F.nll_loss(out[train_mask], batch.y[torch.where(train_mask)])

            loss.backward(retain_graph=True)
            optimizer.step()
            total_loss += loss.item() 
    return total_loss / len(train_loader)

def local_gen_update(
    local_gen_model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_loader: DataLoader,
    privacy_engine = None,
    epoch: int = 5,
    sparsity_level=0.05,
    beta=1,
):

    local_gen_model.train()

    total_loss = 0
    for e in range(epoch):
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            if privacy_engine:
                loss = local_gen_model._module.loss(batch.x, batch.edge_index, e, epoch, sparsity_level, beta)
            else:
                loss = local_gen_model.loss(batch.x, batch.edge_index, e, epoch, sparsity_level, beta)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() 
    
    for batch in train_loader:
        batch = batch.to(device)
        latent_space, generated_x = local_gen_model(batch.x, batch.edge_index)        

    return total_loss / len(train_loader), latent_space, generated_x


def generate_node_feat_from_cluster_centers(local_gen, centers, privacy_engine=None):
    loader = DataLoader(dataset= [centers], batch_size=len(centers), pin_memory=False)
    
    for batch in loader:
        batch = batch.to(device)
        if privacy_engine:
            generated_feats = local_gen._module.decode_latent(batch)
        else:
            generated_feats = local_gen.decode_latent(batch)
        return generated_feats

def local_gvae_update(
    local_gvae_model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_loader: DataLoader,
    client_data,
    all_edge_index,
    epoch: int = 5,
) -> float:

    local_gvae_model.train()
    for _ in range(epoch):
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            loss = local_gvae_model.loss(batch.x, batch.train_pos_edge_index, all_edge_index)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() 

    return total_loss / len(train_loader)

def local_edge_pred_update(
    local_edge_pred_model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    client_data,
    edge_index,
    epoch: int = 1,
) -> float:

    local_edge_pred_model.train()
    total_loss = 0
    
    client_data = client_data.to(device)
    edge_index = edge_index.to(device)

    for _ in range(epoch):
        optimizer.zero_grad()
        z = local_edge_pred_model(client_data.x, edge_index)
        pos_edge_index = client_data.train_pos_edge_index
        neg_edge_index = negative_sampling(
                edge_index=pos_edge_index,
                num_nodes=client_data.num_nodes,
                num_neg_samples=pos_edge_index.size(1)
            )
        loss = local_edge_pred_model.loss(z, pos_edge_index, neg_edge_index)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() 

    return total_loss / epoch

def average_models(
    global_model: torch.nn.Module, client_models: List[torch.nn.Module]
) -> None:
    """Average models across all clients."""
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = (
            torch.stack(
                [client_models[i].state_dict()[k] for i in range(len(client_models))],
                0,
            )
            .float()
            .mean(0)
        )
    global_model.load_state_dict(global_dict)

def evaluate_model(
    model: torch.nn.Module,
    data_loader: DataLoader,
    mask: torch.Tensor,
    return_logits: bool = False,
    return_full_batch = False,
) -> Union[Tuple[float, float], Tuple[float, float, torch.Tensor, torch.Tensor]]:
    """Compute loss and accuracy of a single model on a data_loader."""
    model = model.to(device)
    model.eval()
    loss = 0
    correct = 0
    total = 0
    total_samples = 0
    logits, targets = [], []
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            batch = batch.to(device)
            y_out , out , _ = model(batch.x, batch.edge_index)
            loss += F.nll_loss(out[mask], batch.y[mask], reduction="sum").item()  
            _, pred = torch.max(out, dim=1) 
            correct += (pred[mask] == batch.y[mask]).sum().item()
            total += batch.y[mask].size(0)
            total_samples += batch.x[mask].size(0)

            #TODO: CHECK HERE 
            if return_logits:
                logits.append(y_out[mask].detach().cpu())
                targets.append(batch.y[mask].detach().cpu())

            if return_full_batch:
                logits.append(y_out.detach().cpu())
                targets.append(batch.y.detach().cpu())

    acc = correct / total
    total_loss = loss / total_samples
    acc = acc

    if return_logits or return_full_batch:
        return (
            total_loss,
            acc,
            torch.cat(logits),
            torch.cat(targets),
        )
    else:
        return total_loss, acc

class LargestConnectedComponents(BaseTransform):
    r"""Selects the subgraph that corresponds to the
    largest connected components in the graph.

    Args:
        num_components (int, optional): Number of largest components to keep
            (default: :obj:`1`)
    """
    def __init__(self, num_components: int = 1):
        self.num_components = num_components

    def __call__(self, data: Data) -> Data:
        import numpy as np
        import scipy.sparse as sp

        adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes)

        num_components, component = sp.csgraph.connected_components(adj)

        if num_components <= self.num_components:
            return data

        _, count = np.unique(component, return_counts=True)
        subset = np.in1d(component, count.argsort()[-self.num_components:])

        return data.subgraph(torch.from_numpy(subset).to(torch.bool))

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.num_components})'


def graph_mend_vgae(client_data, all_edge_index, edge_predictions, gen_node_features, device):
    
    print(all_edge_index.shape)    
    print(client_data)
    node_feats = client_data.x
    new_node_feats = torch.vstack((node_feats, gen_node_features))
    edge_index_row = all_edge_index[0].to(device)
    edge_index_col = all_edge_index[1].to(device)

    new_row = []
    new_col = []

    k = len(node_feats)
    for i in range(len(node_feats)):
        for j in range(len(gen_node_features)):
            if edge_predictions[i][j] == True:
                new_row.extend([i,k+j])
                new_col.extend([k+j,i])

    new_row = torch.Tensor(new_row).to(device)
    new_col = torch.Tensor(new_col).to(device)

    new_edge_index_row = torch.cat((edge_index_row, new_row))
    new_edge_index_col = torch.cat((edge_index_col, new_col))

    new_edge_index = torch.vstack((new_edge_index_row, new_edge_index_col)).type(torch.int64).to(device)

    new_mask = torch.zeros(gen_node_features.shape[0], dtype=torch.bool).to(device)

    new_train_mask = torch.cat((client_data.train_mask, new_mask))
    new_val_mask = torch.cat((client_data.val_mask, new_mask))
    new_test_mask = torch.cat((client_data.test_mask, new_mask))

    temp_y = torch.zeros(gen_node_features.shape[0]).type(torch.int64).to(device) - 1

    new_y = torch.cat((client_data.y, temp_y))

    new_client_data = Data( x=new_node_feats, edge_index=new_edge_index, y=new_y, train_mask=new_train_mask,
                            val_mask=new_val_mask, test_mask=new_test_mask)


    print(new_client_data)
    print("---")
    return new_client_data

def print_set_sizes_corrected(num_owners, mended_graph_list, classifier, mended_graph_loaders, alpha = None, t_digest=None):
    temperatures = []
    N = 0
    for i in range(num_owners):
        val_msk = torch.where(mended_graph_list[i].val_mask)[0].cpu()

        _, _, val_scores, val_targets = evaluate_model(classifier, mended_graph_loaders[i], val_msk, return_logits=True)

        val_scores = val_scores
        val_targets = val_targets
        N += len(val_targets)

        T = tune_temp(val_scores, val_targets)
        temperatures.append(T)

    T = sum(temperatures) / len(temperatures)

    if t_digest:
        digest_lac = TDigest()
        digest_aps = TDigest()
        digest_raps = TDigest()
    else:
        lacs_list, aps_list, raps_list = [], [], []

    for i in range(num_owners):
        val_msk = torch.where(mended_graph_list[i].val_mask)[0].cpu()
        _, _, val_scores, val_targets = evaluate_model(classifier, mended_graph_loaders[i], val_msk, return_logits=True)

        val_scores = torch.softmax(val_scores / T, 1)

        if t_digest:
            client_digest_lac = TDigest()
            client_digest_aps = TDigest()
            client_digest_raps = TDigest()

            q_lac, score_dist_lac = calibrate_lac(val_scores, val_targets, alpha=alpha, return_dist=True)
            client_digest_lac.batch_update(score_dist_lac.numpy())
            digest_lac += client_digest_lac
            
            q_aps, score_dist_aps = calibrate_aps(val_scores, val_targets, alpha=alpha, return_dist=True)
            client_digest_aps.batch_update(score_dist_aps.numpy())
            digest_aps += client_digest_aps

            q_raps, score_dist_raps = calibrate_raps(val_scores, val_targets, alpha=alpha, return_dist=True)
            client_digest_raps.batch_update(score_dist_raps.numpy())
            digest_raps += client_digest_raps
        else:
            q_lac = calibrate_lac(val_scores, val_targets, alpha=alpha)
            lacs_list.append(q_lac)
            q_aps = calibrate_aps(val_scores, val_targets, alpha=alpha)
            aps_list.append(q_aps.item())
            q_raps = calibrate_raps(val_scores, val_targets, alpha=alpha)
            raps_list.append(q_raps)

    if t_digest:
        K = num_owners
        t = np.ceil((N + K) * (1 - alpha)) / N
        q_lac = digest_lac.percentile(round(100*t))
        q_aps = digest_aps.percentile(round(100*t))
        q_raps = digest_raps.percentile(round(100*t))
    else:
        q_lac = sum(lacs_list)/len(lacs_list)
        q_aps = sum(aps_list)/len(aps_list)
        q_raps = sum(raps_list)/len(raps_list)

    ineff_list1, ineff_list2, ineff_list3 = [], [], []
    cov_list1, cov_list2, cov_list3 = [], [], []

    for i in range(num_owners):
        tst_msk = torch.where(mended_graph_list[i].test_mask)[0].cpu()
        _, _, test_scores, test_targets = evaluate_model(classifier, mended_graph_loaders[i], tst_msk, return_logits=True)
        test_scores = test_scores
        test_scores = torch.softmax(test_scores / T, 1)

        psets_lac = inference_lac(test_scores, q_lac) 
        coverage = get_coverage(psets_lac, test_targets)
        cov_list1.append(coverage)
        size_lac = psets_lac.sum(1).float().mean().item()
        ineff_list1.append(size_lac)

        psets_aps = inference_aps(test_scores, q_aps)
        coverage = get_coverage(psets_aps, test_targets)
        cov_list2.append(coverage)
        size_aps = psets_aps.sum(1).float().mean().item()
        ineff_list2.append(size_aps)

        psets_raps = inference_raps(test_scores, q_raps)
        coverage = get_coverage(psets_raps, test_targets)
        cov_list3.append(coverage)
        size_raps = psets_raps.sum(1).float().mean().item()
        ineff_list3.append(size_raps)
    
    print("Conformal Score Evaluation: ", 1-alpha)
    print("coverages, lac, aps, raps: ", np.mean(cov_list1), np.mean(cov_list2), np.mean(cov_list3))
    return np.mean(ineff_list1), np.mean(ineff_list2), np.mean(ineff_list3)

def cfgnn_update(client_model, optimizer, client_data, data_loader, mask, epochs, counter):
    alpha = 0.05
    tau = 0.1
    target_size = 1

    train_idx = np.where(mask.cpu())[0]
    np.random.seed(epochs) 
    np.random.shuffle(train_idx) # ???
    train_train_idx = train_idx[:int(len(train_idx)/2)]
    train_calib_idx = train_idx[int(len(train_idx)/2):]
    train_test_idx = train_train_idx

    client_model = client_model.to(device)
    client_model.train()
    optimizer.zero_grad()
    outs = []
    out_softmaxs = []
    for i, data in enumerate(data_loader):
        data = data.to(device)
        out, _, out_softmax = client_model(data.x, data.edge_index)
        outs.append(out.cpu())
        out_softmaxs.append(out_softmax.cpu())
    
    outs = torch.cat(outs).to(device)
    out_softmaxs = torch.cat(out_softmaxs)

    n_temp = len(train_calib_idx)
    q_level = np.floor((n_temp+1)*(1-alpha))/n_temp

    tps_conformal_score = out_softmax[train_calib_idx][torch.arange(len(train_calib_idx)), client_data.y[train_calib_idx]]
    qhat = torch.quantile(tps_conformal_score, 1 - q_level, dim=0, keepdim=True,  interpolation='higher')

    c = torch.sigmoid((out_softmax[train_test_idx] - qhat)/tau)
    size_loss = torch.mean(torch.relu(torch.sum(c, axis = 1) - target_size))
    
    pred_loss = F.cross_entropy(outs[train_train_idx], client_data.y[train_train_idx])

    if counter <= int(epochs/2):
        loss = pred_loss
    else:
        loss = pred_loss + 0.1 * size_loss
        
    loss.backward()
                
    optimizer.step()
    loss = float(loss)

    return loss