import matplotlib.pyplot as plt
import torch
import numpy as np
import matplotlib.patches as patches
def image_histogram_equalization(image, number_bins=255):
# from http://www.janeriksolem.net/histogram-equalization-with-python-and.html
# get image histogram
image_histogram, bins = np.histogram(image.flatten(), number_bins, density=True)
cdf = image_histogram.cumsum() # cumulative distribution function
cdf = (number_bins - 1) * cdf / cdf[-1] # normalize
# use linear interpolation of cdf to find new pixel values
image_equalized = np.interp(image.flatten(), bins[:-1], cdf)
return image_equalized.reshape(image.shape), cdf
[docs]def equalize_tensor(raw_granule_tensor, n_std=2):
"""Equalizes a tensor for a better visualization by clipping outliers of a histogram higher and lower than
pixels value mean \*- n_std times the standarda deviation.
Args:
raw_granule_tensor (torch.tensor): tensor to equalize.
n_std (int, optional): Number of times the standard deviation. Defaults to 2.
Returns:
torch.tensor: equalized tensor.
"""
raw_granule_tensor_equalized = raw_granule_tensor.clone()
for n in range(raw_granule_tensor.shape[2]):
band = raw_granule_tensor_equalized[:, :, n]
band_mean, band_std = band.mean(), band.std()
# Histogram clipping:
band[band < band_mean - n_std * band_std] = band_mean - n_std * band_std
band[band > band_mean + n_std * band_std] = band_mean + n_std * band_std
band, cdf = image_histogram_equalization(band.numpy(), number_bins=2**16)
band = torch.from_numpy(band)
# band_clahe = clahe.apply((band.numpy() * CONVERSION ).astype(np.uint8))
# raw_granule_tensor_equalized[:,:,n]= torch.from_numpy(band_clahe/CONVERSION)
raw_granule_tensor_equalized[:, :, n] = band
return raw_granule_tensor_equalized
[docs]def plot_img1_vs_img2_bands(
img1_band,
img2_band,
img_name_list,
alert_matrix=None,
alert_matrix_unregistered=None,
save_path=None,
):
"""Util function to visualize and compare the bands of two different images. It also allows adding an alert matrix.
Args:
img1_band (torch.tensor): first image band.
img2_band (torch.tensor): second image band.
img_name_list (list): list of names of different images.
alert_matrix (torch.tensor, opional): if not None, the hotmap of normal band is shown. Defaults to None.
alert_matrix_unregistered (torch.tensor, opional): if not None, the hotmap of unregstered band is shown.
Defaults to None.
save_path (string, optional): if not None, the image is saved at save_path. Defaults to None.
"""
cmap = "bone"
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(img1_band.detach().cpu().numpy(), cmap=cmap)
if alert_matrix is not None:
ax1.contour(alert_matrix.detach().cpu().numpy(), colors="r")
ax1.grid(False)
ax1.axis("off")
ax1.title.set_text(img_name_list[0])
ax2.imshow(img2_band.detach().cpu().numpy(), cmap=cmap)
if alert_matrix_unregistered is not None:
ax2.contour(alert_matrix_unregistered.detach().cpu().numpy(), colors="r")
ax2.grid(False)
ax2.axis("off")
ax2.title.set_text(img_name_list[1])
fig.tight_layout()
plt.show()
if save_path is not None:
plt.savefig(save_path)
[docs]def plot_event(img, img_name, bbox_list, alert_matrix=None, save_path=None):
"""Util function to visualize and compare the bands of two different images. It also allows adding an alert matrix.
Args:
img (torch.tensor): img.
img_name (string): image_name.
bbox_list (skimage properties): bbox list.
alert_matrix (torch.tensor, opional): if not None, the hotmap of normal band is shown. Defaults to None.
save_path (string, optional): if not None, the image is saved at save_path. Defaults to None.
"""
cmap = "bone"
fig, ax = plt.subplots()
ax.imshow(img.detach().cpu().numpy(), cmap=cmap)
if alert_matrix is not None:
ax.contour(alert_matrix.detach().cpu().numpy(), colors="r")
ax.grid(False)
ax.axis("off")
ax.title.set_text(img_name)
for prop in bbox_list:
bbox = prop.bbox # x, y, width, height
rect = patches.Rectangle(
(bbox[1], bbox[0]),
abs(bbox[1] - bbox[3]),
abs(bbox[0] - bbox[2]),
linewidth=2,
edgecolor="y",
facecolor="none",
)
ax.add_patch(rect)
fig.tight_layout()
if save_path is not None:
plt.savefig(save_path)