import json
import os
from pathlib import Path
from typing import Any, Callable, List, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from lfc_toolkit.src.data_handlers.encoding_orders import get_serpentine_scan_list
from lfc_toolkit.src.data_handlers.formatters import get_formatted_filename_for_lf
from lfc_toolkit.src.quality.compound_metrics import CompoundMetrics
from lfc_toolkit.src.quality.rd_plot import get_all_metrics
[docs]
class HeatmapFromVMAFReports:
"""Builds heatmap data from VMAF reports."""
[docs]
def __init__(
self,
path: Union[str, Path],
lightfield: Any,
tool_name: str,
target_rate: float,
distortion_name: str,
distortion_unit: str,
title: str,
scan_order: Callable = get_serpentine_scan_list,
configuration: Optional[Any] = None,
) -> None:
"""Initializes HeatmapFromVMAFReports by reading reports from path.
:param path: Path to report directory
:type path: Union[str, Path]
:param lightfield: Lightfield data
:type lightfield: Any
:param tool_name: Name of the quality tool
:type tool_name: str
:param target_rate: Target bits per pixel
:type target_rate: float
:param distortion_name: Name of the distortion metric
:type distortion_name: str
:param distortion_unit: Unit for the distortion metric
:type distortion_unit: str
:param title: Plot title
:type title: str
:param scan_order: Scan order function, defaults to get_serpentine_scan_list
:type scan_order: Callable
:param configuration: Optional configuration, defaults to None
:type configuration: Optional[Any]
:return: None
:rtype: None
"""
self._distortion_unit = distortion_unit
self._title = title
self._metric_values = HeatmapFromVMAFReports.read_reports(
path,
lightfield,
tool_name=tool_name,
target_rate=target_rate,
distortion_name=distortion_name,
scan_order=scan_order,
configuration=configuration,
)
[docs]
@staticmethod
def read_reports(
path: Union[str, Path],
lightfield: Any,
tool_name: str,
target_rate: float,
distortion_name: str,
scan_order: Callable = get_serpentine_scan_list,
configuration: Optional[Any] = None,
) -> Any:
"""Reads VMAF reports and returns metric values as a 2D array.
:param path: Path to report directory
:type path: Union[str, Path]
:param lightfield: Lightfield data
:type lightfield: Any
:param tool_name: Name of the quality tool
:type tool_name: str
:param target_rate: Target bits per pixel
:type target_rate: float
:param distortion_name: Name of the distortion metric
:type distortion_name: str
:param scan_order: Scan order function, defaults to get_serpentine_scan_list
:type scan_order: Callable
:param configuration: Optional configuration, defaults to None
:type configuration: Optional[Any]
:return: 2D array of metric values (transposed)
:rtype: Any
"""
output_vmaf_report = get_formatted_filename_for_lf(
path=path,
lightfield=lightfield,
bpp=target_rate,
file_extension="json",
extra=tool_name,
)
if not os.path.isfile(output_vmaf_report):
print(f"File not found: {output_vmaf_report}")
return []
with open(output_vmaf_report, "r") as f:
data = json.load(f)
metric_values = np.zeros((lightfield.n_views_width, lightfield.n_views_height))
scan_list = scan_order(
lightfield.n_views_width, lightfield.n_views_height, 0, 0, 1, 1
)
for s, i in zip(scan_list, range(len(scan_list))):
frame_metrics = data["frames"][i]["metrics"]
metric_values[s] = CompoundMetrics.resolve_metric_value(
frame_metrics=frame_metrics,
distortion_name=distortion_name,
configuration=configuration,
generating_heatmap=True,
)
return metric_values.T
[docs]
class HeatmapMatplotlib:
"""Plots heatmaps using matplotlib from VMAF reports."""
[docs]
def __init__(
self,
path: Union[str, Path],
lightfield: Any,
tool_name: str,
target_rate: float,
distortion_name: str,
distortion_unit: str,
title: str,
scan_order: Callable = get_serpentine_scan_list,
configuration: Optional[Any] = None,
) -> None:
"""Initializes HeatmapMatplotlib with report path and configuration.
:param path: Path to report directory
:type path: Union[str, Path]
:param lightfield: Lightfield data
:type lightfield: Any
:param tool_name: Name of the quality tool
:type tool_name: str
:param target_rate: Target bits per pixel
:type target_rate: float
:param distortion_name: Name of the distortion metric
:type distortion_name: str
:param distortion_unit: Unit for the distortion metric
:type distortion_unit: str
:param title: Plot title
:type title: str
:param scan_order: Scan order function, defaults to get_serpentine_scan_list
:type scan_order: Callable
:param configuration: Optional configuration, defaults to None
:type configuration: Optional[Any]
:return: None
:rtype: None
"""
self._target_rate = target_rate
self._lightfield = lightfield
self._distortion_name = distortion_name
self._distortion_unit = distortion_unit
self._configuration = configuration
self._heatmap = HeatmapFromVMAFReports(
path=path,
lightfield=lightfield,
tool_name=tool_name,
target_rate=target_rate,
distortion_name=distortion_name,
distortion_unit=distortion_unit,
title=title,
scan_order=scan_order,
configuration=configuration,
)
[docs]
def plot(
self,
show_values: bool = True,
filename: Optional[Union[str, Path]] = None,
codec_name: Optional[str] = None,
cmap: str = "viridis",
) -> None:
"""Plots the heatmap, optionally saving to file.
:param show_values: Whether to overlay metric values on the heatmap, defaults to True
:type show_values: bool
:param filename: Optional path to save the plot, defaults to None
:type filename: Optional[Union[str, Path]]
:param codec_name: Optional codec name for the title, defaults to None
:type codec_name: Optional[str]
:param cmap: Colormap name, defaults to "viridis"
:type cmap: str
:return: None
:rtype: None
"""
detail_for_title = f"{self._target_rate} bpp"
if codec_name:
detail_for_title = f"{codec_name} - {detail_for_title}"
plt.title(f"{self._lightfield.name} ({detail_for_title})")
values = self._heatmap._metric_values
if len(values) == 0:
return None
# _, ax = plt.subplots()
# im = ax.imshow(self._heatmap._metric_values, cmap=plt.cm.Blues)
heatmap = plt.imshow(values, cmap=cmap)
cbar = plt.colorbar(heatmap)
all_metrics = get_all_metrics(configuration=self._configuration)
label = f"${all_metrics[self._distortion_name]['label']}$"
if self._distortion_unit:
cbar.set_label(f"{label} ({self._distortion_unit})")
else:
cbar.set_label(label)
if show_values:
w, h = values.shape
for i in range(w):
for j in range(h):
plt.text(
j,
i,
f"{values[i, j]:.2f}",
ha="center",
va="center",
color="b",
fontsize=6,
)
if filename:
plt.savefig(filename, bbox_inches="tight", pad_inches=0)
plt.clf()
plt.close()
else:
plt.show()