from datetime import datetime
import git
import glob
import importlib
import json
import os
from pathlib import Path
import pickle
import random
import re
import string
import sys
import traceback
from types import SimpleNamespace, FunctionType
import warnings
from xml.sax.saxutils import escape as sx_escape


import torch

from . import BASE_PATH, ADAPTER_PATH


def num_parameters(module: torch.nn.Module, requires_grad: bool = None) -> int:
    total = 0
    for p in module.parameters():
        if requires_grad is None or p.requires_grad == requires_grad:
            total += p.numel()
    return total


def find_runs(path: os.PathLike, pattern: str) -> list[Path]:
    full_pattern = os.path.join(path, f'*/*{pattern}')
    matching_folders = glob.glob(full_pattern, recursive=True)

    return matching_folders


class DualOutput:
    def __init__(self, filename, mode='a'):
        self.terminal = sys.stdout
        self.log = open(filename, mode)

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):  # needed for Python 3 compatibility
        # This flush method is needed for the file and terminal to handle the buffer
        self.terminal.flush()
        self.log.flush()


class DualOutputContext:
    "Context and file object that directs stdout and stderr to a given terminal and an optional log file."
    def __init__(self, terminal, filename=None):
        self.terminal = terminal
        self.log = open(filename or os.devnull, "w")

    def __str__(self):
        return f"Terminal: {self.terminal}, Log: {self.log}"

    def write(self, message):
        if self.terminal:
            self.terminal.write(message)
        self.log.write(message)
        self.flush()

    def flush(self):  # needed for Python 3 compatibility
        if self.terminal:
            self.terminal.flush()
        self.log.flush()

    def __enter__(self):
        self.original_stdout = sys.stdout
        self.original_stderr = sys.stderr
        sys.stdout = self  # Redirect stdout to this class, which writes to both terminal and file
        sys.stderr = self
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            print(str(exc_type.__name__) + ": " + str(exc_val) + "\n", file=self.log)
            traceback.print_tb(exc_tb, file=self.log)
        sys.stdout = self.original_stdout  # Restore original stdout
        sys.stderr = self.original_stderr


def get_adapter_path(adapter_id: str) -> str:
    if not adapter_id or os.path.exists(adapter_id):
        return adapter_id

    matching_folders = find_runs(BASE_PATH / "checkpoints", adapter_id)
    matching_folders.extend(find_runs(ADAPTER_PATH, adapter_id))

    if len(matching_folders) > 1:
        raise ValueError(f"Multiple adapters found: {matching_folders}")

    elif len(matching_folders) == 0:
        raise ValueError(f"Adapter not found: {adapter_id}")

    else:
        print(f"Adapter found: {matching_folders[0]}")
        return matching_folders[0]


def remove_empty(lst: list) -> list:
    return [item for item in lst if item]


def dict_to_simplenamespace(d: dict) -> SimpleNamespace:
    if isinstance(d, dict):
        for key, value in d.items():
            d[key] = dict_to_simplenamespace(value)
        return SimpleNamespace(**d)
    if isinstance(d, list):
        return [dict_to_simplenamespace(item) for item in d]
    return d


def random_id(length: int):
    return ''.join(random.choices(string.ascii_letters + string.digits, k=length))


def xml_encoder(input_string):
    def replace_control_chars(match):
        char = match.group()
        return f"__{ord(char):02x}__"

    # Control characters pattern: 0-31 and 127
    control_char_pattern = re.compile(r'[\x00-\x09\x0b-\x1F\x7F]')
    # First, escape any __ sequences to avoid conflicts
    escaped_string = input_string.replace("__", "__underscore__")
    # Then replace control characters
    encoded_string = control_char_pattern.sub(replace_control_chars, escaped_string)

    return encoded_string


def xml_decoder(encoded_string):
    def replace_encoded_chars(match):
        code = match.group(1)
        return chr(int(code, 16))

    # Pattern to match __xx__ sequences
    encoded_char_pattern = re.compile(r'__(\w{2})__')
    # Replace encoded control characters with their original characters
    decoded_string = encoded_char_pattern.sub(replace_encoded_chars, encoded_string)
    # Restore any escaped __ sequences
    decoded_string = decoded_string.replace("__underscore__", "__")

    return decoded_string


def escape(txt: str) -> str:
    "Escape strings into XML and handle control characters"
    return sx_escape(xml_encoder(txt))
