{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5ac50309-7b21-4c40-828b-20826b9c9ba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Some initializations\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import numpy as np\n",
    "import math\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "import lingam\n",
    "import pickle\n",
    "import warnings\n",
    "import itertools\n",
    "import loli\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ba993231-faff-4b67-962f-c7add495c7cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Via Stack Overflow\n",
    "# https://stackoverflow.com/questions/11130156/suppress-stdout-stderr-print-from-python-functions\n",
    "# Supressing the output of annyoing libraries\n",
    "from contextlib import contextmanager,redirect_stderr,redirect_stdout\n",
    "from os import devnull\n",
    "\n",
    "@contextmanager\n",
    "def suppress_stdout_stderr():\n",
    "    \"\"\"A context manager that redirects stdout and stderr to devnull\"\"\"\n",
    "    with open(devnull, 'w') as fnull:\n",
    "        with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:\n",
    "            yield (err, out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "271e9074-7ca3-4b0d-a978-5334cc1b9a06",
   "metadata": {},
   "source": [
    "### L-ICP in different scenarios"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "75855b7c-40af-47f6-817e-74504bd31bb7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample size 1 of 7\n",
      "run 0\n",
      "run 10\n",
      "run 20\n",
      "run 30\n",
      "run 40\n",
      "run 50\n",
      "run 60\n",
      "run 70\n",
      "run 80\n",
      "run 90\n",
      "run 100\n",
      "run 110\n",
      "run 120\n",
      "run 130\n",
      "run 140\n",
      "run 150\n",
      "run 160\n",
      "run 170\n",
      "run 180\n",
      "run 190\n",
      "run 200\n",
      "run 210\n",
      "run 220\n",
      "run 230\n",
      "run 240\n",
      "run 250\n",
      "run 260\n",
      "run 270\n",
      "run 280\n",
      "run 290\n",
      "Sample size 2 of 7\n",
      "run 0\n",
      "run 10\n",
      "run 20\n",
      "run 30\n",
      "run 40\n",
      "run 50\n",
      "run 60\n",
      "run 70\n",
      "run 80\n",
      "run 90\n",
      "run 100\n",
      "run 110\n",
      "run 120\n",
      "run 130\n",
      "run 140\n",
      "run 150\n",
      "run 160\n",
      "run 170\n",
      "run 180\n",
      "run 190\n",
      "run 200\n",
      "run 210\n",
      "run 220\n",
      "run 230\n",
      "run 240\n",
      "run 250\n",
      "run 260\n",
      "run 270\n",
      "run 280\n",
      "run 290\n",
      "Sample size 3 of 7\n",
      "run 0\n",
      "run 10\n",
      "run 20\n",
      "run 30\n",
      "run 40\n",
      "run 50\n",
      "run 60\n",
      "run 70\n",
      "run 80\n",
      "run 90\n",
      "run 100\n",
      "run 110\n",
      "run 120\n",
      "run 130\n",
      "run 140\n",
      "run 150\n",
      "run 160\n",
      "run 170\n",
      "run 180\n",
      "run 190\n",
      "run 200\n",
      "run 210\n",
      "run 220\n",
      "run 230\n",
      "run 240\n",
      "run 250\n",
      "run 260\n",
      "run 270\n",
      "run 280\n",
      "run 290\n",
      "Sample size 4 of 7\n",
      "run 0\n",
      "run 10\n",
      "run 20\n",
      "run 30\n",
      "run 40\n",
      "run 50\n",
      "run 60\n",
      "run 70\n",
      "run 80\n",
      "run 90\n",
      "run 100\n",
      "run 110\n",
      "run 120\n",
      "run 130\n",
      "run 140\n",
      "run 150\n",
      "run 160\n",
      "run 170\n",
      "run 180\n",
      "run 190\n",
      "run 200\n",
      "run 210\n",
      "run 220\n",
      "run 230\n",
      "run 240\n",
      "run 250\n",
      "run 260\n",
      "run 270\n",
      "run 280\n",
      "run 290\n",
      "Sample size 5 of 7\n",
      "run 0\n",
      "run 10\n",
      "run 20\n",
      "run 30\n",
      "run 40\n",
      "run 50\n",
      "run 60\n",
      "run 70\n",
      "run 80\n",
      "run 90\n",
      "run 100\n",
      "run 110\n",
      "run 120\n",
      "run 130\n",
      "run 140\n",
      "run 150\n",
      "run 160\n",
      "run 170\n",
      "run 180\n",
      "run 190\n",
      "run 200\n",
      "run 210\n",
      "run 220\n",
      "run 230\n",
      "run 240\n",
      "run 250\n",
      "run 260\n",
      "run 270\n",
      "run 280\n",
      "run 290\n",
      "Sample size 6 of 7\n",
      "run 0\n",
      "run 10\n",
      "run 20\n",
      "run 30\n",
      "run 40\n",
      "run 50\n",
      "run 60\n",
      "run 70\n",
      "run 80\n",
      "run 90\n",
      "run 100\n",
      "run 110\n",
      "run 120\n",
      "run 130\n",
      "run 140\n",
      "run 150\n",
      "run 160\n",
      "run 170\n",
      "run 180\n",
      "run 190\n",
      "run 200\n",
      "run 210\n",
      "run 220\n",
      "run 230\n",
      "run 240\n",
      "run 250\n",
      "run 260\n",
      "run 270\n",
      "run 280\n",
      "run 290\n",
      "Sample size 7 of 7\n",
      "run 0\n",
      "run 10\n",
      "run 20\n",
      "run 30\n",
      "run 40\n",
      "run 50\n",
      "run 60\n",
      "run 70\n",
      "run 80\n",
      "run 90\n",
      "run 100\n",
      "run 110\n",
      "run 120\n",
      "run 130\n",
      "run 140\n",
      "run 150\n",
      "run 160\n",
      "run 170\n",
      "run 180\n",
      "run 190\n",
      "run 200\n",
      "run 210\n",
      "run 220\n",
      "run 230\n",
      "run 240\n",
      "run 250\n",
      "run 260\n",
      "run 270\n",
      "run 280\n",
      "run 290\n"
     ]
    }
   ],
   "source": [
    "### Experiment corresponding to Figures 1,2 and 6\n",
    "\n",
    "I=30 # Intervals=Environments\n",
    "np.random.seed(1)\n",
    "d=6  #Dimensionality\n",
    "supp=(1,2) #support indices\n",
    "s=len(supp)  #Number of support entries\n",
    "sample=[8,12,16,20,24,28,32]\n",
    "fntrunc=np.zeros((len(sample)))\n",
    "fptrunc=np.zeros((len(sample)))\n",
    "fntruncbiga=np.zeros((len(sample)))\n",
    "fptruncbiga=np.zeros((len(sample)))\n",
    "fnt=np.zeros((len(sample)))\n",
    "fpt=np.zeros((len(sample)))\n",
    "fntsmalla=np.zeros((len(sample)))\n",
    "fptsmalla=np.zeros((len(sample)))\n",
    "fn=np.zeros((len(sample)))\n",
    "fp=np.zeros((len(sample)))\n",
    "runs=300\n",
    "\n",
    "B=1000 # Bootstrap runs\n",
    "a=[list(itertools.combinations(range(d), k)) for k in range(0,d+1)]\n",
    "subsets = [item for sublist in a for item in sublist]\n",
    "dic={}\n",
    "# We first fix the data for all runs, to eliminate the randomness of this.\n",
    "betas=np.zeros((runs,I,d))\n",
    "for r in range(runs):\n",
    "    for i in range(I):\n",
    "        betas[r,i,supp]=np.random.uniform(low=1,high=5,size=(s))\n",
    "\n",
    "for o,n in enumerate(sample):\n",
    "    print('Sample size',o+1,'of',len(sample))\n",
    "    for r in range(runs):\n",
    "        if r%10==0:\n",
    "            print('run',r)\n",
    "        X=[]\n",
    "        Y=[]\n",
    "        Xtrunc=[]\n",
    "        Ytrunc=[]\n",
    "        Xt=[]\n",
    "        Yt=[]\n",
    "        for i in range(I):\n",
    "            x=np.zeros((n,d))\n",
    "            y=np.zeros((n))\n",
    "            std=np.random.uniform(low=1,high=5,size=(d))\n",
    "            x[:,0]=np.random.normal(scale=std[0],size=(n))\n",
    "            x[:,1]=x[:,0]+np.random.normal(scale=std[1],size=(n))\n",
    "            x[:,2]=0.3*x[:,1]+np.random.normal(scale=std[2],size=(n))\n",
    "            x[:,3]=0.2*x[:,2]+np.random.normal(scale=std[3],size=(n))\n",
    "            x4=np.random.normal(scale=std[4],size=(n))+0.1*x[:,1]\n",
    "            x5=np.random.normal(scale=std[5],size=(n))\n",
    "            # Data for normal noise\n",
    "            e=np.random.multivariate_normal(0*np.ones(n),1.1*np.eye(n))\n",
    "            y=x@betas[r,i,:]+e\n",
    "            x[:,4]=x4+0.3*y\n",
    "            x[:,5]=x5+0.5*y\n",
    "\n",
    "            X.append(x)\n",
    "            Y.append(y)\n",
    "\n",
    "            # Data for uniform noise\n",
    "            e=np.random.uniform(low=-1/2*12**(1/2)*1.1,high=1/2*12**(1/2)*1.1,size=(n))\n",
    "            y=x@betas[r,i,:]+e\n",
    "            x[:,4]=x4+0.3*y\n",
    "            x[:,5]=x5+0.5*y\n",
    "\n",
    "            Xtrunc.append(x)\n",
    "            Ytrunc.append(y)\n",
    "\n",
    "            # Data for standard t noise\n",
    "            e=np.random.standard_t(2*1.1**(2)/(1.1**(2)-1),size=n)\n",
    "            y=x@betas[r,i,:]+e\n",
    "            x[:,4]=x4+0.3*y\n",
    "            x[:,5]=x5+0.5*y\n",
    "\n",
    "            Xt.append(x)\n",
    "            Yt.append(y)\n",
    "\n",
    "        plausibleS=loli.gauss(X,Y,B=B)\n",
    "        if not not plausibleS:\n",
    "            supphat=set.intersection(*plausibleS)\n",
    "            if len(supphat.difference(set(supp)))>0:\n",
    "                fp[o]+=1/runs\n",
    "            if len(set(supp).difference(supphat))>0:\n",
    "                fn[o]+=1/runs\n",
    "        else:\n",
    "            fn[o]+=1/runs\n",
    "\n",
    "\n",
    "\n",
    "        plausibleS=loli.gauss(Xtrunc,Ytrunc,B=B)\n",
    "        if not not plausibleS:\n",
    "            supphat=set.intersection(*plausibleS)\n",
    "            if len(supphat.difference(set(supp)))>0:\n",
    "                fptrunc[o]+=1/runs\n",
    "            if len(set(supp).difference(supphat))>0:\n",
    "                fntrunc[o]+=1/runs\n",
    "        else:\n",
    "            fntrunc[o]+=1/runs\n",
    "\n",
    "        plausibleS=loli.gauss(Xtrunc,Ytrunc,alpha=0.5,B=B)\n",
    "        if not not plausibleS:\n",
    "            supphat=set.intersection(*plausibleS)\n",
    "            if len(supphat.difference(set(supp)))>0:\n",
    "                fptruncbiga[o]+=1/runs\n",
    "            if len(set(supp).difference(supphat))>0:\n",
    "                fntruncbiga[o]+=1/runs\n",
    "        else:\n",
    "            fntruncbiga[o]+=1/runs\n",
    "\n",
    "                \n",
    "        plausibleS=loli.gauss(Xt,Yt,alpha=0.1,B=B)\n",
    "        if not not plausibleS:\n",
    "            supphat=set.intersection(*plausibleS)\n",
    "            if len(supphat.difference(set(supp)))>0:\n",
    "                fpt[o]+=1/runs\n",
    "            if len(set(supp).difference(supphat))>0:\n",
    "                fnt[o]+=1/runs\n",
    "        else:\n",
    "            fnt[o]+=1/runs\n",
    "\n",
    "        plausibleS=loli.gauss(Xt,Yt,alpha=0.01,B=B)\n",
    "        if not not plausibleS:\n",
    "            supphat=set.intersection(*plausibleS)\n",
    "            if len(supphat.difference(set(supp)))>0:\n",
    "                fptsmalla[o]+=1/runs\n",
    "            if len(set(supp).difference(supphat))>0:\n",
    "                fntsmalla[o]+=1/runs\n",
    "        else:\n",
    "            fntsmalla[o]+=1/runs\n",
    "\n",
    "Baseline={}\n",
    "Baseline['fpt']=fpt\n",
    "Baseline['fnt']=fnt\n",
    "Baseline['fptrunc']=fptrunc\n",
    "Baseline['fntrunc']=fntrunc\n",
    "Baseline['fptruncbiga']=fptruncbiga\n",
    "Baseline['fntruncbiga']=fntruncbiga\n",
    "Baseline['fp']=fp\n",
    "Baseline['fn']=fn\n",
    "Baseline['fptsmalla']=fptsmalla\n",
    "Baseline['fntsmalla']=fntsmalla\n",
    "Baseline['sample']=sample\n",
    "\n",
    "with open('BaselineB=1000.pkl', 'wb') as f:\n",
    "    pickle.dump(Baseline, f)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e93cfe76-4dfd-4d3d-a405-b47645ee4199",
   "metadata": {},
   "source": [
    "### Comparison experiments with LiNGAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fefc35b1-bc0f-4efc-bf57-7ebdec11ea04",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Heterogeneity 1 of 7\n",
      "Heterogeneity 2 of 7\n",
      "Heterogeneity 3 of 7\n",
      "Heterogeneity 4 of 7\n",
      "Heterogeneity 5 of 7\n",
      "Heterogeneity 6 of 7\n",
      "Heterogeneity 7 of 7\n"
     ]
    }
   ],
   "source": [
    "### Experiment corresponding to Figure 3\n",
    "\n",
    "### Changing heterogeneity, uniform noise\n",
    "np.random.seed(1)\n",
    "d=6  #Dimensionality\n",
    "supp=(1,2) #support indices\n",
    "s=len(supp)  #Number of support entries\n",
    "hetero=[1.8,1.7,1.6,1.5,1.4,1.2,1]\n",
    "fntrunc=np.zeros((len(hetero)))\n",
    "fptrunc=np.zeros((len(hetero)))\n",
    "fnlingtrunc=np.zeros((len(hetero)))\n",
    "fplingtrunc=np.zeros((len(hetero)))\n",
    "\n",
    "runs=300\n",
    "n=20\n",
    "B=1000 # Bootstrap runs\n",
    "I=30      #Number of Intervals \n",
    "a=[list(itertools.combinations(range(d), k)) for k in range(0,d+1)]\n",
    "subsets = [item for sublist in a for item in sublist]\n",
    "dic={}\n",
    "\n",
    "betas=np.zeros((runs,I,d))\n",
    "\n",
    "for o,h in enumerate(hetero):\n",
    "    print('Heterogeneity',o+1,'of',len(hetero))\n",
    "    for r in range(runs):\n",
    "        Xtrunc=[]\n",
    "        Ytrunc=[]\n",
    "        X_lingtrunc=[]\n",
    "        X_lingt=[]\n",
    "        for i in range(I):\n",
    "            if i<I/2:\n",
    "                betas[r,i,supp]=np.ones((s))\n",
    "                std=2\n",
    "            else:\n",
    "                betas[r,i,supp]=np.ones((s))*h\n",
    "                std=2*h\n",
    "            x=np.zeros((n,d))\n",
    "            x_lingtrunc=np.zeros((n,d+1))\n",
    "            y=np.zeros((n))\n",
    "            x[:,0]=np.random.uniform(low=-1/2*12**(1/2)*std,high=1/2*12**(1/2)*std,size=(n))\n",
    "            x[:,1]=x[:,0]+np.random.uniform(low=-1/2*12**(1/2)*std,high=1/2*12**(1/2)*std,size=(n))\n",
    "            x[:,2]=0.3*x[:,1]+np.random.uniform(low=-1/2*12**(1/2)*std,high=1/2*12**(1/2)*std,size=(n))\n",
    "            x[:,3]=0.2*x[:,2]+np.random.uniform(low=-1/2*12**(1/2)*std,high=1/2*12**(1/2)*std,size=(n))\n",
    "            e=np.random.uniform(low=-1/2*12**(1/2),high=1/2*12**(1/2),size=(n))\n",
    "            y=x@betas[r,i,:]+e\n",
    "            x[:,4]=np.random.uniform(low=-1/2*12**(1/2)*std,high=1/2*12**(1/2)*std,size=(n))+0.1*x[:,1]+0.3*y\n",
    "            x[:,5]=np.random.uniform(low=-1/2*12**(1/2)*std,high=1/2*12**(1/2)*std,size=(n))+0.5*y\n",
    "\n",
    "            Xtrunc.append(x)\n",
    "            Ytrunc.append(y)\n",
    "            x_lingtrunc[:,:d]=copy.copy(x)\n",
    "            x_lingtrunc[:,d]=copy.copy(y)\n",
    "            X_lingtrunc.append(x_lingtrunc)\n",
    "          \n",
    "\n",
    "            \n",
    "        with suppress_stdout_stderr():\n",
    "            model=lingam.MultiGroupDirectLiNGAM()\n",
    "            model.fit(X_lingtrunc)\n",
    "\n",
    "            \n",
    "        plausibleS=loli.gauss(Xtrunc,Ytrunc,alpha=0.1,B=B)\n",
    "        if not not plausibleS:\n",
    "            supphat=set.intersection(*plausibleS)\n",
    "            if len(supphat.difference(set(supp)))>0:\n",
    "                fptrunc[o]+=1/runs\n",
    "            if len(set(supp).difference(supphat))>0:\n",
    "                fntrunc[o]+=1/runs\n",
    "        else:\n",
    "            fntrunc[o]+=1/runs\n",
    "                \n",
    "        lingmat=model.adjacency_matrices_[0][6,:]\n",
    "        lingsupp=np.where(lingmat!=0)\n",
    "        supphatling=set(lingsupp[0])\n",
    "        if len(supphatling.difference(set(supp)))>0:\n",
    "            fplingtrunc[o]+=1/runs\n",
    "        if len(set(supp).difference(supphatling))>0:\n",
    "            fnlingtrunc[o]+=1/runs\n",
    "        \n",
    "\n",
    "ComparisonLingHetero={}\n",
    "ComparisonLingHetero['fptrunc']=fptrunc\n",
    "ComparisonLingHetero['fntrunc']=fntrunc\n",
    "ComparisonLingHetero['fnlingtrunc']=fnlingtrunc\n",
    "ComparisonLingHetero['fplingtrunc']=fplingtrunc\n",
    "ComparisonLingHetero['hetero']=hetero\n",
    "\n",
    "\n",
    "\n",
    "with open('ComparisonLingUniformB=1000.pkl', 'wb') as f:\n",
    "    pickle.dump(ComparisonLingHetero, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e5c3d12e-36b3-47e3-a97f-5a1191db77dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Heterogeneity 1 of 5\n",
      "Heterogeneity 2 of 5\n",
      "Heterogeneity 3 of 5\n",
      "Heterogeneity 4 of 5\n",
      "Heterogeneity 5 of 5\n"
     ]
    }
   ],
   "source": [
    "Experiment corresponding to Figure 4\n",
    "\n",
    "### Changing heterogeneity, normal noise\n",
    "np.random.seed(1)\n",
    "d=6  #Dimensionality\n",
    "supp=(1,2) #support indices\n",
    "s=len(supp)  #Number of support entries\n",
    "hetero=[1,1.2,1.4,1.5,1.6]\n",
    "fnnorm=np.zeros((len(hetero)))\n",
    "fpnorm=np.zeros((len(hetero)))\n",
    "fnlingnorm=np.zeros((len(hetero)))\n",
    "fplingnorm=np.zeros((len(hetero)))\n",
    "\n",
    "runs=300\n",
    "n=20\n",
    "B=1000 # Bootstrap runs\n",
    "I=30      #Number of Intervals \n",
    "a=[list(itertools.combinations(range(d), k)) for k in range(0,d+1)]\n",
    "subsets = [item for sublist in a for item in sublist]\n",
    "dic={}\n",
    "\n",
    "betas=np.zeros((runs,I,d))\n",
    "\n",
    "for o,h in enumerate(hetero):\n",
    "    print('Heterogeneity',o+1,'of',len(hetero))\n",
    "    for r in range(runs):\n",
    "        Xnorm=[]\n",
    "        Ynorm=[]\n",
    "        X_lingnorm=[]\n",
    "        for i in range(I):\n",
    "            if i<I/2:\n",
    "                betas[r,i,supp]=np.ones((s))\n",
    "                std=2\n",
    "            else:\n",
    "                betas[r,i,supp]=np.ones((s))*h\n",
    "                std=2*h\n",
    "            x=np.zeros((n,d))\n",
    "            x_lingnorm=np.zeros((n,d+1))\n",
    "            y=np.zeros((n))\n",
    "            x[:,0]=np.random.normal(0,std,size=(n))\n",
    "            x[:,1]=x[:,0]+np.random.normal(0,std,size=(n))\n",
    "            x[:,2]=0.3*x[:,1]+np.random.normal(0,std,size=(n))\n",
    "            x[:,3]=0.2*x[:,2]+np.random.normal(0,std,size=(n))\n",
    "            e=np.random.normal(0,1,size=(n))\n",
    "            y=x@betas[r,i,:]+e\n",
    "            x[:,4]=np.random.normal(0,std,size=(n))+0.1*x[:,1]+0.3*y\n",
    "            x[:,5]=np.random.normal(0,std,size=(n))+0.5*y\n",
    "\n",
    "            Xnorm.append(x)\n",
    "            Ynorm.append(y)\n",
    "            x_lingnorm[:,:d]=copy.copy(x)\n",
    "            x_lingnorm[:,d]=copy.copy(y)\n",
    "            X_lingnorm.append(x_lingnorm)\n",
    "          \n",
    "\n",
    "            \n",
    "        with suppress_stdout_stderr():\n",
    "            model=lingam.MultiGroupDirectLiNGAM()\n",
    "            model.fit(X_lingnorm)\n",
    "\n",
    "            \n",
    "        plausibleS=loli.gauss(Xnorm,Ynorm,alpha=0.1,B=B)\n",
    "        if not not plausibleS:\n",
    "            supphat=set.intersection(*plausibleS)\n",
    "            if len(supphat.difference(set(supp)))>0:\n",
    "                fpnorm[o]+=1/runs\n",
    "            if len(set(supp).difference(supphat))>0:\n",
    "                fnnorm[o]+=1/runs\n",
    "        else:\n",
    "            fnnorm[o]+=1/runs\n",
    "                \n",
    "        lingmat=model.adjacency_matrices_[0][6,:]\n",
    "        lingsupp=np.where(lingmat!=0)\n",
    "        supphatling=set(lingsupp[0])\n",
    "        if len(supphatling.difference(set(supp)))>0:\n",
    "            fplingnorm[o]+=1/runs\n",
    "        if len(set(supp).difference(supphatling))>0:\n",
    "            fnlingnorm[o]+=1/runs\n",
    "        \n",
    "\n",
    "ComparisonLingHetero={}\n",
    "ComparisonLingHetero['fp']=fpnorm\n",
    "ComparisonLingHetero['fn']=fnnorm\n",
    "ComparisonLingHetero['fnling']=fnlingnorm\n",
    "ComparisonLingHetero['fpling']=fplingnorm\n",
    "ComparisonLingHetero['hetero']=hetero\n",
    "\n",
    "\n",
    "\n",
    "with open('ComparisonLingNormalB=1000.pkl', 'wb') as f:\n",
    "    pickle.dump(ComparisonLingHetero, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d14f2e6a-5a97-4ad2-90e4-0ee6e86922af",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Changing heterogeneity, scaled student-t noise\n",
    "np.random.seed(1)\n",
    "d=6  #Dimensionality\n",
    "supp=(1,2) #support indices\n",
    "s=len(supp)  #Number of support entries\n",
    "hetero=[1,2,5,8]\n",
    "fnt=np.zeros((len(hetero)))\n",
    "fpt=np.zeros((len(hetero)))\n",
    "fnlingt=np.zeros((len(hetero)))\n",
    "fplingt=np.zeros((len(hetero)))\n",
    "\n",
    "runs=100\n",
    "n=20\n",
    "B=1000 # Bootstrap runs\n",
    "I=30      #Number of Intervals \n",
    "a=[list(itertools.combinations(range(d), k)) for k in range(0,d+1)]\n",
    "t=3 # degrees of freedom for student-t\n",
    "subsets = [item for sublist in a for item in sublist]\n",
    "dic={}\n",
    "\n",
    "betas=np.zeros((runs,I,d))\n",
    "\n",
    "for o,h in enumerate(hetero):\n",
    "    print('Heterogeneity',o+1,'of',len(hetero))\n",
    "    for r in range(runs):\n",
    "        Xt=[]\n",
    "        Yt=[]\n",
    "        X_lingt=[]\n",
    "        X_lingt=[]\n",
    "        for i in range(I):\n",
    "            if i<I/2:\n",
    "                betas[r,i,supp]=np.ones((s))\n",
    "                scale=2\n",
    "            else:\n",
    "                betas[r,i,supp]=np.ones((s))*h\n",
    "                scale=2*h\n",
    "            x=np.zeros((n,d))\n",
    "            x_lingt=np.zeros((n,d+1))\n",
    "            y=np.zeros((n))\n",
    "            x[:,0]=scale*np.random.standard_t(t,size=(n))\n",
    "            x[:,1]=scale*x[:,0]+np.random.standard_t(t,size=(n))\n",
    "            x[:,2]=0.3*x[:,1]+scale*np.random.standard_t(t,size=(n))\n",
    "            x[:,3]=0.2*x[:,2]+scale*np.random.standard_t(t,size=(n))\n",
    "            e=np.random.standard_t(t,size=(n))\n",
    "            y=x@betas[r,i,:]+e\n",
    "            x[:,4]=scale*np.random.standard_t(t,size=(n))+0.1*x[:,1]+0.3*y\n",
    "            x[:,5]=scale*np.random.standard_t(t,size=(n))+0.5*y\n",
    "\n",
    "            Xt.append(x)\n",
    "            Yt.append(y)\n",
    "            x_lingt[:,:d]=copy.copy(x)\n",
    "            x_lingt[:,d]=copy.copy(y)\n",
    "            X_lingt.append(x_lingt)\n",
    "          \n",
    "\n",
    "            \n",
    "        with suppress_stdout_stderr():\n",
    "            model=lingam.MultiGroupDirectLiNGAM()\n",
    "            model.fit(X_lingt)\n",
    "\n",
    "            \n",
    "        plausibleS=loli.gauss(Xt,Yt,alpha=0.1,B=B)\n",
    "        if not not plausibleS:\n",
    "            supphat=set.intersection(*plausibleS)\n",
    "            if len(supphat.difference(set(supp)))>0:\n",
    "                fpt[o]+=1/runs\n",
    "            if len(set(supp).difference(supphat))>0:\n",
    "                fnt[o]+=1/runs\n",
    "        else:\n",
    "            fnt[o]+=1/runs\n",
    "                \n",
    "        lingmat=model.adjacency_matrices_[0][6,:]\n",
    "        lingsupp=np.where(lingmat!=0)\n",
    "        supphatling=set(lingsupp[0])\n",
    "        if len(supphatling.difference(set(supp)))>0:\n",
    "            fplingt[o]+=1/runs\n",
    "        if len(set(supp).difference(supphatling))>0:\n",
    "            fnlingt[o]+=1/runs\n",
    "        \n",
    "\n",
    "ComparisonLingHetero={}\n",
    "ComparisonLingHetero['fpt']=fpt\n",
    "ComparisonLingHetero['fnt']=fnt\n",
    "ComparisonLingHetero['fnlingt']=fnlingt\n",
    "ComparisonLingHetero['fplingt']=fplingt\n",
    "ComparisonLingHetero['hetero']=hetero\n",
    "\n",
    "\n",
    "\n",
    "with open('ComparisonLingStudentT=3B=1000.pkl', 'wb') as f:\n",
    "    pickle.dump(ComparisonLingHetero, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9e7e5c1-ddf0-443f-802f-bb1b3368cecf",
   "metadata": {},
   "source": [
    "### Experiments generating restults from Table 1 (comparison to ICP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d12dace-c2e4-496d-907e-c36ed25803b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import causalicp\n",
    "\n",
    "# Sparse environment test\n",
    "runs=300\n",
    "np.random.seed(0)\n",
    "supp=(1,)\n",
    "n=7\n",
    "s=3\n",
    "fp=np.zeros((2))\n",
    "fn=np.zeros((2))\n",
    "\n",
    "for o in range(runs):\n",
    "    print(o)\n",
    "    data=[]\n",
    "    dataX=[]\n",
    "    dataY=[]\n",
    "    for e in range(99):\n",
    "        X=np.random.normal(0,1,size=(n,3))\n",
    "        X[:,2]=X[:,1]+np.random.normal(0,1,size=n)\n",
    "        data.append(X)\n",
    "        dataX.append(X[:,:2])\n",
    "        dataY.append(X[:,2])\n",
    "    X=np.random.normal(0,s,size=(n,3))\n",
    "    X[:,2]=X[:,1]+np.random.normal(0,1,size=n)\n",
    "    data.append(X)\n",
    "    dataX.append(X[:,:2])\n",
    "    dataY.append(X[:,2])\n",
    "    start_time = time.time()\n",
    "    icp=causalicp.fit(data, 2, alpha=0.1, sets=None, precompute=False, verbose=False, color=True)\n",
    "    if icp.estimate:\n",
    "        if len(icp.estimate.difference(set(supp)))>0:\n",
    "            fp[0]+=1/runs\n",
    "        if len(set(supp).difference(icp.estimate))>0:\n",
    "            fn[0]+=1/runs\n",
    "    else:\n",
    "        fn[0]+=1/runs\n",
    "    \n",
    "    start_time = time.time()\n",
    "    plausibleS=loli.gauss(dataX,dataY,alpha=0.1)\n",
    "    if not not plausibleS:\n",
    "        supphat=set.intersection(*plausibleS)\n",
    "        if len(supphat.difference(set(supp)))>0:\n",
    "            fp[1]+=1/runs\n",
    "        if len(set(supp).difference(supphat))>0:\n",
    "            fn[1]+=1/runs\n",
    "    else:\n",
    "        fn[1]+=1/runs\n",
    "\n",
    "print(fp,fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06988c6a-ed9e-468a-8d93-75f23ee3573d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dense environment test\n",
    "runs=300\n",
    "np.random.seed(0)\n",
    "supp=(1,)\n",
    "n=7\n",
    "fpdense=np.zeros((2))\n",
    "fndense=np.zeros((2))\n",
    "\n",
    "for o in range(runs):\n",
    "    print(o)\n",
    "    data=[]\n",
    "    dataX=[]\n",
    "    dataY=[]\n",
    "    for e in range(100):\n",
    " \n",
    "        s=np.random.uniform(low=1,high=5)\n",
    "        X=np.random.normal(0,s,size=(n,3))\n",
    "        X[:,2]=X[:,1]+np.random.normal(0,1,size=n)\n",
    "        data.append(X)\n",
    "        dataX.append(X[:,:2])\n",
    "        dataY.append(X[:,2])\n",
    "    start_time = time.time()\n",
    "    icp=causalicp.fit(data, 2, alpha=0.1, sets=None, precompute=False, verbose=False, color=True)\n",
    "    if icp.estimate:\n",
    "        if len(icp.estimate.difference(set(supp)))>0:\n",
    "            fpdense[0]+=1/runs\n",
    "        if len(set(supp).difference(icp.estimate))>0:\n",
    "            fndense[0]+=1/runs\n",
    "    else:\n",
    "        fndense[0]+=1/runs\n",
    "    \n",
    "    start_time = time.time()\n",
    "    plausibleS=loli.gauss(dataX,dataY,alpha=0.1)\n",
    "    if not not plausibleS:\n",
    "        supphat=set.intersection(*plausibleS)\n",
    "        if len(supphat.difference(set(supp)))>0:\n",
    "            fpdense[1]+=1/runs\n",
    "        if len(set(supp).difference(supphat))>0:\n",
    "            fndense[1]+=1/runs\n",
    "    else:\n",
    "        fndense[1]+=1/runs\n",
    "\n",
    "print(fpdense,fndense)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d71e565b-2959-40ec-8043-735cf00ddacf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ICP violation test\n",
    "\n",
    "runs=300\n",
    "np.random.seed(0)\n",
    "supp=(1,)\n",
    "n=7\n",
    "fpviol=np.zeros((2))\n",
    "fnviol=np.zeros((2))\n",
    "\n",
    "for o in range(runs):\n",
    "    print(o)\n",
    "    data=[]\n",
    "    dataX=[]\n",
    "    dataY=[]\n",
    "    for e in range(100):\n",
    " \n",
    "        s=np.random.uniform(low=1,high=5)\n",
    "        beta=s=np.random.uniform(low=1,high=5)\n",
    "        X=np.random.normal(0,s,size=(n,3))\n",
    "        X[:,2]=beta*X[:,1]+np.random.normal(0,1,size=n)\n",
    "        data.append(X)\n",
    "        dataX.append(X[:,:2])\n",
    "        dataY.append(X[:,2])\n",
    "    start_time = time.time()\n",
    "    icp=causalicp.fit(data, 2, alpha=0.1, sets=None, precompute=False, verbose=False, color=True)\n",
    "    if icp.estimate:\n",
    "        if len(icp.estimate.difference(set(supp)))>0:\n",
    "            fpviol[0]+=1/runs\n",
    "        if len(set(supp).difference(icp.estimate))>0:\n",
    "            fnviol[0]+=1/runs\n",
    "    else:\n",
    "        fnviol[0]+=1/runs\n",
    "    \n",
    "    start_time = time.time()\n",
    "    plausibleS=loli.gauss(dataX,dataY,alpha=0.1)\n",
    "    if not not plausibleS:\n",
    "        supphat=set.intersection(*plausibleS)\n",
    "        if len(supphat.difference(set(supp)))>0:\n",
    "            fpviol[1]+=1/runs\n",
    "        if len(set(supp).difference(supphat))>0:\n",
    "            fnviol[1]+=1/runs\n",
    "    else:\n",
    "        fnviol[1]+=1/runs\n",
    "\n",
    "print(fpviol,fnviol)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
