Source code for src.quality.rd_plot

import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import bjontegaard as bd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from lfc_toolkit.src.codecwrappers.codec_wrapper import CodecWrapper
from lfc_toolkit.src.configuration.configuration_reader import \
    ConfigurationReader
from lfc_toolkit.src.data_handlers.formatters import (
    get_formatted_filename_for_lf, get_formatted_filename_for_rd_report)
from lfc_toolkit.src.data_handlers.lightfield import LightField
from lfc_toolkit.src.quality.bd_wrapper import BDWrapper
from lfc_toolkit.src.quality.compound_metrics import (
    CompoundMetrics, resolve_bd_adjusted_metric)
from lfc_toolkit.src.quality.rd_curve import RDCurve
from lfc_toolkit.src.quality.rd_result import RDResult


[docs] def get_all_metrics(configuration: ConfigurationReader) -> Dict: all_metrics = {} for key, value in configuration["quality"].items(): if isinstance(value, dict) and key != "wrappers": all_metrics.update(value) return all_metrics
[docs] class RDCurveFromVMAFLikeReports(RDCurve): def __init__( self, path, tool_short_name: str, configuration, encoded_path, encoded_extension, lightfield, target_rates, distortion_name, distortion_unit, title, codec_name, encoded_lightfield=None, ): if not encoded_lightfield: encoded_lightfield = lightfield rates, distortions = RDCurveFromVMAFLikeReports.read_reports( path=path, tool_short_name=tool_short_name, configuration=configuration, codec_name=codec_name, encoded_path=encoded_path, encoded_extension=encoded_extension, lightfield=lightfield, encoded_lightfield=encoded_lightfield, target_rates=target_rates, distortion_name=distortion_name, ) RDCurve.__init__( self, rates=rates, rate_unit="bpp", distortions=distortions, distortion_name=distortion_name, distortion_unit=distortion_unit, codec_name=codec_name, title=title, )
[docs] @staticmethod def read_reports( path, tool_short_name: str, configuration, codec_name, encoded_path, encoded_extension, lightfield, encoded_lightfield, target_rates, distortion_name, ): # Load data from sample files sample_codec = configuration["sample-codecs-to-run"].get(codec_name, {}) rd_reports_samples_path = sample_codec.get("samples_path", {}).get("rd_reports", "") if rd_reports_samples_path: sample_file = get_formatted_filename_for_rd_report(lightfield.name, codec_name) sample_file_path = Path(rd_reports_samples_path) / sample_file print(f"Using sample file {sample_file_path}") rates, distortions = process_sample_file( target_rates=target_rates, distortion_name=distortion_name, file_path=sample_file_path ) if rates and distortions: return rates, distortions ctc_enc = CodecWrapper( codec_path=None, results_path=encoded_path, encoded_extension=encoded_extension, repository=None, ) rates = list() distortions = list() for bpp in target_rates: try: result = process_target_rate( bpp=bpp, path=path, tool_short_name=tool_short_name, configuration=configuration, encoded_path=encoded_path, encoded_extension=encoded_extension, lightfield=lightfield, encoded_lightfield=encoded_lightfield, distortion_name=distortion_name, ctc_enc=ctc_enc, ) rates.append(result["bpp"]) distortions.append(result["mean"]) except Exception as e: print(f"WARNING: rate {bpp} not found.") return rates, distortions
[docs] class RDCurveMatplotlibView: def __init__(self, configuration, rd_curve, color, marker): self.__rd_curve = rd_curve self.__color = color self.__marker = marker self.__db_title = None self.__configuration = configuration @property def marker(self): return self.__marker @property def color(self): return self.__color @property def rd_curve(self): return self.__rd_curve @property def db_title(self): return self.__db_title @db_title.setter def db_title(self, db_title): self.__db_title = db_title
[docs] def show(self, interpolation=100, interpolation_method="akima"): rate_test = np.asarray(self.rd_curve.rate) dist_test = np.asarray(self.rd_curve.distortion) title = self.rd_curve.title if self.db_title and interpolation: title += f" (BDBR: {self.db_title}%)" linestyle = "solid" if interpolation: _, _, interp2 = bd.bd_rate( rate_test, dist_test, rate_test, dist_test, method=interpolation_method, require_matching_points=False, interpolators=True, ) dists2 = np.linspace( dist_test.min(), dist_test.max(), num=interpolation, endpoint=True ) rates2 = interp2(dists2) plt.plot(10**rates2, dists2, color=self.color) linestyle = "None" plt.plot( rate_test, dist_test, marker=self.marker, label=title, color=self.color, linestyle=linestyle, )
[docs] class RDPlotMatplotlib: def __init__( self, configuration: ConfigurationReader, rd_plot_config: List[Dict[str, Any]], lightfield: LightField, ) -> None: self.__configuration = configuration self.rd_plot_config = rd_plot_config self.bd_wrapper = BDWrapper() self.lightfield = lightfield pass
[docs] def plot( self, rd_curves, title, anchor=None, bpp_logscale=True, target_bpps=None, interpolation=100, interpolation_method="akima", filename=None, generate_bd_report=True, ): all_metrics = get_all_metrics(configuration=self.__configuration) distortion_name = rd_curves[0].rd_curve.distortion_name rd_plots = self.__configuration["rd_plots"] selected_plot_cfg = next( (cfg for cfg in rd_plots if distortion_name in cfg.get("metrics", [])), None ) figure_size = tuple((selected_plot_cfg or {}).get("figure_size", [5, 3])) plt.figure(figsize=figure_size) plt.title(title) show_bd_rates = self.rd_plot_config.get("show_bd_rates", True) for rd_curve in rd_curves: if anchor and rd_curve != anchor: bd_rate = self.bd_wrapper.compute_bd_rate(anchor=anchor, test=rd_curve) if show_bd_rates: rd_curve.db_title = f"{bd_rate:.2f}" rd_curve.show( interpolation=interpolation, interpolation_method=interpolation_method ) # Save JSON report and returns the file path if generate_bd_report and anchor: self.bd_wrapper.create_json_report( anchor=anchor, rd_curves=rd_curves, lightfield=self.lightfield, rd_plot_config=self.rd_plot_config, configuration_reader=self.__configuration, ) if bpp_logscale: plt.xscale("log") if target_bpps: ax = plt.gca() if bpp_logscale: filtered_ticks = _filter_ticks_log(target_bpps, min_log_gap=0.08) else: filtered_ticks = _filter_ticks_linear(target_bpps, min_gap_ratio=0.05) if len(filtered_ticks) > 6: first = filtered_ticks[0] last = filtered_ticks[-1] middle = filtered_ticks[1:-1][::2] filtered_ticks = [first] + middle + [last] ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(filtered_ticks)) ax.xaxis.set_minor_locator(matplotlib.ticker.NullLocator()) ax.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(_bpp_formatter)) for label in ax.get_xticklabels(): label.set_ha('center') plt.xlabel(rd_curves[0].rd_curve.rate_unit) y_label_unit = "" dist_unit = rd_curves[0].rd_curve.distortion_unit if dist_unit: y_label_unit = f" ({dist_unit})" distortion_label = f"${all_metrics[distortion_name]['label']}$" plt.ylabel(f"{distortion_label}{y_label_unit}") # Sort legend to have the anchor first; handles, labels = plt.gca().get_legend_handles_labels() if anchor: anchor_idx = labels.index(anchor.rd_curve.title) handles.insert(0, handles.pop(anchor_idx)) labels.insert(0, labels.pop(anchor_idx)) plt.legend(handles, labels) plt.grid(linestyle="--") if filename: plt.savefig(filename, bbox_inches="tight", pad_inches=0) plt.clf() else: plt.show()
[docs] def process_sample_file(target_rates: List, distortion_name: str, file_path: Path) -> Tuple[List, List]: rates = [] distortions = [] with open(file_path, "r") as file: data = json.load(file) results = data.get("results", {}) for bpp in target_rates: bpp_key = f"{bpp:g}" if bpp_key in results: rate_data = results[bpp_key] rates.append(rate_data.get("rate")) distortion = rate_data.get(distortion_name, {}).get("mean") if distortion is not None: distortions.append(distortion) else: print(f"WARNING: Distortion {distortion_name} not found for rate {bpp_key}. Falling back to process_target_rate.") return None, None else: print(f"WARNING: Target rate {bpp_key} not found in results.") return rates, distortions
# Process target rate and calculate metrics, respecting pooling-first flag
[docs] def process_target_rate( bpp: float, path: str, tool_short_name: str, configuration: ConfigurationReader, encoded_path: str, encoded_extension: str, lightfield: LightField, encoded_lightfield: LightField, distortion_name: str, ctc_enc: CodecWrapper, ) -> Optional[RDResult]: encoded_filename = get_formatted_filename_for_lf( path=encoded_path, lightfield=encoded_lightfield, bpp=bpp, file_extension=encoded_extension, ) if not os.path.isfile(encoded_filename): print(f"File not found: {encoded_filename}") return None _, actual_bpp = ctc_enc.compute_bytes_and_bpp( encoded_filename=encoded_filename, raw_lightfield=lightfield ) output_vmaf_report = get_formatted_filename_for_lf( path=path, lightfield=lightfield, bpp=bpp, file_extension="json", extra=tool_short_name, ) with open(output_vmaf_report, "r") as f: data = json.load(f) if CompoundMetrics.is_bd_adjusted_metric( metric=distortion_name, configuration=configuration ): origin_metric = configuration["quality"]["bd-adjusted-metrics"][ distortion_name ]["origin"] list_frames = np.array( [frame["metrics"][origin_metric] for frame in data["frames"]] ) return resolve_bd_adjusted_metric( bpp=actual_bpp, frame_metrics=data["frames"], distortion_name=distortion_name, configuration=configuration, ) else: list_frames = CompoundMetrics.get_list_frames( all_frame_metrics=data["frames"], distortion_name=distortion_name, configuration=configuration, ) return RDResult( bpp=actual_bpp, min=np.min(list_frames), max=np.max(list_frames), mean=np.mean(list_frames), stddev=np.std(list_frames) )
def _filter_ticks_log(values: list, min_log_gap: float = 0.08) -> list: if not values: return values sorted_vals = sorted(values) selected = [sorted_vals[0]] for v in sorted_vals[1:-1]: if np.log10(v) - np.log10(selected[-1]) >= min_log_gap: selected.append(v) last = sorted_vals[-1] if np.log10(last) - np.log10(selected[-1]) >= min_log_gap: selected.append(last) elif selected[-1] != last: selected[-1] = last return selected def _filter_ticks_linear(values: list, min_gap_ratio: float = 0.05) -> list: if not values: return values sorted_vals = sorted(values) span = sorted_vals[-1] - sorted_vals[0] min_gap = span * min_gap_ratio selected = [sorted_vals[0]] for v in sorted_vals[1:-1]: if v - selected[-1] >= min_gap: selected.append(v) last = sorted_vals[-1] if last - selected[-1] >= min_gap: selected.append(last) else: selected[-1] = last return selected def _bpp_formatter(x, pos): s = f"{x:.4f}".rstrip("0").rstrip(".") return s if "." in s else s