import os
import logging
import subprocess
import urllib.parse
import urllib.request

import logging
import threading
import time
from typing import Callable, Optional, List
from torch.utils.tensorboard import SummaryWriter

import numpy as np


class dotdict(dict):
    def __getattr__(self, name):
        return self[name]


# ----------------------------
# Download / weights utilities
# ----------------------------

def models_dir():
    """Return the repository models/ directory path and ensure it exists."""
    root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    d = os.path.join(root, 'models')
    os.makedirs(d, exist_ok=True)
    return d


def is_url(s):
    try:
        p = urllib.parse.urlparse(s)
        return p.scheme in ('http', 'https') and bool(p.netloc)
    except Exception:
        return False


def is_scp_like(s):
    # Pattern: host:path or user@host:path and not an existing local path
    return (':' in s) and (not os.path.exists(s)) and (not s.startswith('./')) and (not s.startswith('../'))


def scp_download(remote, dest_dir):
    """Download a remote file via scp into dest_dir; return local path or None on failure."""
    try:
        host, rpath = remote.split(':', 1)
    except ValueError:
        return None
    filename = os.path.basename(rpath.rstrip('/')) or 'weights.pt'
    dest = os.path.join(dest_dir, filename)
    logging.info("Downloading via scp from %s to %s", remote, dest)
    try:
        subprocess.run(['scp', '-q', '-o', 'ConnectTimeout=10', remote, dest], check=True)
        return dest
    except Exception as e:
        logging.error("scp failed for %s: %s", remote, e)
        return None


def http_download(url, dest_dir):
    """Download a URL into dest_dir; return local path or None on failure."""
    filename = os.path.basename(urllib.parse.urlparse(url).path) or 'weights.pt'
    dest = os.path.join(dest_dir, filename)
    logging.info("Downloading from URL %s to %s", url, dest)
    try:
        urllib.request.urlretrieve(url, dest)
        return dest
    except Exception as e:
        logging.error("HTTP download failed for %s: %s", url, e)
        return None


def ensure_local_weights(location):
    """Ensure a weights path is local.

    - If location is None or already a local existing path, return it.
    - If it looks like a URL or scp remote, download to models/ and return the local path.
    - If download fails, return the original location to let the caller handle errors.
    """
    if not location:
        return location
    if os.path.exists(location):
        return location
    # If a file with same basename exists in models/, reuse it
    cand = os.path.join(models_dir(), os.path.basename(location))
    if os.path.exists(cand):
        return cand
    if is_url(location):
        local = http_download(location, models_dir())
        return local or location
    if is_scp_like(location):
        local = scp_download(location, models_dir())
        return local or location
    return location

## logging with tensorboard



class TensorBoardHandler(logging.Handler):
    def __init__(self, log_dir="runs/logs"):
        super().__init__()
        self.writer = SummaryWriter(log_dir)
        self.step = 0   # you can tie this to training steps if you want

    def emit(self, record):
        msg = self.format(record)
        # Log each message under the "Logs" tag
        self.writer.add_text("Logs", msg, self.step)
        self.step += 1

    def close(self):
        self.writer.close()
        super().close()

class BatchingTensorBoardHandler(logging.Handler):
    """
    Buffers log messages and periodically writes them to TensorBoard Text.
    """
    def __init__(
        self,
        log_dir: str = "runs/logs",
        tag: str = "Logs",
        flush_interval_s: float = 3600.0,
        max_buffer_lines: int = 200,
        max_buffer_chars: int = np.inf,
        step_fn: Optional[Callable[[], int]] = None,  # return global_step
        include_timestamp: bool = True,
    ):
        super().__init__()
        self.writer = SummaryWriter(log_dir)
        self.tag = tag
        self.flush_interval_s = flush_interval_s
        self.max_buffer_lines = max_buffer_lines
        self.max_buffer_chars = max_buffer_chars
        self.step_fn = step_fn
        self.include_timestamp = include_timestamp

        self._buf: List[str] = []
        self._buf_chars = 0
        self._lock = threading.Lock()
        self._stop = threading.Event()

        # background flusher thread
        self._thread = threading.Thread(target=self._run, name="TBLogFlusher", daemon=True)
        self._thread.start()

        # reasonable default formatter if none is set
        if self.formatter is None:
            self.setFormatter(logging.Formatter(
                fmt="%(asctime)s | %(levelname)s | %(name)s: %(message)s" if include_timestamp
                    else "%(levelname)s | %(name)s: %(message)s",
                datefmt="%Y-%m-%d %H:%M:%S",
            ))

        # step counter fallback if no step_fn provided
        self._fallback_step = 0

    def emit(self, record: logging.LogRecord) -> None:
        try:
            msg = self.format(record)
        except Exception:
            self.handleError(record)
            return

        with self._lock:
            self._buf.append(msg)
            self._buf_chars += len(msg) + 1
            # flush early if buffer is big
            if len(self._buf) >= self.max_buffer_lines or self._buf_chars >= self.max_buffer_chars:
                self._flush_locked()

    def _run(self):
        # periodic timer loop
        next_tick = time.time() + self.flush_interval_s
        while not self._stop.is_set():
            timeout = max(0.0, next_tick - time.time())
            if self._stop.wait(timeout):
                break
            with self._lock:
                if self._buf:
                    self._flush_locked()
            next_tick = time.time() + self.flush_interval_s

    def _flush_locked(self):
        """
        Assumes caller holds self._lock.
        """
        if not self._buf:
            return

        # Join as a fenced code block to keep formatting readable in TensorBoard
        payload = "```\n" + "\n".join(self._buf) + "\n```"
        step = self.step_fn() if self.step_fn is not None else self._fallback_step
        self.writer.add_text(self.tag, payload, global_step=step)
        self.writer.flush()

        # clear buffer
        self._buf.clear()
        self._buf_chars = 0
        if self.step_fn is None:
            self._fallback_step += 1

    def flush(self):
        with self._lock:
            self._flush_locked()
        super().flush()

    def close(self):
        try:
            self._stop.set()
            if self._thread.is_alive():
                self._thread.join(timeout=2.0)
            self.flush()
            self.writer.close()
        finally:
            super().close()

if __name__ == "__main__":
    # --- Usage ---
    logger = logging.getLogger("mylogger")
    logger.setLevel(logging.INFO)

    # Add our TensorBoard handler
    tb_handler = TensorBoardHandler()
    logger.addHandler(tb_handler)

    # Add a normal console handler too
    console = logging.StreamHandler()
    console.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
    logger.addHandler(console)

    # Now logs go both to console and TensorBoard
    logger.info("Training started")
    logger.warning("Watch out for overfitting")
    logger.error("Something went wrong")