Source code for xspharm.xspharm

import numpy as np
import xarray as xr
from spharm import Spharmt


[docs] class xspharm: """ Xarray interface to spherical harmonic transforms using pyspharm. Attributes: nlat (int): Number of latitude points. nlon (int): Number of longitude points. fvor (xarray DataArray) : planetary vorticity s (Spharmt): Spharmt object for spherical harmonic operations. Methods: truncate: Truncate a data variable or entire dataset to a specific wavenumber. exp_taper: Apply tapering to spherical harmonic coefficients. uv2sfvp: Convert zonal and meridional wind components to streamfunction and velocity potential. uv2vordiv: Convert zonal and meridional wind components to vorticity and divergence. uv2absvor: Convert zonal and meridional wind components to absolute vorticity. sf2uv: Convert streamfunction to rotational wind components. vp2uv: Convert velocity potential to divergent wind components. sfvp2uv: Convert streamfunction and velocity potential to zonal and meridional wind components. """ def __init__( self, grid_ds, gridtype="regular", rsphere=6.3712e6, omega=7.292e-5, legfunc="stored", ): """ Initialize the XSpharm class with the given dataset and spherical harmonic transform parameters. Args: grid_ds (xarray.Dataset): Dataset containing the grid information. gridtype (str): Type of grid, either 'regular' or 'gaussian'. rsphere (float): Radius of the sphere, defaults to Earth's radius in meters. omega (float): Rotation rate of the planet, defaults to Earth's rotation rate. legfunc (str): Legendre function computation method, either 'stored' or 'computed'. """ self.nlat = len(grid_ds["lat"]) self.nlon = len(grid_ds["lon"]) self.fvor = 2.0 * omega * np.sin(np.deg2rad(grid_ds["lat"])) self.s = Spharmt( self.nlon, self.nlat, gridtype=gridtype, rsphere=rsphere, legfunc=legfunc )
[docs] def truncate(self, ds_or_var, ntrunc=None): """ Truncate a data variable or entire dataset. Args: - ds_or_var (xr.DataArray or xr.Dataset): Input data. - ntrunc (int, optional): Truncation wavenumber. Default is None. Returns: xr.DataArray or xr.Dataset: Truncated data. """ if isinstance(ds_or_var, xr.Dataset): # Create a new dataset by applying truncation to each data variable truncated_data = {} for var_name, var_data in ds_or_var.data_vars.items(): truncated_data[var_name] = self._truncate_var(var_data, ntrunc) return xr.Dataset(truncated_data) elif isinstance(ds_or_var, xr.DataArray): return self._truncate_var(ds_or_var, ntrunc) else: raise ValueError("Input should be an xarray Dataset or DataArray.")
def _truncate_var(self, var_ds, ntrunc=None): other_dims = _get_other_dims(var_ds) var_data = _transpose_to_spharm(var_ds) spec_data = self.s.grdtospec(var_data.values, ntrunc=ntrunc) grid_data = var_data.copy(data=self.s.spectogrd(spec_data)) return _transpose_from_spharm(grid_data, other_dims)
[docs] def exp_taper(self, ds_or_var, ntrunc=None, r=2): """ Taper (filter) the spherical harmonic coefficients using the equation provided. Args: - ds_or_var (xr.DataArray or xr.Dataset): Input data. - ntrunc (int): Truncation wavenumber N. - r (float): User-defined parameter. Returns: xr.DataArray or xr.Dataset: Tapered data. """ if isinstance(ds_or_var, xr.Dataset): tapered_data = {} for var_name, var_data in ds_or_var.data_vars.items(): tapered_data[var_name] = self._exp_taper_var(var_data, ntrunc, r) return xr.Dataset(tapered_data) elif isinstance(ds_or_var, xr.DataArray): return self._exp_taper_var(ds_or_var, ntrunc, r) else: raise ValueError("Input should be an xarray Dataset or DataArray.")
def _exp_taper_var(self, var_ds, ntrunc, r): other_dims = _get_other_dims(var_ds) var_data = _transpose_to_spharm(var_ds) # Compute spherical harmonic coefficients spec_data = self.s.grdtospec(var_data.values) # Compute l and m values based on triangular layout l_values, m_values = [], [] for m in range(self.nlat - 1 + 1): for l in range(m, self.nlat - 1 + 1): l_values.append(l) m_values.append(m) l_values = np.array(l_values, dtype=float) m_values = np.array(m_values, dtype=float) total_wavenumber = np.sqrt(l_values * (l_values + 1) + m_values**2) # Apply tapering based on the given equation taper_filter = np.exp( -( ((total_wavenumber * (total_wavenumber + 1)) / (ntrunc * (ntrunc + 1))) ** r ) ) spec_data *= ( taper_filter[:, np.newaxis] if spec_data.ndim == 2 else taper_filter ) # Convert back to grid data grid_data = var_data.copy(data=self.s.spectogrd(spec_data)) return _transpose_from_spharm(grid_data, other_dims)
[docs] def uv2sfvp(self, u_ds, v_ds, ntrunc=None): """ Streamfunction and velocity potential. Inputs: u, v: zonal and meridional winds ntrunc: Truncation limit (triangular truncation) for the spherical harmonic computation. Returns: sf, vp: The streamfunction and velocity potential respectively. """ other_dims = _get_other_dims(u_ds) u_data = _transpose_to_spharm(u_ds) v_data = _transpose_to_spharm(v_ds) psi_data, chi_data = self.s.getpsichi( u_data.values, v_data.values, ntrunc=ntrunc ) psi_ds = _transpose_from_spharm(u_data.copy(data=psi_data), other_dims) chi_ds = _transpose_from_spharm(u_data.copy(data=chi_data), other_dims) psi_ds.attrs["long_name"] = "streamfunction" psi_ds.attrs["units"] = "m**2/s" chi_ds.attrs["long_name"] = "velocity potential" chi_ds.attrs["units"] = "m**2/s" return xr.Dataset({"sf": psi_ds, "vp": chi_ds})
[docs] def uv2vordiv(self, u_ds, v_ds, ntrunc=None): """ Computes the vorticity and divergence via spherical harmonics, given the u and v wind components Inputs: u, v: zonal and meridional winds ntrunc: Truncation limit (triangular truncation) for the spherical harmonic computation. Returns: vor, div : the vorticity and divergence respectively. """ other_dims = _get_other_dims(u_ds) u_data = _transpose_to_spharm(u_ds) v_data = _transpose_to_spharm(v_ds) vor_spec, div_spec = self.s.getvrtdivspec( u_data.values, v_data.values, ntrunc=ntrunc ) vor_ds = _transpose_from_spharm( u_data.copy(data=self.s.spectogrd(vor_spec)), other_dims ) div_ds = _transpose_from_spharm( u_data.copy(data=self.s.spectogrd(div_spec)), other_dims ) vor_ds.attrs["long_name"] = "Vorticity" vor_ds.attrs["units"] = "1/s" div_ds.attrs["long_name"] = "Divergence" div_ds.attrs["units"] = "1/s" return xr.Dataset({"vor": vor_ds, "div": div_ds})
[docs] def uv2absvor(self, u_ds, v_ds, ntrunc=None): """ Computes the absolute vorticity via spherical harmonics, given the u and v wind components Inputs: u, v: zonal and meridional winds ntrunc: Truncation limit (triangular truncation) for the spherical harmonic computation. Returns: vor: The absolute vorticity """ other_dims = _get_other_dims(u_ds) u_data = _transpose_to_spharm(u_ds) v_data = _transpose_to_spharm(v_ds) vor_spec, _ = self.s.getvrtdivspec(u_data.values, v_data.values, ntrunc=ntrunc) vor_ds = _transpose_from_spharm( u_data.copy(data=self.s.spectogrd(vor_spec)), other_dims ) # add planetary vorticity (Coriolis effect) vor_ds = vor_ds + self.fvor vor_ds.attrs["long_name"] = "absolute vorticity" vor_ds.attrs["units"] = "1/s" return vor_ds
[docs] def sf2uv(self, sf_ds, ntrunc=None): """ Computes the rotational wind components via spherical harmonics, given an array containing streamfunction Inputs: sf_ds streamfunction ntrunc: Truncation limit (triangular truncation) for the spherical harmonic computation. Returns: u, v: zonal and meridional winds """ other_dims = _get_other_dims(sf_ds) sf_data = _transpose_to_spharm(sf_ds) psi_spec = self.s.grdtospec(sf_data, ntrunc=ntrunc) vpsi_data, upsi_data = self.s.getgrad(psi_spec) u_ds = _transpose_from_spharm(sf_data.copy(data=-upsi_data), other_dims) v_ds = _transpose_from_spharm(sf_data.copy(data=vpsi_data), other_dims) u_ds.attrs["long_name"] = "rotational component of U wind" u_ds.attrs["units"] = "m/s" v_ds.attrs["long_name"] = "rotational component of V wind" v_ds.attrs["units"] = "m/s" return xr.Dataset({"u_rot": u_ds, "v_rot": v_ds})
[docs] def vp2uv(self, vp_ds, ntrunc=None): """Computes the divergent wind components via spherical harmonics, given an array containing velocity potential Inputs: vp_ds velocity potential. ntrunc: Truncation limit (triangular truncation) for the spherical harmonic computation. Returns: u, v: zonal and meridional winds """ other_dims = _get_other_dims(vp_ds) vp_data = _transpose_to_spharm(vp_ds) chi_spec = self.s.grdtospec(vp_data, ntrunc=ntrunc) udiv_data, vdiv_data = self.s.getgrad(chi_spec) u_ds = _transpose_from_spharm(vp_data.copy(data=udiv_data), other_dims) v_ds = _transpose_from_spharm(vp_data.copy(data=vdiv_data), other_dims) u_ds.attrs["long_name"] = "divergent component of U wind" u_ds.attrs["units"] = "m/s" v_ds.attrs["long_name"] = "divergent component of V wind" v_ds.attrs["units"] = "m/s" return xr.Dataset({"u_div": u_ds, "v_div": v_ds})
[docs] def sfvp2uv(self, sf_ds, vp_ds, ntrunc=None): """ Computes the wind components via spherical harmonics, given streamfunction and velocity potential. Inputs: sf_ds: streamfunction vp_ds: velocity potential ntrunc: Truncation limit (triangular truncation) for the spherical harmonic computation. Returns: u, v: zonal and meridional winds """ other_dims = _get_other_dims(sf_ds) sf_data = _transpose_to_spharm(sf_ds) vp_data = _transpose_to_spharm(vp_ds) psi_spec = self.s.grdtospec(sf_data, ntrunc=ntrunc) vpsi_data, upsi_data = self.s.getgrad(psi_spec) chi_spec = self.s.grdtospec(vp_data, ntrunc=ntrunc) udiv_data, vdiv_data = self.s.getgrad(chi_spec) u_ds = _transpose_from_spharm( sf_data.copy(data=udiv_data - upsi_data), other_dims ) v_ds = _transpose_from_spharm( sf_data.copy(data=vdiv_data + vpsi_data), other_dims ) u_ds.attrs["long_name"] = "U wind" u_ds.attrs["units"] = "m/s" v_ds.attrs["long_name"] = "V wind" v_ds.attrs["units"] = "m/s" return xr.Dataset({"u": u_ds, "v": v_ds})
def _get_other_dims(input_data): """ Get other dimensions besides 'lat' and 'lon'. """ spatial_dims = ["lat", "lon"] other_dims = [dim for dim in input_data.dims if dim not in spatial_dims] return other_dims def _transpose_to_spharm(input_data): """ Transpose data to fit spharm's expected layout: [nlat, nlon, nt]. """ other_dims = _get_other_dims(input_data) if len(other_dims) == 0: return input_data elif len(other_dims) == 1: return input_data.transpose("lat", "lon", other_dims[0]) elif len(other_dims) >= 2: return input_data.stack(nt=other_dims).transpose("lat", "lon", "nt") else: raise ValueError("Input data dimensions not supported.") def _transpose_from_spharm(sp_data, other_dims): """ Transpose data back from spharm's layout. """ if len(other_dims) == 0: return sp_data elif len(other_dims) == 1: return sp_data.transpose(other_dims[0], "lat", "lon") elif len(other_dims) >= 2: return sp_data.unstack("nt").transpose(*other_dims, "lat", "lon") else: raise ValueError("Output data dimensions not supported.")