from concurrent.futures import ThreadPoolExecutor, as_completed
from bs4 import BeautifulSoup
from urllib.parse import urljoin, urlparse
import requests
import random
import time
import os
from make_url_dataset.start_urls import start_urls_100, start_urls_168
import pandas as pd
import csv


def get_domain(url):
    """Extract domain from URL"""
    return urlparse(url).netloc


def crawl_with_diversity(
    start_urls,
    max_pages=1000,
    max_per_domain=50,
    depth=3,
    delay=0.5,
    exclude_word=None,
    max_length=77,
    output_file="urls.csv",
):
    visited = set()
    domain_counts = {}
    scraped_count = 0

    def apply_filter(df):
        """Apply filtering and limit collected data"""
        df = df.drop_duplicates()
        df = df[df["URL"].str.len() <= max_length]
        return df

    def save_to_csv(df, file_path, mode="a", header=False):
        """Save data to CSV"""
        df.to_csv(file_path, mode=mode, header=header, index=False, encoding="utf-8")
        print(f"Data saved to {file_path}. Total saved: {len(df)}")

    def scrape(url, current_depth):
        nonlocal collected_df, scraped_count

        if current_depth > depth or url in visited or len(collected_df) >= max_pages:
            return
        visited.add(url)
        scraped_count += 1

        try:
            response = requests.get(
                url,
                headers={
                    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
                },
                timeout=10,
            )
            soup = BeautifulSoup(response.text, "html.parser")

            page_links = list(
                set(
                    urljoin(url, a_tag["href"])
                    for a_tag in soup.find_all("a", href=True)
                    if urlparse(a_tag["href"]).scheme in ["http", "https"]
                )
            )

            if exclude_word:
                page_links = [link for link in page_links if exclude_word not in link]

            random.shuffle(page_links)

            for link in page_links:
                domain = get_domain(link)
                if domain not in domain_counts:
                    domain_counts[domain] = 0
                if domain_counts[domain] < max_per_domain:
                    new_row = pd.DataFrame({"URL": [link]})
                    collected_df = pd.concat([collected_df, new_row], ignore_index=True)

                    if (
                        len(collected_df) % save_interval == 0
                        collected_df = apply_filter(collected_df)
                        save_to_csv(
                            collected_df,
                            output_file,
                            mode="a",
                        )
                        if len(collected_df) >= max_pages:
                            print(f"Reached max pages: {max_pages}")
                            return

                    scrape(link, current_depth + 1)

            print(
                f"Scraped pages: {scraped_count}, Collected URLs: {len(collected_df)}"
            )

            time.sleep(delay)

        except Exception as e:
            print(f"Failed to crawl {url}: {e}")

    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(scrape, url, 0) for url in start_urls]
        for future in as_completed(futures):
            try:
                future.result()
            except Exception as e:
                print(f"Error in thread: {e}")

    collected_df = apply_filter(collected_df)
    save_to_csv(
        collected_df,
        output_file,
        mode="a",
        header=not os.path.exists(output_file),
    )
    return collected_df


# --------------------------------------------
output_file = "./data/urls.csv"

collected_df = crawl_with_diversity(
    start_urls_100,
    output_file=output_file,
)

print(f"Final collected URLs saved to {output_file}. Total URLs: {len(collected_df)}")
