from concurrent.futures import ThreadPoolExecutor, as_completed
from bs4 import BeautifulSoup
from urllib.parse import urljoin, urlparse
import requests
import random
import time
import os
from make_url_dataset.start_urls import start_urls_100
import pandas as pd


def get_domain(url):
    """Extract domain from URL"""
    return urlparse(url).netloc


def crawl_with_diversity(
    start_urls,
    max_pages=1000000,
    max_per_domain=50,
    depth=3,
    delay=0.5,
    exclude_word=None,
    max_length=77,
    output_file="urls.csv",
):
    visited = set()
    domain_counts = {}
    scraped_count = 0

    def apply_filter(df):
        """Apply filtering and limit collected data"""
        df = df.drop_duplicates()
        df = df[df["URL"].str.len() <= max_length]
        return df

    def save_to_csv(df, file_path, mode="a", header=False):
        """Save data to CSV"""
        df.to_csv(file_path, mode=mode, header=header, index=False, encoding="utf-8")
        print(f"Data saved to {file_path}. Total saved: {len(df)}")

    def scrape(url, current_depth):
        nonlocal collected_df, scraped_count, total_collected_count

        if current_depth > depth or url in visited or len(collected_df) >= max_pages:
            return
        visited.add(url)
        scraped_count += 1

        try:
            response = requests.get(
                url,
                headers={
                    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
                },
                timeout=10,
            )
            soup = BeautifulSoup(response.text, "html.parser")

            page_links = list(
                set(
                    urljoin(url, a_tag["href"])
                    for a_tag in soup.find_all("a", href=True)
                    if urlparse(a_tag["href"]).scheme in ["http", "https"]
                )
            )

            if exclude_word:
                page_links = [link for link in page_links if exclude_word not in link]

            random.shuffle(page_links)

            new_urls = []
            for link in page_links:
                domain = get_domain(link)
                if domain not in domain_counts:
                    domain_counts[domain] = 0
                if domain_counts[domain] < max_per_domain:
                    new_urls.append(link)
                    domain_counts[domain] += 1

            total_collected_count += len(new_urls)

            new_row = pd.DataFrame({"URL": new_urls})
            collected_df = pd.concat([collected_df, new_row], ignore_index=True)

            collected_df = apply_filter(collected_df)

            if total_collected_count % save_interval == 0:
                save_to_csv(
                    collected_df,
                    output_file,
                    mode="a",
                )
                if len(collected_df) >= max_pages:
                    print(f"Reached max pages: {max_pages}")
                    return

            for link in new_urls:
                scrape(link, current_depth + 1)

            print(
                f"Scraped pages: {scraped_count}, Collected URLs: {len(collected_df)}, Total Collected Count: {total_collected_count}"
            )

            time.sleep(delay)

        except Exception as e:
            print(f"Failed to crawl {url}: {e}")

    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(scrape, url, 0) for url in start_urls]
        for future in as_completed(futures):
            try:
                future.result()
            except Exception as e:
                print(f"Error in thread: {e}")

    save_to_csv(
        collected_df,
        output_file,
        mode="a",
        header=not os.path.exists(output_file),
    )
    return collected_df


# --------------------------------------------
output_file = "./data/urls.csv"

collected_df = crawl_with_diversity(
    start_urls_100,
    output_file=output_file,
)

print(f"Final collected URLs saved to {output_file}. Total URLs: {len(collected_df)}")
