Dirichlet Prior with Right Censoring

In last class we derived the conjugacy of the Dirichlet prior for discrete probability mass functions with no censoring. Today we discuss the case with right censoring. Here the posterior is no longer conjugate with the prior so we use Metropolis Hastings Markov Chain Monte Carlo (MCMC) so draw samples from the posterior distribution.

Let \(X_1,\ldots,X_n \sim_{iid} f_X\) with \(f_X\) some discrete distribution. Without loss of generality we assume \(f_X\) support is on \(\{1,\ldots,K\}\). Censoring times are \(C_1,\ldots,C_n \sim_{iid} f_C\) with \(X_i \perp C_i\). Our goal is to estimate \(f_X\). We observe censored sample \((T_i,\delta_i)\) where \(T_i = \min(X_i,C_i)\) and \(\delta_i = \mathbb{1}_{X_i \leq C_i}\).

We use a \(\pi(p) = Dirichlet(\alpha)\) prior on \(f_X\). The posterior is

\[ \begin{align*} \pi(p | (T,\delta)) &\propto f((T,\delta)|p)\pi(p)\\ &= \prod_{i=1}^n f(T_i|p)^{\delta_i} S(T_i|p)^{1-\delta_i} \pi(p)\\ &\propto \left(\prod_{i=1}^n p_{T_i}^{\delta_i} \left(\sum_{j=T_i + 1}^K p_j\right)^{1-\delta_i}\right) \prod_{j=1}^K p_j^{\alpha_j - 1}\\ &= \prod_{j=1}^K \left(\sum_{l=j+1}^K p_l\right)^{\sum_{i=1}^n \mathbb{1}_{\delta_i=0,T_i=j}}p_j^{\alpha_j + \sum_{i=1}^n \mathbb{1}_{T_i=j,\delta_i=1}-1} \end{align*} \] Defining \(d_j = \alpha_j + \sum_{i=1}^n \mathbb{1}_{T_i=j,\delta_i=1}-1\) and \(c_j = \sum_{i=1}^n \mathbb{1}_{\delta_i=0,T_i=j}\) we have

\[ \pi(p|(T,\delta)) \propto \prod_{j=1}^k \left(\sum_{l=j+1}^K p_l\right)^{c_j}p_j^{d_j} \]

Metropolis-Hastings: Small Example

The Metropolis Hastings MCMC sampler allows us to sample from the posterior distribution even when we only know the posterior up to a constant. For illustration, we consider a simple discrete distribution:

set.seed(18042019)
K <- 4
fx <- (K:1)/10
plot(1:K,fx)

We generate a right censored sample from this distribution.

n <- 50
x <- sample(1:K,replace=TRUE,size=n,prob=fx)
ce <- sample(1:(2*K),replace=TRUE,size=n)
ti <- pmin(x,ce)
del <- 1*(x <= ce)
head(cbind(ti,del),10)
##       ti del
##  [1,]  1   1
##  [2,]  1   1
##  [3,]  2   1
##  [4,]  2   1
##  [5,]  2   1
##  [6,]  4   1
##  [7,]  2   1
##  [8,]  1   0
##  [9,]  1   1
## [10,]  2   1
## dirichlet prior
alpha <- rep(1,K)

## data into cj, dj form
cj <- rep(0,K)
names(cj) <- 1:K
temp <- table(ti[del==0])
cj[names(temp)] <- temp


dj <- alpha - 1
names(dj) <- 1:K
temp <- table(ti[del==1])
dj[names(temp)] <- dj[names(temp)] + temp


## proportional to posterior distribution
## dj[j] = alphaj + # deaths at j - 1
## cj[j] = # censoring at j
logpost <- function(p,dj,cj){
  ## return 0 likelihood if p outside simplex
  sp <- sum(p)
  if(sp > 1 | any(p < 0)){
    return(-Inf)
  }
  p <- c(p,1-sp)
  pcs <- c(rev(cumsum(rev(p)))[-1],1)
  return(sum(cj*log(pcs) + dj*log(p)))
}

We run the MCMC:

Niter <- 10000
pchain <- matrix(0,nrow=Niter,ncol=K-1)
pchain[1,] <- 1/K
for(ii in 1:(Niter-1)){
  prop <- pchain[ii,] + rnorm(K-1,mean=0,sd=0.01)
  if((exp(logpost(prop,dj,cj) - logpost(pchain[ii,],dj,cj))) > runif(1)){
    pchain[ii+1,] = prop
  } else {
    pchain[ii+1,] = pchain[ii,]
  }
}

The initial elements of the chain do not sample from the posterior accurately and are known as burn-in. This can be seen by making a plot of the chain. We discard those values.

## view entire
plot(pchain[,1:2],col="#00000030",type='l')
points(fx[1],fx[2],col='red',pch=19,cex=2)

## discard burn in
pchain <- pchain[floor(nrow(pchain)/2):nrow(pchain),]
## compute P(X=4) for each iteration
pchain <- cbind(pchain,1-rowSums(pchain))
phat <- colMeans(pchain)


## only use 1000 sample to keep plots reasonable
ix <- floor(seq(1,nrow(pchain),length.out=1000))
pchain <- pchain[ix,]

We visualize the posterior pmf.

matplot(1:K,t(pchain),type='l',col="#00000010",lty=1,ylim=c(0,1))
points(1:K,phat,col='blue',lwd=2,type='l')
points(1:K,fx,col='red',lwd=2,type='l')

Convert pmf to survival functions.

St <- rbind(1,1-apply(pchain,1,cumsum))
matplot(0:K,St,type='l',col="#00000010",lty=1,ylim=c(0,1))
points(0:K,c(1,1-cumsum(phat)),col='blue',lwd=2,type='l')
points(0:K,c(1,1-cumsum(fx)),col='red',lwd=2,type='l')

Towards Continuous Distributions

In the above example, the random variable had support on only 4 possible values. We consider a discrete random variable with denser support.

K <- 10
fx <- exp(-c(1,1,1,rep(2:(K-2)))) + 0.05
fx <- fx / sum(fx)
plot(1:K,fx,ylim=c(0,max(fx)))

n <- 50
x <- sample(1:K,replace=TRUE,size=n,prob=fx)
ce <- sample(1:(2*K),replace=TRUE,size=n)
ti <- pmin(x,ce)
del <- 1*(x <= ce)

## dirichlet prior
alpha <- rep(1,K)

## data into cj, dj form
cj <- rep(0,K)
names(cj) <- 1:K
temp <- table(ti[del==0])
cj[names(temp)] <- temp


dj <- alpha - 1
names(dj) <- 1:K
temp <- table(ti[del==1])
dj[names(temp)] <- dj[names(temp)] + temp

Run the MCMC

Niter <- 50000
pchain <- matrix(0,nrow=Niter,ncol=K-1)
pchain[1,] <- 1/K
for(ii in 1:(Niter-1)){
  prop <- pchain[ii,] + rnorm(K-1,mean=0,sd=0.005)
  if((exp(logpost(prop,dj,cj) - logpost(pchain[ii,],dj,cj))) > runif(1)){
    pchain[ii+1,] = prop
  } else {
    pchain[ii+1,] = pchain[ii,]
  }
}
## burn in
pchain <- pchain[floor(nrow(pchain)/2):nrow(pchain),]
## compute P(X=4) for each iteration
pchain <- cbind(pchain,1-rowSums(pchain))
phat <- colMeans(pchain)


## only use 1000 sample to keep plots reasonable
ix <- floor(seq(1,nrow(pchain),length.out=1000))
pchain <- pchain[ix,]

matplot(1:K,t(pchain),type='l',col="#00000010",lty=1,ylim=c(0,1))
points(1:K,phat,col='blue',lwd=2,type='l')
points(1:K,fx,col='red',lwd=2,type='l')

We convert the posteriors on pmfs to posteriors on the survival function.

St <- rbind(1,1-apply(pchain,1,cumsum))
matplot(0:K,St,type='l',col="#00000010",lty=1,ylim=c(0,1))
points(0:K,c(1,1-cumsum(phat)),col='blue',lwd=2,type='l')
points(0:K,c(1,1-cumsum(fx)),col='red',lwd=2,type='l')

Notes