Source code for lhp.utils.substitution

"""Enhanced token and secret substitution for LakehousePlumber."""

import logging
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Set

from .error_formatter import ErrorCategory, LHPConfigError, LHPError, LHPValidationError

logger = logging.getLogger(__name__)

_DEPRECATED_BARE_TOKEN_WARNED = False


[docs] class SecretReference: """Represents a secret reference with scope and key.""" def __init__(self, scope: str, key: str): self.scope = scope self.key = key def __hash__(self): return hash((self.scope, self.key)) def __eq__(self, other): if isinstance(other, SecretReference): return self.scope == other.scope and self.key == other.key return False
[docs] def to_dbutils_call(self) -> str: """Generate a dbutils.secrets.get() call as a Python expression. Single-quoted scope/key arguments so the call is safe to embed inside double-quoted string literals (e.g. inside JDBC URL templates) without breaking the outer quote nesting. This matches the format the legacy post-pass SecretCodeGenerator emitted before inline emission replaced it. """ return f"dbutils.secrets.get(scope='{self.scope}', key='{self.key}')"
[docs] class EnhancedSubstitutionManager: """Enhanced substitution manager with YAML and secret support.""" # Regex patterns for token matching DEFAULT_TOKEN_PATTERN = re.compile(r"\{(\w+)\}") DOLLAR_TOKEN_PATTERN = re.compile(r"\$\{(\w+)\}") DOLLAR_TOKEN_SIMPLE_PATTERN = re.compile(r"\$(\w+)") SECRET_PATTERN = re.compile(r"\$\{secret:([^}]+)\}") UNRESOLVED_TOKEN_PATTERN = re.compile(r"\{(?!dbutils\.)(\w+)\}") def __init__( self, substitution_file: Path = None, env: str = "dev", skip_validation: bool = False, ): self.env = env self.skip_validation = ( skip_validation # Flag to skip unresolved token validation ) self.mappings: Dict[str, str] = {} self.prefix_suffix_rules: Dict[str, Dict[str, str]] = {} self.secret_scopes: Dict[str, str] = {} self.default_secret_scope: Optional[str] = None self.secret_references: Set[SecretReference] = set() # Add reserved tokens self._add_reserved_tokens() # Load substitutions and secret configuration if substitution_file and substitution_file.exists(): logger.debug( f"Loading substitution file: {substitution_file} for env '{env}'" ) self._load_config_from_file(substitution_file, env) # Recursively expand tokens self._expand_recursive_tokens() def _add_reserved_tokens(self): """Add reserved tokens automatically available.""" self.mappings["workspace_env"] = self.env self.mappings["logical_env"] = self.env # From environment variables if "WORKSPACE_ENV" in os.environ: self.mappings["workspace_env"] = os.environ["WORKSPACE_ENV"] if "LOGICAL_ENV" in os.environ: self.mappings["logical_env"] = os.environ["LOGICAL_ENV"] def _load_config_from_file(self, file_path: Path, env: str): """Load tokens, secrets, and rules from YAML file.""" try: from .yaml_loader import load_yaml_file config = load_yaml_file(file_path, error_context="substitution file") except LHPError: # Re-raise LHPError as-is (it's already well-formatted) raise except Exception as e: raise LHPConfigError( category=ErrorCategory.CONFIG, code_number="020", title=f"Failed to load substitution file", details=f"Error loading substitution file {file_path}: {e}", suggestions=[ "Check the substitution file for YAML syntax errors", "Ensure the file is readable and not corrupted", ], context={"File": str(file_path), "Environment": env}, ) from e if not config: return # Load token substitutions env_tokens = config.get(env, {}) global_tokens = config.get("global", {}) # Merge tokens (environment-specific overrides global) # Convert primitive types to strings for text substitution logger.debug( f"Loaded {len(env_tokens) if isinstance(env_tokens, dict) else 0} env-specific " f"and {len(global_tokens) if isinstance(global_tokens, dict) else 0} global token(s)" ) if isinstance(env_tokens, dict): for key, value in env_tokens.items(): # Convert primitive types to strings for text-based substitution if isinstance(value, bool): # Convert booleans to lowercase for YAML compatibility self.mappings[key] = str(value).lower() elif isinstance(value, (str, int, float)): self.mappings[key] = str(value) elif not isinstance(value, (dict, list)): # Handle other non-nested types self.mappings[key] = str(value) else: # Keep nested structures (dicts/lists) as-is for prefix_suffix handling self.mappings[key] = value if isinstance(global_tokens, dict): # Only add global tokens that aren't already set for key, value in global_tokens.items(): if key not in self.mappings: # Convert primitive types to strings if isinstance(value, bool): # Convert booleans to lowercase for YAML compatibility self.mappings[key] = str(value).lower() elif isinstance(value, (str, int, float)): self.mappings[key] = str(value) elif not isinstance(value, (dict, list)): self.mappings[key] = str(value) else: self.mappings[key] = value # Load secret configuration secrets_config = config.get("secrets", {}) if isinstance(secrets_config, dict): self.default_secret_scope = secrets_config.get("default_scope") self.secret_scopes = secrets_config.get("scopes", {}) # Load prefix/suffix rules prefix_suffix = config.get("prefix_suffix_rules", {}) if isinstance(prefix_suffix, dict): self.prefix_suffix_rules = prefix_suffix def _expand_recursive_tokens(self): """Recursively expand tokens that reference other tokens.""" max_iterations = 10 for iteration in range(max_iterations): changed = False for token, value in self.mappings.items(): if isinstance(value, str): expanded = self._replace_tokens_in_string(value) if expanded != value: self.mappings[token] = expanded changed = True if not changed: break else: # Reached max iterations - likely circular reference # Log warning but don't fail here - validation will catch it logger.warning( f"Token expansion reached maximum iterations ({max_iterations}). " f"Possible circular reference in substitutions/{self.env}.yaml. " f"Unresolved tokens will be caught by validation." )
[docs] def substitute_yaml(self, data: Dict[str, Any]) -> Dict[str, Any]: """Recursively substitute tokens and collect secret references.""" logger.debug( f"Substituting tokens in YAML data ({len(self.mappings)} mapping(s) available)" ) return self._substitute_recursive(data)
def _substitute_recursive(self, obj: Any) -> Any: """Recursively substitute tokens and secrets in any object.""" if isinstance(obj, str): return self._process_string(obj) elif isinstance(obj, dict): return {k: self._substitute_recursive(v) for k, v in obj.items()} elif isinstance(obj, list): return [self._substitute_recursive(item) for item in obj] else: return obj def _process_string(self, text: str) -> str: """Process string for both token and secret substitution.""" # First handle regular token substitution text = self._replace_tokens_in_string(text) # Handle secret references def secret_replacer(match): secret_ref = match.group(1) if "/" in secret_ref: scope, key = secret_ref.split("/", 1) else: scope = self.default_secret_scope key = secret_ref if not scope: raise LHPValidationError( category=ErrorCategory.CONFIG, code_number="008", title="Missing default secret scope", details=f"No default secret scope configured for secret reference: {secret_ref}", suggestions=[ "Add a 'secrets.default_scope' to your substitutions YAML file", "Or use the full scope/key format: ${secret:scope/key}", ], example="secrets:\n default_scope: my-scope\n\n# Then use: ${secret:my_key}\n# Or explicit: ${secret:my-scope/my_key}", context={"secret_ref": secret_ref, "env": self.env}, ) # Resolve scope alias if it exists actual_scope = self.secret_scopes.get(scope, scope) # Store reference for validation secret_reference = SecretReference(actual_scope, key) self.secret_references.add(secret_reference) # Return placeholder; SecretCodeGenerator post-pass converts these # to bare dbutils calls or f-strings after Jinja templates have # wrapped values in Python string literals. return f"__SECRET_{actual_scope}_{key}__" return self.SECRET_PATTERN.sub(secret_replacer, text) def _replace_tokens_in_string(self, text: str) -> str: """Replace all {TOKEN} and ${TOKEN} patterns in a string.""" def default_replacer(match): token = match.group(1) return self.mappings.get(token, match.group(0)) def dollar_replacer(match): token = match.group(1) return self.mappings.get(token, match.group(0)) # Apply patterns - dollar pattern first to avoid conflicts text = self.DOLLAR_TOKEN_PATTERN.sub(dollar_replacer, text) # Warn once per process about deprecated {token} syntax global _DEPRECATED_BARE_TOKEN_WARNED if not _DEPRECATED_BARE_TOKEN_WARNED and self.DEFAULT_TOKEN_PATTERN.search( text ): logger.warning( "The bare {token} substitution syntax is deprecated and will be " "removed in v1.0. Use ${token} instead." ) _DEPRECATED_BARE_TOKEN_WARNED = True text = self.DEFAULT_TOKEN_PATTERN.sub(default_replacer, text) return text
[docs] def validate_no_unresolved_tokens( self, data: Any, path: str = "config" ) -> List[str]: """Detect unresolved tokens after substitution. Scans configuration for any remaining {token} patterns that weren't resolved during substitution, indicating missing values in substitutions file. Args: data: Configuration data to validate (dict, list, str, or other) path: Current path in config tree for error reporting Returns: List of error messages describing unresolved tokens with their locations Examples: >>> mgr = EnhancedSubstitutionManager() >>> mgr.mappings = {"catalog": "main"} >>> data = {"path": "s3://{bucket}/{missing}/data"} >>> errors = mgr.validate_no_unresolved_tokens(data) >>> print(errors[0]) "Unresolved token '{missing}' found at config.path. Check substitutions/dev.yaml" """ errors = [] if isinstance(data, str): # Find all unresolved tokens except dbutils expressions matches = self.UNRESOLVED_TOKEN_PATTERN.findall(data) if matches: # Format all unresolved tokens in this string token_list = ", ".join(f"{{{m}}}" for m in matches) errors.append( f"Unresolved token(s) {token_list} found at {path}. " f"Check substitutions/{self.env}.yaml for missing value(s)." ) elif isinstance(data, dict): for key, value in data.items(): errors.extend( self.validate_no_unresolved_tokens(value, f"{path}.{key}") ) elif isinstance(data, list): for i, item in enumerate(data): errors.extend(self.validate_no_unresolved_tokens(item, f"{path}[{i}]")) # For other types (int, bool, None, etc.), nothing to validate return errors