444 lines
14 KiB
Python
444 lines
14 KiB
Python
import base64
|
|
import collections
|
|
import copy
|
|
import json
|
|
import math
|
|
import os
|
|
import re
|
|
import tempfile
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Iterable,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
)
|
|
from urllib.parse import urlparse, uses_netloc, uses_params, uses_relative
|
|
|
|
import numpy as np
|
|
from branca.element import Element, Figure
|
|
|
|
# import here for backwards compatibility
|
|
from branca.utilities import ( # noqa F401
|
|
_locations_mirror,
|
|
_parse_size,
|
|
none_max,
|
|
none_min,
|
|
write_png,
|
|
)
|
|
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
pd = None
|
|
|
|
|
|
TypeLine = Iterable[Sequence[float]]
|
|
TypeMultiLine = Union[TypeLine, Iterable[TypeLine]]
|
|
|
|
TypeJsonValueNoNone = Union[str, float, bool, Sequence, dict]
|
|
TypeJsonValue = Union[TypeJsonValueNoNone, None]
|
|
|
|
TypePathOptions = Union[bool, str, float, None]
|
|
|
|
TypeBounds = Sequence[Sequence[float]]
|
|
|
|
|
|
_VALID_URLS = set(uses_relative + uses_netloc + uses_params)
|
|
_VALID_URLS.discard("")
|
|
_VALID_URLS.add("data")
|
|
|
|
|
|
def validate_location(location: Sequence[float]) -> List[float]:
|
|
"""Validate a single lat/lon coordinate pair and convert to a list
|
|
|
|
Validate that location:
|
|
* is a sized variable
|
|
* with size 2
|
|
* allows indexing (i.e. has an ordering)
|
|
* where both values are floats (or convertible to float)
|
|
* and both values are not NaN
|
|
"""
|
|
if isinstance(location, np.ndarray) or (
|
|
pd is not None and isinstance(location, pd.DataFrame)
|
|
):
|
|
location = np.squeeze(location).tolist()
|
|
if not hasattr(location, "__len__"):
|
|
raise TypeError(
|
|
"Location should be a sized variable, "
|
|
"for example a list or a tuple, instead got "
|
|
f"{location!r} of type {type(location)}."
|
|
)
|
|
if len(location) != 2:
|
|
raise ValueError(
|
|
"Expected two (lat, lon) values for location, "
|
|
f"instead got: {location!r}."
|
|
)
|
|
try:
|
|
coords = (location[0], location[1])
|
|
except (TypeError, KeyError):
|
|
raise TypeError(
|
|
"Location should support indexing, like a list or "
|
|
f"a tuple does, instead got {location!r} of type {type(location)}."
|
|
)
|
|
for coord in coords:
|
|
try:
|
|
float(coord)
|
|
except (TypeError, ValueError):
|
|
raise ValueError(
|
|
"Location should consist of two numerical values, "
|
|
f"but {coord!r} of type {type(coord)} is not convertible to float."
|
|
)
|
|
if math.isnan(float(coord)):
|
|
raise ValueError("Location values cannot contain NaNs.")
|
|
return [float(x) for x in coords]
|
|
|
|
|
|
def _validate_locations_basics(locations: TypeMultiLine) -> None:
|
|
"""Helper function that does basic validation of line and multi-line types."""
|
|
try:
|
|
iter(locations)
|
|
except TypeError:
|
|
raise TypeError(
|
|
"Locations should be an iterable with coordinate pairs,"
|
|
f" but instead got {locations!r}."
|
|
)
|
|
try:
|
|
next(iter(locations))
|
|
except StopIteration:
|
|
raise ValueError("Locations is empty.")
|
|
|
|
|
|
def validate_locations(locations: TypeLine) -> List[List[float]]:
|
|
"""Validate an iterable with lat/lon coordinate pairs."""
|
|
locations = if_pandas_df_convert_to_numpy(locations)
|
|
_validate_locations_basics(locations)
|
|
return [validate_location(coord_pair) for coord_pair in locations]
|
|
|
|
|
|
def validate_multi_locations(
|
|
locations: TypeMultiLine,
|
|
) -> Union[List[List[float]], List[List[List[float]]]]:
|
|
"""Validate an iterable with possibly nested lists of coordinate pairs."""
|
|
locations = if_pandas_df_convert_to_numpy(locations)
|
|
_validate_locations_basics(locations)
|
|
try:
|
|
float(next(iter(next(iter(next(iter(locations))))))) # type: ignore
|
|
except (TypeError, StopIteration):
|
|
# locations is a list of coordinate pairs
|
|
return [validate_location(coord_pair) for coord_pair in locations] # type: ignore
|
|
else:
|
|
# locations is a list of a list of coordinate pairs, recurse
|
|
return [validate_locations(lst) for lst in locations] # type: ignore
|
|
|
|
|
|
def if_pandas_df_convert_to_numpy(obj: Any) -> Any:
|
|
"""Return a Numpy array from a Pandas dataframe.
|
|
|
|
Iterating over a DataFrame has weird side effects, such as the first
|
|
row being the column names. Converting to Numpy is more safe.
|
|
"""
|
|
if pd is not None and isinstance(obj, pd.DataFrame):
|
|
return obj.values
|
|
else:
|
|
return obj
|
|
|
|
|
|
def image_to_url(
|
|
image: Any,
|
|
colormap: Optional[Callable] = None,
|
|
origin: str = "upper",
|
|
) -> str:
|
|
"""
|
|
Infers the type of an image argument and transforms it into a URL.
|
|
|
|
Parameters
|
|
----------
|
|
image: string, file or array-like object
|
|
* If string, it will be written directly in the output file.
|
|
* If file, it's content will be converted as embedded in the
|
|
output file.
|
|
* If array-like, it will be converted to PNG base64 string and
|
|
embedded in the output.
|
|
origin: ['upper' | 'lower'], optional, default 'upper'
|
|
Place the [0, 0] index of the array in the upper left or
|
|
lower left corner of the axes.
|
|
colormap: callable, used only for `mono` image.
|
|
Function of the form [x -> (r,g,b)] or [x -> (r,g,b,a)]
|
|
for transforming a mono image into RGB.
|
|
It must output iterables of length 3 or 4, with values between
|
|
0. and 1. You can use colormaps from `matplotlib.cm`.
|
|
|
|
"""
|
|
if isinstance(image, str) and not _is_url(image):
|
|
fileformat = os.path.splitext(image)[-1][1:]
|
|
with open(image, "rb") as f:
|
|
img = f.read()
|
|
b64encoded = base64.b64encode(img).decode("utf-8")
|
|
url = f"data:image/{fileformat};base64,{b64encoded}"
|
|
elif "ndarray" in image.__class__.__name__:
|
|
img = write_png(image, origin=origin, colormap=colormap)
|
|
b64encoded = base64.b64encode(img).decode("utf-8")
|
|
url = f"data:image/png;base64,{b64encoded}"
|
|
else:
|
|
# Round-trip to ensure a nice formatted json.
|
|
url = json.loads(json.dumps(image))
|
|
return url.replace("\n", " ")
|
|
|
|
|
|
def _is_url(url: str) -> bool:
|
|
"""Check to see if `url` has a valid protocol."""
|
|
try:
|
|
return urlparse(url).scheme in _VALID_URLS
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def mercator_transform(
|
|
data: Any,
|
|
lat_bounds: Tuple[float, float],
|
|
origin: str = "upper",
|
|
height_out: Optional[int] = None,
|
|
) -> np.ndarray:
|
|
"""
|
|
Transforms an image computed in (longitude,latitude) coordinates into
|
|
the a Mercator projection image.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
data: numpy array or equivalent list-like object.
|
|
Must be NxM (mono), NxMx3 (RGB) or NxMx4 (RGBA)
|
|
|
|
lat_bounds : length 2 tuple
|
|
Minimal and maximal value of the latitude of the image.
|
|
Bounds must be between -85.051128779806589 and 85.051128779806589
|
|
otherwise they will be clipped to that values.
|
|
|
|
origin : ['upper' | 'lower'], optional, default 'upper'
|
|
Place the [0,0] index of the array in the upper left or lower left
|
|
corner of the axes.
|
|
|
|
height_out : int, default None
|
|
The expected height of the output.
|
|
If None, the height of the input is used.
|
|
|
|
See https://en.wikipedia.org/wiki/Web_Mercator for more details.
|
|
|
|
"""
|
|
|
|
def mercator(x):
|
|
return np.arcsinh(np.tan(x * np.pi / 180.0)) * 180.0 / np.pi
|
|
|
|
array = np.atleast_3d(data).copy()
|
|
height, width, nblayers = array.shape
|
|
|
|
lat_min = max(lat_bounds[0], -85.051128779806589)
|
|
lat_max = min(lat_bounds[1], 85.051128779806589)
|
|
if height_out is None:
|
|
height_out = height
|
|
|
|
# Eventually flip the image
|
|
if origin == "upper":
|
|
array = array[::-1, :, :]
|
|
|
|
lats = lat_min + np.linspace(0.5 / height, 1.0 - 0.5 / height, height) * (
|
|
lat_max - lat_min
|
|
)
|
|
latslats = mercator(lat_min) + np.linspace(
|
|
0.5 / height_out, 1.0 - 0.5 / height_out, height_out
|
|
) * (mercator(lat_max) - mercator(lat_min))
|
|
|
|
out = np.zeros((height_out, width, nblayers))
|
|
for i in range(width):
|
|
for j in range(nblayers):
|
|
out[:, i, j] = np.interp(latslats, mercator(lats), array[:, i, j])
|
|
|
|
# Eventually flip the image.
|
|
if origin == "upper":
|
|
out = out[::-1, :, :]
|
|
return out
|
|
|
|
|
|
def iter_coords(obj: Any) -> Iterator[Tuple[float, ...]]:
|
|
"""
|
|
Returns all the coordinate tuples from a geometry or feature.
|
|
|
|
"""
|
|
if isinstance(obj, (tuple, list)):
|
|
coords = obj
|
|
elif "features" in obj:
|
|
coords = [
|
|
geom["geometry"]["coordinates"]
|
|
for geom in obj["features"]
|
|
if geom["geometry"]
|
|
]
|
|
elif "geometry" in obj:
|
|
coords = obj["geometry"]["coordinates"] if obj["geometry"] else []
|
|
elif (
|
|
"geometries" in obj
|
|
and obj["geometries"][0]
|
|
and "coordinates" in obj["geometries"][0]
|
|
):
|
|
coords = obj["geometries"][0]["coordinates"]
|
|
else:
|
|
coords = obj.get("coordinates", obj)
|
|
for coord in coords:
|
|
if isinstance(coord, (float, int)):
|
|
yield tuple(coords)
|
|
break
|
|
else:
|
|
yield from iter_coords(coord)
|
|
|
|
|
|
def get_bounds(
|
|
locations: Any,
|
|
lonlat: bool = False,
|
|
) -> List[List[Optional[float]]]:
|
|
"""
|
|
Computes the bounds of the object in the form
|
|
[[lat_min, lon_min], [lat_max, lon_max]]
|
|
|
|
"""
|
|
bounds: List[List[Optional[float]]] = [[None, None], [None, None]]
|
|
for point in iter_coords(locations):
|
|
bounds = [
|
|
[
|
|
none_min(bounds[0][0], point[0]),
|
|
none_min(bounds[0][1], point[1]),
|
|
],
|
|
[
|
|
none_max(bounds[1][0], point[0]),
|
|
none_max(bounds[1][1], point[1]),
|
|
],
|
|
]
|
|
if lonlat:
|
|
bounds = _locations_mirror(bounds)
|
|
return bounds
|
|
|
|
|
|
def camelize(key: str) -> str:
|
|
"""Convert a python_style_variable_name to lowerCamelCase.
|
|
|
|
Examples
|
|
--------
|
|
>>> camelize("variable_name")
|
|
'variableName'
|
|
>>> camelize("variableName")
|
|
'variableName'
|
|
"""
|
|
return "".join(x.capitalize() if i > 0 else x for i, x in enumerate(key.split("_")))
|
|
|
|
|
|
def compare_rendered(obj1: str, obj2: str) -> bool:
|
|
"""
|
|
Return True/False if the normalized rendered version of
|
|
two folium map objects are the equal or not.
|
|
|
|
"""
|
|
return normalize(obj1) == normalize(obj2)
|
|
|
|
|
|
def normalize(rendered: str) -> str:
|
|
"""Return the input string without non-functional spaces or newlines."""
|
|
out = "".join([line.strip() for line in rendered.splitlines() if line.strip()])
|
|
out = out.replace(", ", ",")
|
|
return out
|
|
|
|
|
|
@contextmanager
|
|
def temp_html_filepath(data: str) -> Iterator[str]:
|
|
"""Yields the path of a temporary HTML file containing data."""
|
|
filepath = ""
|
|
try:
|
|
fid, filepath = tempfile.mkstemp(suffix=".html", prefix="folium_")
|
|
os.write(fid, data.encode("utf8") if isinstance(data, str) else data)
|
|
os.close(fid)
|
|
yield filepath
|
|
finally:
|
|
if os.path.isfile(filepath):
|
|
os.remove(filepath)
|
|
|
|
|
|
def deep_copy(item_original: Element) -> Element:
|
|
"""Return a recursive deep-copy of item where each copy has a new ID."""
|
|
item = copy.copy(item_original)
|
|
item._id = uuid.uuid4().hex
|
|
if hasattr(item, "_children") and len(item._children) > 0:
|
|
children_new = collections.OrderedDict()
|
|
for subitem_original in item._children.values():
|
|
subitem = deep_copy(subitem_original)
|
|
subitem._parent = item
|
|
children_new[subitem.get_name()] = subitem
|
|
item._children = children_new
|
|
return item
|
|
|
|
|
|
def get_obj_in_upper_tree(element: Element, cls: Type) -> Element:
|
|
"""Return the first object in the parent tree of class `cls`."""
|
|
parent = element._parent
|
|
if parent is None:
|
|
raise ValueError(f"The top of the tree was reached without finding a {cls}")
|
|
if not isinstance(parent, cls):
|
|
return get_obj_in_upper_tree(parent, cls)
|
|
return parent
|
|
|
|
|
|
def parse_options(**kwargs: TypeJsonValue) -> Dict[str, TypeJsonValueNoNone]:
|
|
"""Return a dict with lower-camelcase keys and non-None values.."""
|
|
return {camelize(key): value for key, value in kwargs.items() if value is not None}
|
|
|
|
|
|
def escape_backticks(text: str) -> str:
|
|
"""Escape backticks so text can be used in a JS template."""
|
|
return re.sub(r"(?<!\\)`", r"\`", text)
|
|
|
|
|
|
def escape_double_quotes(text: str) -> str:
|
|
return text.replace('"', r"\"")
|
|
|
|
|
|
def javascript_identifier_path_to_array_notation(path: str) -> str:
|
|
"""Convert a path like obj1.obj2 to array notation: ["obj1"]["obj2"]."""
|
|
return "".join(f'["{escape_double_quotes(x)}"]' for x in path.split("."))
|
|
|
|
|
|
def get_and_assert_figure_root(obj: Element) -> Figure:
|
|
"""Return the root element of the tree and assert it's a Figure."""
|
|
figure = obj.get_root()
|
|
assert isinstance(
|
|
figure, Figure
|
|
), "You cannot render this Element if it is not in a Figure."
|
|
return figure
|
|
|
|
|
|
class JsCode:
|
|
"""Wrapper around Javascript code."""
|
|
|
|
def __init__(self, js_code: Union[str, "JsCode"]):
|
|
if isinstance(js_code, JsCode):
|
|
self.js_code: str = js_code.js_code
|
|
else:
|
|
self.js_code = js_code
|
|
|
|
def __str__(self):
|
|
return self.js_code
|
|
|
|
|
|
def parse_font_size(value: Union[str, int, float]) -> str:
|
|
"""Parse a font size value, if number set as px"""
|
|
if isinstance(value, (int, float)):
|
|
return f"{value}px"
|
|
|
|
if (value[-3:] != "rem") and (value[-2:] not in ["em", "px"]):
|
|
raise ValueError("The font size must be expressed in rem, em, or px.")
|
|
return value
|