import pandas as pd
import numpy as np

import os
import pickle
import random
import shutil
from retry import retry
from tqdm import tqdm
from io import StringIO
from collections import OrderedDict
from copy import deepcopy as cpy

from utilities import openai_apis

def variable2output(causal_problem, dataset, treatment=None, trt_name=None, response=None, res_name=None, mediator=None, med_name=None,
                    condition=None, condition_name=None, condition_value=None, nodes=None, nodes_name=None):
  out = {}
  if causal_problem == 'CSL':
    out['causal_problem'] = ['CSL', None]
    out['dataset'] = [dataset]
    if nodes_name:
      out['nodes'] = nodes_name.split(',')
    else:
      nodes = nodes.split(',')
      nodes = [a.replace(" ", "_") for a in nodes]
      out['nodes'] = nodes

  if causal_problem == 'ATE':
    out['causal_problem'] = ['CEL', 'ATE'] # "cel-ate" / "ate"
    out['dataset']=[dataset]
    if trt_name:
      out['treatment']=[trt_name]
    else:
      treatment = treatment.replace(" ", "_")
      out['treatment']=[treatment]
    if res_name:
      out['response']=[res_name]
    else:
      response = response.replace(" ", "_")
      out['response']=[response]

  if causal_problem == 'HTE' or causal_problem == 'CPL':
    if causal_problem == 'HTE':
      out['causal_problem'] = ['CEL', 'HTE']
    else:
      out['causal_problem'] = ['CPL', None]
    out['dataset']=[dataset]
    if trt_name:
      out['treatment']=[trt_name]
    else:
      treatment = treatment.replace(" ", "_")
      out['treatment']=[treatment]
    if res_name:
      out['response']=[res_name]
    else:
      response = response.replace(" ", "_")
      out['response']=[response]
    if condition_name:
      condition = condition_name
    else:
      condition = condition.replace(" ", "_")
    out['condition'] = [(condition, condition_value)]
  if causal_problem == 'MA':
    out['causal_problem'] = ['CEL', 'MA']
    out['dataset']=[dataset]
    if trt_name:
      out['treatment']=[trt_name]
    else:
      treatment = treatment.replace(" ", "_")
      out['treatment']=[treatment]
    if res_name:
      out['response']=[res_name]
    else:
      response = response.replace(" ", "_")
      out['response']=[response]
    if med_name:
      out['mediator']=[med_name]
    else:
      mediator = mediator.replace(" ", "_")
      out['mediator']=[mediator]

  return out

## READ HERE: Minimal Example with explanation
## 1. input is similar to JSON format (key-val pairs) BUT in natural language
# format: ordered dictionary
# purpose: useful for both examplar and real input (no examples)
input = OrderedDict()
input["key1"] = "value 1" # real example: input["Dataset name"] = "example.csv"
## 2. question examples stored in a list
# format: list of strings
# purpose: demonstrations for ChatGPT
question = ["q1", "q2", "q3"]
## 3. examplar is an list of tuple (input, examples)
examplar = [
    (input, question)
    ]
## To generate a prompt, what we need to call is
# generate_prompt(task = "ate", input = csv_row, examplar = examplar)

## for better understanding of following functions, the final prompt is generated by:
## FIXED FUNCTION: generate a complete prompt
def generate_prompt(task, input_row, examplar):
    prompt = generate_first_part(task)
    input = generate_input_from_row(task, input_row)
    prompt += generate_input_part(input)
    prompt += "-----------------------------------------------------------\n"
    prompt += "Here are some examples:\n"
    prompt += "".join([generate_demonstration_part(example[0], example[1], index + 1) for index, example in enumerate(examplar)])
    prompt += "-----------------------------------------------------------\n"
    prompt += generate_last_part(task)
    return prompt

## ACTION NEEDED: function 2 generate input dictionay from csv rows
# instruction: edit match-case statement for each task
def generate_input_from_row(task, pd_row):
  out = OrderedDict()
  # edit here: add task descriptions
  task1 = task.lower()
  match task1:
    case "csl":   # causal structure learning
      out["Dataset name"] = pd_row["dataset"]
      vars = [x.strip() for x in pd_row["var"].split(",")]
      var_names = [x.strip() for x in pd_row["var_name"].split(",")]
      out["Interested variable"] = ""
      if vars[0] == "all_variables":
        out["Interested variable"] = "all variables  "
      elif var_names == [""]:
        count = 1
        for var in vars:
          out["Interested variable"] += var + ", "
      else:
        count = 1
        for var, var_name in zip(vars, var_names):
          out["Interested variable"] += var + " (" + var_name + "), "
      out["Interested variable"] = out["Interested variable"][:-2]
    case "ate":   # average treatment effect
      out["Dataset name"] = pd_row["dataset"]
      out["Treatment variable"] = pd_row["treatment"]
      if pd_row["treatment_name"] != "":
        out["Treatment variable"] += " (" + pd_row["treatment_name"] + ")"
      out["Outcome variable"] = pd_row["outcome"]
      if pd_row["outcome_name"] != "":
        out["Outcome variable"] += " (" + pd_row["outcome_name"] + ")"
    case "hte":   # heterogeneous treatment effect
      out["Dataset name"] = pd_row["dataset"]
      out["Treatment variable"] = pd_row["treatment"]
      if pd_row["treatment_name"] != "":
        out["Treatment variable"] += " (" + pd_row["treatment_name"] + ")"
      out["Outcome variable"] = pd_row["outcome"]
      if pd_row["outcome_name"] != "":
        out["Outcome variable"] += " (" + pd_row["outcome_name"] + ")"
      out["Group Variable"] = pd_row["condition"]
      out["Group Condition"] = str(pd_row["condition_value"]) + " (" + pd_row["condition_name"] + "=" + str(pd_row["condition_value"]) + ")"
    case "ma":    # mediation analysis
      out["Dataset name"] = pd_row["dataset"]
      out["Treatment variable"] = pd_row["treatment"]
      if pd_row["treatment_name"] != "":
        out["Treatment variable"] += " (" + pd_row["treatment_name"] + ")"
      out["Outcome variable"] = pd_row["outcome"]
      if pd_row["outcome_name"] != "":
        out["Outcome variable"] += " (" + pd_row["outcome_name"] + ")"
      out["Mediator variable"] = pd_row["mediator"]
      if pd_row["mediator_name"] != "":
        out["Mediator variable"] += " (" + pd_row["mediator_name"] + ")"
    case "s_cpl": # single stage policy learning
      out["Dataset name"] = pd_row["dataset"]
      out["Treatment variable"] = pd_row["treatment"]
      if pd_row["treatment_name"] != "":
        out["Treatment variable"] += " (" + pd_row["treatment_name"] + ")"
      out["Outcome variable"] = pd_row["outcome"]
      if pd_row["outcome_name"] != "":
        out["Outcome variable"] += " (" + pd_row["outcome_name"] + ")"
      out["Condition variable"] = pd_row["condition"]
      if pd_row["condition_name"] != "":
        out["Condition value"] = str(pd_row["condition_value"]) + " (" + pd_row["condition_name"] + "=" + str(pd_row["condition_value"]) + ")"
      else:
        out["Condition value"] = str(pd_row["condition_value"])
    case "m_cpl": # multiple stage policy learning
      out["Dataset name"] = pd_row["dataset"]
      out["Treatment variable"] = pd_row["treatment"]
      if pd_row["treatment_name"] != "":
        out["Treatment variable"] += " (" + pd_row["treatment_name"] + ")"
      out["Outcome variable"] = pd_row["outcome"]
      if pd_row["outcome_name"] != "":
        out["Outcome variable"] += " (" + pd_row["outcome_name"] + ")"
      # specify stage number (stage number must be > 1)
      out["Total Stage number"] = str(max(0,int(pd_row["stage"])) + 2) if int(pd_row["stage"]) <= 1 else str(pd_row["stage"])
      #out["State Description"] = pd_row["condition"]
      if pd_row["condition_name"] != "":
        out["State Description"] = str(pd_row["condition_value"]) + " (" + pd_row["condition_name"] + "=" + str(pd_row["condition_value"]) + ")"
      else:
        out["State Description"] = str(pd_row["condition_value"])
      #if "stage" not in pd_row:
      #  out["Stage number"] = str(3)
      #else:
      #  out["Stage number"] = str(max(0,int(pd_row["stage"])) + 2) if int(pd_row["stage"]) <= 1 else str(pd_row["stage"])
      #if "condition" not in pd_row or pd_row["condition"] == "":
      #  out["State Description"] = "given the state values are the same as the first data observation in the training data"
      #else:
      #  out["State Description"] = pd_row["condition"]
    case "mdp":   # markov decision process
      out["Dataset name"] = pd_row["dataset"]
      out["Treatment variable"] = pd_row["treatment"]
      if pd_row["treatment_name"] != "":
        out["Treatment variable"] += " (" + pd_row["treatment_name"] + ")"
      out["Outcome variable"] = pd_row["outcome"]
      if pd_row["outcome_name"] != "":
        out["Outcome variable"] += " (" + pd_row["outcome_name"] + ")"
      if pd_row["condition_name"] != "":
        out["State Description"] = str(pd_row["condition_value"]) + " (" + pd_row["condition_name"] + "=" + str(pd_row["condition_value"]) + ")"
      else:
        out["State Description"] = str(pd_row["condition_value"])
    case _:
      raise NotImplementedError
  return out

## ACTION NEEDED: function 3 task specific header function
# instruction: edit match-case statement for each task
# note: such content will appear at the top, which is a high level description of the task
def generate_first_part(task):
  out = "Given the details below, generate five diverse and plain-worded questions"
  # edit here: add task descriptions
  match task.lower():
    case "csl":   # causal structure learning
      out += " asking about the existence of causal effects between interested variables:"
    case "ate":   # average treatment effect
      out += "  asking about the impact of the treatment on the outcome:"
    case "hte":   # heterogeneous treatment effect
      out += " asking about the impact of the treatment on the outcome under a group condition:"
    case "ma":    # mediation analysis
      out += " asking about the impact of the mediator in mediating the treatment's effect on the outcome:"#. Note that the mediator is a variable that comes in the middle of the causal pathway from the treatment and the outcome:"
    case "s_cpl": # single stage policy learning
      out += ". Each question should seek advice on what action a person or subject should take under specific conditions:"
    case "m_cpl": # multiple stage policy learning
      out += "asking about the recommmened action from a policy trained on a data with a multiple stage decision making process:"
    case "mdp":   # markov decision process
      out += "asking about the recommmened action under given state value from a reinforcement learning policy trained on a data from a Markov Decision Process with infinite stages:"
    case _:
      raise NotImplementedError
  out += "\n"
  return out

## ACTION NEEDED: function 4 task specific requirement (at last)
# instruction:
# 1. only ADD new requirements in the statement dict
# 2. select required statments by index and the function will do the rest
# note: this part appears at last, which is the requirement
def generate_last_part(task):
  # edit here: common requirements for each task
  statement_dict = {
      0:  "Ensure that the questions employ a mix of different phrasing and diverse sentence structures.",
      1:  "Ensure that the questions place a strong emphasis on the effect size.",
      2:  "Ensure that the questions do not mention correlation and association.",
      3:  "Ensure that the provided names are integrated naturally into the questions without discarding or altering any part of them.",
      4:  "Ensure that the input names are preserved exactly as given without quotes.",
      5:  "Ensure that the questions place a strong emphasis on the effect existence or the total number.",
      6:  "Ensure that the provided interested variable names are integrated naturally into the questions without discarding or altering any part of them.",
      7:  "Ensure that both the group condition variable name and a value expressed in the parenthsis are inside of the question.",
      8:  "Ensure that all variable names under interested variables field exist in all five quetsion.",
      9:  "Ensure that the questions focus on soliciting recommendations for the best possible action listed in Treatment variable field.",
      10: "Ensure that the stage number is explicitly mentioned and described as a part of the data setting in the question.",
      11: "Ensure that the setting of infinite stage number or ever-going stages (note: please try to paraphrase this concept but not use the same words 'infinite stage number' or 'ever-going stages') is included and paraphrased in the question.",
      12: "Ensure that if Treatment variable input is in the format of treatment variable name (treatment_name), then in the example sentences generated it should also be in this format of treatment variable name (treatment_name), as is shown in the example.",
      13: "Ensure that if Outcome variable input is in the format of outcome variable name (outcome_name), then in the example sentences generated it should also be in this format of outcome variable name (outcome_name), as is shown in the example.",
      14: "Ensure that the questions place a strong emphasis on the size of the mediator effect.",
      15: "Ensure that all of Treatment variable, Outcome variable, and Condition variable are mentioned in the generated query."
  }
  # edit here: output parts
  out = ""
  task1 = task.lower()
  match task1:
    case "csl":   # causal structure learning
      index = [0, 2, 5, 8, 6]
    case "ate":   # average treatment effect
      index = [0, 1, 2, 3, 4]
    case "hte":   # heterogeneous treatment effect
      index = [0, 1, 2, 3, 4, 6, 7]
    case "ma":    # mediation analysis
      index = [0, 14, 2, 3, 4]
    case "s_cpl": # single stage policy learning
      index = [0, 9, 2, 3, 4, 15]
    case "m_cpl": # multiple stage policy learning
      index = [0, 3, 4, 6, 8, 10, 12, 13]
    case "mdp":   # markov decision process
      index = [0, 3, 4, 6, 8, 11, 12, 13]
    case _:
      raise NotImplementedError
  # below should be the same
  for idx in index:
    out += statement_dict[idx] + "\n"
  return out

## FIXED FUNCTION: input helper function
def generate_input_part(task_dict):
  out = ""
  for key in task_dict.keys():
    out += f"- {key}: {task_dict[key]}\n"
  return out

## FIXED FUNCTION:  example helper function
def generate_example_part(example_list, N = 3):
  out = ""
  for example in random.sample(example_list, N):
    out += f"- {example}\n"
  return out

## FIXED FUNCTION:  demonstration helper function
def generate_demonstration_part(task_dict, example_list, index, N = 3):
  out = f"Example {str(index)}: given the following information:\n"
  out += generate_input_part(task_dict) + "\n"
  out += f"The corresponding questions can be:\n"
  out += generate_example_part(example_list, N) + "\n"
  return out

def split_df_input(df_in):
  rows = []
  for idx, row in df_in.iterrows():
    row_save = row
    for idd, input_idx in enumerate(row["input"].split("\n")):
      if len(input_idx) < 15:
        continue
      if input_idx[0].isalpha():
        continue
      if idd > 4:
        print(row["input"].split("\n")[idd - 5])
      row_copy = cpy(row_save)
      row_copy["input"] = input_idx[2:]
      rows.append(row_copy)
  ret = pd.DataFrame(rows)
  ret.reset_index(inplace = True, drop = True)
  #ret["choose"] = [x for y in range(100) for x in [0, 1, 1, 1, 0]]
  return ret
