# CausalEM – Ensemble Matching for Causal Inference

> **CausalEM** is a toolbox for multi-arm treatment‑effect estimation using stochastic matching and a stacked ensemble of heterogeneous ML models. It supports continuous, binary, and survival outcomes.

---

## Key Features

1. **Stochastic nearest-neighbor (NN) matching** -> Larger effective sample size (ESS) and improved TE estimation accuracy compared to standard (deterministic) NN matching.
1. **G-computation using two-staged, stacked ensemble of hetrogeneous learners** -> Generalization of standard G-computation framework to ensemble learning; cross-fitting of propensity-score and outcome models, similar to DoubleML.
1. **Support for multi-arm treatments** -> Improved multi-arm ESS via stochastic matching.
1. **Support for survival outcomes** -> Use of data simulation from survival outcome models to implement stacked-ensemble for TE estimation in right-censored, time-to-event data.
1. **Bootstrapped confidence interval (CI) estimation** -> Honest estimation of CI by including entire (matching + TE estimation) pipeline in bootstrap loop.
1. **Compatible with `scikit-learn`** -> Maximum flexibility in using ML models by providing access to `scikit-learn` (and `scikit-survival` for survival) for propensity-score, outcome and meta-learner stages.
1. **Full reproducibility of results** --> Careful implementation of random number generation (RNG) seeding, including in `scikit-learn` models.
<!-- 1. **Available in Python and R** -> Identical - function-centric - API in both languages using `reticulate`; combined with RNG management, leads to identical, reproducible results across the two platforms. -->

---

## API

| Function         | Brief description                                         |
| ------------------------ | --------------------------------------------------------- |
| `estimate_te`           | Main pipeline – ensemble matching + meta‑learner          |
| `StochasticMatcher`      | 1:1 nearest‑neighbor matcher (deterministic ↔ stochastic) |
| `summarize_matching`     | Diagnostics: ESS, ASMD, variance ratios, overlap plots    |
| `load_data_lalonde`      | Copy of Lalonde job‑training dataset                     |
| `load_data_tof` | Simulated TOF dataset (survival or binary outcome)                            |

---

## ⚙️ Installation <!--- install -->

```bash
pip install causalem
```

Optional dev extras:

```bash
pip install "causalem[dev]"
```

Minimum Python 3.9. Tested on macOS and Windows.

---

## Package Vignette

For a more detailed introduction to `CausalEM`, including the underlying math, see the _package vignette_ [insert link later], available on arXiv.

---

## 🚀 Quick Start <!--- quickstart -->

### Two-arm Analysis

Load the necessary packages:

```python
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

from causalem import (
  estimate_te,
  load_data_tof,
  stochastic_match,
  summarize_matching
)
```
Load the ToF data with two treatment levels and binarized outcome:
```python
X, t, y = load_data_tof(
  raw = False,
  treat_levels = ['PrP', 'SPS'],
  binarize_outcome=True,
)
```
Stochastic matching using propensity scores:
```python
lr = LogisticRegression(solver="newton-cg", max_iter=1000)
lr.fit(X, t)
score = lr.predict_proba(X)[:, 1]
logit_score = np.log(score / (1 - score))

cluster = stochastic_match(
    treatment=t,
    score=logit_score,
    nsmp=10,
    scale=1.0,
    random_state=0,
)

diag = summarize_matching(
  cluster, X,
  treatment=t, plot=False
)
print("Combined Effective Sample Size (ESS):", diag.ess["combined"])
print("Absolute standardized mean difference (ASMD) by covariate:\n")
print(diag.summary)
```
TE estimation:
```python
res = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
)
print("Two-arm TE:", res["te"])
```

### Multi-arm Analysis

Load data for multi-arm analysis:
```python
df = load_data_tof(
  raw = True,
  binarize_outcome=True,
)
t_all = df["treatment"].to_numpy()
X_all = df[["age", "zscore"]].to_numpy()
y_all = df["outcome"].to_numpy()
```
Constructing propensity scores using multinomial logistic regression:
```python
lr_multi = LogisticRegression(multi_class="multinomial", max_iter=1000)
lr_multi.fit(X_all, t_all)
proba = lr_multi.predict_proba(X_all)
ref = "PrP"
cols = [i for i, c in enumerate(lr_multi.classes_) if c != ref]
logit_multi = np.log(proba[:, cols] / (1 - proba[:, cols]))
```
Multi-arm stochastic matching:
```python
cluster_multi = stochastic_match(
    treatment=t_all,
    score=logit_multi,
    nsmp=5,
    scale=1.0,
    ref_group=ref,
    random_state=0,
)
diag_multi = summarize_matching(
    cluster_multi, X_all, treatment=t_all, ref_group=ref, plot=False
)
print("Multi-arm ESS per draw:\n", diag_multi.ess["per_draw"])
```
Multi-arm TE estimation:
```python
res_multi = estimate_te(
    X_all,
    t_all,
    y_all,
    outcome_type="binary",
    ref_group=ref,
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
)
print("Multi-arm pairwise effects:\n", res_multi["pairwise"])
```

### Confidence-Interval Calculation

Adding bootstrap CI to the two-arm analysis:
```python
res_boot = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    niter=5,
    nboot=200,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=1,
    random_state_boot=7,
)
print("Bootstrap CI:", res_boot["ci"])
```

### Heterogeneous Ensemble

```python
learners = [
    LogisticRegression(max_iter=1000),
    RandomForestClassifier(n_estimators=200, max_depth=3),
]
res_ensemble = estimate_te(
    X,
    t,
    y,
    outcome_type="binary",
    model_outcome=learners,
    niter=len(learners),
    do_stacking=True,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=42,
)
print("Ensemble TE:", res_ensemble["te"])
```

### TE Estimation for Survival Outcomes
```python
X_surv, t_surv, y_surv = load_data_tof(
  raw=False
  , treat_levels = ['SPS', 'PrP']
)
res_surv = estimate_te(
    X_surv,
    t_surv,
    y_surv,
    outcome_type="survival",
    niter=5,
    matching_scale=1.0,
    matching_is_stochastic=True,
    random_state_master=0,
)
print("Survival HR:", res_surv["te"])
```

<!-- ## `CausalEM` in `R`

After installing the Python package, install the R wrapper:
```R
install.packages('CausalEM')
```
-->

## License

This project is licensed under the terms of the MIT License.

## Release Notes

### 0.6.0
- Improved consistency of return data structure when `do_stacking=False` in multi-arm TE estimation.

### 0.5.4
- Added github action for publishing to PyPI

### 0.5.3
- First public release

### 0.5.1
- Edits to readme
- Added github action for publishing to (test) PyPI

### 0.5.0

- First test release
