AI 및 Data Analysis/Code

[Hierarchical MIL] Code ; Train.py

doraemin_dev 2025. 3. 27. 18:23

논문

Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Data
https://www.biorxiv.org/content/10.1101/2025.02.10.637389v1.full.pdf

2025.02.10.637389v1.full.pdf
1.55MB

 

 

논문 정리

2025.03.22 - [AI 및 Data Analysis/Paper] - [Hierarchical MIL] Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Data

 

[Hierarchical MIL] Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with

논문Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Datahttps://www.biorxiv.org/content/10.1101/2025.02.10.637389v1.full.pdf깃허브https://github.com/minhchaudo/hier-mil GitHub -

doraemin.tistory.com


깃허브

https://github.com/minhchaudo/hier-mil

 

GitHub - minhchaudo/hier-mil

Contribute to minhchaudo/hier-mil development by creating an account on GitHub.

github.com


데이터 개요

[Hierarchical MIL] Exploratory Data (Summary)

 

[Hierarchical MIL] Exploratory Data (Summary)

논문Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Datahttps://www.biorxiv.org/content/10.1101/2025.02.10.637389v1.full.pdf   논문 정리2025.03.22 - [AI 및 Data Analysis/Paper]

doraemin.tistory.com


데이터 분석

[Hierarchical MIL] scRNA Analysis

 

[Hierarchical MIL] scRNA Analysis

논문Incorporating Hierarchical Information into Multiple Instance Learning for Patient Phenotype Prediction with scRNA-seq Datahttps://www.biorxiv.org/content/10.1101/2025.02.10.637389v1.full.pdf 논문 정리2025.03.22 - [AI 및 Data Analysis/Paper]

doraemin.tistory.com


데이터 분석 시 사용했던 코드에 대해 살펴보자.

데이터 분석 시 TASK 가 7가지 존재. 그 중, 가장 MAIN이 되는 TASK 2 에 대한 흐름을 따라가보자.

task 2는 repeated K-fold cross-validation을 통해
모델의 기본적인 예측 성능 (e.g., AUC) 을 안정적으로 평가하는 메인 실험
Task 번호 실험 이름 목적
0 train_and_tune 모델 학습 + 하이퍼파라미터 튜닝
1 predict_and_save 학습된 모델로 예측값 저장
2 repeated_k_fold 10번 반복된 k-fold 교차검증으로 모델 성능 평균화
3 vary_train_size 학습 데이터 크기를 줄여가며 성능 비교 (0.25, 0.5, 0.75)
4 vary_cell_count 셀 수를 줄였을 때 성능 변화 분석
5 randomize_cell_annot 셀 타입 정보를 랜덤으로 섞어서 모델 의존도 확인
6 get_p_val_cell_type permutation test로 중요한 세포 타입 찾기 (biological insight용)

train.py 의 
repeated_k_fold(df, meta, args)
return np.mean(aucs), np.std(aucs)

1. 샘플 목록 만들기 : 각 환자 샘플에 대해 label과 함께 정리

* label : 면역항암제 반응 결과 (Combined_outcome)  이진 분류.
  
"Favourable" → 1(양성), "Unfavourable" → 0(음성)
samples = df[["patient", "label"]].drop_duplicates()
2. 반복 루프 (args.n_repeats번 반복)
각 반복마다 다른 seed설정해서 K-fold를 새롭게 섞음 (Outer Cross Validation, Outer CV)
for i in range(args.n_repeats):
        skf = StratifiedKFold(args.n_folds, shuffle=True, random_state=i)

3.fold에 대해: 튜닝 → 학습 → 평가
fold마다 Optuna를 이용해 하이퍼파라미터 튜닝 수행 (inner CV)
    for train_idx, test_idx in skf.split(samples, samples["label"]):
	 ...
        study = optuna.create_study(direction="maximize", sampler=sampler)
        study.optimize(objective_wrapper(…), n_trials=args.n_tune_trials)
        best_params = study.best_params

사진의 하나의 루프(repeat) 일 때의 개념.

⇨ n 개의 n-fold (outer fold)로 나뉘고,
n-fold 내부에서, testtrain으로 나뉨. 이후, 또 다시 m 개의 m-fold (inner fold) 로 나뉨.

이곳에서 p개의 파라미터 경우의 수를 모두 시행(m*p번 시행). 그 중, test fold에서 가장 accuracy성능이 좋았던, 파라미터를 선택. → 해당 파라미터로 해당 outer test fold 예측 수행. accuracy 성능 측정 (*N 가지)

➡️ N가지의 결과 accracy 추출.


utils.py 의
get_data(df, all_ct, samples, meta=None, binary=True, attn2=True) return Xs, ys, batches, meta

Xs : 한 명의 환자에 대한 모든 세포의 유전자 발현 행렬

sample_df = df[df["patient"]==sample]
        x = sample_df.iloc[:,:df.shape[-1]-3].to_numpy()

 

Example >

df.columns = ['Gene1', ..., 'Gene824', 'patient', 'cell_type_annotation', 'label']

# 예: patient == "BIOKEY_10"
sample_df = df[df["patient"] == "BIOKEY_10"]

# 유전자 발현만 뽑기 (세포 개수 × 유전자 수)
x = sample_df.iloc[:, :824].to_numpy()  # shape: (n_cells_patient, 824)

 

ys : 환자 샘플들의 진단 결과 (label: 0 or 1)

ys = torch.tensor(samples["label"].to_list(), dtype = …)

 

batches : 환자 × 셀타입 조합에 해당하는 그룹 번호

batch = [(idx * len(all_ct) + ct_dict[ct]) for … ]
idx: 환자 index
ct_dict[ct]: cell_type을 숫자로 인코딩한 값

 

batch의 역할

세포 환자 셀 타입 batch
C1 P01 T-cell 0×4+0 = 0
C2 P01 B-cell 0×4+1 = 1
C3 P02 T-cell 1×4+0 = 4
C4 P02 B-cell 1×4+1 = 5

이렇게 하면 global_add_pool()이나 attention에서 같은 그룹끼리 처리 가능!

 

 

meta : (옵션) 환자 단위의 임상정보 또는 메타데이터

 

  • 어떤 실험은 유전자 발현 정보만으로 분류하고 싶을 때가 있고
  • 어떤 실험은 **유전자 + 임상 정보(combined features)**를 쓰고 싶을 때도 있어서
    → meta 입력은 유연하게 on/off 가능하도록 설계된 거야!

 

adata.obs (세포 단위)
---------------------
cell_id   patient   age   sex   label   cell_type
cell1     P01       62    M     1       T-cell
cell2     P01       62    M     1       B-cell
cell3     P02       55    F     0       NK
...

↓ meta_cols = ['age', 'sex']

meta (환자 단위)
---------------
patient   age   sex
P01       62    M
P02       55    F

model.py

1.multi-layer perceptron (MLP)

X = self.lin(X)

# layers.extend([torch.nn.Linear(curr_in, curr_out), torch.nn.ReLU(), torch.nn.Dropout(dropout)])
 

2.셀 수준 attention or pooling (attn1)

환자 하나가 여러 cell type을 갖고 있고,
각 cell type 안에는 수십~수천 개의 cell(세포)이 있다.
→ 그 중 더 중요한 세포를 구별해서,
하나의 셀타입 벡터로 요약하는 게 attn1의 역할이다.

( 더 자세한 설명과 예시는, 아래의 3번째 '더보기' 참조) 

 

if self.attn1:
    w_c = softmax(self.w_c(X).squeeze(), batch)
    # 이 모델이 셀 타입 단위 attention까지 할 것인지를 미리 파악해서, 출력 shape을 맞추는 용도
    X = global_add_pool(X * w_c.unsqueeze(dim=-1), batch, size=ct_size)

* batches : 환자 × 셀타입

* w_c : batch 별로 softmax정규화되어 attention weight가 됨

* 세포 표현 X에 각 attention weight를 곱 적용, batch 단위로 pooling

더보기

✅ w_c란?

👉 w_c는 "각 세포의 중요도를 나타내는 attention weight"를 계산하는 작은 신경망 레이어


📦 코드 정의 (model.py)

self.w_c = nn.Sequential(
    nn.Linear(curr_in, 1),
    nn.Dropout(dropout)
)
  • curr_in은 세포 표현 X의 hidden dimension (예: 64)
  • 출력은 [N_cells, 1]: 각 세포마다 1개의 스칼라 점수

🧠 왜 w_c가 필요할까?

모델이 많은 세포 중에서 중요한 세포에 더 집중할 수 있도록
각 세포에 대해 "얼마나 중요한가?"라는 weight를 부여하려는 것!

즉, w_c(X)는 다음 역할을 해:

  • 모든 세포의 표현 X에 대해:
  • score = Linear(hidden_dim → 1)
  • 결과적으로, 각 세포의 중요도를 나타내는 스칼라 값이 나옴
  • 이후 softmax를 통해 같은 그룹(batch) 내에서 정규화:
  • w_c = softmax(w_c(X), batch)

💡 예를 들어서


세포 hidden vector (X) w_c(X) raw   softmax (w_c)
C1 [0.5, 1.2, ...] 2.0 0.7
C2 [0.1, 0.8, ...] 1.2 0.3

→ 결과: C1은 C2보다 더 중요한 세포로 간주됨


🎯 요약 정리

변수 의미 크기 설명
w_c 세포 attention weight [N_cells] 각 세포가 그룹 내에서 얼마나 중요한지를 나타내는 값 (softmax 정규화됨)
self.w_c 세포 score 계산용 MLP nn.Sequential(...) X를 받아 score 출력

📈 최종 목적

다양한 세포들로 구성된 환자 샘플에서
어떤 세포가 질병 예측에 더 중요한지를 모델이 스스로 "배울 수 있도록" 하기 위한 attention 메커니즘


혹시 이 w_c를 시각화하거나, 모델이 어떤 세포에 높은 가중치를 줬는지 보고 싶은 경우도 도와줄 수 있어!
다음 단계 (3단계: 셀타입 수준 요약)도 이어서 가볼까? 😎

 

더보기

아주 좋아! 바로 그 "raw 값을 어떻게 구하는가",
w_c(X)가 실제로 어떻게 계산되는가를 정확히 짚고 넘어가면 완전히 이해할 수 있어.
그럼 지금부터 코드 → 수학 → 예시 → 시각화 순서로 풀어볼게. 🔍


✅ 질문 핵심: w_c(X) raw 값은 어떻게 계산되는가?

📦 코드부터 다시 보자:

self.w_c = nn.Sequential(
    nn.Linear(hidden_dim, 1),
    nn.Dropout(p=dropout)
)
  • 여기서 hidden_dim은 예를 들어 64라고 하자 (즉, 세포 표현 벡터 차원)
  • 즉, self.w_c는 각 세포 표현 X_i ∈ ℝ⁶⁴에 대해:
score_i = Linear(X_i) = W·X_i + b  → scalar

🧠 Step-by-step: 실제 계산 흐름

예: 세포 표현 X가 아래와 같이 생겼다고 해보자.

X = [
    [0.2, 0.5, ..., 1.0],   # 세포 1
    [0.1, 0.0, ..., 0.9],   # 세포 2
    ...
]  → shape = (N_cells, 64)
  1. self.w_c(X)는 내부적으로:
score = X @ W^T + b
  • W: shape = (1, 64) — 학습 가능한 파라미터
  • b: shape = (1,) — bias
  • 결과: [N_cells, 1] 스칼라 벡터
  1. .squeeze()를 하면 [N_cells]로 바뀜
  2. 이후 softmax(score, batch)를 통해 같은 batch 내에서 정규화

✏️ 수식으로 표현하면

세포 i의 표현:

Xi∈R64X_i ∈ ℝ^{64}

가중치 벡터:

w∈R64,b∈Rw ∈ ℝ^{64}, b ∈ ℝ

그럼 attention raw score는:

si=w⋅Xi+bs_i = w · X_i + b

→ 이것이 w_c(X)의 "raw 값"


🔍 시각적 예시

Cell X (vector) W (learned) w_c(X) raw score
C1 [0.2, 0.1, ...] [0.5, -0.1, ...] 1.73
C2 [0.1, 0.0, ...] [0.5, -0.1, ...] 0.98
C3 [0.3, 0.4, ...] [0.5, -0.1, ...] 2.11

✔️ 이렇게 raw score가 만들어지고
✔️ 이후 batch 내 softmax로 정규화되어 attention weight가 됨


💡 쉽게 말하면:

  • w_c(X)는 각 세포가 "얼마나 중요한지"에 대한 점수(score)를 계산하는 선형 함수
  • 이 점수는 W·X + b로 계산되고,
  • 이후 softmax(batch별)로 확률처럼 정규화되어 가중치가 되는 것!

🎯 실전 요약

단계 의미
X 세포 표현 (MLP 통과 후 벡터)
self.w_c(X) 각 세포 표현에 대해 점수 계산 (선형 조합)
.squeeze() [N_cells, 1] → [N_cells]
softmax(..., batch) batch별 확률처럼 정규화된 attention weight 생성

 

 

더보기

 

✅ attn1: "세포 수준 attention"

환자 하나가 여러 cell type을 갖고 있고,
각 cell type 안에는 수십~수천 개의 **cell(세포)**이 있어.
→ 그 중 더 중요한 세포를 구별해서,
하나의 셀타입 벡터로 요약하는 게 attn1의 역할이야.


🎯 단계별로 보면:

🔹 예: 환자 A의 T-cell 그룹

  • 세포 ①: [0.1, 0.8, ..., 0.2]
  • 세포 ②: [0.3, 0.5, ..., 0.1]
  • 세포 ③: [0.7, 0.2, ..., 0.6]
    → 이 3개 세포를 어떻게 하나의 T-cell 벡터로 만들까?

🧠 Step-by-step 해석

① 임베딩된 X: [n_cells, hidden_dim]

X = self.lin(X)  # MLP 임베딩
  • 각 세포의 유전자 발현 → 의미 있는 벡터로 변환됨

② 세포별 중요도 스코어 계산

w_c = self.w_c(X).squeeze()
  • self.w_c: [hidden_dim → 1] 선형 레이어
  • 각 세포마다 **"중요도 점수 (raw score)"**를 하나씩 생성

예:

  • Cell A: 1.7
  • Cell B: 2.1
  • Cell C: 0.3

③ Softmax (batch 단위)

w_c = softmax(w_c, batch)
  • 같은 그룹(=환자×셀타입)에 속한 세포들끼리 softmax로 정규화

→ 중요도 가중치처럼 바뀜:

  • Cell A: 0.3
  • Cell B: 0.6
  • Cell C: 0.1

④ attention-weighted pooling

X = global_add_pool(X * w_c.unsqueeze(-1), batch, size=ct_size)
  • 각 세포 벡터에 w_c 가중치를 곱해서
  • 같은 그룹(batch) 내에서 가중합
  • 결과: [num_patient × num_cell_types, hidden_dim]

→ 즉, 환자 A의 T-cell 전체 세포 → 하나의 T-cell 벡터로 요약


📦 구조적 흐름 요약

[세포 표현 X]    ← self.lin(X)
      ↓
[세포 중요도 w_c] ← self.w_c + softmax
      ↓
[가중합] ← global_add_pool → 셀타입 표현 벡터
      ↓
[reshape] → (환자 수, 셀타입 수, hidden_dim)

📈 예를 들어 하나의 환자에게

Cell Type 세포 수 중요도 높은 세포 → 셀타입 벡터
T-cell 250개 top 10~20%에 높은 w_c 부여
B-cell 100개 덜 중요한 세포는 무시됨
NK-cell 75개 attention으로 가중합

✅ 요약 정리

요소 의미
attn1 세포 단위 attention 사용 여부
목적 같은 셀타입 내에서 더 중요한 세포를 골라내고 요약
결과 (환자 × 셀타입) 벡터 생성 → 셀타입 attention 단계로 넘어감
구현 self.w_c, softmax, global_add_pool

 

 


2단계 이후, X 구조:

X.shape = [n_patients, n_cell_types, hidden_dim]​


예: [64, 12, 64] → 환자 64명, 셀 타입 12개, 각 셀 타입은 64차원 벡터

 

 

3.셀타입 수준 attention or pooling (attn2)

 

X = self.lin2(X)
if self.attn2:
    w_ct = torch.nn.Softmax(dim=1)(self.w_ct(X))
    X = torch.sum(X * w_ct, dim=1)

* w_ct : 각 환자 내 셀 타입 중요도 계산

* 각 환자마다 하나의 벡터로 요약

더보기

좋아, 이제 마지막 핵심 조각 중 하나인 **w_ct**를 완전히 이해해보자!
이건 3단계인 **"셀 타입 수준 attention"**에서 등장하는 중요한 개념이야.


✅ Q: w_ct란?

w_ct는 각 환자 안의 셀 타입 벡터들 중에서, 어떤 셀 타입이 더 중요한지를 모델이 판단해서 주는 attention weight야.

즉,
"이 환자에게 있어서 T-cell이 중요할까? B-cell이 중요할까?"
이걸 모델이 스스로 학습해서 각 셀 타입별로 가중치를 주는 거야.


📌 어디서 등장하나?

if self.attn2:
    w_ct = torch.nn.Softmax(dim=1)(self.w_ct(X))
    X = torch.sum(X * w_ct, dim=1)

💡 변수 구조 설명

변수 shape 의미
X [n_patients, n_cell_types, hidden_dim] 환자별 셀타입 벡터들
self.w_ct(X) [n_patients, n_cell_types, 1] 각 셀타입에 대한 점수 (raw attention score)
w_ct [n_patients, n_cell_types, 1] softmax로 정규화된 attention weight

🧠 w_ct 계산 과정

1. 각 셀 타입 벡터에 대해 score 계산 (선형 레이어)

self.w_ct = nn.Linear(hidden_dim, 1)

예를 들어 환자 P01의 셀 타입 벡터들이 이렇게 생겼다면:

X[P01] = [
    [0.2, 0.5, ..., 0.1],  # T-cell
    [0.9, 0.3, ..., 0.6],  # B-cell
    ...
]

→ self.w_ct(X[P01])는:

[1.2], [0.7], ...  → shape: (n_ct, 1)

즉, 각 셀 타입 벡터 → 스칼라 점수로 압축


2. Softmax 정규화

w_ct = torch.nn.Softmax(dim=1)(self.w_ct(X))
  • 각 환자 내에서 셀 타입별 점수를 확률처럼 정규화
  • sum = 1이 됨

예:

T-cell: 0.6  
B-cell: 0.3  
NK-cell: 0.1

3. 가중합 → 환자 표현 벡터 만들기

X = torch.sum(X * w_ct, dim=1)
  • 각 셀 타입 벡터에 해당 가중치를 곱해서
  • 하나의 최종 환자 벡터로 합침

🧬 개념적으로 보면

환자 셀 타입 셀 타입 벡터 중요도 w_ct 결과 (곱)
P01 T-cell [0.3, 0.5, ...] 0.6 [0.18, 0.3, ...]
  B-cell [0.1, 0.2, ...] 0.3 [0.03, 0.06, ...]
  NK-cell [0.05, 0.1,...] 0.1 [0.005, 0.01, ...]

↓ 이 벡터들을 모두 더해서 → 환자 표현 하나로 만듦


🤔 왜 필요한가?

  • 환자마다 중요한 셀 타입이 다를 수 있음
    • P01: T-cell 중심
    • P02: B-cell이 더 중요
  • w_ct를 통해 "환자 맞춤형" 셀 타입 중요도를 학습할 수 있음
  • 이를 통해 더 정밀한 예측이 가능

🧠 요약

항목 설명
w_ct란? 환자 내 셀 타입 벡터들에 대한 attention weight
어디서 나옴? self.w_ct(X) → Softmax로 정규화
왜 필요함? 환자별로 중요한 셀 타입을 구별하기 위해
결과 가중합 → 환자 표현 벡터 완성

혹시 실제 w_ct를 시각화해서 "어떤 셀타입이 중요하다고 모델이 판단했는지" 보고 싶다면 decompose_logits() 함수를 활용할 수도 있어!
필요하면 그것도 보여줄게 😎

 

4. 예측

X = self.lin_out(X)