"""
Stores the class for XSectionDisplay.
"""
# Import third party libraries
import matplotlib.pyplot as plt
import numpy as np
try:
import cartopy.crs as ccrs
CARTOPY_AVAILABLE = True
except ImportError:
CARTOPY_AVAILABLE = False
# Import Local Libs
from ..utils import data_utils
from .plot import Display
[docs]class XSectionDisplay(Display):
"""
Plots cross sections of multidimensional datasets. The data
must be able to be sliced into a 2 dimensional slice using the
xarray :func:`xarray.Dataset.sel` and :func:`xarray.Dataset.isel` commands.
This is inherited from the :func:`act.plotting.Display`
class and has therefore has the same attributes as that class.
See :func:`act.plotting.Display`
for more information. There are no additional attributes or parameters
to this class.
In order to create geographic plots, ACT needs the Cartopy package to be
installed on your system. More information about
Cartopy go here:https://scitools.org.uk/cartopy/docs/latest/.
Examples
--------
For example, if you only want to do a cross section through the first
time period of a 3D dataset called :code:`ir_temperature`, you would
do the following in xarray:
.. code-block:: python
time_slice = my_ds["ir_temperature"].isel(time=0)
The methods of this class support passing in keyword arguments into
xarray :func:`xarray.Dataset.sel` and :func:`xarray.Dataset.isel` commands
so that new datasets do not need to be created when slicing by specific time
periods or spatial slices. For example, to plot the first time period
from :code:`my_ds`, simply do:
.. code-block:: python
xsection = XSectionDisplay(my_ds, figsize=(15, 8))
xsection.plot_xsection_map(
None,
"ir_temperature",
vmin=220,
vmax=300,
cmap="Greys",
x="longitude",
y="latitude",
isel_kwargs={"time": 0},
)
Here, the array is sliced by the first time period as specified
in :code:`isel_kwargs`. The other keyword arguments are standard keyword
arguments taken by :func:`matplotlib.pyplot.pcolormesh`.
"""
def __init__(self, ds, subplot_shape=(1,), ds_name=None, **kwargs):
super().__init__(ds, subplot_shape, ds_name, **kwargs)
[docs] def set_subplot_to_map(self, subplot_index):
self.fig.delaxes(self.axes[subplot_index])
total_num_plots = self.axes.shape
if len(total_num_plots) == 2:
second_number = total_num_plots[0]
j = subplot_index[1]
else:
second_number = 1
j = 0
third_number = second_number * subplot_index[0] + j + 1
self.axes[subplot_index] = plt.subplot(
total_num_plots[0],
second_number,
third_number,
projection=ccrs.PlateCarree(),
)
[docs] def set_xrng(self, xrng, subplot_index=(0,)):
"""
Sets the x range of the plot.
Parameters
----------
xrng : 2 number array
The x limits of the plot.
subplot_index : 1 or 2D tuple, list, or array
The index of the subplot to set the x range of.
"""
if self.axes is None:
raise RuntimeError('set_xrng requires the plot to be displayed.')
if not hasattr(self, 'xrng') and len(self.axes.shape) == 2:
self.xrng = np.zeros((self.axes.shape[0], self.axes.shape[1], 2), dtype=xrng[0].dtype)
elif not hasattr(self, 'xrng') and len(self.axes.shape) == 1:
self.xrng = np.zeros((self.axes.shape[0], 2), dtype=xrng[0].dtype)
self.axes[subplot_index].set_xlim(xrng)
self.xrng[subplot_index, :] = np.array(xrng)
[docs] def set_yrng(self, yrng, subplot_index=(0,)):
"""
Sets the y range of the plot.
Parameters
----------
yrng : 2 number array
The y limits of the plot.
subplot_index : 1 or 2D tuple, list, or array
The index of the subplot to set the x range of.
"""
if self.axes is None:
raise RuntimeError('set_yrng requires the plot to be displayed.')
if not hasattr(self, 'yrng') and len(self.axes.shape) == 2:
self.yrng = np.zeros((self.axes.shape[0], self.axes.shape[1], 2), dtype=yrng[0].dtype)
elif not hasattr(self, 'yrng') and len(self.axes.shape) == 1:
self.yrng = np.zeros((self.axes.shape[0], 2), dtype=yrng[0].dtype)
if yrng[0] == yrng[1]:
yrng[1] = yrng[1] + 1
self.axes[subplot_index].set_ylim(yrng)
self.yrng[subplot_index, :] = yrng
[docs] def plot_xsection(
self,
field,
dsname=None,
x=None,
y=None,
subplot_index=(0,),
sel_kwargs=None,
isel_kwargs=None,
set_title=None,
**kwargs,
):
"""
This function plots a cross section whose x and y coordinates are
specified by the variable names either provided by the user or
automatically detected by xarray.
Parameters
----------
field : str
The name of the variable to plot.
dsname : str or None
The name of the datastream to plot from.
x : str or None
The name of the x coordinate variable.
y : str or None
The name of the y coordinate variable.
subplot_index : tuple
The index of the subplot to create the plot in.
sel_kwargs : dict
The keyword arguments to pass into :py:func:`xarray.DataArray.sel`
This is useful when your data is in 3 or more dimensions and you
want to only view a cross section on a specific x-y plane. For more
information on how to use xarray's .sel and .isel functionality
to slice datasets, see the documentation on :func:`xarray.DataArray.sel`.
isel_kwargs : dict
The keyword arguments to pass into :py:func:`xarray.DataArray.sel`
set_title : str
Title for the plot
**kwargs : keyword arguments
Additional keyword arguments will be passed into
:func:`xarray.DataArray.plot`.
Returns
-------
ax : matplotlib axis handle
The matplotlib axis handle corresponding to the plot.
"""
if dsname is None and len(self._ds.keys()) > 1:
raise ValueError(
'You must choose a datastream when there are 2 '
'or more datasets in the TimeSeriesDisplay '
'object.'
)
elif dsname is None:
dsname = list(self._ds.keys())[0]
temp_ds = self._ds[dsname].copy()
if sel_kwargs is not None:
temp_ds = temp_ds.sel(**sel_kwargs, method='nearest')
if isel_kwargs is not None:
temp_ds = temp_ds.isel(**isel_kwargs)
if (x is not None and y is None) or (y is None and x is not None):
raise RuntimeError(
'Both x and y must be specified if we are'
+ 'not trying to automatically detect them!'
)
if x is not None:
coord_list = {}
x_coord_dim = temp_ds[x].dims[0]
coord_list[x] = x_coord_dim
y_coord_dim = temp_ds[y].dims[0]
coord_list[y] = y_coord_dim
new_ds = data_utils.assign_coordinates(temp_ds, coord_list)
my_dataarray = new_ds[field]
else:
my_dataarray = temp_ds[field]
coord_keys = [key for key in my_dataarray.coords.keys()]
# X-array will sometimes shorten latitude and longitude variables
if x == 'longitude' and x not in coord_keys:
xc = 'lon'
else:
xc = x
if y == 'latitude' and y not in coord_keys:
yc = 'lat'
else:
yc = y
if x is None:
my_dataarray.plot(ax=self.axes[subplot_index], **kwargs)
else:
my_dataarray.plot(ax=self.axes[subplot_index], x=xc, y=yc, **kwargs)
the_coords = [the_keys for the_keys in my_dataarray.coords.keys()]
if x is None:
x = the_coords[0]
else:
x = coord_list[x]
if y is None:
y = the_coords[1]
else:
y = coord_list[y]
xrng = self.axes[subplot_index].get_xlim()
self.set_xrng(xrng, subplot_index)
yrng = self.axes[subplot_index].get_ylim()
self.set_yrng(yrng, subplot_index)
if set_title is None:
if 'long_name' in self._ds[dsname][field].attrs:
set_title = self._ds[dsname][field].attrs['long_name']
plt.title(set_title)
del temp_ds
return self.axes[subplot_index]
[docs] def plot_xsection_map(
self,
field,
dsname=None,
subplot_index=(0,),
coastlines=True,
background=False,
set_title=None,
**kwargs,
):
"""
Plots a cross section of 2D data on a geographical map.
Parameters
----------
field : str
The name of the variable to plot.
dsname : str or None
The name of the datastream to plot from.
subplot_index : tuple
The index of the subplot to plot inside.
coastlines : bool
Set to True to plot the coastlines.
background : bool
Set to True to plot a stock image background.
set_title : str
Title for the plot
**kwargs : keyword arguments
Additional keyword arguments will be passed into
:func:`act.plotting.XSectionDisplay.plot_xsection`
Returns
-------
ax : matplotlib axis handle
The matplotlib axis handle corresponding to the plot.
"""
if not CARTOPY_AVAILABLE:
raise ImportError(
'Cartopy needs to be installed in order to plot ' + 'cross sections on maps!'
)
self.set_subplot_to_map(subplot_index)
self.plot_xsection(
field, dsname=dsname, subplot_index=subplot_index, set_title=set_title, **kwargs
)
xlims = self.xrng[subplot_index].flatten()
ylims = self.yrng[subplot_index].flatten()
self.axes[subplot_index].set_xticks(np.linspace(round(xlims[0], 0), round(xlims[1], 0), 10))
self.axes[subplot_index].set_yticks(np.linspace(round(ylims[0], 0), round(ylims[1], 0), 10))
if coastlines:
self.axes[subplot_index].coastlines(resolution='10m')
if background:
self.axes[subplot_index].stock_img()
return self.axes[subplot_index]