{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculating Seasonal Averages from Time Series of Monthly Means \n",
    "=====\n",
    "\n",
    "Author: [Joe Hamman](https://github.com/jhamman/)\n",
    "\n",
    "The data used for this example can be found in the [xarray-data](https://github.com/pydata/xarray-data) repository. You may need to change the path to `rasm.nc` below.\n",
    "\n",
    "Suppose we have a netCDF or `xarray.Dataset` of monthly mean data and we want to calculate the seasonal average.  To do this properly, we need to calculate the weighted average considering that each month has a different number of days."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-28T20:51:35.958210Z",
     "start_time": "2018-11-28T20:51:35.936966Z"
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import xarray as xr\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Open the `Dataset`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-28T20:51:36.072316Z",
     "start_time": "2018-11-28T20:51:36.016594Z"
    }
   },
   "outputs": [],
   "source": [
    "ds = xr.tutorial.open_dataset(\"rasm\").load()\n",
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Now for the heavy lifting:\n",
    "We first have to come up with the weights,\n",
    "- calculate the month length for each monthly data record\n",
    "- calculate weights using `groupby('time.season')`\n",
    "\n",
    "Finally, we just need to multiply our weights by the `Dataset` and sum along the time dimension.  Creating a `DataArray` for the month length is as easy as using the `days_in_month` accessor on the time coordinate.  The calendar type, in this case `'noleap'`, is automatically considered in this operation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "month_length = ds.time.dt.days_in_month\n",
    "month_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-28T20:51:36.132413Z",
     "start_time": "2018-11-28T20:51:36.073708Z"
    }
   },
   "outputs": [],
   "source": [
    "# Calculate the weights by grouping by 'time.season'.\n",
    "weights = (\n",
    "    month_length.groupby(\"time.season\") / month_length.groupby(\"time.season\").sum()\n",
    ")\n",
    "\n",
    "# Test that the sum of the weights for each season is 1.0\n",
    "np.testing.assert_allclose(weights.groupby(\"time.season\").sum().values, np.ones(4))\n",
    "\n",
    "# Calculate the weighted average\n",
    "ds_weighted = (ds * weights).groupby(\"time.season\").sum(dim=\"time\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-28T20:51:36.152913Z",
     "start_time": "2018-11-28T20:51:36.133997Z"
    }
   },
   "outputs": [],
   "source": [
    "ds_weighted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-28T20:51:36.190765Z",
     "start_time": "2018-11-28T20:51:36.154416Z"
    }
   },
   "outputs": [],
   "source": [
    "# only used for comparisons\n",
    "ds_unweighted = ds.groupby(\"time.season\").mean(\"time\")\n",
    "ds_diff = ds_weighted - ds_unweighted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-28T20:51:40.264871Z",
     "start_time": "2018-11-28T20:51:36.192467Z"
    }
   },
   "outputs": [],
   "source": [
    "# Quick plot to show the results\n",
    "notnull = pd.notnull(ds_unweighted[\"Tair\"][0])\n",
    "\n",
    "fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14, 12))\n",
    "for i, season in enumerate((\"DJF\", \"MAM\", \"JJA\", \"SON\")):\n",
    "    ds_weighted[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n",
    "        ax=axes[i, 0],\n",
    "        vmin=-30,\n",
    "        vmax=30,\n",
    "        cmap=\"Spectral_r\",\n",
    "        add_colorbar=True,\n",
    "        extend=\"both\",\n",
    "    )\n",
    "\n",
    "    ds_unweighted[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n",
    "        ax=axes[i, 1],\n",
    "        vmin=-30,\n",
    "        vmax=30,\n",
    "        cmap=\"Spectral_r\",\n",
    "        add_colorbar=True,\n",
    "        extend=\"both\",\n",
    "    )\n",
    "\n",
    "    ds_diff[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n",
    "        ax=axes[i, 2],\n",
    "        vmin=-0.1,\n",
    "        vmax=0.1,\n",
    "        cmap=\"RdBu_r\",\n",
    "        add_colorbar=True,\n",
    "        extend=\"both\",\n",
    "    )\n",
    "\n",
    "    axes[i, 0].set_ylabel(season)\n",
    "    axes[i, 1].set_ylabel(\"\")\n",
    "    axes[i, 2].set_ylabel(\"\")\n",
    "\n",
    "for ax in axes.flat:\n",
    "    ax.axes.get_xaxis().set_ticklabels([])\n",
    "    ax.axes.get_yaxis().set_ticklabels([])\n",
    "    ax.axes.axis(\"tight\")\n",
    "    ax.set_xlabel(\"\")\n",
    "\n",
    "axes[0, 0].set_title(\"Weighted by DPM\")\n",
    "axes[0, 1].set_title(\"Equal Weighting\")\n",
    "axes[0, 2].set_title(\"Difference\")\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.suptitle(\"Seasonal Surface Air Temperature\", fontsize=16, y=1.02)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-28T20:51:40.284898Z",
     "start_time": "2018-11-28T20:51:40.266406Z"
    }
   },
   "outputs": [],
   "source": [
    "# Wrap it into a simple function\n",
    "def season_mean(ds, calendar=\"standard\"):\n",
    "    # Make a DataArray with the number of days in each month, size = len(time)\n",
    "    month_length = ds.time.dt.days_in_month\n",
    "\n",
    "    # Calculate the weights by grouping by 'time.season'\n",
    "    weights = (\n",
    "        month_length.groupby(\"time.season\") / month_length.groupby(\"time.season\").sum()\n",
    "    )\n",
    "\n",
    "    # Test that the sum of the weights for each season is 1.0\n",
    "    np.testing.assert_allclose(weights.groupby(\"time.season\").sum().values, np.ones(4))\n",
    "\n",
    "    # Calculate the weighted average\n",
    "    return (ds * weights).groupby(\"time.season\").sum(dim=\"time\")"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
