import json
import random
from typing import(
    Tuple,
    List,
    Dict,
)

import glob
import pandas as pd
from tqdm import tqdm
from loguru import logger
from datasets import load_dataset

from src.schema import (
    User,
    Session,
    Turn,
    Memory,
    MemoryType,
)
from src.utils.json import read_jsonl_file


class RealDataInjectionStage:
    def __init__(self) -> None:
        logger.info("Initializing RealDataInjectionStage.")

        dataset = load_dataset("allenai/WildChat-1M", split="train")
        df = dataset.to_pandas()
        filtered_df = df[(df["language"] == "English") & (df["redacted"] == False)]  # TODO: check here
        self.available_sessions = self._convert_df_to_sessions(filtered_df)

        self.observation_map, self.summary_map, self.episodic_memory_map = self._get_wildchat_responses()

    def _get_wildchat_responses(self) -> Tuple[Dict, Dict, Dict]:
        observation_map = {}
        summary_map = {}
        episodic_memory_map = {}

        base_path = "/Users/maum/Desktop/dev/github/PELIQAN/.db/wildchat/response/"
        files = glob.glob(base_path + "*.jsonl")
        wildchat_responses = []
        for file in tqdm(files):
            wildchat_responses.extend(read_jsonl_file(file))

        for resp in wildchat_responses:
            try:
                data = json.loads(resp["response"]["body"]["choices"][0]["message"]["content"].replace("```json", "").replace("```", "").strip())
                observation_map[resp["custom_id"]] = data["observations"]
                episodic_memory_map[resp["custom_id"]] = data["events"]
                summary_map[resp["custom_id"]] = [data["summary"]]
            except Exception as e:
                # orig_text = resp["response"]["body"]["choices"][0]["message"]["content"].replace("```json", "").replace("```", "").strip()
                # observation_map[resp["custom_id"]] = [orig_text]
                # episodic_memory_map[resp["custom_id"]] = [orig_text]
                # summary_map[resp["custom_id"]] = [orig_text]
                observation_map[resp["custom_id"]] = None
                episodic_memory_map[resp["custom_id"]] = None
                summary_map[resp["custom_id"]] = None

        return observation_map, summary_map, episodic_memory_map

    def _check_availability(self, session: Session) -> bool:
        conversation_hash = session.session_id.replace("S_wildchat_", "")
        if self.observation_map.get(conversation_hash) is None:
            return False
        return True

    def _convert_df_to_sessions(self, df: pd.DataFrame) -> List[Session]:
        sessions = []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Converting DataFrame to Sessions"):
            turns = [
                Turn(
                    turn_id=f"T_wildchat_{row['conversation_hash']}_{turn['turn_identifier']}",
                    speaker=turn["role"],
                    content=turn["content"],
                )
                for turn in row["conversation"]
            ]
            session = Session(
                session_id=f"S_wildchat_{row['conversation_hash']}",
                turns=turns,
            )
            sessions.append(session)
        random.shuffle(sessions)
        return sessions
    
    def _convert_session_to_dialogue(self, session: Session) -> Memory:
        content = "\n".join([f"{turn.speaker}: {turn.content}" for turn in session.turns])
        return Memory(
            memory_id=f"M_{session.session_id}",
            memory_type=MemoryType.DIALOGUE,
            content=content,
        )

    def _convert_session_to_observation(self, session: Session) -> Memory:  
        conversation_hash = session.session_id.replace("S_wildchat_", "")
        observations = self.observation_map[conversation_hash]
        content = "\n".join(observations)
        return Memory(
            memory_id=f"M_{session.session_id}",
            memory_type=MemoryType.OBSERVATION,
            content=content,
        )
    
    def _convert_session_to_summary(self, session: Session) -> Memory:
        conversation_hash = session.session_id.replace("S_wildchat_", "")
        summaries = self.summary_map[conversation_hash]
        content = "\n".join(summaries)
        return Memory(
            memory_id=f"M_{session.session_id}",
            memory_type=MemoryType.SUMMARY,
            content=content,
        )

    def _convert_session_to_episodic_memory(self, session: Session) -> Memory:
        conversation_hash = session.session_id.replace("S_wildchat_", "")
        events = self.episodic_memory_map[conversation_hash]
        content = "\n".join(events)
        return Memory(
            memory_id=f"M_{session.session_id}",
            memory_type=MemoryType.EPISODIC_MEMORY,
            content=content,
        )

    def run(
        self,
        memory_dataset: List[User],
        session_multiplier: float,
    ) -> List[User]:
        session_counts = [len(user_sample.raw_dialogue) for user_sample in memory_dataset]
        max_session_count = max(session_counts)

        for user_sample in memory_dataset:
            target_session_count = int(max_session_count * session_multiplier)
            count_diff = target_session_count - len(user_sample.raw_dialogue)
            if count_diff > len(self.available_sessions):
                raise ValueError("Not enough unique WildChat sessions available.")
            count = 0
            while count < count_diff:
                new_session = self.available_sessions.pop()
                if self._check_availability(new_session):
                    user_sample.raw_dialogue.append(new_session)
                    user_sample.dialogue.append(self._convert_session_to_dialogue(new_session))
                    user_sample.observation.append(self._convert_session_to_observation(new_session))
                    user_sample.summary.append(self._convert_session_to_summary(new_session))
                    user_sample.episodic_memory.append(self._convert_session_to_episodic_memory(new_session))
                    count += 1

        return memory_dataset
