Source code for src.quality.heatmap

import json
import os
from pathlib import Path
from typing import Any, Callable, List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np


from lfc_toolkit.src.data_handlers.encoding_orders import get_serpentine_scan_list
from lfc_toolkit.src.data_handlers.formatters import get_formatted_filename_for_lf
from lfc_toolkit.src.quality.compound_metrics import CompoundMetrics
from lfc_toolkit.src.quality.rd_plot import get_all_metrics


[docs] class HeatmapFromVMAFReports: """Builds heatmap data from VMAF reports."""
[docs] def __init__( self, path: Union[str, Path], lightfield: Any, tool_name: str, target_rate: float, distortion_name: str, distortion_unit: str, title: str, scan_order: Callable = get_serpentine_scan_list, configuration: Optional[Any] = None, ) -> None: """Initializes HeatmapFromVMAFReports by reading reports from path. :param path: Path to report directory :type path: Union[str, Path] :param lightfield: Lightfield data :type lightfield: Any :param tool_name: Name of the quality tool :type tool_name: str :param target_rate: Target bits per pixel :type target_rate: float :param distortion_name: Name of the distortion metric :type distortion_name: str :param distortion_unit: Unit for the distortion metric :type distortion_unit: str :param title: Plot title :type title: str :param scan_order: Scan order function, defaults to get_serpentine_scan_list :type scan_order: Callable :param configuration: Optional configuration, defaults to None :type configuration: Optional[Any] :return: None :rtype: None """ self._distortion_unit = distortion_unit self._title = title self._metric_values = HeatmapFromVMAFReports.read_reports( path, lightfield, tool_name=tool_name, target_rate=target_rate, distortion_name=distortion_name, scan_order=scan_order, configuration=configuration, )
[docs] @staticmethod def read_reports( path: Union[str, Path], lightfield: Any, tool_name: str, target_rate: float, distortion_name: str, scan_order: Callable = get_serpentine_scan_list, configuration: Optional[Any] = None, ) -> Any: """Reads VMAF reports and returns metric values as a 2D array. :param path: Path to report directory :type path: Union[str, Path] :param lightfield: Lightfield data :type lightfield: Any :param tool_name: Name of the quality tool :type tool_name: str :param target_rate: Target bits per pixel :type target_rate: float :param distortion_name: Name of the distortion metric :type distortion_name: str :param scan_order: Scan order function, defaults to get_serpentine_scan_list :type scan_order: Callable :param configuration: Optional configuration, defaults to None :type configuration: Optional[Any] :return: 2D array of metric values (transposed) :rtype: Any """ output_vmaf_report = get_formatted_filename_for_lf( path=path, lightfield=lightfield, bpp=target_rate, file_extension="json", extra=tool_name, ) if not os.path.isfile(output_vmaf_report): print(f"File not found: {output_vmaf_report}") return [] with open(output_vmaf_report, "r") as f: data = json.load(f) metric_values = np.zeros((lightfield.n_views_width, lightfield.n_views_height)) scan_list = scan_order( lightfield.n_views_width, lightfield.n_views_height, 0, 0, 1, 1 ) for s, i in zip(scan_list, range(len(scan_list))): frame_metrics = data["frames"][i]["metrics"] metric_values[s] = CompoundMetrics.resolve_metric_value( frame_metrics=frame_metrics, distortion_name=distortion_name, configuration=configuration, generating_heatmap=True, ) return metric_values.T
[docs] class HeatmapMatplotlib: """Plots heatmaps using matplotlib from VMAF reports."""
[docs] def __init__( self, path: Union[str, Path], lightfield: Any, tool_name: str, target_rate: float, distortion_name: str, distortion_unit: str, title: str, scan_order: Callable = get_serpentine_scan_list, configuration: Optional[Any] = None, ) -> None: """Initializes HeatmapMatplotlib with report path and configuration. :param path: Path to report directory :type path: Union[str, Path] :param lightfield: Lightfield data :type lightfield: Any :param tool_name: Name of the quality tool :type tool_name: str :param target_rate: Target bits per pixel :type target_rate: float :param distortion_name: Name of the distortion metric :type distortion_name: str :param distortion_unit: Unit for the distortion metric :type distortion_unit: str :param title: Plot title :type title: str :param scan_order: Scan order function, defaults to get_serpentine_scan_list :type scan_order: Callable :param configuration: Optional configuration, defaults to None :type configuration: Optional[Any] :return: None :rtype: None """ self._target_rate = target_rate self._lightfield = lightfield self._distortion_name = distortion_name self._distortion_unit = distortion_unit self._configuration = configuration self._heatmap = HeatmapFromVMAFReports( path=path, lightfield=lightfield, tool_name=tool_name, target_rate=target_rate, distortion_name=distortion_name, distortion_unit=distortion_unit, title=title, scan_order=scan_order, configuration=configuration, )
[docs] def plot( self, show_values: bool = True, filename: Optional[Union[str, Path]] = None, codec_name: Optional[str] = None, cmap: str = "viridis", ) -> None: """Plots the heatmap, optionally saving to file. :param show_values: Whether to overlay metric values on the heatmap, defaults to True :type show_values: bool :param filename: Optional path to save the plot, defaults to None :type filename: Optional[Union[str, Path]] :param codec_name: Optional codec name for the title, defaults to None :type codec_name: Optional[str] :param cmap: Colormap name, defaults to "viridis" :type cmap: str :return: None :rtype: None """ detail_for_title = f"{self._target_rate} bpp" if codec_name: detail_for_title = f"{codec_name} - {detail_for_title}" plt.title(f"{self._lightfield.name} ({detail_for_title})") values = self._heatmap._metric_values if len(values) == 0: return None # _, ax = plt.subplots() # im = ax.imshow(self._heatmap._metric_values, cmap=plt.cm.Blues) heatmap = plt.imshow(values, cmap=cmap) cbar = plt.colorbar(heatmap) all_metrics = get_all_metrics(configuration=self._configuration) label = f"${all_metrics[self._distortion_name]['label']}$" if self._distortion_unit: cbar.set_label(f"{label} ({self._distortion_unit})") else: cbar.set_label(label) if show_values: w, h = values.shape for i in range(w): for j in range(h): plt.text( j, i, f"{values[i, j]:.2f}", ha="center", va="center", color="b", fontsize=6, ) if filename: plt.savefig(filename, bbox_inches="tight", pad_inches=0) plt.clf() plt.close() else: plt.show()