import importlib
from typing import Any, Dict, List, Union
import numpy as np
from lfc_toolkit.src.configuration.configuration_reader import ConfigurationReader
from lfc_toolkit.src.quality.rd_result import RDResult
[docs]
def adjust_bd_metric(configuration: Any, metric: str, value: float) -> float:
"""Applies BD adjustment to a metric value using the configured adjustment module.
:param configuration: Configuration reader or dict with quality metrics
:type configuration: Any
:param metric: Metric name
:type metric: str
:param value: Raw metric value to adjust
:type value: float
:return: BD-adjusted value
:rtype: float
"""
bd_adjusted_metrics = configuration["quality"]["bd-adjusted-metrics"]
wrapper_file = bd_adjusted_metrics[metric]["adjustment_for_bd"]
wrapper_module = importlib.import_module(
f"lfc_toolkit.src.quality.bd_adjustment.{wrapper_file}"
)
adjust_values_for_bd = getattr(wrapper_module, "adjust_values_for_bd")
return adjust_values_for_bd(value=value)
[docs]
def resolve_bd_adjusted_metric(
bpp: float,
frame_metrics: List[Dict[str, Any]],
distortion_name: str,
configuration: ConfigurationReader,
) -> RDResult:
"""Resolves BD-adjusted metric from frame metrics, applying pooling-first or per-frame adjustment.
:param bpp: Bits per pixel
:type bpp: float
:param frame_metrics: List of frame metric dictionaries
:type frame_metrics: List[Dict[str, Any]]
:param distortion_name: Name of the distortion metric
:type distortion_name: str
:param configuration: Configuration reader
:type configuration: ConfigurationReader
:return: RDResult with adjusted min, max, mean, stddev
:rtype: RDResult
"""
origin_metric = configuration["quality"]["bd-adjusted-metrics"][distortion_name][
"origin"
]
pooling_first = configuration["quality"]["bd-adjusted-metrics"][distortion_name][
"pooling-first"
]
list_frames = [frame["metrics"][origin_metric] for frame in frame_metrics]
if pooling_first:
# Calculate the metrics first and then adjust apply the BD adjustment
min_not_adjusted_val = np.min(list_frames)
max_not_adjusted_val = np.max(list_frames)
mean_not_adjusted_val = np.mean(list_frames)
stddev_not_adjusted_val = np.std(list_frames)
# Adjust the values for BD
return RDResult(
bpp=bpp,
min=adjust_bd_metric(
configuration=configuration,
metric=distortion_name,
value=min_not_adjusted_val,
),
max=adjust_bd_metric(
configuration=configuration,
metric=distortion_name,
value=max_not_adjusted_val,
),
mean=adjust_bd_metric(
configuration=configuration,
metric=distortion_name,
value=mean_not_adjusted_val,
),
stddev=adjust_bd_metric(
configuration=configuration,
metric=distortion_name,
value=stddev_not_adjusted_val,
),
)
else:
# Apply the BD adjustment to each frame and then calculate the metrics
adjusted_values = list()
for frame_value in list_frames:
value = adjust_bd_metric(
configuration=configuration, metric=distortion_name, value=frame_value
)
adjusted_values.append(value)
return RDResult(
bpp=bpp,
min=np.min(adjusted_values),
max=np.max(adjusted_values),
mean=np.mean(adjusted_values),
stddev=np.std(adjusted_values),
)
[docs]
def resolve_derived_metric(
frame_metrics: Dict[str, float],
distortion_name: str,
configuration: ConfigurationReader,
) -> float:
"""Resolves a derived (weighted) metric value from frame metrics.
:param frame_metrics: Frame metrics dictionary
:type frame_metrics: Dict[str, float]
:param distortion_name: Name of the derived metric
:type distortion_name: str
:param configuration: Configuration reader
:type configuration: ConfigurationReader
:return: Resolved metric value
:rtype: float
"""
derived_config = configuration["quality"]["weighted-metrics"][distortion_name]
if "metrics" in derived_config:
selected_metrics = {m: frame_metrics[m] for m in derived_config["metrics"]}
return np.mean(list(selected_metrics.values()))
if "weights" in derived_config:
weights = derived_config["weights"]
selected_metrics = {m: frame_metrics[m] for m in weights.keys()}
return CompoundMetrics.calculate_weighted_metric_value(
metrics_dict=selected_metrics, weights=weights
)
[docs]
class CompoundMetrics:
[docs]
@staticmethod
def is_weighted_metric(metric: str, configuration: ConfigurationReader) -> bool:
"""Checks if the metric is a weighted/derived metric.
:param metric: Metric name
:type metric: str
:param configuration: Configuration reader
:type configuration: ConfigurationReader
:return: True if metric is weighted
:rtype: bool
"""
return metric in configuration["quality"]["weighted-metrics"]
[docs]
@staticmethod
def is_bd_adjusted_metric(metric: str, configuration: ConfigurationReader) -> bool:
"""Checks if the metric is a BD-adjusted metric.
:param metric: Metric name
:type metric: str
:param configuration: Configuration reader
:type configuration: ConfigurationReader
:return: True if metric is BD-adjusted
:rtype: bool
"""
return metric in configuration["quality"]["bd-adjusted-metrics"]
[docs]
@staticmethod
def calculate_weighted_metric_value(
metrics_dict: Dict[str, float], weights: Dict[str, float]
) -> float:
"""Calculates weighted average of metrics using the given weights.
:param metrics_dict: Dictionary of metric names to values
:type metrics_dict: Dict[str, float]
:param weights: Dictionary of metric names to weights
:type weights: Dict[str, float]
:return: Weighted metric value
:rtype: float
"""
weighted_sum = 0
total_weight = 0
for metric, weight in weights.items():
weighted_sum += metrics_dict[metric] * weight
total_weight += weight
return weighted_sum / total_weight if total_weight != 0 else 0
[docs]
@staticmethod
def resolve_metric_value(
frame_metrics: Dict[str, Any],
distortion_name: str,
configuration: ConfigurationReader,
generating_heatmap: bool = False,
) -> Union[float, RDResult]:
"""Resolves the metric value from frame metrics, handling direct, BD-adjusted, and weighted metrics.
:param frame_metrics: Frame metrics dictionary
:type frame_metrics: Dict[str, Any]
:param distortion_name: Name of the distortion metric
:type distortion_name: str
:param configuration: Configuration reader
:type configuration: ConfigurationReader
:param generating_heatmap: Whether resolving for heatmap (affects BD-adjusted), defaults to False
:type generating_heatmap: bool, optional
:return: Resolved metric value (float) or RDResult for BD-adjusted
:rtype: Union[float, RDResult]
"""
if distortion_name in configuration["quality"]["metrics"]:
return frame_metrics[distortion_name]
if CompoundMetrics.is_bd_adjusted_metric(distortion_name, configuration):
if generating_heatmap:
origin = configuration["quality"]["bd-adjusted-metrics"][
distortion_name
]["origin"]
return adjust_bd_metric(
configuration=configuration,
metric=distortion_name,
value=frame_metrics[origin],
)
return resolve_bd_adjusted_metric(
frame_metrics=frame_metrics,
distortion_name=distortion_name,
configuration=configuration,
)
if CompoundMetrics.is_weighted_metric(distortion_name, configuration):
return resolve_derived_metric(
frame_metrics=frame_metrics,
distortion_name=distortion_name,
configuration=configuration,
)
raise Exception(f"Distortion '{distortion_name}' not found in configuration.")
[docs]
@staticmethod
def get_list_frames(
all_frame_metrics: List[Dict[str, Any]],
distortion_name: str,
configuration: ConfigurationReader,
) -> List[float]:
"""Extracts list of metric values from all frame metrics.
:param all_frame_metrics: List of frame metric dictionaries
:type all_frame_metrics: List[Dict[str, Any]]
:param distortion_name: Name of the distortion metric
:type distortion_name: str
:param configuration: Configuration reader
:type configuration: ConfigurationReader
:return: List of resolved metric values per frame
:rtype: List[float]
"""
list_frames = list()
for frame_metrics in all_frame_metrics:
list_frames.append(
CompoundMetrics.resolve_metric_value(
frame_metrics=frame_metrics["metrics"],
distortion_name=distortion_name,
configuration=configuration,
)
)
return list_frames