Our group, led by Rajesh Ranganath, applies probabilistic approaches to tackle a wide range of fundamental challenges in machine learning. From the NYU Courant Institute of Mathematical Sciences (Computer Science), the NYU Center for Data Science, and NYU Langone Health, we focus on areas including but not limited to:
ai for healthcare and science
- QTNet: Predicting Drug-Induced QT Prolongation With Artificial Intelligence–Enabled ElectrocardiogramsHao Zhang, Constantine Tarabanis, Neil Jethani, Mark Goldstein, Silas Smith, Larry Chinitz, Rajesh Ranganath, Yindalon Aphinyanaphongs, and Lior Jankelson2024
Prediction of drug-induced long QT syndrome (diLQTS) is of critical importance given its association with torsades de pointes. There is no reliable method for the outpatient prediction of diLQTS. This study sought to evaluate the use of a convolutional neural network (CNN) applied to electrocardiograms (ECGs) to predict diLQTS in an outpatient population. We identified all adult outpatients newly prescribed a QT-prolonging medication between January 1, 2003, and March 31, 2022, who had a 12-lead sinus ECG in the preceding 6 months. Using risk factor data and the ECG signal as inputs, the CNN QTNet was implemented in TensorFlow to predict diLQTS. Models were evaluated in a held-out test dataset of 44,386 patients (57% female) with a median age of 62 years. Compared with 3 other models relying on risk factors or ECG signal or baseline QTc alone, QTNet achieved the best (P < 0.001) performance with a mean area under the curve of 0.802 (95% CI: 0.786-0.818). In a survival analysis, QTNet also had the highest inverse probability of censorship–weighted area under the receiver-operating characteristic curve at day 2 (0.875; 95% CI: 0.848-0.904) and up to 6 months. In a subgroup analysis, QTNet performed best among males and patients ≤50 years or with baseline QTc <450 ms. In an external validation cohort of solely suburban outpatient practices, QTNet similarly maintained the highest predictive performance. An ECG-based CNN can accurately predict diLQTS in the outpatient setting while maintaining its predictive performance over time. In the outpatient setting, our model could identify higher-risk individuals who would benefit from closer monitoring.
- Quantifying impairment and disease severity using AI models trained on healthy subjectsBoyang Yu, Aakash Kaku, Kangning Liu, Avinash Parnandi, Emily Fokas, Anita Venkatesan, Natasha Pandit, Rajesh Ranganath, Heidi Schambra, and Carlos Fernandez-Grandanpj Digital Medicine 2024
Automatic assessment of impairment and disease severity is a key challenge in data-driven medicine. We propose a framework to address this challenge, which leverages AI models trained exclusively on healthy individuals. The COnfidence-Based chaRacterization of Anomalies (COBRA) score exploits the decrease in confidence of these models when presented with impaired or diseased patients to quantify their deviation from the healthy population. We applied the COBRA score to address a key limitation of current clinical evaluation of upper-body impairment in stroke patients. The gold-standard Fugl-Meyer Assessment (FMA) requires in-person administration by a trained assessor for 30-45 minutes, which restricts monitoring frequency and precludes physicians from adapting rehabilitation protocols to the progress of each patient. The COBRA score, computed automatically in under one minute, is shown to be strongly correlated with the FMA on an independent test cohort for two different data modalities: wearable sensors (ρ = 0.814, 95% CI [0.700,0.888]) and video (ρ = 0.736, 95% C.I [0.584, 0.838]). To demonstrate the generalizability of the approach to other conditions, the COBRA score was also applied to quantify severity of knee osteoarthritis from magnetic-resonance imaging scans, again achieving significant correlation with an independent clinical assessment (ρ = 0.644, 95% C.I [0.585,0.696]).
- Robust Anomaly Detection for Particle Physics Using Multi-background Representation LearningAbhijith Gandrakota, Lily H. Zhang, Aahlad Puli, Kyle Cranmer, Jennifer Ngadiuba, Rajesh Ranganath, and Nhan TranMLST 2024
Anomaly, or out-of-distribution, detection is a promising tool for aiding discoveries of new particles or processes in particle physics. In this work, we identify and address two overlooked opportunities to improve anomaly detection for high-energy physics. First, rather than train a generative model on the single most dominant background process, we build detection algorithms using representation learning from multiple background types, thus taking advantage of more information to improve estimation of what is relevant for detection. Second, we generalize decorrelation to the multi-background setting, thus directly enforcing a more complete definition of robustness for anomaly detection. We demonstrate the benefit of the proposed robust multi-background anomaly detection algorithms on a high-dimensional dataset of particle decays at the Large Hadron Collider.
- Adaptive Sampling of k-Space in Magnetic Resonance for Rapid Pathology PredictionChen-Yu Yen, Raghav Singhal, Umang Sharma, Rajesh Ranganath, Sumit Chopra, and Lerrel PintoICML 2024
Magnetic Resonance (MR) imaging, despite its proven diagnostic utility, remains an inaccessible imaging modality for disease surveillance at the population level. A major factor rendering MR inaccessible is lengthy scan times. An MR scanner collects measurements associated with the underlying anatomy in the Fourier space, also known as the k-space. Creating a high-fidelity image requires collecting large quantities of such measurements, increasing the scan time. Traditionally to accelerate an MR scan, image reconstruction from under-sampled k-space data is the method of choice. However, recent works show the feasibility of bypassing image reconstruction and directly learning to detect disease directly from a sparser learned subset of the k-space measurements. In this work, we propose Adaptive Sampling for MR (ASMR), a sampling method that learns an adaptive policy to sequentially select k-space samples to optimize for target disease detection. On 6 out of 8 pathology classification tasks spanning the Knee, Brain, and Prostate MR scans, ASMR reaches within 2% of the performance of a fully sampled classifier while using only 8% of the k-space, as well as outperforming prior state-of-the-art work in k-space sampling such as EMRT, LOUPE, and DPS.
- Deep learning models for electrocardiograms are susceptible to adversarial attackXintian Han, Yuxuan Hu, Luca Foschini, Larry Chinitz, Lior Jankelson, and Rajesh RanganathNature Medicine 2020
Electrocardiogram (ECG) acquisition is increasingly widespread in medical and commercial devices, necessitating the development of automated interpretation strategies. Recently, deep neural networks have been used to automatically analyze ECG tracings and outperform physicians in detecting certain rhythm irregularities. However, deep learning classifiers are susceptible to adversarial examples, which are created from raw data to fool the classifier such that it assigns the example to the wrong class, but which are undetectable to the human eye. Adversarial examples have also been created for medical-related tasks. However, traditional attack methods to create adversarial examples do not extend directly to ECG signals, as such methods introduce square-wave artefacts that are not physiologically plausible. Here we develop a method to construct smoothed adversarial examples for ECG tracings that are invisible to human expert evaluation and show that a deep learning model for arrhythmia detection from single-lead ECG is vulnerable to this type of attack. Moreover, we provide a general technique for collating and perturbing known adversarial examples to create multiple new ones. The susceptibility of deep learning ECG algorithms to adversarial misclassification implies that care should be taken when evaluating these models on ECGs that may have been altered, particularly when incentives for causing misclassification exist.
- A validated, real-time prediction model for favorable outcomes in hospitalized COVID-19 patientsNarges Razavian, Vincent J. Major, Mukund Sudarshan, Jesse Burk-Rafel, Peter Stella, Hardev Randhawa, Seda Bilaloglu, Ji Chen, Vuthy Nguy, Walter Wang, Hao Zhang, Ilan Reinstein, David Kudlowitz, Cameron Zenger, Meng Cao, Ruina Zhang, Siddhant Dogra, Keerthi B. Harish, Brian Bosworth, Fritz Francois, Leora I. Horwitz, Rajesh Ranganath, Jonathan Austrian, and Yindalon Aphinyanaphongs2020
The COVID-19 pandemic has challenged front-line clinical decision-making, leading to numerous published prognostic tools. However, few models have been prospectively validated and none report implementation in practice. Here, we use 3345 retrospective and 474 prospective hospitalizations to develop and validate a parsimonious model to identify patients with favorable outcomes within 96 h of a prediction, based on real-time lab values, vital signs, and oxygen support variables. In retrospective and prospective validation, the model achieves high average precision (88.6% 95% CI: [88.4–88.7] and 90.8% [90.8–90.8]) and discrimination (95.1% [95.1–95.2] and 86.8% [86.8–86.9]) respectively. We implemented and integrated the model into the EHR, achieving a positive predictive value of 93.3% with 41% sensitivity. Preliminary results suggest clinicians are adopting these scores into their clinical workflows.
distribution shift
- When more is less: Incorporating additional datasets can hurt performance by introducing spurious correlationsRhys Compton, Lily Zhang, Aahlad Puli, and Rajesh RanganathMLHC 2023
In machine learning, incorporating more data is often seen as a reliable strategy for improving model performance; this work challenges that notion by demonstrating that the addition of external datasets in many cases can hurt the resulting model’s performance. In a large-scale empirical study across combinations of four different open-source chest x-ray datasets and 9 different labels, we demonstrate that in 43% of settings, a model trained on data from two hospitals has poorer worst group accuracy over both hospitals than a model trained on just a single hospital’s data. This surprising result occurs even though the added hospital makes the training distribution more similar to the test distribution. We explain that this phenomenon arises from the spurious correlation that emerges between the disease and hospital, due to hospital-specific image artifacts. We highlight the trade-off one encounters when training on multiple datasets, between the obvious benefit of additional data and insidious cost of the introduced spurious correlation. In some cases, balancing the dataset can remove the spurious correlation and improve performance, but it is not always an effective strategy. We contextualize our results within the literature on spurious correlations to help explain these outcomes. Our experiments underscore the importance of exercising caution when selecting training data for machine learning models, especially in settings where there is a risk of spurious correlations such as with medical imaging. The risks outlined highlight the need for careful data selection and model evaluation in future research and practice.
- Learning invariant representations with missing dataMark Goldstein, Jörn-Henrik Jacobsen, Olina Chau, Adriel Saporta, Aahlad Manas Puli, Rajesh Ranganath, and Andrew MillerConference on Causal Learning and Reasoning 2022
Spurious correlations allow flexible models to predict well during training but poorly on related test distributions. Recent work has shown that models that satisfy particular independencies involving correlation-inducing nuisance variables have guarantees on their test performance. Enforcing such independencies requires nuisances to be observed during training. However, nuisances, such as demographics or image background labels, are often missing. Enforcing independence on just the observed data does not imply independence on the entire population. Here we derive MMD estimators used for invariance objectives under missing nuisances. On simulations and clinical data, optimizing through these estimates achieves test performance similar to using estimators that make use of the full data.
generative modeling
- Improving Large Language Models with Targeted Negative TrainingLily H. Zhang, Rajesh Ranganath, and Arya TafviziTMLR 2024
Generative models of language exhibit impressive capabilities but still place non-negligible probability mass over undesirable outputs. In this work, we address the task of updating a model to avoid unwanted outputs while minimally changing model behavior otherwise, a challenge we refer to as a minimal targeted update. We first formalize the notion of a minimal targeted update and propose a method to achieve such updates using negative examples from a model’s generations. Our proposed Targeted Negative Training (TNT) results in updates that keep the new distribution close to the original, unlike existing losses for negative signal which push down probability but do not control what the updated distribution will be. In experiments, we demonstrate that TNT yields a better trade-off between reducing unwanted behavior and maintaining model generation behavior than baselines, paving the way towards a modeling paradigm based on iterative training updates that constrain models from generating undesirable outputs while preserving their impressive capabilities.
- Preference Learning Algorithms do not Learn Preference RankingsAngelica Chen, Sadhika Malladi, Lily H. Zhang, Xinyi Chen, Qiuyi Zhang, Rajesh Ranganath, and Kyunghyun ChoNeurIPS 2024
Preference learning algorithms (e.g., RLHF and DPO) are frequently used to steer LLMs to produce generations that are more preferred by humans, but our understanding of their inner workings is still limited. In this work, we study the conventional wisdom that preference learning trains models to assign higher likelihoods to more preferred outputs than less preferred outputs, measured via ranking accuracy. Surprisingly, we find that most state-of-the-art preference-tuned models achieve a ranking accuracy of less than 60% on common preference datasets. We furthermore derive the idealized ranking accuracy that a preference-tuned LLM would achieve if it optimized the DPO or RLHF objective perfectly. We demonstrate that existing models exhibit a significant alignment gap – i.e., a gap between the observed and idealized ranking accuracies. We attribute this discrepancy to the DPO objective, which is empirically and theoretically ill-suited to fix even mild ranking errors in the reference model, and derive a simple and efficient formula for quantifying the difficulty of learning a given preference datapoint. Finally, we demonstrate that ranking accuracy strongly correlates with the empirically popular win rate metric when the model is close to the reference model used in the objective, shedding further light on the differences between on-policy (e.g., RLHF) and off-policy (e.g., DPO) preference learning algorithms.
- What’s the score? Automated Denoising Score Matching for Nonlinear DiffusionsRaghav Singhal, Mark Goldstein, and Rajesh RanganathICML 2024
Reversing a diffusion process by learning its score forms the heart of diffusion-based generative modeling and for estimating properties of scientific systems. The diffusion processes that are tractable center on linear processes with a Gaussian stationary distribution. This limits the kinds of models that can be built to those that target a Gaussian prior or more generally limits the kinds of problems that can be generically solved to those that have conditionally linear score functions. In this work, we introduce a family of tractable denoising score matching objectives, called local-DSM, built using local increments of the diffusion process. We show how local-DSM melded with Taylor expansions enables automated training and score estimation with nonlinear diffusion processes. To demonstrate these ideas, we use automated-DSM to train generative models using non-Gaussian priors on challenging low dimensional distributions and the CIFAR10 image dataset. Additionally, we use the automated-DSM to learn the scores for nonlinear processes studied in statistical physics.
- Stochastic interpolants with data-dependent couplingsMichael S Albergo, Mark Goldstein, Nicholas M Boffi, Rajesh Ranganath, and Eric Vanden-EijndenICML 2024 (Spotlight)
Spotlight
Generative models inspired by dynamical transport of measure – such as flows and diffusions – construct a continuous-time map between two probability densities. Conventionally, one of these is the target density, only accessible through samples, while the other is taken as a simple base density that is data-agnostic. In this work, using the framework of stochastic interpolants, we formalize how to \textitcouple the base and the target densities, whereby samples from the base are computed conditionally given samples from the target in a way that is different from (but does preclude) incorporating information about class labels or continuous embeddings. This enables us to construct dynamical transport maps that serve as conditional generative models. We show that these transport maps can be learned by solving a simple square loss regression problem analogous to the standard independent setting. We demonstrate the usefulness of constructing dependent couplings in practice through experiments in super-resolution and in-painting.
- Where to diffuse, how to diffuse, and how to get back: Automated learning for multivariate diffusionsRaghav Singhal, Mark Goldstein, and Rajesh RanganathICLR 2023
Diffusion-based generative models (DBGMs) perturb data to a target noise distribution and reverse this process to generate samples. The choice of noising process, or inference diffusion process, affects both likelihoods and sample quality. For example, extending the inference process with auxiliary variables leads to improved sample quality. While there are many such multivariate diffusions to explore, each new one requires significant model-specific analysis, hindering rapid prototyping and evaluation. In this work, we study Multivariate Diffusion Models (MDMs). For any number of auxiliary variables, we provide a recipe for maximizing a lower-bound on the MDMs likelihood without requiring any model-specific analysis. We then demonstrate how to parameterize the diffusion for a specified target noise distribution; these two points together enable optimizing the inference diffusion process. Optimizing the diffusion expands easy experimentation from just a few well-known processes to an automatic search over all linear diffusions. To demonstrate these ideas, we introduce two new specific diffusions as well as learn a diffusion process on the MNIST, CIFAR10, and ImageNet32 datasets. We show learned MDMs match or surpass bits-per-dims (BPDs) relative to fixed choices of diffusions for a given dataset and model architecture.
interpretability
- Don’t be fooled: label leakage in explanation methods and the importance of their quantitative evaluationNeil Jethani, Adriel Saporta, and Rajesh RanganathAISTATS 2023 (notable paper, oral presentation)
notable paper, oral presentation
Feature attribution methods identify which features of an input most influence a model’s output. Most widely-used feature attribution methods (such as SHAP, LIME, and Grad-CAM) are "class-dependent" methods in that they generate a feature attribution vector as a function of class. In this work, we demonstrate that class-dependent methods can "leak" information about the selected class, making that class appear more likely than it is. Thus, an end user runs the risk of drawing false conclusions when interpreting an explanation generated by a class-dependent method. In contrast, we introduce "distribution-aware" methods, which favor explanations that keep the label’s distribution close to its distribution given all features of the input. We introduce SHAP-KL and FastSHAP-KL, two baseline distribution-aware methods that compute Shapley values. Finally, we perform a comprehensive evaluation of seven class-dependent and three distribution-aware methods on three clinical datasets of different high-dimensional data types: images, biosignals, and text.
- FastSHAP: Real-Time Shapley Value EstimationNeil Jethani, Mukund Sudarshan, Ian Covert, Su-in Lee, and Rajesh RanganathICLR 2022 2022
- Have We Learned to Explain?: How Interpretability Methods Can Learn to Encode Predictions in their Interpretations.Neil Jethani, Mukund Sudarshan, Yindalon Aphinyanaphongs, and Rajesh RanganathAISTATS 2021
While the need for interpretable machine learning has been established, many common approaches are slow, lack fidelity, or hard to evaluate. Amortized explanation methods reduce the cost of providing interpretations by learning a global selector model that returns feature importances for a single instance of data. The selector model is trained to optimize the fidelity of the interpretations, as evaluated by a predictor model for the target. Popular methods learn the selector and predictor model in concert, which we show allows predictions to be encoded within interpretations. We introduce EVAL-X as a method to quantitatively evaluate interpretations and REAL-X as an amortized explanation method, which learn a predictor model that approximates the true data generating distribution given any subset of the input. We show EVAL-X can detect when predictions are encoded in interpretations and show the advantages of REAL-X through quantitative and radiologist evaluation.
out-of-distribution and anomaly detection
- Quantifying impairment and disease severity using AI models trained on healthy subjectsBoyang Yu, Aakash Kaku, Kangning Liu, Avinash Parnandi, Emily Fokas, Anita Venkatesan, Natasha Pandit, Rajesh Ranganath, Heidi Schambra, and Carlos Fernandez-Grandanpj Digital Medicine 2024
Automatic assessment of impairment and disease severity is a key challenge in data-driven medicine. We propose a framework to address this challenge, which leverages AI models trained exclusively on healthy individuals. The COnfidence-Based chaRacterization of Anomalies (COBRA) score exploits the decrease in confidence of these models when presented with impaired or diseased patients to quantify their deviation from the healthy population. We applied the COBRA score to address a key limitation of current clinical evaluation of upper-body impairment in stroke patients. The gold-standard Fugl-Meyer Assessment (FMA) requires in-person administration by a trained assessor for 30-45 minutes, which restricts monitoring frequency and precludes physicians from adapting rehabilitation protocols to the progress of each patient. The COBRA score, computed automatically in under one minute, is shown to be strongly correlated with the FMA on an independent test cohort for two different data modalities: wearable sensors (ρ = 0.814, 95% CI [0.700,0.888]) and video (ρ = 0.736, 95% C.I [0.584, 0.838]). To demonstrate the generalizability of the approach to other conditions, the COBRA score was also applied to quantify severity of knee osteoarthritis from magnetic-resonance imaging scans, again achieving significant correlation with an independent clinical assessment (ρ = 0.644, 95% C.I [0.585,0.696]).
- Robust Anomaly Detection for Particle Physics Using Multi-background Representation LearningAbhijith Gandrakota, Lily H. Zhang, Aahlad Puli, Kyle Cranmer, Jennifer Ngadiuba, Rajesh Ranganath, and Nhan TranMLST 2024
Anomaly, or out-of-distribution, detection is a promising tool for aiding discoveries of new particles or processes in particle physics. In this work, we identify and address two overlooked opportunities to improve anomaly detection for high-energy physics. First, rather than train a generative model on the single most dominant background process, we build detection algorithms using representation learning from multiple background types, thus taking advantage of more information to improve estimation of what is relevant for detection. Second, we generalize decorrelation to the multi-background setting, thus directly enforcing a more complete definition of robustness for anomaly detection. We demonstrate the benefit of the proposed robust multi-background anomaly detection algorithms on a high-dimensional dataset of particle decays at the Large Hadron Collider.
- Robustness to Spurious Correlations Improves Semantic Out-of-Distribution DetectionLily H. Zhang, and Rajesh RanganathAAAI 2023
Methods which utilize the outputs or feature representations of predictive models have emerged as promising approaches for out-of-distribution (OOD) detection of image inputs. However, these methods struggle to detect OOD inputs that share nuisance values (e.g. background) with in-distribution inputs. The detection of shared-nuisance out-of-distribution (SN-OOD) inputs is particularly relevant in real-world applications, as anomalies and in-distribution inputs tend to be captured in the same settings during deployment. In this work, we provide a possible explanation for SN-OOD detection failures and propose nuisance-aware OOD detection to address them. Nuisance-aware OOD detection substitutes a classifier trained via empirical risk minimization and cross-entropy loss with one that 1. is trained under a distribution where the nuisance-label relationship is broken and 2. yields representations that are independent of the nuisance under this distribution, both marginally and conditioned on the label. We can train a classifier to achieve these objectives using Nuisance-Randomized Distillation (NuRD), an algorithm developed for OOD generalization under spurious correlations. Output- and feature-based nuisance-aware OOD detection perform substantially better than their original counterparts, succeeding even when detection based on domain generalization algorithms fails to improve performance.
- Understanding Failures in Out-of-distribution Detection with Deep Generative ModelsLily H. Zhang, Mark Goldstein, and Rajesh RanganathICML 2021
Deep generative models (DGMs) seem a natural fit for detecting out-of-distribution (OOD) inputs, but such models have been shown to assign higher probabilities or densities to OOD images than images from the training distribution. In this work, we explain why this behavior should be attributed to model misestimation. We first prove that no method can guarantee performance beyond random chance without assumptions on which out-distributions are relevant. We then interrogate the typical set hypothesis, the claim that relevant out-distributions can lie in high likelihood regions of the data distribution, and that OOD detection should be defined based on the data distribution’s typical set. We highlight the consequences implied by assuming support overlap between in- and out-distributions, as well as the arbitrariness of the typical set for OOD detection. Our results suggest that estimation error is a more plausible explanation than the misalignment between likelihood-based OOD detection and out-distributions of interest, and we illustrate how even minimal estimation error can lead to OOD detection failures, yielding implications for future work in deep generative modeling and OOD detection.
representation learning
- Set Norm and Equivariant Residual Connections: Putting the Deep in Deep SetsLily H. Zhang, Veronica Tozzo, John Higgins, and Rajesh RanganathICML 2022
Permutation invariant neural networks are a promising tool for predictive modeling of set data. We show, however, that existing architectures struggle to perform well when they are deep. In this work, we mathematically and empirically analyze normalization layers and residual connections in the context of deep permutation invariant neural networks. We develop set norm, a normalization tailored for sets, and introduce the “clean path principle” for equivariant residual connections alongside a novel benefit of such connections, the reduction of information loss. Based on our analysis, we propose Deep Sets++ and Set Transformer++, deep models that reach comparable or better performance than their original counterparts on a diverse suite of tasks. We additionally introduce Flow-RBC, a new single-cell dataset and real-world application of permutation invariant prediction. We open-source our data and code here: https://github.com/rajesh-lab/deep_permutation_invariant.
- Learning invariant representations with missing dataMark Goldstein, Jörn-Henrik Jacobsen, Olina Chau, Adriel Saporta, Aahlad Manas Puli, Rajesh Ranganath, and Andrew MillerConference on Causal Learning and Reasoning 2022
Spurious correlations allow flexible models to predict well during training but poorly on related test distributions. Recent work has shown that models that satisfy particular independencies involving correlation-inducing nuisance variables have guarantees on their test performance. Enforcing such independencies requires nuisances to be observed during training. However, nuisances, such as demographics or image background labels, are often missing. Enforcing independence on just the observed data does not imply independence on the entire population. Here we derive MMD estimators used for invariance objectives under missing nuisances. On simulations and clinical data, optimizing through these estimates achieves test performance similar to using estimators that make use of the full data.
survival analysis
- Development and external validation of a dynamic risk score for early prediction of cardiogenic shock in cardiac intensive care units using machine learningYuxuan Hu, Albert Lui, Mark Goldstein, Mukund Sudarshan, Andrea Tinsay, Cindy Tsui, Samuel D Maidman, John Medamana, Neil Jethani, Aahlad Puli, and othersEuropean Heart Journal: Acute Cardiovascular Care 2024
Myocardial infarction and heart failure are major cardiovascular diseases that affect millions of people in the USA with morbidity and mortality being highest among patients who develop cardiogenic shock. Early recognition of cardiogenic shock allows prompt implementation of treatment measures. Our objective is to develop a new dynamic risk score, called CShock, to improve early detection of cardiogenic shock in the cardiac intensive care unit (ICU).
- Survival mixture density networksXintian Han, Mark Goldstein, and Rajesh RanganathMachine Learning for Healthcare Conference 2022
Survival analysis, the art of time-to-event modeling, plays an important role in clinical treatment decisions. Recently, continuous time models built from neural ODEs have been proposed for survival analysis. However, the training of neural ODEs is slow due to the high computational complexity of neural ODE solvers. Here, we propose an efficient alternative for flexible continuous time models, called Survival Mixture Density Networks (Survival MDNs). Survival MDN applies an invertible positive function to the output of Mixture Density Networks (MDNs). While MDNs produce flexible real-valued distributions, the invertible positive function maps the model into the time-domain while preserving a tractable density. Using four datasets, we show that Survival MDN performs better than, or similarly to continuous and discrete time baselines on concordance, integrated Brier score and integrated binomial log-likelihood. Meanwhile, Survival MDNs are also faster than ODE-based models and circumvent binning issues in discrete models.
- Inverse-weighted survival gamesXintian Han, Mark Goldstein, Aahlad Puli, Thomas Wies, Adler Perotte, and Rajesh RanganathNeurIPS 2021
Deep models trained through maximum likelihood have achieved state-of-the-art results for survival analysis. Despite this training scheme, practitioners evaluate models under other criteria, such as binary classification losses at a chosen set of time horizons, e.g. Brier score (BS) and Bernoulli log likelihood (BLL). Models trained with maximum likelihood may have poor BS or BLL since maximum likelihood does not directly optimize these criteria. Directly optimizing criteria like BS requires inverse-weighting by the censoring distribution. However, estimating the censoring model under these metrics requires inverse-weighting by the failure distribution. The objective for each model requires the other, but neither are known. To resolve this dilemma, we introduce Inverse-Weighted Survival Games. In these games, objectives for each model are built from re-weighted estimates featuring the other model, where the latter is held fixed during training. When the loss is proper, we show that the games always have the true failure and censoring distributions as a stationary point. This means models in the game do not leave the correct distributions once reached. We construct one case where this stationary point is unique. We show that these games optimize BS on simulations and then apply these principles on real world cancer and critically-ill patient data.
- X-cal: Explicit calibration for survival analysisMark Goldstein, Xintian Han, Aahlad Puli, Adler Perotte, and Rajesh RanganathNeurIPS 2020
Survival analysis models the distribution of time until an event of interest, such as discharge from the hospital or admission to the ICU. When a model’s predicted number of events within any time interval is similar to the observed number, it is called well-calibrated. A survival model’s calibration can be measured using, for instance, distributional calibration (D-CALIBRATION) (Haider et al., 2020) which computes the squared difference between the observed and predicted number of events within different time intervals. Classically, calibration is addressed in post-training analysis. We develop explicit calibration (X-CAL), which turns DCALIBRATION into a differentiable objective that can be used in survival modeling alongside maximum likelihood estimation and other objectives. X-CAL allows practitioners to directly optimize calibration and strike a desired balance between predictive power and calibration. In our experiments, we fit a variety of shallow and deep models on simulated data, a survival dataset based on MNIST, on lengthof-stay prediction using MIMIC-III data, and on brain cancer data from The Cancer Genome Atlas. We show that the models we study can be miscalibrated. We give experimental evidence on these datasets that X-CAL improves D-CALIBRATION without a large decrease in concordance or likelihood.
Rajesh Ranganath
Principal Investigator
Aahlad Puli
Postdoctoral Fellow
Yoav Wald
Postdoctoral Fellow
Xiang Gao
PhD Student
Mark Goldstein
PhD Student
Nhi Nguyen
PhD Student
Jatin Prakash
PhD Student
Adriel Saporta
PhD Student
Raghav Singhal
PhD Student
Wanqian Yang
PhD Student
Boyang Yu
PhD Student
Lily Zhang
PhD Student
Hao Zhang
PhD Student
Xintian Han
PhD Student
Neil Jethani
MD/PhD Student
Mukund Sudarshan
PhD Student
Wouter van Amsterdam
MD/PhD Student
Rhys Compton
MS Student
Courant Institute of Mathematical Sciences
New York University
60 Fifth Avenue
New York, NY 10011
rajeshr at cims dot nyu dot edu