import json
import yaml
from typing import Any, Dict, List, get_origin, get_args, Union
from transformers.hf_argparser import HfArgumentParser
from dataclasses import is_dataclass, fields


class CustomHFArgumentParser(HfArgumentParser):
    def _postprocess_dict_fields(self, dataclass_obj):
        if not is_dataclass(dataclass_obj):
            return dataclass_obj

        for f in fields(dataclass_obj):
            val = getattr(dataclass_obj, f.name)

            # Heuristics for JSON string → dict parsing
            if isinstance(val, str) and (
                f.name.endswith("_kwargs") or get_origin(f.type) is dict
            ):
                try:
                    parsed_val = json.loads(val)
                    setattr(dataclass_obj, f.name, parsed_val)
                except Exception:
                    pass  # Keep string if it isn't valid JSON

        return dataclass_obj

    def parse_args_into_dataclasses(self, return_remaining_strings=False):
        result = super().parse_args_into_dataclasses(return_remaining_strings=return_remaining_strings)
        if return_remaining_strings:
            args, remaining = result
            processed = tuple(self._postprocess_dict_fields(a) for a in args)
            return (processed, remaining)
        else:
            processed = tuple(self._postprocess_dict_fields(a) for a in result)
            return processed

    def parse_json_file(self, json_file: str):
        with open(json_file, "r") as f:
            json_data = json.load(f)
        # Flatten nested config sections if present
        flat_data = {}
        for v in json_data.values():
            if isinstance(v, dict):
                flat_data.update(v)
        if flat_data:
            return self.parse_dict(flat_data)
        else:
            return self.parse_dict(json_data)

    def parse_yaml_file(self, yaml_file: str):
        with open(yaml_file, "r") as f:
            yaml_data = yaml.safe_load(f)
        # Flatten nested config sections if present
        flat_data = {}
        for v in yaml_data.values():
            if isinstance(v, dict):
                flat_data.update(v)
        if flat_data:
            return self.parse_dict(flat_data)
        else:
            return self.parse_dict(yaml_data)

    def parse_dict(self, data: Dict[str, Any]):
        args = super().parse_dict(data)
        return tuple(self._postprocess_dict_fields(a) for a in args)

    def _parse_dataclass_field(self, parser, field):
        # Patch: avoid isinstance/issubclass on subscripted generics
        # Use get_origin and get_args for robust type handling
        field_type = field.type
        origin = get_origin(field_type)
        args = get_args(field_type)
        # If the field is Optional[...] or Union[..., NoneType]
        if origin is Union and type(None) in args:
            # Remove NoneType from args
            non_none_args = [a for a in args if a is not type(None)]
            if len(non_none_args) == 1:
                field_type = non_none_args[0]
                origin = get_origin(field_type)
                args = get_args(field_type)
        # Now, avoid isinstance/issubclass on subscripted generics
        # Call the original method, but catch the TypeError and handle gracefully
        try:
            return super()._parse_dataclass_field(parser, field)
        except TypeError as e:
            if "Subscripted generics cannot be used with class and instance checks" in str(e):
                # Fallback: treat as string for argument parsing, but only if not already present
                arg_name = f"--{field.name}"
                if not any(arg_name in action.option_strings for action in parser._actions):
                    parser.add_argument(arg_name, type=str, help=f"[Patched] Field {field.name} could not be parsed as its annotated type due to Python typing limitations. Parsed as string.")
                return None
            else:
                raise 