Source code for openrl.configs.utils
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""""
import os
import re
import tempfile
import yaml
from jinja2 import Environment, Template, meta
from jsonargparse import ActionConfigFile, ArgumentParser
[docs]class ProcessYamlAction(ActionConfigFile):
def __call__(self, parser, cfg, values, option_string=None):
# Read the original YAML file
assert isinstance(values, str) and values.endswith(".yaml")
with open(values, "r") as file:
content = file.read()
# Initialize global variables
global_variables = {}
# Extract globals section using regular expressions if present
globals_match = re.search(
r"^globals:\n((?: [^\n]*\n)*)", content, re.MULTILINE
)
if globals_match:
global_variables_yaml = globals_match.group(1)
global_variables = yaml.safe_load("globals:\n" + global_variables_yaml).get(
"globals", {}
)
# Create a Jinja2 environment
env = Environment()
# Parse original content without rendering to find all variable names
parsed_content = env.parse(content)
all_variables = meta.find_undeclared_variables(parsed_content)
# Check that all variables are defined in the global variables
undefined_variables = all_variables - set(global_variables.keys())
if undefined_variables:
# Iterate through the undefined variables and find their line numbers
error_messages = []
for variable in undefined_variables:
line_number = next(
(
i + 1
for i, line in enumerate(content.splitlines())
if "{{ " + variable + " }}" in line
),
"Unknown",
)
error_messages.append(
f"Undefined global variable: '{variable}' at line {line_number}"
)
raise ValueError("\n".join(error_messages))
# Remove 'globals' section and its variables from the YAML content
content_without_globals = re.sub(
r"^globals:\n((?: [^\n]*\n)*)", "", content, flags=re.MULTILINE
)
# Render content without globals using Jinja2 with global variables
template_without_globals = env.from_string(content_without_globals)
rendered_content = template_without_globals.render(global_variables)
# Load the rendered content as a dictionary
data = yaml.safe_load(rendered_content)
# Write the result to a temporary file. Not work on Windows.
# with tempfile.NamedTemporaryFile("w", delete=True, suffix=".yaml") as temp_file:
# yaml.dump(data, temp_file)
# temp_file.seek(0) # Move to the beginning of the file
# # Use the default behavior of ActionConfigFile to handle the temporary file
# super().__call__(parser, cfg, temp_file.name, option_string)
# Write the result to a temporary file. This works on all platforms.
temp_fd, temp_filename = tempfile.mkstemp(suffix=".yaml")
with os.fdopen(temp_fd, "w") as temp_file:
yaml.dump(data, temp_file)
try:
# Use the default behavior of ActionConfigFile to handle the temporary file
super().__call__(parser, cfg, temp_filename, option_string)
finally:
os.remove(temp_filename)