[Hierarchical MIL] Code ; Train.py
논문
Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Data
https://www.biorxiv.org/content/10.1101/2025.02.10.637389v1.full.pdf
논문 정리
[Hierarchical MIL] Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with
논문Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Datahttps://www.biorxiv.org/content/10.1101/2025.02.10.637389v1.full.pdf깃허브https://github.com/minhchaudo/hier-mil GitHub -
doraemin.tistory.com
깃허브
https://github.com/minhchaudo/hier-mil
GitHub - minhchaudo/hier-mil
Contribute to minhchaudo/hier-mil development by creating an account on GitHub.
github.com
데이터 개요
[Hierarchical MIL] Exploratory Data (Summary)
[Hierarchical MIL] Exploratory Data (Summary)
논문Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Datahttps://www.biorxiv.org/content/10.1101/2025.02.10.637389v1.full.pdf 논문 정리2025.03.22 - [AI 및 Data Analysis/Paper]
doraemin.tistory.com
데이터 분석
[Hierarchical MIL] scRNA Analysis
[Hierarchical MIL] scRNA Analysis
논문Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Datahttps://www.biorxiv.org/content/10.1101/2025.02.10.637389v1.full.pdf 논문 정리2025.03.22 - [AI 및 Data Analysis/Paper]
doraemin.tistory.com
데이터 분석 시 사용했던 코드에 대해 살펴보자.
데이터 분석 시 TASK 가 7가지 존재. 그 중, 가장 MAIN이 되는 TASK 2 에 대한 흐름을 따라가보자.
task 2는 repeated K-fold cross-validation을 통해
모델의 기본적인 예측 성능 (e.g., AUC) 을 안정적으로 평가하는 메인 실험
Task 번호 | 실험 이름 | 목적 |
0 | train_and_tune | 모델 학습 + 하이퍼파라미터 튜닝 |
1 | predict_and_save | 학습된 모델로 예측값 저장 |
2 | repeated_k_fold | 10번 반복된 k-fold 교차검증으로 모델 성능 평균화 |
3 | vary_train_size | 학습 데이터 크기를 줄여가며 성능 비교 (0.25, 0.5, 0.75) |
4 | vary_cell_count | 셀 수를 줄였을 때 성능 변화 분석 |
5 | randomize_cell_annot | 셀 타입 정보를 랜덤으로 섞어서 모델 의존도 확인 |
6 | get_p_val_cell_type | permutation test로 중요한 세포 타입 찾기 (biological insight용) |
train.py 의
repeated_k_fold(df, meta, args)
return np.mean(aucs), np.std(aucs)
* label : 면역항암제 반응 결과 (Combined_outcome) 이진 분류.
"Favourable" → 1(양성), "Unfavourable" → 0(음성)
samples = df[["patient", "label"]].drop_duplicates()
for i in range(args.n_repeats):
skf = StratifiedKFold(args.n_folds, shuffle=True, random_state=i)
for train_idx, test_idx in skf.split(samples, samples["label"]):
...
study = optuna.create_study(direction="maximize", sampler=sampler)
study.optimize(objective_wrapper(…), n_trials=args.n_tune_trials)
best_params = study.best_params
사진의 하나의 루프(repeat) 일 때의 개념.
⇨ n 개의 n-fold (outer fold)로 나뉘고,
⇨ 각 n-fold 내부에서, test와 train으로 나뉨. 이후, 또 다시 m 개의 m-fold (inner fold) 로 나뉨.
이곳에서 p개의 파라미터 경우의 수를 모두 시행(m*p번 시행). 그 중, test fold에서 가장 accuracy성능이 좋았던, 파라미터를 선택. → 해당 파라미터로 해당 outer test fold 예측 수행. accuracy 성능 측정 (*N 가지)➡️ N가지의 결과 accracy 추출.
utils.py 의
get_data(df, all_ct, samples, meta=None, binary=True, attn2=True) return Xs, ys, batches, meta
Xs : 한 명의 환자에 대한 모든 세포의 유전자 발현 행렬
sample_df = df[df["patient"]==sample]
x = sample_df.iloc[:,:df.shape[-1]-3].to_numpy()
Example >
df.columns = ['Gene1', ..., 'Gene824', 'patient', 'cell_type_annotation', 'label']
# 예: patient == "BIOKEY_10"
sample_df = df[df["patient"] == "BIOKEY_10"]
# 유전자 발현만 뽑기 (세포 개수 × 유전자 수)
x = sample_df.iloc[:, :824].to_numpy() # shape: (n_cells_patient, 824)
ys : 환자 샘플들의 진단 결과 (label: 0 or 1)
ys = torch.tensor(samples["label"].to_list(), dtype = …)
batches : 환자 × 셀타입 조합에 해당하는 그룹 번호
batch = [(idx * len(all_ct) + ct_dict[ct]) for … ]
idx: 환자 index
ct_dict[ct]: cell_type을 숫자로 인코딩한 값
batch의 역할
세포 | 환자 | 셀 타입 | batch |
C1 | P01 | T-cell | 0×4+0 = 0 |
C2 | P01 | B-cell | 0×4+1 = 1 |
C3 | P02 | T-cell | 1×4+0 = 4 |
C4 | P02 | B-cell | 1×4+1 = 5 |
이렇게 하면 global_add_pool()이나 attention에서 같은 그룹끼리 처리 가능!
meta : (옵션) 환자 단위의 임상정보 또는 메타데이터
- 어떤 실험은 유전자 발현 정보만으로 분류하고 싶을 때가 있고
- 어떤 실험은 **유전자 + 임상 정보(combined features)**를 쓰고 싶을 때도 있어서
→ meta 입력은 유연하게 on/off 가능하도록 설계된 거야!
adata.obs (세포 단위)
---------------------
cell_id patient age sex label cell_type
cell1 P01 62 M 1 T-cell
cell2 P01 62 M 1 B-cell
cell3 P02 55 F 0 NK
...
↓ meta_cols = ['age', 'sex']
meta (환자 단위)
---------------
patient age sex
P01 62 M
P02 55 F
model.py
1.multi-layer perceptron (MLP)
X = self.lin(X)
# layers.extend([torch.nn.Linear(curr_in, curr_out), torch.nn.ReLU(), torch.nn.Dropout(dropout)])
2.셀 수준 attention or pooling (attn1)
환자 하나가 여러 cell type을 갖고 있고,
각 cell type 안에는 수십~수천 개의 cell(세포)이 있다.
→ 그 중 더 중요한 세포를 구별해서,
하나의 셀타입 벡터로 요약하는 게 attn1의 역할이다.
( 더 자세한 설명과 예시는, 아래의 3번째 '더보기' 참조)
if self.attn1:
w_c = softmax(self.w_c(X).squeeze(), batch)
# 이 모델이 셀 타입 단위 attention까지 할 것인지를 미리 파악해서, 출력 shape을 맞추는 용도
X = global_add_pool(X * w_c.unsqueeze(dim=-1), batch, size=ct_size)
* batches : 환자 × 셀타입
* w_c : batch 별로 softmax로 정규화되어 attention weight가 됨
* 세포 표현 X에 각 attention weight를 곱 적용, batch 단위로 pooling
✅ w_c란?
👉 w_c는 "각 세포의 중요도를 나타내는 attention weight"를 계산하는 작은 신경망 레이어
📦 코드 정의 (model.py)
self.w_c = nn.Sequential(
nn.Linear(curr_in, 1),
nn.Dropout(dropout)
)
- curr_in은 세포 표현 X의 hidden dimension (예: 64)
- 출력은 [N_cells, 1]: 각 세포마다 1개의 스칼라 점수
🧠 왜 w_c가 필요할까?
모델이 많은 세포 중에서 중요한 세포에 더 집중할 수 있도록
각 세포에 대해 "얼마나 중요한가?"라는 weight를 부여하려는 것!
즉, w_c(X)는 다음 역할을 해:
- 모든 세포의 표현 X에 대해:
- score = Linear(hidden_dim → 1)
- 결과적으로, 각 세포의 중요도를 나타내는 스칼라 값이 나옴
- 이후 softmax를 통해 같은 그룹(batch) 내에서 정규화:
- w_c = softmax(w_c(X), batch)
💡 예를 들어서
세포 | hidden vector (X) | w_c(X) raw | softmax (w_c) |
C1 | [0.5, 1.2, ...] | 2.0 | 0.7 |
C2 | [0.1, 0.8, ...] | 1.2 | 0.3 |
→ 결과: C1은 C2보다 더 중요한 세포로 간주됨
🎯 요약 정리
변수 | 의미 | 크기 | 설명 |
w_c | 세포 attention weight | [N_cells] | 각 세포가 그룹 내에서 얼마나 중요한지를 나타내는 값 (softmax 정규화됨) |
self.w_c | 세포 score 계산용 MLP | nn.Sequential(...) | X를 받아 score 출력 |
📈 최종 목적
다양한 세포들로 구성된 환자 샘플에서
어떤 세포가 질병 예측에 더 중요한지를 모델이 스스로 "배울 수 있도록" 하기 위한 attention 메커니즘
혹시 이 w_c를 시각화하거나, 모델이 어떤 세포에 높은 가중치를 줬는지 보고 싶은 경우도 도와줄 수 있어!
다음 단계 (3단계: 셀타입 수준 요약)도 이어서 가볼까? 😎
아주 좋아! 바로 그 "raw 값을 어떻게 구하는가",
즉 w_c(X)가 실제로 어떻게 계산되는가를 정확히 짚고 넘어가면 완전히 이해할 수 있어.
그럼 지금부터 코드 → 수학 → 예시 → 시각화 순서로 풀어볼게. 🔍
✅ 질문 핵심: w_c(X) raw 값은 어떻게 계산되는가?
📦 코드부터 다시 보자:
self.w_c = nn.Sequential(
nn.Linear(hidden_dim, 1),
nn.Dropout(p=dropout)
)
- 여기서 hidden_dim은 예를 들어 64라고 하자 (즉, 세포 표현 벡터 차원)
- 즉, self.w_c는 각 세포 표현 X_i ∈ ℝ⁶⁴에 대해:
score_i = Linear(X_i) = W·X_i + b → scalar
🧠 Step-by-step: 실제 계산 흐름
예: 세포 표현 X가 아래와 같이 생겼다고 해보자.
X = [
[0.2, 0.5, ..., 1.0], # 세포 1
[0.1, 0.0, ..., 0.9], # 세포 2
...
] → shape = (N_cells, 64)
- self.w_c(X)는 내부적으로:
score = X @ W^T + b
- W: shape = (1, 64) — 학습 가능한 파라미터
- b: shape = (1,) — bias
- 결과: [N_cells, 1] 스칼라 벡터
- .squeeze()를 하면 [N_cells]로 바뀜
- 이후 softmax(score, batch)를 통해 같은 batch 내에서 정규화
✏️ 수식으로 표현하면
세포 i의 표현:
Xi∈R64X_i ∈ ℝ^{64}
가중치 벡터:
w∈R64,b∈Rw ∈ ℝ^{64}, b ∈ ℝ
그럼 attention raw score는:
si=w⋅Xi+bs_i = w · X_i + b
→ 이것이 w_c(X)의 "raw 값"
🔍 시각적 예시
Cell | X (vector) | W (learned) | w_c(X) raw score |
C1 | [0.2, 0.1, ...] | [0.5, -0.1, ...] | 1.73 |
C2 | [0.1, 0.0, ...] | [0.5, -0.1, ...] | 0.98 |
C3 | [0.3, 0.4, ...] | [0.5, -0.1, ...] | 2.11 |
✔️ 이렇게 raw score가 만들어지고
✔️ 이후 batch 내 softmax로 정규화되어 attention weight가 됨
💡 쉽게 말하면:
- w_c(X)는 각 세포가 "얼마나 중요한지"에 대한 점수(score)를 계산하는 선형 함수
- 이 점수는 W·X + b로 계산되고,
- 이후 softmax(batch별)로 확률처럼 정규화되어 가중치가 되는 것!
🎯 실전 요약
단계 의미X | 세포 표현 (MLP 통과 후 벡터) |
self.w_c(X) | 각 세포 표현에 대해 점수 계산 (선형 조합) |
.squeeze() | [N_cells, 1] → [N_cells] |
softmax(..., batch) | batch별 확률처럼 정규화된 attention weight 생성 |
✅ attn1: "세포 수준 attention"
환자 하나가 여러 cell type을 갖고 있고,
각 cell type 안에는 수십~수천 개의 **cell(세포)**이 있어.
→ 그 중 더 중요한 세포를 구별해서,
하나의 셀타입 벡터로 요약하는 게 attn1의 역할이야.
🎯 단계별로 보면:
🔹 예: 환자 A의 T-cell 그룹
- 세포 ①: [0.1, 0.8, ..., 0.2]
- 세포 ②: [0.3, 0.5, ..., 0.1]
- 세포 ③: [0.7, 0.2, ..., 0.6]
→ 이 3개 세포를 어떻게 하나의 T-cell 벡터로 만들까?
🧠 Step-by-step 해석
① 임베딩된 X: [n_cells, hidden_dim]
X = self.lin(X) # MLP 임베딩
- 각 세포의 유전자 발현 → 의미 있는 벡터로 변환됨
② 세포별 중요도 스코어 계산
w_c = self.w_c(X).squeeze()
- self.w_c: [hidden_dim → 1] 선형 레이어
- 각 세포마다 **"중요도 점수 (raw score)"**를 하나씩 생성
예:
- Cell A: 1.7
- Cell B: 2.1
- Cell C: 0.3
③ Softmax (batch 단위)
w_c = softmax(w_c, batch)
- 같은 그룹(=환자×셀타입)에 속한 세포들끼리 softmax로 정규화
→ 중요도 가중치처럼 바뀜:
- Cell A: 0.3
- Cell B: 0.6
- Cell C: 0.1
④ attention-weighted pooling
X = global_add_pool(X * w_c.unsqueeze(-1), batch, size=ct_size)
- 각 세포 벡터에 w_c 가중치를 곱해서
- 같은 그룹(batch) 내에서 가중합
- 결과: [num_patient × num_cell_types, hidden_dim]
→ 즉, 환자 A의 T-cell 전체 세포 → 하나의 T-cell 벡터로 요약
📦 구조적 흐름 요약
[세포 표현 X] ← self.lin(X)
↓
[세포 중요도 w_c] ← self.w_c + softmax
↓
[가중합] ← global_add_pool → 셀타입 표현 벡터
↓
[reshape] → (환자 수, 셀타입 수, hidden_dim)
📈 예를 들어 하나의 환자에게
Cell Type 세포 수 중요도 높은 세포 → 셀타입 벡터T-cell | 250개 | top 10~20%에 높은 w_c 부여 |
B-cell | 100개 | 덜 중요한 세포는 무시됨 |
NK-cell | 75개 | attention으로 가중합 |
✅ 요약 정리
요소 의미attn1 | 세포 단위 attention 사용 여부 |
목적 | 같은 셀타입 내에서 더 중요한 세포를 골라내고 요약 |
결과 | (환자 × 셀타입) 벡터 생성 → 셀타입 attention 단계로 넘어감 |
구현 | self.w_c, softmax, global_add_pool |
2단계 이후, X 구조:
X.shape = [n_patients, n_cell_types, hidden_dim]
예: [64, 12, 64] → 환자 64명, 셀 타입 12개, 각 셀 타입은 64차원 벡터
3.셀타입 수준 attention or pooling (attn2)
X = self.lin2(X)
if self.attn2:
w_ct = torch.nn.Softmax(dim=1)(self.w_ct(X))
X = torch.sum(X * w_ct, dim=1)
* w_ct : 각 환자 내 셀 타입 중요도 계산
* 각 환자마다 하나의 벡터로 요약
좋아, 이제 마지막 핵심 조각 중 하나인 **w_ct**를 완전히 이해해보자!
이건 3단계인 **"셀 타입 수준 attention"**에서 등장하는 중요한 개념이야.
✅ Q: w_ct란?
w_ct는 각 환자 안의 셀 타입 벡터들 중에서, 어떤 셀 타입이 더 중요한지를 모델이 판단해서 주는 attention weight야.
즉,
"이 환자에게 있어서 T-cell이 중요할까? B-cell이 중요할까?"
이걸 모델이 스스로 학습해서 각 셀 타입별로 가중치를 주는 거야.
📌 어디서 등장하나?
if self.attn2:
w_ct = torch.nn.Softmax(dim=1)(self.w_ct(X))
X = torch.sum(X * w_ct, dim=1)
💡 변수 구조 설명
변수 shape 의미X | [n_patients, n_cell_types, hidden_dim] | 환자별 셀타입 벡터들 |
self.w_ct(X) | [n_patients, n_cell_types, 1] | 각 셀타입에 대한 점수 (raw attention score) |
w_ct | [n_patients, n_cell_types, 1] | softmax로 정규화된 attention weight |
🧠 w_ct 계산 과정
1. 각 셀 타입 벡터에 대해 score 계산 (선형 레이어)
self.w_ct = nn.Linear(hidden_dim, 1)
예를 들어 환자 P01의 셀 타입 벡터들이 이렇게 생겼다면:
X[P01] = [
[0.2, 0.5, ..., 0.1], # T-cell
[0.9, 0.3, ..., 0.6], # B-cell
...
]
→ self.w_ct(X[P01])는:
[1.2], [0.7], ... → shape: (n_ct, 1)
즉, 각 셀 타입 벡터 → 스칼라 점수로 압축
2. Softmax 정규화
w_ct = torch.nn.Softmax(dim=1)(self.w_ct(X))
- 각 환자 내에서 셀 타입별 점수를 확률처럼 정규화
- sum = 1이 됨
예:
T-cell: 0.6
B-cell: 0.3
NK-cell: 0.1
3. 가중합 → 환자 표현 벡터 만들기
X = torch.sum(X * w_ct, dim=1)
- 각 셀 타입 벡터에 해당 가중치를 곱해서
- 하나의 최종 환자 벡터로 합침
🧬 개념적으로 보면
환자 셀 타입 셀 타입 벡터 중요도 w_ct 결과 (곱)P01 | T-cell | [0.3, 0.5, ...] | 0.6 | [0.18, 0.3, ...] |
B-cell | [0.1, 0.2, ...] | 0.3 | [0.03, 0.06, ...] | |
NK-cell | [0.05, 0.1,...] | 0.1 | [0.005, 0.01, ...] |
↓ 이 벡터들을 모두 더해서 → 환자 표현 하나로 만듦
🤔 왜 필요한가?
- 환자마다 중요한 셀 타입이 다를 수 있음
- P01: T-cell 중심
- P02: B-cell이 더 중요
- w_ct를 통해 "환자 맞춤형" 셀 타입 중요도를 학습할 수 있음
- 이를 통해 더 정밀한 예측이 가능
🧠 요약
항목 설명w_ct란? | 환자 내 셀 타입 벡터들에 대한 attention weight |
어디서 나옴? | self.w_ct(X) → Softmax로 정규화 |
왜 필요함? | 환자별로 중요한 셀 타입을 구별하기 위해 |
결과 | 가중합 → 환자 표현 벡터 완성 |
혹시 실제 w_ct를 시각화해서 "어떤 셀타입이 중요하다고 모델이 판단했는지" 보고 싶다면 decompose_logits() 함수를 활용할 수도 있어!
필요하면 그것도 보여줄게 😎
4. 예측
X = self.lin_out(X)