Source code for utils.raw_utils

import os
import torch
from glob import glob
from xml.dom import minidom
from termcolor import colored
from .database_utils import get_raw_shift_lut, get_event_info, get_event_granule_bb_dict
import numpy as np
import geopy.distance
from tifffile import imread
from tqdm import tqdm
from rasterio.control import GroundControlPoint as GCP
from rasterio.transform import from_gcps
import rasterio

[docs]def get_bands_shift( bands_list, satellite, detector_number, downsampling=True, cfg_file_dict=None ): """It returns the number of backward shift pixels of the various bands with respect to the first band in the list. Negative shift means the bands shall translate forward. Args: bands_list (list): list of bands satellite (str, optional): "S2A" or "S2B" respectively for "Sentinel-2A" data and "Sentinel-2B" data. detector_number (int): Detectorn number. downsampling (boolean, optional): if True, shift values for downsampled bands of the chosen satellite are used. Otherwise, values for upsampled bands are used. Defaults to True. cfg_file_dict (dict, optional): dictionary containing paths to the different pyraws directories. If None, internal CSV database will be parsed. Defaults to None. Returns: list: list of relative pixel shift compared to the first band. """ lut_df = get_raw_shift_lut(satellite, detector_number, downsampling, cfg_file_dict) lut_df_keys = list(lut_df.keys()) b_m = bands_list[0] if b_m == "B02": b_m_index = -1 else: b_m_index = lut_df_keys.index(b_m) bands_shift_0 = [] BAND_SPATIAL_RESOLUTION_ACROSS_TRACK_DICT = dict( zip( list(BAND_SPATIAL_RESOLUTION_DICT.keys()), [ 20 if BAND_SPATIAL_RESOLUTION_DICT[band_name] == 60 else BAND_SPATIAL_RESOLUTION_DICT[band_name] for band_name in list(BAND_SPATIAL_RESOLUTION_DICT.keys()) ], ) ) for b_s in bands_list[1:]: if b_s == "B02": b_s_index = -1 else: b_s_index = lut_df_keys.index(b_s) if b_m_index < b_s_index: s_along_track_ms = np.array( [ BAND_SPATIAL_RESOLUTION_DICT[lut_df_keys[k + 1]] / BAND_SPATIAL_RESOLUTION_DICT[b_s] * lut_df[lut_df_keys[k + 1]][0] for k in range(b_m_index, b_s_index) ] ) s_across_track_ms = np.array( [ BAND_SPATIAL_RESOLUTION_ACROSS_TRACK_DICT[lut_df_keys[k + 1]] / BAND_SPATIAL_RESOLUTION_ACROSS_TRACK_DICT[b_s] * lut_df[lut_df_keys[k + 1]][1] for k in range(b_m_index, b_s_index) ] ) else: s_along_track_ms = np.array( [ -BAND_SPATIAL_RESOLUTION_DICT[lut_df_keys[k + 1]] / BAND_SPATIAL_RESOLUTION_DICT[b_s] * lut_df[lut_df_keys[k + 1]][0] for k in range(b_s_index, b_m_index) ] ) s_across_track_ms = np.array( [ -BAND_SPATIAL_RESOLUTION_ACROSS_TRACK_DICT[lut_df_keys[k + 1]] / BAND_SPATIAL_RESOLUTION_ACROSS_TRACK_DICT[b_s] * lut_df[lut_df_keys[k + 1]][1] for k in range(b_s_index, b_m_index) ] ) bands_shift_0.append([s_along_track_ms.sum(), s_across_track_ms.sum()]) return bands_shift_0
[docs]def get_granule_px_length( n_stacked_granules, satellite, detector_number, cropped_pixels_along=None ): """Returns the length of a granule in px Args: n_stacked_granules (int): number of stacked granules. satellite (str, optional): "S2A" or "S2B" respectively for "Sentinel-2A" data and "Sentinel-2B" data. detector_number (int): Detectorn number. cropped_pixels_along (int): Number of cropped pixels along the along-track direction Returns: int: Length of a granule in px. """ b09_size = BANDS_RAW_SHAPE_DICT["B09"][0] if cropped_pixels_along is not None: b09_size -= cropped_pixels_along b_02_b09_shift = abs( get_bands_shift(["B02", "B09"], satellite, detector_number, downsampling=False)[ 0 ][0] ) return int( np.round( (n_stacked_granules * b09_size + b_02_b09_shift) * BAND_SPATIAL_RESOLUTION_DICT["B09"] / BAND_SPATIAL_RESOLUTION_DICT["B02"] ) )
[docs]def read_Raw_granule( granule_path, bands_list, verbose=True, device=torch.device("cpu") ): """Read specific bands of an Raw Sentine2 granule, specified in "bands_list". The image contains several granules at "dir_path". Args: granule_path (str): Sentinel 2 Raw granule path. bands_list (list): bands list. verbose (bool, optional): if True, if True, verbose mode is used. Defaults to True. device (torch.device, optional): torch device. Defaults to torch.device("cpu"). Raises: ValueError: Impossible to open the images with the requested bands. Returns: list: it includes G granules, each of them is a list including torch.tensor for each Sentinel 2A image band. list: metadata including original polygon coordinates. list: metadata including polygon cloud cover percentage. """ metadata_xml_path = os.path.join(granule_path, "Inventory_Metadata.xml") granule_path = os.path.join(granule_path, "TIF") try: bands_img_paths = sorted(glob(os.path.join(granule_path, "*"))) band_name_file_dict = dict( zip(bands_list, bands_list) ) # This dictionary is to match the desired band with the file. We initialized with bands_list also as # value because they will be fixed in the next for loop. for name in bands_img_paths: band_number = name[name.find("_B") + 1 : name.find(".tif")] if name[name.find(".tif") + 1 :] == "tif" and band_number in bands_list: band_name_file_dict[ name[name.find("_B") + 1 : name.find(".tif")] ] = name sentinel_raw_granule = [] if verbose: for band in tqdm(bands_list, desc="Parsing sentinel bands"): print("Taking band: " + colored(band, "green")) band_k = imread(band_name_file_dict[band])[:, :, 0] sentinel_raw_granule.append( torch.from_numpy(band_k.astype(np.float32)).to(device) ) else: for band in bands_list: band_k = imread(band_name_file_dict[band])[:, :, 0] sentinel_raw_granule.append( torch.from_numpy(band_k.astype(np.float32)).to(device) ) except: # noqa: E722 raise ValueError( colored("Error. ", "red") + " impossible to open: " + colored(granule_path, "blue") + " with the requested bands." ) try: # Parsing XML metadata xml_content = minidom.parse(metadata_xml_path) polygon_content = xml_content.getElementsByTagName("Geographic_Localization") polygon_coords_children = polygon_content[0].getElementsByTagName("Geo_Pnt") polygon_coordinates_list = [] for point in polygon_coords_children[ :-1 ]: # Last one is excluded to avoid repetition of the first point. latitude = float(point.getElementsByTagName("LATITUDE")[0] longitude = float( point.getElementsByTagName("LONGITUDE")[0] ) polygon_coordinates_list.append([latitude, longitude]) cloud_percentage = float( xml_content.getElementsByTagName("CloudPercentage")[0] ) except: # noqa: E722 raise ValueError( colored("Error. ", "red") + " impossible to read: " + colored(xml_content, "blue") + " Raw granule metatada." ) return sentinel_raw_granule, polygon_coordinates_list, cloud_percentage
[docs]def read_Raw_event_from_path( dir_path, bands_list, verbose=True, device=torch.device("cpu") ): """Read specific bands of an Raw Sentine2 event, specified in "bands_list". The image contains several granules at "dir_path". Args: dir_path (str): Sentinel 2 Raw image path. bands_list (list): bands list. verbose (bool, optional): if True, if True, verbose mode is used. Defaults to True. device (torch.device, optional): torch device. Defaults to torch.device("cpu"). Raises: ValueError: Impossible to open the events with the requested bands. Returns: list: list of G granules, each of them is a list of the requested Sentinel 2A bands. list: list of granules paths. list: list of granules polygon coordinates. list: list of cloud coverage percentages. """ try: granules_path = sorted(glob(os.path.join(dir_path, "*"))) granules_list = [] granules_polygons_coordinates_list = [] cloud_percentages_list = [] if verbose: for granule_path in tqdm(granules_path, desc="Parsing granules..."): granule, polygon_coordinates, cloud_percentage = read_Raw_granule( granule_path, bands_list, verbose, device ) granules_list.append(granule) granules_polygons_coordinates_list.append(polygon_coordinates) cloud_percentages_list.append(cloud_percentage) else: for granule_path in granules_path: granule, polygon_coordinates, cloud_percentage = read_Raw_granule( granule_path, bands_list, verbose, device ) granules_list.append(granule) granules_polygons_coordinates_list.append(polygon_coordinates) cloud_percentages_list.append(cloud_percentage) return ( granules_list, granules_path, granules_polygons_coordinates_list, cloud_percentages_list, ) except: # noqa: E722 raise ValueError( colored("Error. ", "red") + " impossible to open: " + colored(dir_path, "blue") + " with the requested bands." )
[docs]def find_granules_names(granules_paths): """Extract name of granules from granules path. Args: granules_paths (list): list of paths to the granules. Returns: list: list of granule names. """ granule_names = [] for granule_path in granules_paths: first_name_char_pos = -granule_path[::-1].find(os.sep) granule_names.append(granule_path[first_name_char_pos:]) return granule_names
[docs]def read_Raw_event_from_database( id_event, bands_list, cfg_file_dict=None, id_raw_l1_dict=None, database="THRAWS", verbose=True, device=torch.device("cpu"), ): """Read specific bands of the Raw Sentinel-2 event ""id_event"", specified in "bands_list". Args: id_event (str): event ID. bands_list (list): bands list. cfg_file_dict (dict, optional): dictionary containing paths to the different pyraws directories. If None, internal CSV database will be parsed. id_raw_l1_dict (dict, optional): id-raw-l1 dictionary. If None, internal CSV database will be parsed. database (string, optional): database name. Defaults to "THRAWS". verbose (bool, optional): if True, if True, verbose mode is used. Defaults to True. device (torch.device, optional): torch device. Defaults to torch.device("cpu"). Raises: ValueError: impossible to find information on the database. Returns: torch.tensor: tensor containing TOA values of the requested Sentinel 2A image bands. string: expected class name. list: list of granules names. list: list of raw data useful granules. list: list of complementary raw data granules (to coregister without 0). list: list of polygon coordinates for each granule. list: list of cloud coverage percentage for each granule. dict: {useful granule : bbox}. """ try: ( raw_dir_path, _, _, expected_class, raw_useful_granules, raw_complementary_granules, _, _, ) = get_event_info(id_event, cfg_file_dict, id_raw_l1_dict, database=database) except: # noqa: E722 raise ValueError( "Impossible to find information on event: " + colored(id_event, "blue") + ". Check it is included in the database." ) ( sentinel_raw_img, granules_paths, granules_polygons_coordinates_list, cloud_percentages_list, ) = read_Raw_event_from_path(raw_dir_path, bands_list, verbose, device) granule_names = find_granules_names(granules_paths) useful_granules_bb_dict = get_event_granule_bb_dict( id_event, database=database, cfg_file_dict=cfg_file_dict ) return ( sentinel_raw_img, expected_class, granule_names, raw_useful_granules, raw_complementary_granules, granules_polygons_coordinates_list, cloud_percentages_list, useful_granules_bb_dict, )
[docs]def shift_point_coordinates(point_coordinates, point_distance_y): """Vertical shift of point coordinates. Args: point_coordinates (list): [lan, lon] - point coordinates. point_distance_y (int): vertical shift. Use negative for shift towards north. Returns: list: coordinates of the shifted point. """ if point_distance_y > 0: polygon_vertex = geopy.distance.distance(meters=point_distance_y).destination( (point_coordinates[0], point_coordinates[1]), bearing=180 ) else: polygon_vertex = geopy.distance.distance(meters=-point_distance_y).destination( (point_coordinates[0], point_coordinates[1]), bearing=0 ) return [polygon_vertex[0], polygon_vertex[1]]
[docs]def swap_latlon(poly): """Function that swaps latitude and logitude values Args: poly (list): list of points coordinates Returns: poly (list): list of points coordinates swapped. """ poly = [[x[1], x[0]] for x in poly] return poly
[docs]def export_band_to_tif(band, coords, save_path): """Export band to TIF. Args: band (torch.tensor): band to save. coords (list): list of bands coordinates [UL, BL, BR, UR]. Each point is (LON, LAT). save_path (str): save path. """ height, width = band[:, :].shape # UL #BL #BR #UR gcps = [ GCP(0, 0, *coords[0]), GCP(height, 0, *coords[1]), GCP(height, width, *coords[2]), GCP(0, width, *coords[3]), ] transform = from_gcps(gcps) kwargs = { "crs": {"init": "EPSG:4326"}, "transform": transform, "width": width, "height": height, "count": 1, "dtype": "uint16", } with, "w", **kwargs) as dst: dst.write(band.detach().cpu().numpy().astype(rasterio.uint16), 1)