import numpy as np
import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as dist
import jax
from jax import numpy as jnp
from celerite2.jax import GaussianProcess, terms
import arviz
from .utils import floatX, rng_key
from .reflected import reflected_phase_curve
from .thermal import thermal_phase_curve
from .tp import get_Tarr, polynomial_order, element_number, Parr, dParr
from .spectrum import (
exojax_spectrum, res_vis, nus, wav, stellar_spectrum, stellar_spectrum_vis
)
from .hatp7 import (
get_observed_depths, get_planet_params
)
from .lightcurve import (
get_light_curve, eclipse_model, get_filter
)
__all__ = [
'model',
'run_mcmc',
'get_model_kwargs'
]
(all_depths, all_depths_errs, all_wavelengths,
kepler_mean_wl) = get_observed_depths()
def estimate_ellipsoidal_amplitude(mass, rstar, mstar, period):
ellipsoidal_amplitude_estimate = (
mass / 0.077 * rstar ** 3 * mstar ** -2 * period ** -2
)
return ellipsoidal_amplitude_estimate
def estimate_doppler_amplitude(mass, mstar, period):
doppler_amplitude_estimate = (
mass / 0.37 * mstar**(-2/3) * period**(-1/3)
)
return doppler_amplitude_estimate
[docs]def model(
n_temps, phase, time, y, yerr, eclipse_numpy,
filt_wavelength, filt_trans, a_rs, a_rp, T_s, rprs,
mstar, mass, period, rstar, nus=nus, wav=wav, Parr=Parr, dParr=dParr,
stellar_spectrum=stellar_spectrum,
res=res_vis,
predict=False
):
"""
The full joint model passed to numpyro.
Parameters
----------
n_temps : int
Number of temperatures in the T-P profile
phase : numpy.ndarray
Phase of the planetary orbit
time : numpy.ndarray
Time in BJD
y : numpy.ndarray
Normalized flux in ppm
yerr : numpy.ndarray
Normalized flux error in ppm
eclipse_numpy : numpy.ndarray
Normalized eclipse vector
filt_wavelength : numpy.ndarray
Filter transmittance wavelength array
filt_trans : numpy.ndarray
Filter transmittance array
a_rs : float
Semimajor axis normalized by stellar radius
a_rp : float
Semimajor axis normalized by the planetary radius
T_s : float
Stellar effective temperature
rprs : float
Radius ratio
mstar : float
Stellar mass in solar masses
mass : float
Planet mass in Jupiter masses
period : float
Orbital period [d]
rstar : float
Stellar radius in solar radii
nus : numpy.ndarray
Frequencies sampled in the spectrum
wav : numpy.ndarray
Wavelengths sampled in the spectrum
Parr : numpy.ndarray
Pressure array at each temperature in the T-P profile
dParr : numpy.ndarray
Delta pressure array at each temperature in the T-P profile
stellar_spectrum_vis : numpy.ndarray
Spectrum of the star
res : float
Spectral resolution
predict : bool
Turn on or off the gaussian process ``predict`` features
"""
temps = numpyro.sample(
"temperatures", dist.Uniform(low=500, high=5000),
sample_shape=(n_temps,)
)
n_grid_points = 150
phases_grid = jnp.linspace(0 + 0.01, 1 - 0.01, n_grid_points, dtype=floatX)
xi_grid = jnp.linspace(-np.pi + 0.01, np.pi - 0.01, n_grid_points,
dtype=floatX)
xi = 2 * np.pi * (phase - 0.5)
# Define reflected light phase curve model according to
# Heng, Morris & Kitzmann (2021)
omega = numpyro.sample('omega', dist.Uniform(low=0, high=1))
g = numpyro.sample('g',
dist.TwoSidedTruncatedDistribution(
dist.Normal(loc=0, scale=0.01),
low=-0.1, high=0.1)
)
reflected_ppm_grid, A_g = reflected_phase_curve(
phases_grid, omega, g, a_rp
)
# reflected_ppm = interpolate(phases_grid, reflected_ppm_grid, phase)
reflected_ppm = jnp.interp(phase, phases_grid, reflected_ppm_grid)
numpyro.deterministic('A_g', A_g)
ellipsoidal_amp_estimate = estimate_ellipsoidal_amplitude(
mass, rstar, mstar, period
)
doppler_amp_estimate = estimate_doppler_amplitude(
mass, mstar, period
)
# Define the ellipsoidal variation parameterization (simple sinusoid)
ellipsoidal_amp = numpyro.sample(
'ellip_amp',
dist.TwoSidedTruncatedDistribution(
dist.Normal(
loc=ellipsoidal_amp_estimate,
scale=ellipsoidal_amp_estimate/4
), low=0, high=100
)
)
ellipsoidal_model_ppm = - ellipsoidal_amp * jnp.cos(
4 * np.pi * (phase - 0.5)) + ellipsoidal_amp
# Define the doppler variation parameterization (simple sinusoid)
doppler_amp = numpyro.sample(
'doppler_amp',
dist.TwoSidedTruncatedDistribution(
dist.Normal(
loc=doppler_amp_estimate,
scale=doppler_amp_estimate/4
), low=0, high=10
)
)
doppler_model_ppm = doppler_amp * jnp.sin(2 * np.pi * phase)
# Define the thermal emission model according to description in
# Morris et al. (in prep)
n_phi = 150
n_theta = 10
phi = jnp.linspace(-2 * np.pi, 2 * np.pi, n_phi, dtype=floatX)
theta = jnp.linspace(0, np.pi, n_theta, dtype=floatX)
theta2d, phi2d = jnp.meshgrid(theta, phi)
C_11_kepler = 0.2 #numpyro.sample('C_11', dist.Uniform(low=0, high=0.55))
# hml_eps = numpyro.sample('epsilon', dist.Uniform(low=0, high=8 / 5))
hml_f = 0.73 #(2 / 3 - hml_eps * 5 / 12) ** 0.25
delta_phi = 0 #numpyro.sample(
# 'delta_phi',
# dist.TwoSidedTruncatedDistribution(
# dist.Normal(loc=0, scale=0.05),
# low=-np.pi/4, high=np.pi/4
# )
# )
A_B = 0.0
# Compute the thermal phase curve with zero phase offset
thermal_grid, temp_map = thermal_phase_curve(
xi_grid, delta_phi, 4.5, 0.6, C_11_kepler, T_s, a_rs, 1 / a_rp, A_B,
theta2d, phi2d, filt_wavelength, filt_trans, hml_f
)
# thermal = interpolate(xi_grid, 1e6 * thermal_grid, xi)
thermal = jnp.interp(xi, xi_grid, 1e6 * thermal_grid)
# epsilon = 8 * nightside**4 / (3 * dayside**4 + 5 * nightside**4)
# f = (2 / 3 - hml_eps * 5 / 12) ** 0.25
# numpyro.deterministic('f', f)
# numpyro.deterministic('epsilon', epsilon)
# Define the composite phase curve model
flux_norm = (eclipse_numpy *
(reflected_ppm + thermal) + doppler_model_ppm + ellipsoidal_model_ppm
)
flux_norm -= jnp.mean(flux_norm)
sigma = numpyro.sample(
"sigma", dist.TwoSidedTruncatedDistribution(
dist.Normal(loc=y.std(), scale=y.std()/10),
low=0, high=10 * y.std()
)
)
# rho = numpyro.sample(
# "rho", dist.TwoSidedTruncatedDistribution(
# dist.Normal(loc=22, scale=5),
# low=6, high=50
# )
# )
kernel = terms.Matern32Term(sigma=sigma, rho=30)
jitter = numpyro.sample('jitter', dist.Uniform(low=0, high=100))
gp = GaussianProcess(kernel, mean=flux_norm)
gp.compute(time, yerr=jnp.sqrt(yerr ** 2 + jitter ** 2),
check_sorted=False)
if predict:
gp.condition(y)
pred = gp.predict(y)
numpyro.deterministic("therm", thermal)
numpyro.deterministic("ellip", ellipsoidal_model_ppm)
numpyro.deterministic("doppl", doppler_model_ppm)
numpyro.deterministic("refle", reflected_ppm)
numpyro.deterministic("model", flux_norm)
numpyro.deterministic("resid", y - pred)
numpyro.deterministic("pred", pred)
# log_vmr_prod = numpyro.sample('log_vmr_prod',
# dist.Uniform(low=-10, high=-4))
vmr_prod = 1e-6
mmr_TiO = 1e-6 #numpyro.sample("mmr_TiO", dist.Uniform(low=-9, high=-2))
Tarr = get_Tarr(temps, Parr)
Fcgs, _, _ = exojax_spectrum(
temps, vmr_prod, mmr_TiO,
Parr, dParr, nus, wav
)
fpfs_spectrum = rprs ** 2 * Fcgs / stellar_spectrum
interp_depths = jnp.interp(
all_wavelengths, wav / 1000, fpfs_spectrum
)
numpyro.deterministic("FpFs", fpfs_spectrum)
numpyro.deterministic("interp_depths", interp_depths)
numpyro.deterministic("Tarr", Tarr)
# kepler_thermal_eclipse_depth_obs = numpyro.deterministic(
# "kep_depth", jnp.interp(0, xi_grid, thermal_grid)
# )
# kepler_thermal_eclipse_depth_model = jnp.average(
# fpfs_spectrum,
# weights=(jnp.abs(kepler_mean_wl[0] - wav / 1000) < 0.15).astype(int)
# )
numpyro.sample('phase_curve', gp.numpyro_dist(), obs=y)
numpyro.factor('spectrum',
dist.Normal(
loc=interp_depths,
scale=all_depths_errs
).log_prob(all_depths)
)
# kepler_thermal_eclipse_depth_err = 10e-6 # numpyro.sample(
# # "kep_depth_err", dist.Uniform(low=1e-6, high=500e-6)
# # )
# numpyro.factor(
# "obs_spectrum_kepler", dist.Normal(
# loc=kepler_thermal_eclipse_depth_model,
# scale=kepler_thermal_eclipse_depth_err
# ).log_prob(kepler_thermal_eclipse_depth_obs)
# )
[docs]def get_model_kwargs(quarter=None):
phase, time, flux_normed, flux_normed_err = get_light_curve(quarter=quarter)
filt_wavelength, filt_trans = get_filter()
(planet_name, a_rs, a_rp, T_s, rprs, t0, period, eclipse_half_dur, b,
rstar, rho_star, rp_rstar, mstar, mass) = get_planet_params()
model_kwargs = dict(
phase=phase.astype(floatX),
time=(time - time.mean()).astype(floatX),
y=flux_normed.astype(floatX),
yerr=flux_normed_err.astype(floatX),
eclipse_numpy=jnp.array(eclipse_model(quarter=quarter)).astype(floatX),
filt_wavelength=jnp.array(filt_wavelength.astype(floatX)),
filt_trans=jnp.array(filt_trans.astype(floatX)),
a_rs=a_rs, a_rp=a_rp, T_s=T_s, period=period,
mass=mass, mstar=mstar, rstar=rstar,
n_temps=polynomial_order * element_number + 1,
res=res_vis, rprs=rprs
)
return model_kwargs
[docs]def run_mcmc(run_title='tmp', num_warmup=5, num_samples=10, quarter=10):
"""
Run MCMC with the NUTS via numpyro.
Parameters
----------
run_title : str
Name of the run
num_warmup : int
Number of iterations in the burn-in phase
num_samples : int
Number of iterations of the sampler
quarter : int, list of ints
Kepler quarters to fit
"""
model_kwargs = get_model_kwargs(quarter=quarter)
print(f'Start MCMC, n chains = {len(jax.devices())}')
mcmc = MCMC(
sampler=NUTS(
model, dense_mass=True,
),
num_warmup=num_warmup,
num_samples=num_samples,
chain_method='parallel',
num_chains=len(jax.devices()),
)
mcmc.run(
rng_key,
**model_kwargs
)
mcmc.post_warmup_state = mcmc.last_state
print('Save first output')
arviz_mcmc = arviz.from_numpyro(mcmc)
arviz_mcmc.to_netcdf('chains_' + run_title + '0.nc')