#!/usr/bin/env python3
"""
Generate messy Python code variants from clean reference code.

Applies deterministic corruptions that preserve semantics:
- Variable renaming to uninformative names
- Remove docstrings/comments
- Break formatting (line length, spacing)
- Messy import order
- Inconsistent quotes
"""

import ast
import random
import re
from typing import List, Tuple, Dict, Any


class MessyCodeGenerator:
    """Generate semantically equivalent but stylistically messy code."""

    def __init__(self, seed: int = 42):
        self.rng = random.Random(seed)
        self.var_counter = 0

    def generate_messy(self, clean_code: str) -> str:
        """Apply corruptions to clean code while preserving semantics."""
        code = clean_code

        # 1. Remove docstrings and comments
        code = self._remove_docstrings(code)
        code = self._remove_comments(code)

        # 2. Rename variables to uninformative names
        code = self._rename_variables(code)

        # 3. Break formatting
        code = self._mess_formatting(code)

        # 4. Mess up imports
        code = self._mess_imports(code)

        # 5. Inconsistent quotes
        code = self._mess_quotes(code)

        return code

    def _remove_docstrings(self, code: str) -> str:
        """Remove all docstrings."""
        try:
            tree = ast.parse(code)

            class DocstringRemover(ast.NodeTransformer):
                def visit_FunctionDef(self, node):
                    self.generic_visit(node)
                    if (node.body and isinstance(node.body[0], ast.Expr) and
                        isinstance(node.body[0].value, ast.Constant) and
                        isinstance(node.body[0].value.value, str)):
                        node.body = node.body[1:] or [ast.Pass()]
                    return node

                def visit_ClassDef(self, node):
                    self.generic_visit(node)
                    if (node.body and isinstance(node.body[0], ast.Expr) and
                        isinstance(node.body[0].value, ast.Constant) and
                        isinstance(node.body[0].value.value, str)):
                        node.body = node.body[1:] or [ast.Pass()]
                    return node

                def visit_Module(self, node):
                    self.generic_visit(node)
                    if (node.body and isinstance(node.body[0], ast.Expr) and
                        isinstance(node.body[0].value, ast.Constant) and
                        isinstance(node.body[0].value.value, str)):
                        node.body = node.body[1:]
                    return node

            tree = DocstringRemover().visit(tree)
            ast.fix_missing_locations(tree)
            return ast.unparse(tree)
        except:
            return code

    def _remove_comments(self, code: str) -> str:
        """Remove inline comments."""
        lines = code.split('\n')
        result = []
        for line in lines:
            # Remove inline comments (but not # in strings)
            in_string = False
            string_char = None
            new_line = []
            i = 0
            while i < len(line):
                c = line[i]
                if c in '"\'':
                    if not in_string:
                        in_string = True
                        string_char = c
                    elif c == string_char and (i == 0 or line[i-1] != '\\'):
                        in_string = False
                elif c == '#' and not in_string:
                    break
                new_line.append(c)
                i += 1
            result.append(''.join(new_line).rstrip())
        return '\n'.join(result)

    def _rename_variables(self, code: str) -> str:
        """Rename local variables to uninformative names like x, y, tmp1, etc."""
        try:
            tree = ast.parse(code)

            # Collect all local variable names
            class VarCollector(ast.NodeVisitor):
                def __init__(self):
                    self.vars = set()
                    self.params = set()

                def visit_FunctionDef(self, node):
                    for arg in node.args.args:
                        self.params.add(arg.arg)
                    self.generic_visit(node)

                def visit_Name(self, node):
                    if isinstance(node.ctx, ast.Store):
                        self.vars.add(node.id)

            collector = VarCollector()
            collector.visit(tree)

            # Create rename mapping (exclude builtins and params initially)
            builtins = {'print', 'len', 'range', 'int', 'str', 'list', 'dict',
                       'set', 'tuple', 'True', 'False', 'None', 'open', 'file',
                       'sum', 'min', 'max', 'abs', 'sorted', 'enumerate', 'zip',
                       'map', 'filter', 'any', 'all', 'isinstance', 'type'}

            rename_map = {}
            bad_names = ['x', 'y', 'z', 'a', 'b', 'tmp', 'val', 'data', 'res', 'r']
            idx = 0
            for var in collector.vars:
                if var not in builtins and var not in collector.params:
                    if idx < len(bad_names):
                        rename_map[var] = bad_names[idx]
                    else:
                        rename_map[var] = f'tmp{idx}'
                    idx += 1

            # Also rename some params
            param_bad = ['p1', 'p2', 'arg', 'v', 'n']
            pidx = 0
            for param in list(collector.params)[:3]:  # Only rename first few
                if param not in builtins and pidx < len(param_bad):
                    rename_map[param] = param_bad[pidx]
                    pidx += 1

            # Apply renaming
            class VarRenamer(ast.NodeTransformer):
                def visit_Name(self, node):
                    if node.id in rename_map:
                        node.id = rename_map[node.id]
                    return node

                def visit_arg(self, node):
                    if node.arg in rename_map:
                        node.arg = rename_map[node.arg]
                    return node

            tree = VarRenamer().visit(tree)
            ast.fix_missing_locations(tree)
            return ast.unparse(tree)
        except:
            return code

    def _mess_formatting(self, code: str) -> str:
        """Mess up formatting - remove blank lines, inconsistent spacing."""
        lines = code.split('\n')
        result = []
        for line in lines:
            # Remove extra blank lines
            if line.strip() == '' and result and result[-1].strip() == '':
                continue
            # Sometimes remove spaces around operators
            if self.rng.random() < 0.3:
                line = re.sub(r' = ', '=', line)
            if self.rng.random() < 0.3:
                line = re.sub(r', ', ',', line)
            result.append(line)
        return '\n'.join(result)

    def _mess_imports(self, code: str) -> str:
        """Shuffle import order."""
        lines = code.split('\n')
        imports = []
        other = []
        for line in lines:
            if line.strip().startswith('import ') or line.strip().startswith('from '):
                imports.append(line)
            else:
                other.append(line)

        self.rng.shuffle(imports)
        return '\n'.join(imports + other)

    def _mess_quotes(self, code: str) -> str:
        """Make quotes inconsistent (mix single and double)."""
        result = []
        in_string = False
        i = 0
        while i < len(code):
            c = code[i]
            if c == '"' and not in_string and self.rng.random() < 0.5:
                # Find matching quote and replace with single
                j = i + 1
                while j < len(code) and code[j] != '"':
                    if code[j] == '\\':
                        j += 1
                    j += 1
                if j < len(code):
                    inner = code[i+1:j]
                    if "'" not in inner:
                        result.append("'")
                        result.append(inner)
                        result.append("'")
                        i = j + 1
                        continue
            result.append(c)
            i += 1
        return ''.join(result)


def generate_messy_dataset(clean_programs: List[Tuple[str, str, str]]) -> List[Dict[str, Any]]:
    """
    Generate messy variants for a list of clean programs.

    Args:
        clean_programs: List of (name, clean_code, test_code) tuples

    Returns:
        List of dicts with 'name', 'clean', 'messy', 'test' keys
    """
    generator = MessyCodeGenerator()
    dataset = []

    for name, clean_code, test_code in clean_programs:
        messy_code = generator.generate_messy(clean_code)
        dataset.append({
            'name': name,
            'clean': clean_code,
            'messy': messy_code,
            'test': test_code
        })

    return dataset


if __name__ == '__main__':
    # Test with a simple example
    clean = '''
def calculate_average(numbers):
    """Calculate the average of a list of numbers."""
    if not numbers:
        return 0.0
    total = sum(numbers)
    count = len(numbers)
    return total / count
'''

    gen = MessyCodeGenerator()
    messy = gen.generate_messy(clean)
    print("=== Clean ===")
    print(clean)
    print("\n=== Messy ===")
    print(messy)
