from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from seleniumbase import Driver
from time import sleep
from dotenv import load_dotenv
import json
import os
import argparse
import utils
import re

class ChatGPTClient:
    def __init__(self, driver):
        self.driver = driver

    def clear_conversation(self):
        self.driver.get("https://chat.openai.com/?temporary-chat=true")
        while self.driver.execute_script("return document.readyState") != "complete":
            pass

    def enter_prompt(self, prompt):
        while True:
            utils.play_sound()
            input_box = WebDriverWait(self.driver, 60).until(
                EC.element_to_be_clickable((By.ID, "prompt-textarea"))
            )
            self.driver.execute_script("arguments[0].innerHTML = arguments[1];", input_box, prompt)
            print("Prompt is written in the text box. (not yet sent)")
            sleep(2)
            input_box.send_keys(Keys.SHIFT, Keys.ENTER)
            input_box.send_keys(Keys.ENTER)
            print("Enter is pressed.")
            sleep(5)
            print("Waiting for response...")
            WebDriverWait(self.driver, 600).until(
                EC.presence_of_element_located((By.CLASS_NAME, "mt-3"))
            )
            print("Response is ready.")
            try:
                button = WebDriverWait(self.driver, 5).until(
                    EC.element_to_be_clickable((By.CLASS_NAME, "btn.relative.btn-secondary"))
                )
                button.click()
                print("Continue Button clicked.")
            except Exception as e:
                print("Continue Button not found.")
                break  # Exit the loop if the button is not found
        try:
            # Check if the error message is present
            error_elements = self.driver.find_elements(By.CSS_SELECTOR, "div.text-sm.text-token-text-error")
            if error_elements:
                error_message = error_elements[0].text
                print(f"Error from ChatGPT: {error_message}")
                raise Exception("ChatGPT returned an error message.")
        except Exception as e:
            raise e

    def get_latest_response(self):
        elements = self.driver.find_elements(
            By.CSS_SELECTOR, 
            "div.markdown.prose.w-full.break-words.dark\\:prose-invert.light"
        )
        if elements:
            return elements[-1].text.strip()
        else:
            return ""
        
    def get_all_response(self):
        elements = self.driver.find_elements(
            By.CSS_SELECTOR, 
            "div.markdown.prose.w-full.break-words.dark\\:prose-invert.light"
        )
        return [element.text.strip() for element in elements]

    def get_model_name(self):
        try:
            # Locate the model selector button using the aria-label attribute
            model_button = self.driver.find_element(By.CSS_SELECTOR, "button[aria-label^='Model selector']")
            model_aria_label = model_button.get_attribute("aria-label")
            # Extract the model name from the aria-label attribute
            match = re.search(r'current model is (.+)', model_aria_label)
            if match:
                model_name = match.group(1).strip()
            else:
                model_name = 'Unknown'
        except Exception as e:
            print(f"Could not determine model name: {e}")
            model_name = 'Unknown'
        return model_name
    
    def switch_account(self, account_credentials, first_switch=False,timeout=600):
        """
        Logs out of the current account and logs into the next account using the provided credentials.
        """
        print(f"Waiting for {timeout} seconds before switching accounts...")
        if first_switch:
            timeout=1
        sleep(timeout)
        print(f"{timeout} seconds have passed. Switching accounts now...")
        print("Switching accounts...")
        self.driver.get("https://chat.openai.com/")

        sleep(5)

        WebDriverWait(self.driver, 60).until(EC.element_to_be_clickable((By.XPATH, "//button[contains(@class, 'rounded-full')]")))
        # Find the user menu button based on the class and click it
        user_menu_button = self.driver.find_element(By.XPATH, "//button[contains(@class, 'rounded-full')]")
        user_menu_button.click()


        # WebDriverWait to ensure the 'Log out' button is clickable
        WebDriverWait(self.driver, 30).until(EC.element_to_be_clickable((By.XPATH, "//div[contains(text(), 'Log out')]")))
        # Find the 'Log out' button based on the text and click it
        logout_button = self.driver.find_element(By.XPATH, "//div[contains(text(), 'Log out')]")
        logout_button.click()

        sleep(5)

        # WebDriverWait to ensure the 'Log in' button is clickable using data-testid
        WebDriverWait(self.driver, 30).until(EC.element_to_be_clickable((By.XPATH, "//button[@data-testid='welcome-login-button']")))

        # Find the 'Log in' button based on the data-testid and click it
        login_button = self.driver.find_element(By.XPATH, "//button[@data-testid='welcome-login-button']")
        login_button.click()

        sleep(2)

        if "google" in account_credentials['email']:
            # Click on the Google login button
            google_login_button = WebDriverWait(self.driver, 30).until(
                EC.element_to_be_clickable((By.XPATH, "//button[contains(., 'Continue with Google')]"))
            )
            google_login_button.click()
        
        else:
            # Enter email in the email input field
            email_input = WebDriverWait(self.driver, 30).until(
                EC.element_to_be_clickable((By.ID, "email-input"))
            )
            email_input.clear()
            email_input.send_keys(account_credentials['email'])
            email_input.send_keys(Keys.ENTER)
            sleep(2)

            # Enter password
            password_input = WebDriverWait(self.driver, 30).until(
                EC.element_to_be_clickable((By.ID, "password"))
            )
            password_input.clear()
            password_input.send_keys(account_credentials['password'])
            password_input.send_keys(Keys.ENTER)
            sleep(5)

        # Wait for the main page to load
        WebDriverWait(self.driver, 60).until(
            EC.presence_of_element_located((By.ID, "prompt-textarea"))
        )
        print(f"Logged in as {account_credentials['email']}")

class MovieProcessor:
    def __init__(self, data, movie_name, client):
        self.data = data
        self.movie_name = movie_name
        self.client = client
        self.all_responses = []
        self.isNoTimestamp = False

    def process_movie(self, all_process=True):
        if all_process:
            self.process_all_events()
        else:
            self.process_events_in_batches()

    
    def process_events_in_batches(self):
        duration = self.data[self.movie_name]['duration']
        events = self.data[self.movie_name]['events']

        batch_size = 150  # Adjust the batch size as needed
        total_events = len(events)
        num_batches = (total_events + batch_size - 1) // batch_size  # Ceiling division

        initial_prompt = """
        You are an expert in scene separation and scene summarization. You have prepared hundreds of sentences that consist of audio descriptions (AD) and speech transcripts from movies. Audio descriptions provide explanations of visual elements in a narrative manner. Therefore, when the speech timestamps overlap with the appearance times of the visual scenes, the timestamps of the audio descriptions are assigned after the speech, which may differ from the actual timestamps of the movie. The speech information is generated automatically using an ASR model, so there is a possibility of mismatch compared to the real video. You must effectively utilize both the speech information as the only scene context and the audio description information as time-noised and semantic visual information to construct the scenes.

I will provide you with sets of incidents that include timestamps along with visual captions and speech subtitles of events.

Below is instruction for your scene summarization.

**1.How to summarize? 2 step - Group connected incidents and summarize with grouped incidents**: 
First step - Separate the scene with grouping incidents) Scenes should be separated by different important events. Here, an important event refers to an occurrence that involves a single coherent context (e.g., when a single action or interaction lasts for an extended period, when conversations are continuously based on a similar theme, or when scenes in the same category persist at an abstract level, even if the detailed actions differ, etc.).
Second step - Summarize with the grouped incidents) Then you must summarized the scene within the grouped incidents. For example, grouped results are scene1:{[45,55]sentence1, [57,200]sentence2, [200,212]sentence3}, scene 2:{[222,225]sentence5, [240,251]sentence5}. Then you summarize scene 1 with sentence1-3, summarize scene2 with sentence 4 and sentence5. Don’t use other grouped incidents when summarize a scene. 

**2.Use audio descriptions (AD) as the main time reference**: When merging audio descriptions with speech, prioritize the timestamps of the AD to correct any timing mismatches from speech recognition (ASR). Structure scenes based on this timing to create a cohesive narrative.

**3.Keep the summary objective**: Do not include opinions or assumptions. Summaries should be factual and focus on describing the events and actions.

**4.Summarize all scenes concisely:** Avoid overly detailed descriptions and use as few semantically compressed sentences as possible to convey the key events. Keep the focus on the most essential elements needed to understand the scene.

**5.Paraphrase speech**: When referencing speech, summarize or paraphrase the dialogue instead of quoting it exactly. Focus on the meaning or intent behind the words.

**6.Strict Time Boundaries**: Summarize the scene only within the timestamp range provided. Do not project or hint at events that happen beyond the current scene's time range.  When summarizing a scene, stick to what is shown up to the last second of the scene. Refrain from mentioning characters' emotions, intentions, or events that occur after the defined end of the scene.

**7.Avoid redundancy**: If similar scenes are part of the same overarching event or theme, group them as one scene to avoid stemming the tide. Then we can get a continuous single scene. Treat minor variations in action or dialogue as part of the same scene if they share the same context or objective.
Please summarize the scenes for the following events:
        """
        # initial_prompt = """
        # You are an expert in scene separation and scene summarization. You have prepared hundreds of sentences that consist of audio descriptions (AD) and speech transcripts from movies. Audio descriptions provide explanations of visual elements in a narrative manner. Therefore, when the speech timestamps overlap with the appearance times of the visual scenes, the timestamps of the audio descriptions are assigned after the speech, which may differ from the actual timestamps of the movie. The speech information is generated automatically using an ASR model, so there is a possibility of mismatch compared to the real video. You must effectively utilize both the speech information as the only scene context and the audio description information as time-noised and semantic visual information to construct the scenes.
        # """
        #The duration of each scene should be between 300 and 500 seconds.
        self.client.enter_prompt(initial_prompt)
        response = self.client.get_latest_response()
        self.all_responses.append(response)

        for batch_num in range(num_batches):
            model_name = self.client.get_model_name()
            if model_name.strip() == '4o mini':
                print("Model name is not '4o'. Switching accounts.")
                raise Exception("Model name is not '4o'")  # Raise an exception to trigger account switch
            start_index = batch_num * batch_size
            end_index = start_index + batch_size
            if end_index > total_events:
                end_index = total_events  # Ensure the last batch is handled properly

            batch_events = events[start_index:end_index]

            events_text = ',\n'.join(f"\"{event}\"" for event in batch_events)

            raw_prompt = f"""
            Movie Name: {self.movie_name}
Total Movie Duration: {duration}

I will provide you with sets of incidents that include timestamps along with visual captions and speech subtitles of events.

Below is instruction for your scene summarization.

**1.How to summarize? 2 step - Group connected incidents and summarize with grouped incidents**: 
First step - Separate the scene with grouping incidents) Scenes should be separated by different important events. Here, an important event refers to an occurrence that involves a single coherent context (e.g., when a single action or interaction lasts for an extended period, when conversations are continuously based on a similar theme, or when scenes in the same category persist at an abstract level, even if the detailed actions differ, etc.).
Second step - Summarize with the grouped incidents) Then you must summarized the scene within the grouped incidents. For example, grouped results are scene1:[[45,55]sentence1, [57,200]sentence2, [200,212]sentence3], scene 2:[[222,225]sentence5, [240,251]sentence5]. Then you summarize scene 1 with sentence1-3, summarize scene2 with sentence 4 and sentence5. Don’t use other grouped incidents when summarize a scene. 

**2.Use audio descriptions (AD) as the main time reference**: When merging audio descriptions with speech, prioritize the timestamps of the AD to correct any timing mismatches from speech recognition (ASR). Structure scenes based on this timing to create a cohesive narrative.

**3.Keep the summary objective**: Do not include opinions or assumptions. Summaries must be factual and focus on describing the events and actions.

**4.Summarize all scenes concisely:** Avoid overly detailed descriptions and use as few semantically compressed sentences as possible to convey the key events. Keep the focus on the most essential elements needed to understand the scene.

**5.Utilize speech**: When utilize speech, summarize or simply paraphrase the dialogue instead of quoting it exactly. Focus on the meaning or intent behind the words.

**6.Strict Time Boundaries**: Summarize the scene only within the grouped incidents provided. Do not project or hint at events that happen beyond the current scene.  When summarizing a scene, stick to what is shown up to the last second of the scene. Refrain from mentioning characters' emotions, intentions, or events that occur after the defined end of the scene.

**7.Avoid redundancy**: If similar scenes are part of the same overarching event or theme, group them as one scene to avoid stemming the tide. Then we can get a continuous single scene. Treat minor variations in action or dialogue as part of the same scene if they share the same context or objective.

Please summarize the scenes for the following events:
{events_text}

Output format is below:
Step1 - Grouping incidents to Scene : 
<Grouping Start>
[Scene n] 
Grouping Theme1 : 1or2 words 
Time : [Start time,End time]
Grouped Audio Description Index : [Start Index, End Index]
<Grouping End>

Step2 - Scene Summarization:
[Start of Scene n]
Timestamp: [start_time, end_time]
Summary: ~
[End of Scene n]
 You must strictly follow the 8 instructions provided for scene summarization without exception. Especially, think about it step-by-step that you follows the 1,3,4,6 instruction, then answer me.
            """
            self.client.enter_prompt(raw_prompt)
            response = self.client.get_latest_response()
            self.all_responses.append(response)

            if "Timestamp:" not in response:
                print("Error in response. Please check the prompt.")
                self.isNoTimestamp = True
                return


    def save_results(self, folder_name, txt_file_path=None, json_file_path=None):
        if txt_file_path is None:
            txt_file_path = f'{folder_name}/{self.movie_name}_results.txt'

        if json_file_path is None:
            json_file_path = f'{folder_name}/{self.movie_name}_results.json'

        model_name = self.client.get_model_name()
        if model_name.strip() == '4o mini':
            print("Model name is not '4o'. Switching accounts.")
            raise Exception("Model name is not '4o'")  # Raise an exception to trigger account switch

        # Save all responses in a text file
        with open(txt_file_path, 'w', encoding='utf-8') as result_file:
            # Write model name at the beginning
            result_file.write(f'Model used: {model_name}\n\n')
            for response in self.all_responses:
                if response.strip():
                    result_file.write(response.strip() + '\n')

        # Process and combine the final JSON result
        try:
            all_timestamps = []
            all_sentences = []

            for response in self.all_responses:
                scenes = self.parse_response(response)
                for timestamp, summary in scenes:
                    all_timestamps.append(timestamp)
                    all_sentences.append(summary)

            full_data = {
                self.movie_name: {
                    "original_sentences": self.data[self.movie_name]['events'],
                    "timestamps": all_timestamps,
                    "sentences": all_sentences
                }
            }

            # Save the combined JSON result
            with open(json_file_path, 'w', encoding='utf-8') as result_file:
                json.dump(full_data, result_file, indent=4, ensure_ascii=False)

        except Exception as e:
            print(f"Error processing results: {e}")
            raise Exception(e)
        
    def parse_response(self, response_text):
        scenes = []
        lines = response_text.splitlines()
        i = 0
        while i < len(lines):
            line = lines[i].strip()
            if line.startswith("[Start of Scene"):
                timestamp = None
                summary = ""
                i += 1  # Move to the next line after [Start of Scene]
                while i < len(lines):
                    line = lines[i].strip()
                    if line.startswith("Timestamp:"):
                        # Extract timestamp
                        timestamp_str = line[len("Timestamp:"):].strip()
                        # Remove any brackets
                        timestamp_str = timestamp_str.strip('[]')
                        # Split into start_time and end_time
                        try:
                            start_time_str, end_time_str = timestamp_str.split(',')
                            start_time = float(start_time_str.strip())
                            end_time = float(end_time_str.strip())
                            timestamp = [start_time, end_time]
                        except ValueError:
                            print(f"Invalid timestamp format at line {i}: {line}")
                            timestamp = [0.0, 0.0]
                        i += 1  # Move to the next line after Timestamp
                    elif line.startswith("Summary:"):
                        summary = line[len("Summary:"):].strip()
                        i += 1  # Move to the line after Summary
                        # Collect all lines until [End of Scene]
                        while i < len(lines):
                            line = lines[i].strip()
                            if line.startswith("[End of Scene"):
                                break
                            else:
                                if line:  # If the line is not empty
                                    summary += " " + line
                                i += 1
                        # Break out to process the next scene
                        i += 1  # Move past [End of Scene]
                        break
                    elif line.startswith("[End of Scene"):
                        i += 1  # Move past [End of Scene]
                        break
                    else:
                        # If the line is neither Timestamp nor Summary, move to the next line
                        i += 1
                if timestamp and summary:
                    scenes.append((timestamp, summary))
                else:
                    print(f"Missing timestamp or summary for a scene starting at line {i}")
            else:
                i += 1  # Move to the next line if not [Start of Scene]
        return scenes
        

def open_json_file(file_name):
    with open(file_name, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--file_name', type=str, default='mad_v2_sub-AD_train.json')
    parser.add_argument('--output_dir', type=str, default='HourHDVC_train')
    parser.add_argument('--current_account_index', type=int, default=2)
    args = parser.parse_args()
    return args

def main():
    args = arg_parser()
    file_name = args.file_name
    folder_name = args.output_dir
    current_account_index = args.current_account_index
    data = open_json_file(file_name)

    # Initialize the driver
    user_data_dir = utils.get_user_data_dir()
    driver = Driver(uc=True, headless=False, user_data_dir=user_data_dir)

    driver.get("https://chat.openai.com/?temporary-chat=true")
    while driver.execute_script("return document.readyState") != "complete":
        pass

    print("Please log in to ChatGPT in the opened browser window.")
    input("After logging in, press Enter to continue...")
    print("Program has started.")

    client = ChatGPTClient(driver)
    load_dotenv()
    accounts = [
        {'email': os.getenv('EMAIL_1'), 'password': os.getenv('PASSWORD_1')},
        {'email': os.getenv('EMAIL_2'), 'password': os.getenv('PASSWORD_2')},
        {'email': os.getenv('EMAIL_3'), 'password': os.getenv('PASSWORD_3')},
        {'email': os.getenv('EMAIL_4'), 'password': os.getenv('PASSWORD_4')},
    ]
    max_retries = len(accounts)
    account_credentials = accounts[current_account_index]
    client.switch_account(account_credentials,first_switch=True)

    for movie_name in data.keys():
        json_check_path = f'{folder_name}/{movie_name}_results.json'

        if os.path.exists(json_check_path):
            print(f"Results for {movie_name} already exist. Skipping this movie...")
            continue

        read_file = open("no_timestamp.txt", "r")
        if movie_name in read_file.read():
            print(f"{movie_name} has no timestamp. Skipping this movie...")
            continue

        print(f"Processing movie: {movie_name}")
        success = False

        retry_count = 0

        while not success and retry_count < max_retries:
            try:
                client.clear_conversation()

                movie_processor = MovieProcessor(data, movie_name, client)
                movie_processor.process_movie(all_process=False)
                if movie_processor.isNoTimestamp:
                    print(f"Error in processing {movie_name}. Skipping this movie...")
                    with open("no_timestamp.txt", "a", encoding='utf-8') as f:
                        f.write(f"{movie_name}\n")
                    break
                movie_processor.save_results(folder_name)
                print(f"Results for {movie_name} are saved!\n")
                
                success = True

            except Exception as e:
                print(f"Error occurred while processing {movie_name}. Error details: {e}")
                if 'Model name is not \'4o\'' in str(e):
                    # Switch accounts
                    current_account_index = (current_account_index + 1) % len(accounts)
                    account_credentials = accounts[current_account_index]
                    client.switch_account(account_credentials)
                    client = ChatGPTClient(driver)
                    print(f"Switched to account {account_credentials['email']}")
                else:
                    client.clear_conversation()
                    retry_count += 1

        if not success:
            print(f"Failed to process {movie_name} after {max_retries} retries.")
            with open("no_timestamp.txt", "a", encoding='utf-8') as f:
                f.write(f"{movie_name}\n")
            continue

        print(f"\nMovie {movie_name} has been successfully processed.\n")

    print("\nAll movies have been processed.\n")

    for _ in range(10):
        utils.play_sound()

    while True:
        pass

if __name__ == '__main__':
    main()