{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"text-align: center;\">\n",
    "    <img src=\"https://cdn.jsdelivr.net/gh/DishengL/ResearchPics/Causal_inference_LLM.png\" alt=\"Causal Inference LLM\" width=\"350\">\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import base64\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from openai import OpenAI\n",
    "import random\n",
    "import numpy as np\n",
    "import sys\n",
    "from datetime import datetime\n",
    "import csv \n",
    "from utils import info\n",
    "from utils import evaluation\n",
    "import numpy as np\n",
    "from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Functions in scene class:\n",
      "- get_all_scenes\n",
      "- get_scencs_name\n",
      "- get_scene\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('/home/lds/github/Causality-informed-Generation/inference/evaluation/utils')\n",
    "\n",
    "# Import the module\n",
    "import info\n",
    "\n",
    "scene_info = info.scene()\n",
    "functions = [func for func in dir(scene_info) if callable(getattr(scene_info, func)) and not func.startswith(\"__\")]\n",
    "\n",
    "# Print all methods\n",
    "print(\"Functions in scene class:\")\n",
    "for func in functions:\n",
    "    print(f\"- {func}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['Reflection', 'Spring', 'Seesaw', 'seesaw', 'Magnets', 'Convex', 'Parabola', 'Waterflow', 'Pendulum', 'V2', 'V3_V', 'V3_F', 'V4_V', 'V4_F', 'V5', 'V5_F'])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_scenes = scene_info.get_all_scenes()\n",
    "all_scenes.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Magnets\n",
      "{'variables': {0: 'needle_position_x', 1: 'needle_position_y', 2: 'magnetic_bar_orientation', 3: 'needle_orientation'}, 'adjacency_matrix': array([[0, 0, 0, 1],\n",
      "       [0, 0, 0, 1],\n",
      "       [0, 0, 0, 1],\n",
      "       [0, 0, 0, 0]])}\n",
      "Reflection\n",
      "{'variables': {0: 'incident_degree', 1: 'reflection_degree'}, 'adjacency_matrix': array([[0, 1],\n",
      "       [0, 0]])}\n",
      "Convex\n",
      "{'variables': {0: 'the distance from object to the convex len', 1: 'the distance from image to the convex len', 2: 'the Magnification'}, 'adjacency_matrix': array([[0, 1, 1],\n",
      "       [0, 0, 1],\n",
      "       [0, 0, 0]])}\n",
      "['needle orientation rotate in counter-clockwise direction', 'needle orientation rotate in clockwise direction', 'needle orientation will not change']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['A. needle orientation will not change',\n",
       " 'B. needle orientation rotate in counter-clockwise direction',\n",
       " 'C. needle orientation rotate in clockwise direction']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "intervention_scenes = [\"Magnets\", \"Reflection\", \"Convex\"]\n",
    "for scene_name in intervention_scenes:\n",
    "    print(scene_info.get_scene(scene_name))\n",
    "    # print(\"\\n\")\n",
    "    \n",
    "question_table = {\n",
    "  \"Magnets\": \n",
    "  {\n",
    "    'x': \n",
    "    {\"magnetic_bar_orientation\": \n",
    "      {'slightly rotate the orientation of magnetic bar in clockwise direction':\n",
    "        {\"needle orientation rotate in counter-clockwise direction\"},\n",
    "        'slightly rotate the orientation of magnetic bar in counter-clockwise direction' : \n",
    "          {\"needle orientation rotate in clockwise direction\"}}\n",
    "      },\n",
    "    \"options\": \n",
    "    [\"needle orientation rotate in counter-clockwise direction\", \n",
    "              \"needle orientation rotate in clockwise direction\",\n",
    "              'needle orientation will not change']\n",
    "  },\n",
    "  'Convex': \n",
    "  {\n",
    "    'x': \n",
    "    {\"the distance from object to the convex len\": \n",
    "      {\"move the object closer to the convex lens\": {\n",
    "        \"the magnification increases\"},\n",
    "        \"move the object farther from the convex lens\": {\n",
    "          \"the magnification decreases\"}}\n",
    "      },\n",
    "    \"options\": \n",
    "    [\"the magnification increases\", \n",
    "      \"the magnification decreases\",\n",
    "      'the magnification will not change']\n",
    "  },\n",
    "  'Reflection':\n",
    "  {\n",
    "    \"x\":\n",
    "    {'incident_degree':\n",
    "      {\n",
    "        'increase the incident angle': \n",
    "          {\"the reflected angle increases\"},\n",
    "        'decrease the incident angle': \n",
    "          {\"the reflected angle decreases\"}\n",
    "      }\n",
    "      },\n",
    "    \"options\":\n",
    "    [\"the reflected angle increases\",\n",
    "      \"the reflected angle decreases\",\n",
    "      'the reflected angle will not change']\n",
    "  }\n",
    "}\n",
    "\n",
    "import random\n",
    "options = question_table['Magnets']['options']\n",
    "print(options)\n",
    "random.shuffle(options)\n",
    "# concate it with A. B. C. D.'\n",
    "options = [f\"{chr(65+i)}. {option}\" for i, option in enumerate(options)]\n",
    "options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Magnets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/30 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30/30 [03:12<00:00,  6.42s/it]\n"
     ]
    }
   ],
   "source": [
    "def format_causality_prompt(system_info, prompt_text, image_base64_list, scene_name = None):\n",
    "    # Construct messages\n",
    "    question, answer, select, prompt_text = prompt_composition(prompt_text, scene_name)\n",
    "    # print(prompt_text)\n",
    "    # raise ValueError(\"1111Matrix not found in the response\")\n",
    "    \n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": system_info},\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": [\n",
    "                {\"type\": \"text\", \"text\": prompt_text}\n",
    "            ] + [\n",
    "                {\"type\": \"image_url\", \"image_url\": {\"url\": f\"data:image/png;base64,{img_base64}\"}}\n",
    "                for img_base64 in image_base64_list\n",
    "            ]\n",
    "        }\n",
    "    ]\n",
    "    text_prompt = system_info + prompt_text\n",
    "    return messages, text_prompt, question, answer, select, \n",
    "\n",
    "def causal_intervention(client, image_base64_list, prompt_type = None, prompt_info = None, dump = False,\n",
    "                      multi_view = False, \n",
    "                      scene_name = None,\n",
    "                      view = None,\n",
    "                      images = None,\n",
    "                      syn_background = False):\n",
    "    \"\"\"\n",
    "    Makes an API call to determine causal relationships in seesaw system images\n",
    "    Returns a parsed adjacency matrix\n",
    "    \"\"\"\n",
    "    # Construct the messages list with optimized prompt\n",
    "    if prompt_type == \"explicted\":\n",
    "      system_info = \"You are a causal discovery expert. Your objective is to analyze the provided images and identify any causal relationships between the variables. Use the identified relationships to complete the causality adjacency matrix and provide a brief explanation supporting your conclusions.\"\n",
    "    elif prompt_type == \"basic\":\n",
    "      system_info = \"Analyze the provided images and identify causal relationships between the variables. Complete the causality adjacency matrix based on the identified relationships and briefly explain your conclusions\"\n",
    "    messages, text_prompt,question, answer, selec = format_causality_prompt(system_info, prompt_info, image_base64_list,scene_name = scene_name)\n",
    "    # raise ValueError(\"1111Matrix not found in the response\")\n",
    "    # Make the API call\n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4o\",\n",
    "        messages=messages,\n",
    "        max_tokens=500\n",
    "    )\n",
    "    # print( response.choices[0].message.content)\n",
    "\n",
    "    if dump:\n",
    "        text = response.choices[0].message.content\n",
    "        os.makedirs(f\"causal_OpenAI_res/{scene_name}_intervention\", exist_ok=True)\n",
    "        res = {\n",
    "          \"timestamp\": datetime.now().strftime(\"%Y%m%d_%H%M%S\"),\n",
    "          \"prompt_type\": prompt_type,\n",
    "          \"images\": images,\n",
    "          \"message\": text_prompt,\n",
    "          \"response\": response.choices[0].message.content,\n",
    "          \"gpt_ans\": text[text.find(\"The answer is: \") + len(\"The answer is: \"):text.find(\".\", text.find(\"The answer is: \") + len(\"The answer is: \"))].strip(),\n",
    "          \"ground_truth\": selec\n",
    "            \n",
    "          # \"matrix\": extract_matrix(response.choices[0].message.content)\n",
    "  \n",
    "        }\n",
    "        file_path = f\"causal_OpenAI_res/{scene_name}_intervention/{prompt_type}.csv\"\n",
    "        if multi_view:\n",
    "          file_path = f\"causal_OpenAI_res/{scene_name}_intervention/{prompt_type}_multi_view_{view}.csv\"\n",
    "        if syn_background:\n",
    "          file_path = f\"causal_OpenAI_res/{scene_name}_intervention/{prompt_type}_synbackground_{view}.csv\"\n",
    "        file_exists = os.path.exists(file_path)\n",
    "\n",
    "        with open(file_path, mode=\"a\", newline=\"\", encoding=\"utf-8\") as f:\n",
    "                writer = csv.DictWriter(f, fieldnames=[\"timestamp\", \"prompt_type\", \"images\",\"message\", \n",
    "                                                       \"response\",'gpt_ans','ground_truth'])\n",
    "                \n",
    "                # Write header if the file is new\n",
    "                if not file_exists:\n",
    "                    writer.writeheader()\n",
    "                \n",
    "                # Write the data row\n",
    "                writer.writerow(res)\n",
    "    return response.choices[0].message.content\n",
    "\n",
    "\n",
    "def find_ans(answer, options):\n",
    "  for i, option in enumerate(options):\n",
    "    if answer in option:\n",
    "      return option[:1]\n",
    "  return -1\n",
    "\n",
    "def prompt_composition(scene_info_dict, scene_name):\n",
    "  # print(scene_info_dict)\n",
    "  variables = \",\".join(list(scene_info_dict['variables'].values()))\n",
    "  adv = \"between\"\n",
    "  if len(list(scene_info_dict['variables'].values())) > 2: adv = \"among\"\n",
    "  \n",
    "\n",
    "  options = question_table[scene_name]['options']\n",
    "\n",
    "  random.shuffle(options)\n",
    "  # concate it with A. B. C. D.'\n",
    "  options = [f\"{chr(65+i)}. {option}\" for i, option in enumerate(options)]\n",
    "\n",
    "  # options = [scene_info_dict['scene_name']]['options']\n",
    "\n",
    "  question_ans_pair = list(question_table[scene_name]['x'].values())\n",
    "  question = random.choice(list(question_ans_pair[0].keys()))\n",
    "  modified_X = question\n",
    "  answer = question_ans_pair[0][question]\n",
    "  \n",
    "\n",
    "  select = find_ans(list(answer)[0], options)\n",
    "\n",
    "\n",
    "  \n",
    "  # raise ValueError(\"1111Matrix not found in the response\")\n",
    "  template = f\"Observing the provided images. Based on the causal relathionship {adv} the variables: {variables} in the image in the images, now, waht will happen after I {modified_X} in the last image?\\n(your answer should be return in this format first: The answer is: xx. \\nExplain: xxx)\\n\"\n",
    "  \n",
    "  options = \"\\n\".join(options)\n",
    "\n",
    "  prompt = template + options\n",
    "\n",
    "  # scene_info + matrix + matrix_info\n",
    "  return question, answer, select, prompt\n",
    "\n",
    "\n",
    "def encode_image(image_path):\n",
    "    with open(image_path, \"rb\") as image_file:\n",
    "        return base64.b64encode(image_file.read()).decode('utf-8')\n",
    "\n",
    "def get_images(image_dir, img_num, view = -1):\n",
    "  files = os.listdir(image_dir)\n",
    "  files = [os.path.join(image_dir, f) for f in files]\n",
    "  if  view == -1 :\n",
    "    files = random.sample(files, img_num)\n",
    "  else:\n",
    "    # get all of file in files end with view{view}.png\n",
    "    \n",
    "    files = [f for f in files if f.endswith(f\"_{view}.png\")]\n",
    "    files = random.sample(files, img_num)\n",
    "\n",
    "  encoded_files = [encode_image(f) for f in files]\n",
    "  return  files,encoded_files\n",
    "\n",
    "def main(image_dir, scene_name, \n",
    "         strategy,  default_image_size = 10, \n",
    "         multi_view = False, \n",
    "         view = 0, \n",
    "         syn_background = False):\n",
    "    client = OpenAI(api_key=api_key)  # this is also the default, it can be omitted)\n",
    "    image_dir = image_dir\n",
    "    scene_name = scene_name\n",
    "    strategy = strategy\n",
    "    images, encoded_imgs = get_images(image_dir, default_image_size, view)\n",
    "    image_base64_list = encoded_imgs\n",
    "    all_results = []\n",
    "    scene = info.scene()\n",
    "    scene_info_dict = scene.get_scene(scene_name)\n",
    "    for i in tqdm(range(30)):\n",
    "        try:\n",
    "          \n",
    "            res = causal_intervention(client, image_base64_list, \n",
    "                                       prompt_type=strategy,\n",
    "                                       prompt_info=scene_info_dict, \n",
    "                                       dump=True, \n",
    "                                       scene_name=scene_name,\n",
    "                                       multi_view = multi_view,\n",
    "                                       view = view,\n",
    "                                       images = images, \n",
    "                                       syn_background = syn_background)\n",
    "            all_results.append(res)\n",
    "        except Exception as e:\n",
    "            print(f\"Error: {str(e)}\")\n",
    "            \n",
    "\n",
    "    return \n",
    "  \n",
    "  \n",
    "one_view_mage = \"/home/lds/github/Causality-informed-Generation/code1/database/Real_magnet_v3_one_view\"\n",
    "main (one_view_mage,\n",
    "      \"Magnets\", \"basic\",  default_image_size = 10, \n",
    "         multi_view = False, \n",
    "         view = -1, \n",
    "         syn_background = True\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Magnets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/30 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30/30 [03:26<00:00,  6.90s/it]\n"
     ]
    }
   ],
   "source": [
    "multi_view_mage = \"/home/lds/github/Causality-informed-Generation/code1/database/Real_magnet_v3\"\n",
    "main (multi_view_mage,\n",
    "      \"Magnets\", \"basic\",  default_image_size = 10, \n",
    "         multi_view = False, \n",
    "         view = -1, \n",
    "         syn_background = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reflection\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/30 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30/30 [03:14<00:00,  6.49s/it]\n"
     ]
    }
   ],
   "source": [
    "one_view_mage = \"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_reflection_v2__256P/real_rendered_reflection_256P\"\n",
    "main (one_view_mage,\n",
    "      \"Reflection\", \"basic\",  default_image_size = 10, \n",
    "         multi_view = False, \n",
    "         view = -1, \n",
    "         syn_background = True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Convex\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/30 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30/30 [05:00<00:00, 10.02s/it]\n"
     ]
    }
   ],
   "source": [
    "multi_view_mage = \"/home/lds/github/Causality-informed-Generation/code1/database/convex_len_render_images/\"\n",
    "main (multi_view_mage,\n",
    "      \"Convex\", \"basic\",  default_image_size = 10, \n",
    "         multi_view = False, \n",
    "         view = -1, \n",
    "         syn_background = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cal_acc(df, prediction, GT):\n",
    "  df['correct'] = df[prediction] == df[GT]\n",
    "  return df['correct'].sum()/df.shape[0]\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "basic_synbackground_-1.csv\n",
      "0.4666666666666667\n",
      "basic_synbackground_-1.csv\n",
      "0.5\n",
      "basic_synbackground_-1.csv\n",
      "0.9666666666666667\n",
      "basic_synbackground_-1.csv\n",
      "1.0\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "paths = [\"/home/lds/github/Causality-informed-Generation/experiment_chatgpt_api/causal_OpenAI_res/Magnets_intervention/basic_synbackground_-1.csv\",\n",
    "         '/home/lds/github/Causality-informed-Generation/experiment_chatgpt_api/causal_OpenAI_res/Magnets_intervention_one_view/basic_synbackground_-1.csv',\n",
    "         '/home/lds/github/Causality-informed-Generation/experiment_chatgpt_api/causal_OpenAI_res/Reflection_intervention/basic_synbackground_-1.csv',\n",
    "         \"/home/lds/github/Causality-informed-Generation/experiment_chatgpt_api/causal_OpenAI_res/Convex_intervention/basic_synbackground_-1.csv\"\n",
    "         ]\n",
    "\n",
    "for path  in paths:\n",
    "  data = pd.read_csv(path)\n",
    "  print(os.path.basename(path))\n",
    "  print(cal_acc(data, 'gpt_ans', 'ground_truth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "joe",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
