import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import bjontegaard as bd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from lfc_toolkit.src.codecwrappers.codec_wrapper import CodecWrapper
from lfc_toolkit.src.configuration.configuration_reader import \
ConfigurationReader
from lfc_toolkit.src.data_handlers.formatters import (
get_formatted_filename_for_lf, get_formatted_filename_for_rd_report)
from lfc_toolkit.src.data_handlers.lightfield import LightField
from lfc_toolkit.src.quality.bd_wrapper import BDWrapper
from lfc_toolkit.src.quality.compound_metrics import (
CompoundMetrics, resolve_bd_adjusted_metric)
from lfc_toolkit.src.quality.rd_curve import RDCurve
from lfc_toolkit.src.quality.rd_result import RDResult
[docs]
def get_all_metrics(configuration: ConfigurationReader) -> Dict:
all_metrics = {}
for key, value in configuration["quality"].items():
if isinstance(value, dict) and key != "wrappers":
all_metrics.update(value)
return all_metrics
[docs]
class RDCurveFromVMAFLikeReports(RDCurve):
def __init__(
self,
path,
tool_short_name: str,
configuration,
encoded_path,
encoded_extension,
lightfield,
target_rates,
distortion_name,
distortion_unit,
title,
codec_name,
encoded_lightfield=None,
):
if not encoded_lightfield:
encoded_lightfield = lightfield
rates, distortions = RDCurveFromVMAFLikeReports.read_reports(
path=path,
tool_short_name=tool_short_name,
configuration=configuration,
codec_name=codec_name,
encoded_path=encoded_path,
encoded_extension=encoded_extension,
lightfield=lightfield,
encoded_lightfield=encoded_lightfield,
target_rates=target_rates,
distortion_name=distortion_name,
)
RDCurve.__init__(
self,
rates=rates,
rate_unit="bpp",
distortions=distortions,
distortion_name=distortion_name,
distortion_unit=distortion_unit,
codec_name=codec_name,
title=title,
)
[docs]
@staticmethod
def read_reports(
path,
tool_short_name: str,
configuration,
codec_name,
encoded_path,
encoded_extension,
lightfield,
encoded_lightfield,
target_rates,
distortion_name,
):
# Load data from sample files
sample_codec = configuration["sample-codecs-to-run"].get(codec_name, {})
rd_reports_samples_path = sample_codec.get("samples_path", {}).get("rd_reports", "")
if rd_reports_samples_path:
sample_file = get_formatted_filename_for_rd_report(lightfield.name, codec_name)
sample_file_path = Path(rd_reports_samples_path) / sample_file
print(f"Using sample file {sample_file_path}")
rates, distortions = process_sample_file(
target_rates=target_rates,
distortion_name=distortion_name,
file_path=sample_file_path
)
if rates and distortions:
return rates, distortions
ctc_enc = CodecWrapper(
codec_path=None,
results_path=encoded_path,
encoded_extension=encoded_extension,
repository=None,
)
rates = list()
distortions = list()
for bpp in target_rates:
try:
result = process_target_rate(
bpp=bpp,
path=path,
tool_short_name=tool_short_name,
configuration=configuration,
encoded_path=encoded_path,
encoded_extension=encoded_extension,
lightfield=lightfield,
encoded_lightfield=encoded_lightfield,
distortion_name=distortion_name,
ctc_enc=ctc_enc,
)
rates.append(result["bpp"])
distortions.append(result["mean"])
except Exception as e:
print(f"WARNING: rate {bpp} not found.")
return rates, distortions
[docs]
class RDCurveMatplotlibView:
def __init__(self, configuration, rd_curve, color, marker):
self.__rd_curve = rd_curve
self.__color = color
self.__marker = marker
self.__db_title = None
self.__configuration = configuration
@property
def marker(self):
return self.__marker
@property
def color(self):
return self.__color
@property
def rd_curve(self):
return self.__rd_curve
@property
def db_title(self):
return self.__db_title
@db_title.setter
def db_title(self, db_title):
self.__db_title = db_title
[docs]
def show(self, interpolation=100, interpolation_method="akima"):
rate_test = np.asarray(self.rd_curve.rate)
dist_test = np.asarray(self.rd_curve.distortion)
title = self.rd_curve.title
if self.db_title and interpolation:
title += f" (BDBR: {self.db_title}%)"
linestyle = "solid"
if interpolation:
_, _, interp2 = bd.bd_rate(
rate_test,
dist_test,
rate_test,
dist_test,
method=interpolation_method,
require_matching_points=False,
interpolators=True,
)
dists2 = np.linspace(
dist_test.min(), dist_test.max(), num=interpolation, endpoint=True
)
rates2 = interp2(dists2)
plt.plot(10**rates2, dists2, color=self.color)
linestyle = "None"
plt.plot(
rate_test,
dist_test,
marker=self.marker,
label=title,
color=self.color,
linestyle=linestyle,
)
[docs]
class RDPlotMatplotlib:
def __init__(
self,
configuration: ConfigurationReader,
rd_plot_config: List[Dict[str, Any]],
lightfield: LightField,
) -> None:
self.__configuration = configuration
self.rd_plot_config = rd_plot_config
self.bd_wrapper = BDWrapper()
self.lightfield = lightfield
pass
[docs]
def plot(
self,
rd_curves,
title,
anchor=None,
bpp_logscale=True,
target_bpps=None,
interpolation=100,
interpolation_method="akima",
filename=None,
generate_bd_report=True,
):
all_metrics = get_all_metrics(configuration=self.__configuration)
distortion_name = rd_curves[0].rd_curve.distortion_name
rd_plots = self.__configuration["rd_plots"]
selected_plot_cfg = next(
(cfg for cfg in rd_plots if distortion_name in cfg.get("metrics", [])), None
)
figure_size = tuple((selected_plot_cfg or {}).get("figure_size", [5, 3]))
plt.figure(figsize=figure_size)
plt.title(title)
show_bd_rates = self.rd_plot_config.get("show_bd_rates", True)
for rd_curve in rd_curves:
if anchor and rd_curve != anchor:
bd_rate = self.bd_wrapper.compute_bd_rate(anchor=anchor, test=rd_curve)
if show_bd_rates:
rd_curve.db_title = f"{bd_rate:.2f}"
rd_curve.show(
interpolation=interpolation, interpolation_method=interpolation_method
)
# Save JSON report and returns the file path
if generate_bd_report and anchor:
self.bd_wrapper.create_json_report(
anchor=anchor,
rd_curves=rd_curves,
lightfield=self.lightfield,
rd_plot_config=self.rd_plot_config,
configuration_reader=self.__configuration,
)
if bpp_logscale:
plt.xscale("log")
if target_bpps:
ax = plt.gca()
if bpp_logscale:
filtered_ticks = _filter_ticks_log(target_bpps, min_log_gap=0.08)
else:
filtered_ticks = _filter_ticks_linear(target_bpps, min_gap_ratio=0.05)
if len(filtered_ticks) > 6:
first = filtered_ticks[0]
last = filtered_ticks[-1]
middle = filtered_ticks[1:-1][::2]
filtered_ticks = [first] + middle + [last]
ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(filtered_ticks))
ax.xaxis.set_minor_locator(matplotlib.ticker.NullLocator())
ax.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(_bpp_formatter))
for label in ax.get_xticklabels():
label.set_ha('center')
plt.xlabel(rd_curves[0].rd_curve.rate_unit)
y_label_unit = ""
dist_unit = rd_curves[0].rd_curve.distortion_unit
if dist_unit:
y_label_unit = f" ({dist_unit})"
distortion_label = f"${all_metrics[distortion_name]['label']}$"
plt.ylabel(f"{distortion_label}{y_label_unit}")
# Sort legend to have the anchor first;
handles, labels = plt.gca().get_legend_handles_labels()
if anchor:
anchor_idx = labels.index(anchor.rd_curve.title)
handles.insert(0, handles.pop(anchor_idx))
labels.insert(0, labels.pop(anchor_idx))
plt.legend(handles, labels)
plt.grid(linestyle="--")
if filename:
plt.savefig(filename, bbox_inches="tight", pad_inches=0)
plt.clf()
else:
plt.show()
[docs]
def process_sample_file(target_rates: List, distortion_name: str, file_path: Path) -> Tuple[List, List]:
rates = []
distortions = []
with open(file_path, "r") as file:
data = json.load(file)
results = data.get("results", {})
for bpp in target_rates:
bpp_key = f"{bpp:g}"
if bpp_key in results:
rate_data = results[bpp_key]
rates.append(rate_data.get("rate"))
distortion = rate_data.get(distortion_name, {}).get("mean")
if distortion is not None:
distortions.append(distortion)
else:
print(f"WARNING: Distortion {distortion_name} not found for rate {bpp_key}. Falling back to process_target_rate.")
return None, None
else:
print(f"WARNING: Target rate {bpp_key} not found in results.")
return rates, distortions
# Process target rate and calculate metrics, respecting pooling-first flag
[docs]
def process_target_rate(
bpp: float,
path: str,
tool_short_name: str,
configuration: ConfigurationReader,
encoded_path: str,
encoded_extension: str,
lightfield: LightField,
encoded_lightfield: LightField,
distortion_name: str,
ctc_enc: CodecWrapper,
) -> Optional[RDResult]:
encoded_filename = get_formatted_filename_for_lf(
path=encoded_path,
lightfield=encoded_lightfield,
bpp=bpp,
file_extension=encoded_extension,
)
if not os.path.isfile(encoded_filename):
print(f"File not found: {encoded_filename}")
return None
_, actual_bpp = ctc_enc.compute_bytes_and_bpp(
encoded_filename=encoded_filename, raw_lightfield=lightfield
)
output_vmaf_report = get_formatted_filename_for_lf(
path=path,
lightfield=lightfield,
bpp=bpp,
file_extension="json",
extra=tool_short_name,
)
with open(output_vmaf_report, "r") as f:
data = json.load(f)
if CompoundMetrics.is_bd_adjusted_metric(
metric=distortion_name, configuration=configuration
):
origin_metric = configuration["quality"]["bd-adjusted-metrics"][
distortion_name
]["origin"]
list_frames = np.array(
[frame["metrics"][origin_metric] for frame in data["frames"]]
)
return resolve_bd_adjusted_metric(
bpp=actual_bpp,
frame_metrics=data["frames"],
distortion_name=distortion_name,
configuration=configuration,
)
else:
list_frames = CompoundMetrics.get_list_frames(
all_frame_metrics=data["frames"],
distortion_name=distortion_name,
configuration=configuration,
)
return RDResult(
bpp=actual_bpp,
min=np.min(list_frames),
max=np.max(list_frames),
mean=np.mean(list_frames),
stddev=np.std(list_frames)
)
def _filter_ticks_log(values: list, min_log_gap: float = 0.08) -> list:
if not values:
return values
sorted_vals = sorted(values)
selected = [sorted_vals[0]]
for v in sorted_vals[1:-1]:
if np.log10(v) - np.log10(selected[-1]) >= min_log_gap:
selected.append(v)
last = sorted_vals[-1]
if np.log10(last) - np.log10(selected[-1]) >= min_log_gap:
selected.append(last)
elif selected[-1] != last:
selected[-1] = last
return selected
def _filter_ticks_linear(values: list, min_gap_ratio: float = 0.05) -> list:
if not values:
return values
sorted_vals = sorted(values)
span = sorted_vals[-1] - sorted_vals[0]
min_gap = span * min_gap_ratio
selected = [sorted_vals[0]]
for v in sorted_vals[1:-1]:
if v - selected[-1] >= min_gap:
selected.append(v)
last = sorted_vals[-1]
if last - selected[-1] >= min_gap:
selected.append(last)
else:
selected[-1] = last
return selected
def _bpp_formatter(x, pos):
s = f"{x:.4f}".rstrip("0").rstrip(".")
return s if "." in s else s