{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1b4cf135",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import itertools\n",
    "\n",
    "from copy import deepcopy\n",
    "from ucimlrepo import fetch_ucirepo \n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn import svm\n",
    "import sklearn.preprocessing as preprocessing\n",
    "from scipy.stats import beta\n",
    "from sklearn.utils import shuffle\n",
    "from sklearn import metrics\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "24e6e75c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defining True Parameters\n",
    "mu_1_group1_cluster1 = 12\n",
    "mu_0_group1_cluster1 = 8\n",
    "mu_1_group0_cluster1 = 11\n",
    "mu_0_group0_cluster1 = 7\n",
    "\n",
    "mu_1_group1_cluster11 = 12\n",
    "mu_0_group1_cluster11 = 9\n",
    "mu_1_group0_cluster11 = 12\n",
    "mu_0_group0_cluster11 = 9\n",
    "\n",
    "mu_1_group1_cluster111 = 11.5\n",
    "mu_0_group1_cluster111 = 8.5\n",
    "mu_1_group0_cluster111 = 11.5\n",
    "mu_0_group0_cluster111 = 8.5\n",
    "\n",
    "mu_1_group1_cluster1111 = 10.5\n",
    "mu_0_group1_cluster1111 = 7.5\n",
    "mu_1_group0_cluster1111 = 10.5\n",
    "mu_0_group0_cluster1111 = 7.5\n",
    "\n",
    "mu_1_group1_cluster0 = 8\n",
    "mu_0_group1_cluster0 = 6\n",
    "mu_1_group0_cluster0 = 8\n",
    "mu_0_group0_cluster0 = 6\n",
    "\n",
    "mu_1_group1_cluster00 = 7.5\n",
    "mu_0_group1_cluster00 = 5.5\n",
    "mu_1_group0_cluster00 = 7.5\n",
    "mu_0_group0_cluster00 = 5.5\n",
    "\n",
    "mu_1_group1_cluster000 = 12\n",
    "mu_0_group1_cluster000 = 8\n",
    "mu_1_group0_cluster000 = 11\n",
    "mu_0_group0_cluster000 = 7\n",
    "\n",
    "mu_1_group1_cluster0000 = 11\n",
    "mu_0_group1_cluster0000 = 8\n",
    "mu_1_group0_cluster0000 = 11\n",
    "mu_0_group0_cluster0000 = 8\n",
    "\n",
    "sd_1 = 1\n",
    "sd_0 = 1\n",
    "\n",
    "pop_prob = 0.1\n",
    "size = 1200 \n",
    "batch_size = 32  \n",
    "shuffle = True  \n",
    "\n",
    "iid_train = []\n",
    "iid_test = []\n",
    "\n",
    "    \n",
    "# Generate initial training data \n",
    "data_1_group1_cluster1 = np.random.normal(mu_1_group1_cluster1,sd_1,size=size)\n",
    "data_0_group1_cluster1 = np.random.normal(mu_0_group1_cluster1,sd_1,size=size)\n",
    "data_1_group0_cluster1 = np.random.normal(mu_1_group0_cluster1,sd_1,size=size)\n",
    "data_0_group0_cluster1 = np.random.normal(mu_0_group0_cluster1,sd_1,size=size)\n",
    "\n",
    "data_1_group1_cluster11 = np.random.normal(mu_1_group1_cluster11,sd_1,size=size)\n",
    "data_0_group1_cluster11 = np.random.normal(mu_0_group1_cluster11,sd_1,size=size)\n",
    "data_1_group0_cluster11 = np.random.normal(mu_1_group0_cluster11,sd_1,size=size)\n",
    "data_0_group0_cluster11 = np.random.normal(mu_0_group0_cluster11,sd_1,size=size)\n",
    "\n",
    "data_1_group1_cluster111 = np.random.normal(mu_1_group1_cluster111,sd_1,size=size)\n",
    "data_0_group1_cluster111 = np.random.normal(mu_0_group1_cluster111,sd_1,size=size)\n",
    "data_1_group0_cluster111 = np.random.normal(mu_1_group0_cluster111,sd_1,size=size)\n",
    "data_0_group0_cluster111 = np.random.normal(mu_0_group0_cluster111,sd_1,size=size)\n",
    "\n",
    "data_1_group1_cluster1111 = np.random.normal(mu_1_group1_cluster1111,sd_1,size=size)\n",
    "data_0_group1_cluster1111 = np.random.normal(mu_0_group1_cluster1111,sd_1,size=size)\n",
    "data_1_group0_cluster1111 = np.random.normal(mu_1_group0_cluster1111,sd_1,size=size)\n",
    "data_0_group0_cluster1111 = np.random.normal(mu_0_group0_cluster1111,sd_1,size=size)\n",
    "\n",
    "data_1_group1_cluster0 = np.random.normal(mu_1_group1_cluster0,sd_0,size=size)\n",
    "data_0_group1_cluster0 = np.random.normal(mu_0_group1_cluster0,sd_0,size=size)\n",
    "data_1_group0_cluster0 = np.random.normal(mu_1_group0_cluster0,sd_0,size=size)\n",
    "data_0_group0_cluster0 = np.random.normal(mu_0_group0_cluster0,sd_0,size=size)\n",
    "\n",
    "data_1_group1_cluster00 = np.random.normal(mu_1_group1_cluster00,sd_0,size=size)\n",
    "data_0_group1_cluster00 = np.random.normal(mu_0_group1_cluster00,sd_0,size=size)\n",
    "data_1_group0_cluster00 = np.random.normal(mu_1_group0_cluster00,sd_0,size=size)\n",
    "data_0_group0_cluster00 = np.random.normal(mu_0_group0_cluster00,sd_0,size=size)\n",
    "\n",
    "data_1_group1_cluster000 = np.random.normal(mu_1_group1_cluster000,sd_0,size=size)\n",
    "data_0_group1_cluster000 = np.random.normal(mu_0_group1_cluster000,sd_0,size=size)\n",
    "data_1_group0_cluster000 = np.random.normal(mu_1_group0_cluster000,sd_0,size=size)\n",
    "data_0_group0_cluster000 = np.random.normal(mu_0_group0_cluster000,sd_0,size=size)\n",
    "\n",
    "data_1_group1_cluster0000 = np.random.normal(mu_1_group1_cluster0000,sd_0,size=size)\n",
    "data_0_group1_cluster0000 = np.random.normal(mu_0_group1_cluster0000,sd_0,size=size)\n",
    "data_1_group0_cluster0000 = np.random.normal(mu_1_group0_cluster0000,sd_0,size=size)\n",
    "data_0_group0_cluster0000 = np.random.normal(mu_0_group0_cluster0000,sd_0,size=size)\n",
    "\n",
    "def generate_data (data1, data2, group1, data3, data4, group2):\n",
    "    \n",
    "    X1 = data1.T  # Label 1\n",
    "    X2 = data2.T  # Label 0\n",
    "    X_1 = np.hstack((X1, X2))\n",
    "    if group1 == 1:\n",
    "        s_1 = np.ones((len(X_1), 1))\n",
    "    else:\n",
    "        s_1 = np.zeros((len(X_1), 1))\n",
    "    y_1 = np.vstack((np.ones((len(X1), 1)), np.zeros((len(X2), 1))))\n",
    "    \n",
    "    # Convert to torch tensors\n",
    "    X_1 = torch.tensor(X_1, dtype=torch.float32).view(-1, 1)\n",
    "    s_1 = torch.tensor(s_1, dtype=torch.float32).view(-1, 1)\n",
    "    y_1 = torch.tensor(y_1, dtype=torch.float32)\n",
    "    \n",
    "    X3 = data3.T  # Label 1\n",
    "    X4 = data4.T  # Label 0\n",
    "    X_2 = np.hstack((X3, X4))\n",
    "    if group2 == 1:\n",
    "        s_2 = np.ones((len(X_2), 1))\n",
    "    else:\n",
    "        s_2 = np.zeros((len(X_2), 1))\n",
    "    y_2 = np.vstack((np.ones((len(X3), 1)), np.zeros((len(X4), 1))))\n",
    "    \n",
    "    # Convert to torch tensors\n",
    "    X_2 = torch.tensor(X_2, dtype=torch.float32).view(-1, 1)\n",
    "    s_2 = torch.tensor(s_2, dtype=torch.float32).view(-1, 1)\n",
    "    y_2 = torch.tensor(y_2, dtype=torch.float32)\n",
    "    \n",
    "    (X_1_train, X_1_test, \n",
    "     s_1_train, s_1_test, \n",
    "     y_1_train, y_1_test) = train_test_split(X_1, s_1, y_1, test_size=400, random_state=42)\n",
    "\n",
    "    (X_2_train, X_2_test, \n",
    "     s_2_train, s_2_test, \n",
    "     y_2_train, y_2_test) = train_test_split(X_2, s_2, y_2, test_size=400, random_state=42)\n",
    "    \n",
    "    combined_X_train = torch.cat((X_1_train, X_2_train), dim=0)\n",
    "    combined_s_train = torch.cat((s_1_train, s_2_train), dim=0)\n",
    "    combined_y_train = torch.cat((y_1_train, y_2_train), dim=0)\n",
    "\n",
    "    combined_X_test = torch.cat((X_1_test, X_2_test), dim=0)\n",
    "    combined_s_test = torch.cat((s_1_test, s_2_test), dim=0)\n",
    "    combined_y_test = torch.cat((y_1_test, y_2_test), dim=0)\n",
    "\n",
    "\n",
    "    return combined_X_train, combined_X_test, combined_s_train, combined_s_test, combined_y_train, combined_y_test\n",
    "\n",
    "# generate client 1 and 2 in cluster 1\n",
    "(X1_train_cluster1, X1_test_cluster1, \n",
    " s1_train_cluster1, s1_test_cluster1, \n",
    " y1_train_cluster1, y1_test_cluster1) = generate_data(data_1_group1_cluster1, data_0_group1_cluster1, 1,\n",
    "                                                      data_1_group0_cluster1, data_0_group0_cluster1, 0)\n",
    "\n",
    "\n",
    "\n",
    "(X1_train_cluster11, X1_test_cluster11, \n",
    " s1_train_cluster11, s1_test_cluster11, \n",
    " y1_train_cluster11, y1_test_cluster11) = generate_data(data_1_group1_cluster11, data_0_group1_cluster11, 1,\n",
    "                                                        data_1_group0_cluster11, data_0_group0_cluster11, 0)\n",
    "\n",
    "(X1_train_cluster111, X1_test_cluster111, \n",
    " s1_train_cluster111, s1_test_cluster111, \n",
    " y1_train_cluster111, y1_test_cluster111) = generate_data(data_1_group1_cluster111, data_0_group1_cluster111, 1,\n",
    "                                                        data_1_group0_cluster111, data_0_group0_cluster111, 0)\n",
    "\n",
    "(X1_train_cluster1111, X1_test_cluster1111, \n",
    " s1_train_cluster1111, s1_test_cluster1111, \n",
    " y1_train_cluster1111, y1_test_cluster1111) = generate_data(data_1_group1_cluster1111, data_0_group1_cluster1111, 1,\n",
    "                                                        data_1_group0_cluster1111, data_0_group0_cluster1111, 0)\n",
    "\n",
    "(X1_train_cluster0, X1_test_cluster0, \n",
    " s1_train_cluster0, s1_test_cluster0, \n",
    " y1_train_cluster0, y1_test_cluster0) = generate_data(data_1_group1_cluster0, data_0_group1_cluster0, 1,\n",
    "                                                      data_1_group0_cluster0, data_0_group0_cluster0, 0)\n",
    "\n",
    "\n",
    "\n",
    "(X1_train_cluster00, X1_test_cluster00, \n",
    " s1_train_cluster00, s1_test_cluster00, \n",
    " y1_train_cluster00, y1_test_cluster00) = generate_data(data_1_group1_cluster00, data_0_group1_cluster00, 1,\n",
    "                                                        data_1_group0_cluster00, data_0_group0_cluster00, 0)\n",
    "\n",
    "(X1_train_cluster000, X1_test_cluster000, \n",
    " s1_train_cluster000, s1_test_cluster000, \n",
    " y1_train_cluster000, y1_test_cluster000) = generate_data(data_1_group1_cluster000, data_0_group1_cluster000, 1,\n",
    "                                                        data_1_group0_cluster000, data_0_group0_cluster000, 0)\n",
    "\n",
    "(X1_train_cluster0000, X1_test_cluster0000, \n",
    " s1_train_cluster0000, s1_test_cluster0000, \n",
    " y1_train_cluster0000, y1_test_cluster0000) = generate_data(data_1_group1_cluster0000, data_0_group1_cluster0000, 1,\n",
    "                                                        data_1_group0_cluster0000, data_0_group0_cluster0000, 0)\n",
    "\n",
    "# Create a dataset and a dataloader\n",
    "def data_loader (X,s,y,batch_size):\n",
    "    \n",
    "    train = TensorDataset(X,s,y)\n",
    "    train_loader = DataLoader(dataset=train, batch_size=batch_size, shuffle=True)\n",
    "    return train_loader\n",
    "\n",
    "\n",
    "for j in range(2):\n",
    "    X_train = globals()[f'X{1}_train_cluster{j}']\n",
    "    s_train = globals()[f's{1}_train_cluster{j}']\n",
    "    y_train = globals()[f'y{1}_train_cluster{j}']\n",
    "\n",
    "    result_train = data_loader(X_train, s_train, y_train, batch_size)\n",
    "        \n",
    "    X_test = globals()[f'X{1}_test_cluster{j}']\n",
    "    s_test = globals()[f's{1}_test_cluster{j}']\n",
    "    y_test = globals()[f'y{1}_test_cluster{j}']\n",
    "\n",
    "    result_test = data_loader(X_test, s_test, y_test, batch_size)\n",
    "        \n",
    "    iid_train.append(result_train)\n",
    "    iid_test.append(result_test)\n",
    "\n",
    "for j in range(2):\n",
    "    X_train = globals()[f'X{1}_train_cluster{j}{j}']\n",
    "    s_train = globals()[f's{1}_train_cluster{j}{j}']\n",
    "    y_train = globals()[f'y{1}_train_cluster{j}{j}']\n",
    "\n",
    "    result_train = data_loader(X_train, s_train, y_train, batch_size)\n",
    "        \n",
    "    X_test = globals()[f'X{1}_test_cluster{j}{j}']\n",
    "    s_test = globals()[f's{1}_test_cluster{j}{j}']\n",
    "    y_test = globals()[f'y{1}_test_cluster{j}{j}']\n",
    "\n",
    "    result_test = data_loader(X_test, s_test, y_test, batch_size)\n",
    "        \n",
    "    iid_train.append(result_train)\n",
    "    iid_test.append(result_test)\n",
    "\n",
    "for j in range(2):\n",
    "    X_train = globals()[f'X{1}_train_cluster{j}{j}{j}']\n",
    "    s_train = globals()[f's{1}_train_cluster{j}{j}{j}']\n",
    "    y_train = globals()[f'y{1}_train_cluster{j}{j}{j}']\n",
    "\n",
    "    result_train = data_loader(X_train, s_train, y_train, batch_size)\n",
    "        \n",
    "    X_test = globals()[f'X{1}_test_cluster{j}{j}{j}']\n",
    "    s_test = globals()[f's{1}_test_cluster{j}{j}{j}']\n",
    "    y_test = globals()[f'y{1}_test_cluster{j}{j}{j}']\n",
    "\n",
    "    result_test = data_loader(X_test, s_test, y_test, batch_size)\n",
    "        \n",
    "    iid_train.append(result_train)\n",
    "    iid_test.append(result_test)\n",
    "    \n",
    "for j in range(2):\n",
    "    X_train = globals()[f'X{1}_train_cluster{j}{j}{j}{j}']\n",
    "    s_train = globals()[f's{1}_train_cluster{j}{j}{j}{j}']\n",
    "    y_train = globals()[f'y{1}_train_cluster{j}{j}{j}{j}']\n",
    "\n",
    "    result_train = data_loader(X_train, s_train, y_train, batch_size)\n",
    "        \n",
    "    X_test = globals()[f'X{1}_test_cluster{j}{j}{j}{j}']\n",
    "    s_test = globals()[f's{1}_test_cluster{j}{j}{j}{j}']\n",
    "    y_test = globals()[f'y{1}_test_cluster{j}{j}{j}{j}']\n",
    "\n",
    "    result_test = data_loader(X_test, s_test, y_test, batch_size)\n",
    "        \n",
    "    iid_train.append(result_train)\n",
    "    iid_test.append(result_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8360039",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
