import json
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import bjontegaard as bd
import numpy as np
import prettytable
from lfc_toolkit.src.configuration.configuration_reader import \
ConfigurationReader
from lfc_toolkit.src.data_handlers.lightfield import LightField
from lfc_toolkit.src.quality.rd_curve import RDCurve
[docs]
@dataclass(frozen=True)
class BDWrapper:
method: str = "akima"
interpolator: bool = False
require_matching_points: bool = False
min_overlap: float = 0.0 # The overlap should be added as a result in the future.
[docs]
def compute_bd_rate(self, anchor: Any, test: Any) -> float:
"""Computes the BD-Rate between two RD curves.
:param anchor: Anchor RD curve
:type anchor: RDCurve
:param test: Test RD curve
:type test: RDCurve
:return: BD-Rate value as percentage
:rtype: float
"""
rate_anchor = np.asarray(anchor.rd_curve.rate)
dist_anchor = np.asarray(anchor.rd_curve.distortion)
rate_test = np.asarray(test.rd_curve.rate)
dist_test = np.asarray(test.rd_curve.distortion)
bd_rate = bd.bd_rate(
rate_anchor=rate_anchor,
dist_anchor=dist_anchor,
rate_test=rate_test,
dist_test=dist_test,
method=self.method,
require_matching_points = False,
)
return bd_rate
[docs]
def compute_bd_distortion(self, anchor: Any, test: Any) -> float:
"""Computes the BD-Distortion between two RD curves.
:param anchor: Anchor RD curve
:type anchor: RDCurve
:param test: Test RD curve
:type test: RDCurve
:return: BD-Distortion value
:rtype: float
"""
rate_anchor = np.asarray(anchor.rd_curve.rate)
dist_anchor = np.asarray(anchor.rd_curve.distortion)
rate_test = np.asarray(test.rd_curve.rate)
dist_test = np.asarray(test.rd_curve.distortion)
bd_psnr = bd.bd_psnr(
rate_anchor=rate_anchor,
dist_anchor=dist_anchor,
rate_test=rate_test,
dist_test=dist_test,
method=self.method,
require_matching_points = False,
)
return bd_psnr
[docs]
def create_json_report(
self,
anchor: Any,
rd_curves: List[Any],
lightfield: LightField,
rd_plot_config: Dict[str, Any],
configuration_reader: ConfigurationReader
) -> Path:
"""Creates JSON report with BD rates, maintaining lightfield order from configuration.
:param anchor: Anchor RD curve view
:type anchor: Any
:param rd_curves: List of RD curve views to compare
:type rd_curves: List[Any]
:param lightfield: Light field data
:type lightfield: LightField
:param rd_plot_config: RD plot configuration
:type rd_plot_config: Dict[str, Any]
:param configuration_reader: Configuration reader instance
:type configuration_reader: ConfigurationReader
:return: Path to the created JSON report file
:rtype: Path
"""
lightfield_name = lightfield.name
lightfields = configuration_reader["lightfields"]["ctc"]
metrics = rd_plot_config["metrics"]
anchor_encoder = anchor.rd_curve.codec_name
anchor_label = anchor.rd_curve.title
output_path = Path(configuration_reader["bd_reports"]["bd_results_path"])
output_path.mkdir(parents=True, exist_ok=True)
json_filename = output_path / f"bd_report_for_anchor_{anchor_encoder}.json"
report = []
if json_filename.exists() and json_filename.stat().st_size > 0:
try:
with open(json_filename, "r") as json_file:
report = json.load(json_file)
except json.JSONDecodeError:
report = []
for rd_curve in (rc for rc in rd_curves if rc != anchor):
metric = rd_curve.rd_curve.distortion_name
test_encoder = rd_curve.rd_curve.codec_name
test_label = rd_curve.rd_curve.title
existing_entry = next(
(e for e in report
if e["anchor-encoder"] == anchor_encoder
and e["test-encoder"] == test_encoder),
None
) or {
"anchor-encoder": anchor_encoder,
"anchor-label": anchor_label,
"test-encoder": test_encoder,
"test-label": test_label,
"interpolation-method": self.method,
"results": OrderedDict()
}
if existing_entry not in report:
report.append(existing_entry)
bd_rate = self.compute_bd_rate(anchor=anchor, test=rd_curve)
bd_distortion = self.compute_bd_distortion(anchor=anchor, test=rd_curve)
if bd_rate is None or bd_distortion is None:
continue
existing_entry["results"].setdefault(metric, OrderedDict())
existing_entry["results"][metric][lightfield_name] = {
"bd-rate": bd_rate,
"overlap-bd-rate": "WIP",
"bd-distortion": bd_distortion,
"overlap-bd-distortion": "WIP"
}
for entry in report:
entry["results"] = {
metric: OrderedDict(
(lf, entry["results"][metric][lf])
for lf in lightfields
if lf in entry["results"].get(metric, {})
)
for metric in metrics
if metric in entry["results"]
}
with open(json_filename, "w") as json_file:
json.dump(report, json_file, indent=2)
return json_filename
[docs]
@staticmethod
def create_bd_rate_tables(configuration_reader: ConfigurationReader) -> None:
"""Generates BD-rate tables from JSON report files.
:param configuration_reader: Configuration reader instance
:type configuration_reader: ConfigurationReader
:return: None
:rtype: None
"""
# Get paths and configurations
results_path = Path(configuration_reader["bd_reports"]["bd_results_path"])
formats = configuration_reader["bd_reports"]["table_formats"]
print_text = configuration_reader["bd_reports"].get("print_text", True)
# Process all JSON files in the results directory
for json_file in results_path.glob("*.json"):
with open(json_file, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError:
continue
for entry in data:
anchor = entry["anchor-encoder"]
test = entry["test-encoder"]
for metric in entry["results"]:
# Create and configure table
table = prettytable.PrettyTable()
table.set_style(prettytable.MSWORD_FRIENDLY)
title = f"{metric}-based BD-Rates (w.r.t {anchor})"
table.title = title
table.field_names = ["Category", "Lightfield", test]
# Populate table rows
for lf in entry["results"][metric]:
if lf in configuration_reader.lightfield_names:
category = configuration_reader.get_lightfield_configuration(name=lf).get("category", "None")
bd_rate = entry["results"][metric][lf]["bd-rate"]
table.add_row([category, lf, f"{bd_rate:.2f}"])
# Print table if enabled
if print_text:
print('\n' + table.get_string())
# Generate output files for each format
for fmt in formats:
output_dir = results_path / "tables" / fmt
output_dir.mkdir(parents=True, exist_ok=True)
filename = output_dir / f"{metric}_BD_rates_wrt_{anchor}.{fmt}"
# Generate content based on format
if fmt == "latex":
content = table.get_latex_string().replace(
r"\begin{tabular}",
f"\\begin{{tabular}}\n\\caption{{{title}}}",
1
)
elif fmt == "html" or fmt == "mediawiki":
content = table.get_html_string()
elif fmt == "json":
content = table.get_json_string()
elif fmt == "csv":
content = table.get_csv_string()
elif fmt == "text":
content = table.get_string()
# Write to file
with open(filename, 'w') as f:
f.write(content)