import json
import os
import time
import yaml

import faiss
import numpy as np
from selenium import webdriver
from selenium.common.exceptions import WebDriverException
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from sentence_transformers import SentenceTransformer
from tqdm import tqdm


class KaggleData:
    def __init__(self, discussion_links_path='competitions.yml', output_file='competitions_ideas.json',
                 ideas_metadata="ideas_metadata.json", database_path='ideas_vectorbase.faiss',
                 parse_ideas=True, top_n_solutions=5, max_competitions=-1, sleep_time=5):
        self.discussion_links_file = discussion_links_path
        self.output_file = output_file
        self.ideas_metadata = ideas_metadata
        self.database_path = database_path
        self.parse_ideas = parse_ideas
        self.top_n_solutions = top_n_solutions
        self.max_competitions = max_competitions
        self.sleep_time = sleep_time
        self.retrieve_model = None

    # main function
    def parse_create_database(self):
        self.create_json_competitions()
        self.create_database()


    # Create json of Kaggle Competitions
    def create_json_competitions(self):
        input_file = self.discussion_links_file
        output_file = self.output_file
        parse_ideas = self.parse_ideas
        top_n_solutions = self.top_n_solutions
        max_competitions = self.max_competitions

        with open(input_file, 'r') as f:
            competitions = yaml.safe_load(f).get('competitions', [])

        if os.path.exists(output_file):
            with open(output_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
        else:
            data = {}

        parse_limit = len(competitions) if max_competitions == -1 else min(max_competitions, len(competitions))
        solution_texts_ready = None
        for competition in tqdm(competitions[:parse_limit], desc="Processing competitions"):
            comp_link = competition['link']
            
            if comp_link in data:
                if data[comp_link]['overview'] != '':
                    if parse_ideas and data[comp_link]['solution_texts'] != []:
                        continue
                
            if competition['solutions'] == None:
                continue
            
            solution_links = [self._check_links(s['link']) for s in competition['solutions'][:top_n_solutions]]
            solution_texts = []

            if parse_ideas:
                solution_texts = [self.parse_discussion(link) for link in solution_links]

            overview_snippet, description_snippet, tags_snippet, lines = self.parse_description(comp_link)

            data[comp_link] = {'overview': overview_snippet, 'description': description_snippet, 'tags': tags_snippet,
                            'solution_links': solution_links, 'solution_texts': solution_texts, 
                            'solution_texts_ready': solution_texts_ready}

            # Autosave after processing each competition
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=4)


    # description
    def _get_description(self, data):
        try:
            start = next((i for i, line in enumerate(data) if self._normalize_names(line) == 'description'), None)
            end = next((i for i, line in enumerate(data) if 'evaluation' in self._normalize_names(line)), None)
            description = data[start + 1:end]

            filtered_description = [line for line in description if not line.strip().lower() in ('link', 'keyboard_arrow_up')]

            final_description = "\n".join(filtered_description)

        except:
            final_description = ''
            
        return final_description
    
    def _get_tags(self, data):
        try:
            start = next((i for i, line in enumerate(data) if self._normalize_names(line) == 'tags'), None)
            end = next((i for i, line in enumerate(data) if 'table of contents' in self._normalize_names(line)), None)
            description = data[start + 1:end]

            final_description = "\n".join(description)

        except:
            final_description = ''
            
        return final_description
    
    @staticmethod
    def _get_overview(data):
        try:
            matches = [i for i, line in enumerate(data) if "Overview" in line]
            start = matches[1] if len(matches) > 1 else -1
            end = next((i for i, line in enumerate(data) if "Start" in line), 0)
            overview_snippet = '\n'.join(data[start + 1:end])
        except:
            overview_snippet = ""

        return overview_snippet
    
    def parse_description(self, url_overview):
        sleep_time = self.sleep_time
        driver = self._create_driver()
        overview_snippet = ""
        description_snippet = ""
        tags_snippet = ""
        lines = []

        try:
            driver.get(url_overview)
            time.sleep(sleep_time)

            page_text = driver.find_element(By.TAG_NAME, "body").text
            lines = page_text.split('\n')

            overview_snippet = self._get_overview(lines)
            description_snippet = self._get_description(lines)
            tags_snippet = self._get_tags(lines)

        except WebDriverException as e:
            print(f"Error parsing {url_overview}: {e}")
        finally:
            driver.quit()

        return overview_snippet, description_snippet, tags_snippet, lines
    

    # discussion
    def parse_discussion(self, url_discussion):
        sleep_time = self.sleep_time
        try:
            driver = self._create_driver()
            driver.get(url_discussion)
            time.sleep(sleep_time)
            div_container = driver.find_element(By.CSS_SELECTOR, 'div.sc-kdNOGn.hNHnYp')
            text = div_container.text.strip()

        except:
            text = ""

        finally:
            driver.quit()

        return text
    
    # check links
    def _check_links(self, url_discussion):
        sleep_time = self.sleep_time
        try:
            driver = self._create_driver()
            driver.get(url_discussion)
            time.sleep(sleep_time)
            final_url = driver.current_url
            return final_url

        except:
            final_url = url_discussion

        finally:
            driver.quit()

        return final_url
    
    # create database
    def create_database(self):
        descriptions, metadata = self._extract_descriptions(self.output_file, create_metadata=True, 
                                                            metadata_path=self.ideas_metadata)

        if self.retrieve_model == None:
            self.retrieve_model = SentenceTransformer('intfloat/multilingual-e5-large')
        embeddings = self.retrieve_model.encode(
            descriptions,
            normalize_embeddings=False
        )
        embeddings = np.array(embeddings, dtype="float32")
        embeddings = self._normalize_vecs(embeddings)

        dim = embeddings.shape[1]
        index = faiss.IndexFlatIP(dim)
        index.add(embeddings)

        faiss.write_index(index, self.database_path)

        with open(self.ideas_metadata, "w", encoding="utf-8") as f:
            json.dump(metadata, f, ensure_ascii=False, indent=2)

    # utils
    @staticmethod
    def _normalize_names(s):
        return s.strip().lower()

    @staticmethod
    def _create_driver():
        options = Options()
        options.add_argument("--headless")
        options.add_argument("--disable-gpu")
        options.add_argument("--no-sandbox")
        service = Service()
        driver = webdriver.Chrome(service=service, options=options)
        return driver
    
    @staticmethod
    def _normalize_vecs(vecs):
        norms = np.linalg.norm(vecs, axis=1, keepdims=True)
        return vecs / norms
    
    @staticmethod
    def _extract_descriptions(json_path, create_metadata=True, metadata_path=None):
        with open(json_path, 'r') as f:
            data = json.load(f)

        descriptions = []
        metadata = []

        for url, content in data.items():
            overview = content.get('overview', '').strip()
            description = content.get('description', '').strip()

            desc = overview if overview else description
            if desc:
                descriptions.append(desc)
                if create_metadata:
                    metadata.append({
                        "description": desc,
                        "link": url
                    })

        if create_metadata:
            if metadata_path:
                with open(metadata_path, 'w') as f:
                    json.dump(metadata, f, indent=2)
            return descriptions, metadata

        return descriptions
