"""
Class for creating timeseries plots from ACT datasets.
"""
import warnings
# Import third party libraries
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import inspect
[docs]class Display:
    """
    This class is the base class for all of the other Display object
    types in ACT. This contains the common attributes and routines
    between the differing *Display* classes. We recommend that you
    use the classes inherited from Display for making your plots
    such as :func:`act.plotting.TimeSeriesDisplay` and
    :func:`act.plotting.WindRoseDisplay` instead of
    trying to do so using the Display object.
    However, we do ask that if you add another object to the plotting
    module of ACT that you make it a subclass of Display. Display provides
    some basic functionality for the handling of datasets and subplot
    parameters.
    Attributes
    ----------
    fields : dict
        The dictionary containing the fields inside the ARM dataset. Each field
        has a key that links to an xarray DataArray object.
    ds : str
        The name of the datastream.
    file_dates : list
        The dates of each file being displayed.
    fig : matplotlib figure handle
        The matplotlib figure handle to display the plots on. Initializing the
        class with this set to None will create a new figure handle. See the
        matplotlib documentation on what keyword arguments are
        available.
    axes : list
        The list of axes handles to each subplot.
    plot_vars : list
        The list of variables being plotted.
    cbs : list
        The list of colorbar handles.
    Parameters
    ----------
    ds : ACT xarray.Dataset, dict, or tuple
        The ACT xarray dataset to display in the object. If more than one dataset
        is to be specified, then a tuple can be used if all of the datasets
        conform to ARM standards. Otherwise, a dict with a key corresponding
        to the name of each datastream will need to be supplied in order
        to create the ability to plot multiple datasets.
    subplot_shape : 1 or 2D tuple
        A tuple representing the number of (rows, columns) for the subplots
        in the display. If this is None, the figure and axes will not
        be initialized.
    ds_name : str or None
        The name of the datastream to plot. This is only used if a non-ARM
        compliant dataset is being loaded and if only one such dataset is
        loaded.
    subplot_kw : dict, optional
        The kwargs to pass into :func:`fig.subplots`
    **kwargs : keywords arguments
        Keyword arguments passed to :func:`plt.figure`.
    """
    def __init__(self, ds, subplot_shape=(1,), ds_name=None, subplot_kw=None, **kwargs):
        if isinstance(ds, xr.Dataset):
            if 'datastream' in ds.attrs.keys() is not None:
                self._ds = {ds.attrs['datastream']: ds}
            elif ds_name is not None:
                self._ds = {ds_name: ds}
            else:
                warnings.warn(
                    (
                        'Could not discern datastream'
                        + 'name and dict or tuple were '
                        + 'not provided. Using default'
                        + 'name of act_datastream!'
                    ),
                    UserWarning,
                )
                self._ds = {'act_datastream': ds}
        # Automatically name by datastream if a tuple of datasets is supplied
        if isinstance(ds, tuple):
            self._ds = {}
            for multi_ds in ds:
                self._ds[multi_ds.attrs['datastream']] = multi_ds
        if isinstance(ds, dict):
            self._ds = ds
        self.fields = {}
        self.ds = {}
        self.file_dates = {}
        self.xrng = np.zeros((1, 2))
        self.yrng = np.zeros((1, 2))
        for dsname in self._ds.keys():
            self.fields[dsname] = self._ds[dsname].variables
            if '_datastream' in self._ds[dsname].attrs.keys():
                self.ds[dsname] = str(self._ds[dsname].attrs['_datastream'])
            else:
                self.ds[dsname] = 'act_datastream'
            if '_file_dates' in self._ds[dsname].attrs.keys():
                self.file_dates[dsname] = self._ds[dsname].attrs['_file_dates']
        self.fig = None
        self.axes = None
        self.plot_vars = []
        self.cbs = []
        if subplot_shape is not None:
            self.add_subplots(subplot_shape, subplot_kw=subplot_kw, **kwargs)
[docs]    def add_subplots(self, subplot_shape=(1,), secondary_y=False, subplot_kw=None, **kwargs):
        """
        Adds subplots to the Display object. The current
        figure in the object will be deleted and overwritten.
        Parameters
        ----------
        subplot_shape : 1 or 2D tuple, list, or array
            The structure of the subplots in (rows, cols).
        subplot_kw : dict, optional
            The kwargs to pass into fig.subplots.
        **kwargs : keyword arguments
            Any other keyword arguments that will be passed
            into :func:`matplotlib.pyplot.subplots`. See the matplotlib
            documentation for further details on what keyword
            arguments are available.
        """
        if self.fig is not None:
            plt.close(self.fig)
            del self.fig
        if len(subplot_shape) == 2:
            fig, ax = plt.subplots(
                subplot_shape[0], subplot_shape[1], subplot_kw=subplot_kw, **kwargs
            )
            self.xrng = np.zeros((subplot_shape[0], subplot_shape[1], 2))
            self.yrng = np.zeros((subplot_shape[0], subplot_shape[1], 2))
            if subplot_shape[0] == 1:
                ax = ax.reshape(1, subplot_shape[1])
        elif len(subplot_shape) == 1:
            fig, ax = plt.subplots(subplot_shape[0], 1, subplot_kw=subplot_kw, **kwargs)
            if subplot_shape[0] == 1:
                ax = np.array([ax])
            self.xrng = np.zeros((subplot_shape[0], 2))
            self.yrng = np.zeros((subplot_shape[0], 2))
        else:
            raise ValueError('subplot_shape must be a 1 or 2 dimensional' + 'tuple list, or array!')
        self.fig = fig
        self.axes = ax 
[docs]    def put_display_in_subplot(self, display, subplot_index):
        """
        This will place a Display object into a specific subplot.
        The display object must only have one subplot.
        This will clear the display in the Display object being added.
        Parameters
        ----------
        Display : Display object or subclass
            The Display object to add as a subplot
        subplot_index : tuple
            Which subplot to add the Display to.
        Returns
        -------
        ax : matplotlib axis handle
            The axis handle to the display object being added.
        """
        if len(display.axes) > 1:
            raise RuntimeError(
                'Only single plots can be made as subplots ' + 'of another Display object!'
            )
        my_projection = display.axes[0].name
        plt.close(display.fig)
        display.fig = self.fig
        self.fig.delaxes(self.axes[subplot_index])
        the_shape = self.axes.shape
        if len(the_shape) == 1:
            second_value = 1
        else:
            second_value = the_shape[1]
        self.axes[subplot_index] = self.fig.add_subplot(
            the_shape[0],
            second_value,
            (second_value - 1) * the_shape[0] + subplot_index[0] + 1,
            projection=my_projection,
        )
        display.axes = np.array([self.axes[subplot_index]])
        return display.axes[0] 
[docs]    def add_colorbar(
        self, mappable, title=None, subplot_index=(0,), pad=None, width=None, **kwargs
    ):
        """
        Adds a colorbar to the plot.
        Parameters
        ----------
        mappable : matplotlib mappable
            The mappable to base the colorbar on.
        title : str
            The title of the colorbar. Set to None to have no title.
        subplot_index : 1 or 2D tuple, list, or array
            The index of the subplot to set the x range
        pad : float
            Padding to right of plot for placement of the colorbar
        width : float
            Width of the colorbar
        **kwargs : keyword arguments
            The keyword arguments for :func:`plt.colorbar`
        Returns
        -------
        cbar : matplotlib colorbar handle
            The handle to the matplotlib colorbar.
        """
        if self.axes is None:
            raise RuntimeError('add_colorbar requires the plot ' 'to be displayed.')
        fig = self.fig
        ax = self.axes[subplot_index]
        if pad is None:
            pad = 0.01
        if width is None:
            width = 0.01
        # Give the colorbar it's own axis so the 2D plots line up with 1D
        box = ax.get_position()
        cax = fig.add_axes([box.xmax + pad, box.ymin, width, box.height])
        cbar = plt.colorbar(mappable, cax=cax, **kwargs)
        if title is not None:
            cbar.ax.set_ylabel(title, rotation=270, fontsize=8, labelpad=3)
        cbar.ax.tick_params(labelsize=6)
        self.cbs.append(cbar)
        return cbar 
[docs]    def group_by(self, units):
        """
        Group the Display by specific units of time.
        Parameters
        ----------
        units: str
            One of: 'year', 'month', 'day', 'hour', 'minute', 'second'.
            Group the plot by this unit of time (year, month, etc.)
        Returns
        -------
        groupby: act.plotting.DisplayGroupby
            The DisplayGroupby object to be retuned.
        """
        return DisplayGroupby(self, units)  
class DisplayGroupby:
    def __init__(self, display, units):
        """
        Parameters
        ----------
        display: Display
            The Display object to group by time.
        units: str
            The time units to group by. Can be one of:
            'year', 'month', 'day', 'hour', 'minute', 'second'
        """
        self.display = display
        self._groupby = {}
        self.mapping = {}
        self.xlims = {}
        self.units = units
        self.isTimeSeriesDisplay = hasattr(self.display, 'time_height_scatter')
        num_groups = 0
        datastreams = list(display._ds.keys())
        for key in datastreams:
            self._groupby[key] = display._ds[key].groupby('time.%s' % units)
            num_groups = max([num_groups, len(self._groupby[key])])
    def plot_group(self, func_name, dsname=None, **kwargs):
        """
        Plots each group created in :func:`act.plotting.Display.group_by` into each subplot of the display.
        Parameters
        ----------
        func_name: str
            The name of the plotting function in the Display that you are grouping.
        dsname: str or None
            The name of the datastream to plot
        Additional keyword objects are passed into *func_name*.
        Returns
        -------
        axis: Array of matplotlib axes handles
            The array of matplotlib axes handles that correspond to each subplot.
        """
        if dsname is None:
            dsname = list(self.display._ds.keys())[0].split('_')[0]
        func = getattr(self.display, func_name)
        if not callable(func):
            raise RuntimeError("The specified string is not a function of " "the Display object.")
        subplot_shape = self.display.axes.shape
        i = 0
        wrap_around = False
        old_ds = self.display._ds
        for key in self._groupby.keys():
            if dsname == key:
                self.display._ds = {}
                for k, ds in self._groupby[key]:
                    num_years = len(np.unique(ds.time.dt.year))
                    self.display._ds[key + '_%d' % k] = ds
                    if i >= np.prod(subplot_shape):
                        i = 0
                        wrap_around = True
                    if len(subplot_shape) == 2:
                        subplot_index = (int(i / subplot_shape[1]), i % subplot_shape[1])
                    else:
                        subplot_index = (i % subplot_shape[0],)
                    args, varargs, varkw, _, _, _, _ = inspect.getfullargspec(func)
                    if "subplot_index" in args:
                        kwargs["subplot_index"] = subplot_index
                    if "time_rng" in args:
                        kwargs["time_rng"] = (ds.time.values.min(), ds.time.values.max())
                    if num_years > 1 and self.isTimeSeriesDisplay:
                        first_year = ds.time.dt.year[0]
                        for yr, ds1 in ds.groupby('time.year'):
                            if ds1.time.dt.year[0] % 4 == 0:
                                days_in_year = 366
                            else:
                                days_in_year = 365
                            year_diff = ds1.time.dt.year - first_year
                            time_diff = np.array(
                                [np.timedelta64(x * days_in_year, 'D') for x in year_diff.values]
                            )
                            ds1['time'] = ds1.time - time_diff
                            self.display._ds[key + '%d_%d' % (k, yr)] = ds1
                            func(dsname=key + '%d_%d' % (k, yr), label=str(yr), **kwargs)
                            self.mapping[key + '%d_%d' % (k, yr)] = subplot_index
                            self.xlims[key + '%d_%d' % (k, yr)] = (
                                ds1.time.values.min(),
                                ds1.time.values.max(),
                            )
                        del self.display._ds[key + '_%d' % k]
                    else:
                        func(dsname=key + '_%d' % k, **kwargs)
                        self.mapping[key + '_%d' % k] = subplot_index
                        if self.isTimeSeriesDisplay:
                            self.xlims[key + '_%d' % k] = (
                                ds.time.values.min(),
                                ds.time.values.max(),
                            )
                    i = i + 1
        if wrap_around is False and i < np.prod(subplot_shape):
            while i < np.prod(subplot_shape):
                if len(subplot_shape) == 2:
                    subplot_index = (int(i / subplot_shape[1]), i % subplot_shape[1])
                else:
                    subplot_index = (i % subplot_shape[0],)
                self.display.axes[subplot_index].axis('off')
                i = i + 1
        for i in range(1, np.prod(subplot_shape)):
            if len(subplot_shape) == 2:
                subplot_index = (int(i / subplot_shape[1]), i % subplot_shape[1])
            else:
                subplot_index = (i % subplot_shape[0],)
            try:
                self.display.axes[subplot_index].get_legend().remove()
            except AttributeError:
                pass
        if self.isTimeSeriesDisplay:
            key_list = list(self.display._ds.keys())
            for k in key_list:
                time_min, time_max = self.xlims[k]
                subplot_index = self.mapping[k]
                self.display.set_xrng([time_min, time_max], subplot_index)
        self.display._ds = old_ds
        return self.display.axes