import os
import json
from openai import OpenAI
from policies import BasePolicy
from policies.utils.episode_memory import EpisodeMemory
from policies.utils.replay_buffer import ReplayBuffer
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
MODELS = ["gpt-3.5-turbo",
          "gpt-4-1106-preview",
          "gpt-4-0125-preview", #latest gpt4
          "gpt-3.5-turbo-0125", #latest gpt3.5"
          ]
"""
Three concepts of the policy:
(1) message_history
(2) prompt
(3) experience_buffer
"""

import os
import re
import json
import random
import torch
from rich import print
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline
from policies import BasePolicy
from policies.utils.replay_buffer import ReplayBuffer
from policies.utils.reflection import Reflection, get_obs_message
import policies.prompts as prompts

MODELS = [
    "meta-llama/Meta-Llama-3-70B-Instruct",
]
"""
Three concepts of the policy:
(1) message_history
(2) prompt
(3) replay_buffer
"""
class CommGPTPolicy(BasePolicy):
    def __init__(self,
                 model="gpt-4o",
                 agent_id="",
                 temperature=0.2,
                 adapter=None,
                 device="cuda",
                 comm_only=False,
                 control_only=False,
                 skip_frames=0,
                 batch_size=1,
                 logdir=None,
                 is_focal=False
    ):
        # basic policy setup
        self.agent_id = agent_id
        self.decision_frequency = 10 # frames, decision_frequency / frame_rate = seconds per decision
        self.comm_only = comm_only # indicate whether the agent is communication only
        self.control_only = control_only # indicate whether the agent is control only
        self.skip_frames = skip_frames
        self.is_focal = is_focal

        # set up language model
        self.model = model
        self.temperature = temperature
        self.openai_api_key = os.environ.get("OPENAI_API_KEY")

        # basic prompts
        self.instruction = prompts.get_instruction(comm_only)
        self.common_sense = prompts.get_common_sense()
        self.message_history = []

        # set up learning module 
        self.replay_buffer = ReplayBuffer()
        self.batch_size = batch_size
        # self.reflection = Reflection(model=self.model,
        #                              tokenizer=None,
        #                              device=self.device,
        #                              temperature=self.temperature,
        #                              comm_only=self.comm_only,
        #                              control_only=self.control_only,
        #                              is_focal=self.is_focal
        #                              )
        self.iteration = 0
        self.experience = None
        self.current_observation = None
        
        # episodic metric setup
        self.episode_return = 0
        self.step_count = 0
        self.prev_action = None
        self.plan = None
        self.logdir = logdir
        self.learned_knowledge = ""
        self.cooperative_knowledge = ""

    def reset(self):
        """
        Reset the agent when a new episode starts
        Warning: you should not reset any learning modules here!
        """
        self.step_count = 0
        self.episode_return = 0
        self.current_observation = None
        self.prev_action = None
        self.plan = None
        self.message_history = []
        print("Resetting the agent")

    def observe(self, obs, reward, terminated, truncated, info):
        if obs is None:
            self.current_observation = None
            return
        self.current_observation = obs
        self.episode_return += reward

    def act(self):
        """
        Generate actions using language model
        """
        self.step_count += 1
        if self.step_count <= self.skip_frames: 
            return {"command":"go", "message":""}
        if self.step_count % self.decision_frequency == 1:
            response = self.prompting()
            action = self.parse_action(response)
            self.prev_action = action
            print(self.message_history)
        action = self.prev_action
        return action

    def chat(self, prompt, json_format=False):
        client = OpenAI(api_key=self.openai_api_key)
        response = client.chat.completions.create(
            model=self.model,
            messages=prompt,
            temperature=self.temperature,
            stream=False,
            response_format={"type": "json_object"} if json_format else None
            )
        response = response.choices[0].message.content.strip()
        return response

    def prompting(self)->str:
        """
        Generate the chain of thought prompting
        """
        response = {}

        # System message step
        self.message_history = [{"role":"system", "content":self.instruction}]
        self.message_history.append({"role":"system", "content":self.common_sense})

        # Previous knowledge step
        if self.cooperative_knowledge:
            self.message_history.append({"role":"assistant", "content":self.cooperative_knowledge})
        if self.learned_knowledge:
            self.message_history.append({"role":"assistant", "content":self.learned_knowledge})

        # Prompt Observation
        observation = self.current_observation
        if observation is None:
            obs_message = "No observation"
        else:
            obs_message = get_obs_message(observation)
        self.message_history.append({"role":"user", "content": obs_message})

        # Reasoning step
        cot_prompt1 = prompts.get_cot_prompt_1(self.comm_only, self.control_only)
        self.message_history.append({"role":"user", "content":cot_prompt1})
        response['reasoning'] = self.chat(self.message_history)
        self.message_history.append({"role":"assistant", "content":response['reasoning']})

        # Decision step
        cot_prompt2 = prompts.get_cot_prompt_2(self.comm_only, self.control_only)
        self.message_history.append({"role":"user", "content":cot_prompt2})
        response['action'] = self.chat(self.message_history, json_format=True)

        # Record the action generated by model
        self.message_history.append({"role":"assistant", "content":response['action']})
        return response

    def parse_action(self, response):
        """
        Parse the action into a dictionary
        """
        action = {"command":"go", "reasoning":response['reasoning']}
        try:
            response = re.findall(r"\{[^*]*\}", response['action'])[0]
            if response:
                response = json.loads(response)
                if "command" in response:
                    action["command"] = response["command"]
                if "message" in response:
                    action["message"] = response["message"]
        except:
            print("Error in parsing the response", response)
        action = dict(sorted(action.items(), key=lambda item: item[0]))
        return action

    def get_episode_return(self)->float:
        return self.episode_return

    def store_transition(self, transition)->None:
        """
        Store the transition in replay buffer for learning
        """
        if transition.obs is None:
            return
        self.replay_buffer.add(transition)

    def learn(self)->None:
        """
        Learn from the experince
        """
        self.iteration += 1
        batch = self.replay_buffer.sample_batch(batch_size=self.batch_size)
        updated_knowledge = self.reflection.reflect(batch)
        self.learned_knowledge = updated_knowledge
        self.save(self.iteration)
        # Clear replay buffer
        self.replay_buffer.clear()

    def debrief(self, collective_knowledges, learned_knowledges, agent_in_debrief, batch_size, last_round=False):
        """
        Debrief the agent after the training
        """
        # debrief batch size could be smaller than default batch size
        batch = self.replay_buffer.sample_batch(batch_size=batch_size)
        # ask agent to reflect on the batch and propose cooperation and ego knowledge
        collective_knowledge, learned_knowledge = self.reflection.debrief(batch,
                                                                          collective_knowledges,
                                                                          learned_knowledges,
                                                                          agent_in_debrief,
                                                                          last_round=last_round)
        return collective_knowledge, learned_knowledge

    def internalize(self):
        """
        Internalize the knowledge from the debriefing
        """
        if self.reflection.cooperative_knowledge:
            self.cooperative_knowledge = self.reflection.cooperative_knowledge
        if self.reflection.learned_knowledge:
            self.learned_knowledge = self.reflection.learned_knowledge

    def save(self, ckpt_num):
        """
        Save the knowledge for the future training
        """
        assert self.logdir is not None
        if not os.path.exists(self.logdir):
            os.makedirs(self.logdir, exist_ok=True)
        json.dump(
            {"knowledge": self.learned_knowledge,
             "cooperative_knowledge": self.cooperative_knowledge},
            open(os.path.join(self.logdir, f"ckpt-{ckpt_num}.json"), "w")
        )
    
    def load(self, ckpt_num):
        """
        load the knowledge from the previous training
        """
        if self.logdir is None:
            return
        if ckpt_num == -1:
            # load the latest checkpoint
            ckpt_num = max([int(ckpt.split("-")[-1].split(".")[0]) for ckpt in os.listdir(self.logdir)])
        knowledge = json.load(open(os.path.join(
                                                self.logdir,
                                                f"ckpt-{ckpt_num}.json"
                                        ), "r"))
        print(f"Loading checkpoint {ckpt_num}")
        if "knowledge" in knowledge:
            self.learned_knowledge = knowledge["knowledge"]
        if "cooperative_knowledge" in knowledge:
            self.cooperative_knowledge = knowledge["cooperative_knowledge"]
