AI 및 Data Analysis/Code

[ScRAT] customized dataset

doraemin_dev 2025. 7. 1. 03:03
  • For customized dataset:

Please pack the dataset in the h5ad format and set the value of args.dataset to the path of the dataset. Also, please set args.task to 'custom'.

Furthermore, in the dataloader.py file, please modify the following lines:

  1. line 178 for label dictionary to map string to integer (default: {})
  2. line 185 for patient id (default: data.obs['patient_id'])
  3. line 187 for label, which is clinical phenotype for prediction (default: data.obs['Outcome'])
  4. line 189 for cell type, which assist for mixup (default: data.obs['cell_type'])

Also in the main.py file, please add elif args.task == 'custom': label_dict = {0: 'XXX', 1: 'XXX'} in line 83.

 

👉 /data/project/kim89/cardio.h5ad 파일을 ScRAT에 입력 데이터로 활용하여 실행
👉 ScRAT 코드 일부를 커스터마이징 (dataloader.py, main.py)


🚀 실행 준비 단계

1️⃣ 데이터 준비

당신의 데이터:
✅ cardio.h5ad
✅ 이 파일에는 최소 아래 컬럼이 필요합니다:

  • obs['patient_id']: 환자 ID
  • obs['Outcome']: 예측할 label (예: 0, 1 / control, disease)
  • obs['cell_type']: cell type 정보 (mixup 용)

👉 만약 컬럼 이름이 다르다면 h5ad를 로드해서 컬럼 이름 확인:

import scanpy as sc
adata = sc.read_h5ad('/data/project/kim89/cardio.h5ad')
print(adata.obs.columns)

필요하면 .obs 컬럼을 rename해서 h5ad 다시 저장하세요.

 

* cardio.h5ad 데이터에서는

더보기
"""
데이터 파일에는 최소 아래 컬럼이 필요합니다:

obs['patient_id']: 환자 ID

obs['Outcome']: 예측할 label (예: 0, 1 / control, disease)

obs['cell_type']: cell type 정보 (mixup 용)

👉 만약 컬럼 이름이 다르다면 h5ad를 로드해서 컬럼 이름 확인:
"""
import scanpy as sc
# 데이터 로드
adata = sc.read_h5ad('/data/project/kim89/cardio.h5ad')

# 컬럼 값 확인
print(adata.obs.columns)
"""
Index(['biosample_id', 'patient', 'cell_type', 'cell_type_annotation', 'sex',
       'cell_type_leiden06', 'sub_cluster', 'n_umi', 'n_genes',
       'cellranger_percent_mito', 'exon_prop', 'entropy', 'doublet_score',
       'species', 'species__ontology_label', 'disease',
       'disease__ontology_label', 'organ', 'organ__ontology_label',
       'library_preparation_protocol',
       'library_preparation_protocol__ontology_label', 'label'],
      dtype='object')
"""

# 각 컬럼 값 확인
print("🔹 biosample_id unique values:")
print(adata.obs['biosample_id'].unique())
"""
🔹 biosample_id unique values:
['LV_1622_2_nf', 'LV_1422_1_hcm', 'LV_1722_2_hcm', 'LV_1462_1_hcm', 'LV_1558_2_nf', ..., 'LV_1472_1_dcm', 'LV_1735_2_hcm', 'LV_1600_2_nf', 'LV_1606_1_dcm', 'LV_1561_2_nf']
Length: 80
Categories (80, object): ['LV_1290_1_dcm', 'LV_1290_2_dcm', 'LV_1300_1_dcm', 'LV_1300_2_dcm', ...,
                          'LV_1726_1_hcm', 'LV_1726_2_hcm', 'LV_1735_1_hcm', 'LV_1735_2_hcm']
"""

print("\n🔹 patient unique values:")
print(adata.obs['patient'].unique())
"""
🔹 patient unique values:
['P1622', 'P1422', 'P1722', 'P1462', 'P1558', ..., 'P1539', 'P1726', 'P1504', 'P1472', 'P1606']
Length: 42
Categories (42, object): ['P1290', 'P1300', 'P1304', 'P1358', ..., 'P1718', 'P1722', 'P1726', 'P1735']
"""

print("\n🔹 disease unique values:")
print(adata.obs['disease'].unique())
"""
🔹 disease unique values:
['PATO_0000461', 'MONDO_0005045', 'MONDO_0005021']
Categories (3, object): ['MONDO_0005021', 'MONDO_0005045', 'PATO_0000461']
"""

print("\n🔹 disease__ontology_label unique values:")
print(adata.obs['disease__ontology_label'].unique())
"""
🔹 disease__ontology_label unique values:
['normal', 'hypertrophic cardiomyopathy', 'dilated cardiomyopathy']
Categories (3, object): ['dilated cardiomyopathy', 'hypertrophic cardiomyopathy', 'normal']
"""

print("\n🔹 label unique values:")
print(adata.obs['label'].unique())
"""
🔹 label unique values:
[0 1 2]
"""

print("\n🔹 cell_type unique values:")
print(adata.obs['cell_type'].unique())
"""
🔹 cell_type unique values:
['CL_0000746', 'CL_0000136', 'CL_0000235', 'CL_0002350', 'CL_2000066', ..., 'CL_0000359', 'CL_0000669', 'CL_0000097', 'CL_0000542', 'CL_0000077']
Length: 13
Categories (13, object): ['CL_0000077', 'CL_0000097', 'CL_0000136', 'CL_0000235', ..., 'CL_0002350',
                          'CL_0010008', 'CL_0010022', 'CL_2000066']
"""

print("\n🔹 cell_type_annotation unique values:")
print(adata.obs['cell_type_annotation'].unique())
"""
🔹 cell_type_annotation unique values:
['cardiac muscle cell', 'fat cell', 'macrophage', 'endocardial cell', 'cardiac ventricle fibroblast', ..., 'vascular associated smooth muscle cell', 'pericyte cell', 'mast cell', 'lymphocyte', 'mesothelial cell']
Length: 13
Categories (13, object): ['cardiac endothelial cell', 'cardiac muscle cell', 'cardiac neuron',
                          'cardiac ventricle fibroblast', ..., 'mast cell', 'mesothelial cell', 'pericyte cell',
                          'vascular associated smooth muscle cell']
"""

 

 


 patient_id로 쓸 컬럼

후보

  • biosample_id: 80개 유니크 (샘플 단위, 샘플+상태 정보 포함: LV_####_#_xxx 형태)
  • patient: 42개 유니크 (환자 단위, P#### 형태)

추천

👉 patient

  • ScRAT은 환자 단위로 split/mixup을 관리하므로 환자 단위 ID가 더 적절
  • biosample_id는 같은 환자에서 여러 샘플이 있을 수 있어 patient가 mixup 방지에도 유리

 Outcome (label)으로 쓸 컬럼

후보

  • disease__ontology_label: ['normal', 'hypertrophic cardiomyopathy', 'dilated cardiomyopathy']
  • label: [0, 1, 2] (아마 disease label encoding)

추천

👉 disease__ontology_label

  • string label이라 코드 가독성/관리 용이
  • label도 쓸 수 있지만, 의미를 확인하기 어려움 → disease__ontology_label을 label_dict로 매핑하면 깔끔

예)

label_dict = {
    'normal': 0,
    'hypertrophic cardiomyopathy': 1,
    'dilated cardiomyopathy': 2
}

 cell_type으로 쓸 컬럼

후보

  • cell_type: CL 코드 (ex. CL_0000746)
  • cell_type_annotation: cell type 이름 (ex. cardiac muscle cell)

추천

👉 cell_type_annotation

  • 사람이 보기에 직관적이고 mixup 시 분석/디버그가 쉬움
  • CL 코드 필요하면 나중에 매핑 가능

🌟 최종 선택

역할 컬럼
patient_id patient
label (Outcome) disease__ontology_label
cell_type cell_type_annotation

 

* covid.h5ad에서는

더보기
import scanpy as sc
# 데이터 로드
adata = sc.read_h5ad('/data/project/kim89/covid.h5ad')
print(adata.obs.columns)

# 후보 컬럼 유니크 값 출력
# 환자 ID (patient_id)
print("🔹 donor_id unique values:")
print(adata.obs['donor_id'].unique())
print()

print("🔹 patient unique values:")
print(adata.obs['patient'].unique())
print()

# Outcome (예측 label)
print("🔹 SARSCoV2_PCR_Status unique values:")
print(adata.obs['SARSCoV2_PCR_Status'].unique())
print()

print("🔹 Cohort_Disease_WHO_Score unique values:")
print(adata.obs['Cohort_Disease_WHO_Score'].unique())
print()

print("🔹 SARSCoV2_PCR_Status_and_WHO_Score unique values:")
print(adata.obs['SARSCoV2_PCR_Status_and_WHO_Score'].unique())
print()

print("🔹 Peak_Respiratory_Support_WHO_Score unique values:")
print(adata.obs['Peak_Respiratory_Support_WHO_Score'].unique())
print()

print("🔹 disease__ontology_label unique values:")
print(adata.obs['disease__ontology_label'].unique())
print()

#  cell_type
print("🔹 cell_type_annotation unique values:")
print(adata.obs['cell_type_annotation'].unique())
print()

print("🔹 Coarse_Cell_Annotations unique values:")
print(adata.obs['Coarse_Cell_Annotations'].unique())
print()

print("🔹 Detailed_Cell_Annotations unique values:")
print(adata.obs['Detailed_Cell_Annotations'].unique())
print()

"""
🔹 donor_id unique values:
['Control_Participant7', 'COVID19_Participant13', 'COVID19_Participant31', 'Control_Participant12', 'COVID19_Participant5', ..., 'COVID19_Participant18', 'COVID19_Participant27', 'COVID19_Participant21', 'COVID19_Participant20', 'COVID19_Participant30']
Length: 50
Categories (50, object): ['COVID19_Participant2', 'COVID19_Participant3', 'COVID19_Participant4',
                          'COVID19_Participant5', ..., 'Control_Participant12', 'Control_Participant13',
                          'Control_Participant14', 'Control_Participant15']

🔹 patient unique values:
['Control_Participant7', 'COVID19_Participant13', 'COVID19_Participant31', 'Control_Participant12', 'COVID19_Participant5', ..., 'COVID19_Participant18', 'COVID19_Participant27', 'COVID19_Participant21', 'COVID19_Participant20', 'COVID19_Participant30']
Length: 50
Categories (50, object): ['COVID19_Participant2', 'COVID19_Participant3', 'COVID19_Participant4',
                          'COVID19_Participant5', ..., 'Control_Participant12', 'Control_Participant13',
                          'Control_Participant14', 'Control_Participant15']

🔹 SARSCoV2_PCR_Status unique values:
['neg', 'pos']
Categories (2, object): ['neg', 'pos']

🔹 Cohort_Disease_WHO_Score unique values:
['Control_WHO_0', 'COVID19_WHO_6-8', 'COVID19_WHO_1-5']
Categories (3, object): ['COVID19_WHO_1-5', 'COVID19_WHO_6-8', 'Control_WHO_0']

🔹 SARSCoV2_PCR_Status_and_WHO_Score unique values:
['neg_0', 'pos_8', 'pos_6', 'pos_5', 'pos_4', 'pos_3', 'pos_1', 'pos_7']
Categories (8, object): ['neg_0', 'pos_1', 'pos_3', 'pos_4', 'pos_5', 'pos_6', 'pos_7', 'pos_8']

🔹 Peak_Respiratory_Support_WHO_Score unique values:
['0', '8', '6', '5', '4', '3', '1', '7']
Categories (8, object): ['0', '1', '3', '4', '5', '6', '7', '8']

🔹 disease__ontology_label unique values:
['normal', 'COVID-19']
Categories (2, object): ['COVID-19', 'normal']

🔹 cell_type_annotation unique values:
['Developing Ciliated Cells', 'Ciliated Cells', 'Secretory Cells', 'Squamous Cells', 'Goblet Cells', ..., 'Developing Secretory and Goblet Cells', 'Plasmacytoid DCs', 'Enteroendocrine Cells', 'Erythroblasts', 'Mast Cells']
Length: 18
Categories (18, object): ['B Cells', 'Basal Cells', 'Ciliated Cells', 'Dendritic Cells', ...,
                          'Plasmacytoid DCs', 'Secretory Cells', 'Squamous Cells', 'T Cells']

🔹 Coarse_Cell_Annotations unique values:
['Developing Ciliated Cells', 'Ciliated Cells', 'Secretory Cells', 'Squamous Cells', 'Goblet Cells', ..., 'Developing Secretory and Goblet Cells', 'Plasmacytoid DCs', 'Enteroendocrine Cells', 'Erythroblasts', 'Mast Cells']
Length: 18
Categories (18, object): ['B Cells', 'Basal Cells', 'Ciliated Cells', 'Dendritic Cells', ...,
                          'Plasmacytoid DCs', 'Secretory Cells', 'Squamous Cells', 'T Cells']

🔹 Detailed_Cell_Annotations unique values:
['Developing Ciliated Cells', 'FOXJ1 high Ciliated Cells', 'BEST4 high Cilia high Ciliated Cells', 'Cilia high Ciliated Cells', 'SERPINB11 high Secretory Cells', ..., 'Enteroendocrine Cells', 'Interferon Responsive Cytotoxic CD8 T Cells', 'Interferon Responsive Secretory Cells', 'Erythroblasts', 'Mast Cells']
Length: 39
Categories (39, object): ['AZGP1 SCGB3A1 LTF high Goblet Cells', 'AZGP1 high Goblet Cells', 'B Cells',
                          'BEST4 high Cilia high Ciliated Cells', ..., 'SCGB1A1 high Goblet Cells',
                          'SERPINB11 high Secretory Cells', 'SPRR2D high Squamous Cells',
                          'VEGFA high Squamous Cells']
"""

 

 


환자 ID (patient_id)

후보 내용
donor_id Control_ParticipantX, COVID19_ParticipantX 같은 형태
patient 동일한 값 (=donor_id)

🔹 추천: donor_id

  • patient와 값이 동일 → donor_id가 더 직관적이고 표준적

Outcome (label)

후보 내용 목적
SARSCoV2_PCR_Status 'neg', 'pos' 이진 분류: 감염 여부
Cohort_Disease_WHO_Score 'Control_WHO_0', 'COVID19_WHO_1-5', 'COVID19_WHO_6-8' severity 분류
disease__ontology_label 'normal', 'COVID-19' 이진 분류 (control vs COVID-19)

🔹 추천:

  • disease__ontology_label

💡 예를 들어 label dict:

id_dict = {'normal': 0, 'COVID-19': 1}

cell type

후보 내용
cell_type_annotation 18개 타입, mixup 용도로 충분히 구체적
Coarse_Cell_Annotations = cell_type_annotation
Detailed_Cell_Annotations 39개, 더 세분화됨 (필요 시 선택 가능)

🔹 추천: cell_type_annotation

  • 구체적이고 mixup 분석하기에 직관적
  • 너무 세분화된 Detailed_Cell_Annotations는 overfitting 가능성 있음

🌟 최종 추천

역할 컬럼 label dict 예시
patient_id donor_id -
Outcome disease__ontology_label {'normal': 0, 'COVID-19': 1}
cell_type cell_type_annotation -

🚀 다음 단계

👉 ScRAT용 코드:

id_dict = {'normal': 0, 'COVID-19': 1}
patient_id = data.obs['donor_id']
labels = data.obs['disease__ontology_label']        
cell_type = data.obs['cell_type_annotation']

severity 분류로 하실 경우 알려주시면 label_dict도 바로 만들어 드리겠습니다.
필요하면 새 h5ad 생성 코드까지 작성해 드릴 수 있습니다! 🚀
어떤 task로 진행할지 결정하시면 말씀해 주세요.


2️⃣ ScRAT 코드 수정

👉 dataloader.py

# line 178 수정
id_dict = {
    'normal': 0,
    'hypertrophic cardiomyopathy': 1,
    'dilated cardiomyopathy': 2
}

# line 185, 187, 189 수정
patient_id = data.obs['patient']
label = data.obs['disease__ontology_label']
cell_type = data.obs['cell_type_annotation']

👉 main.py

# line 83 추가
elif args.task == 'custom':
    label_dict = {0: 'normal', 1: 'hypertrophic cardiomyopathy', 2: 'dilated cardiomyopathy'}


3️⃣ 실행

가상환경 활성화 후 실행:

cd /data/project/kim89/ScRAT
source scrat/bin/activate

python main.py --dataset /data/project/kim89/cardio.h5ad --task custom

# 최소한의 학습으로 코드 돌아가는지 확인하기
# python main.py --dataset /data/project/kim89/cardio_pca.h5ad --task custom_cardio --batch_size 1 --train_sample_cells 50 --test_sample_cells 50 --heads 2 --emb_dim 32 --epochs 3
# python main.py --dataset /data/project/kim89/covid_pca.h5ad --task custom_covid --batch_size 1 --train_sample_cells 50 --test_sample_cells 50 --heads 2 --emb_dim 32 --epochs 3

# gpu 최대로 활용하여 학습
CUDA_VISIBLE_DEVICES=0,1,2,3 \
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \
python main.py \
  --dataset /data/project/kim89/cardio_minimal.h5ad \
  --task custom_cardio \
  --batch_size 4 \
  --train_sample_cells 100 \
  --test_sample_cells 100 \
  --train_num_sample 10 \
  --test_num_sample 20 \
  --heads 2 \
  --emb_dim 16

 

더보기

gpu 최대로 활용하여 학습 했을 때,

 

현재 상태

👉 nvidia-smi 결과에 따르면:

  • GPU 0: 22.7 GiB / 24.5 GiB 사용 (91% 사용률)
  • GPU 1: 20.3 GiB / 24.5 GiB 사용 (93% 사용률)
  • GPU 2: 20.3 GiB / 24.5 GiB 사용 (91% 사용률)
  • GPU 3: 20.3 GiB / 24.5 GiB 사용 (82% 사용률)

👉 Processes 확인:

  • 동일한 Python 프로세스 (PID 1190569)GPU 0,1,2,3 메모리 모두 잡고 있습니다.
  • 즉, 현재 코드가 멀티 GPU 사용 중입니다 (아마 DataParallel).

🌟 현재 코드만으로 더 학습 가능한가?

👉 지금 상황

  • 이미 GPU 메모리의 대부분을 사용하고 있습니다.
  • batch size, 모델 크기 (emb_dim, heads 등) 더 키우면 OOM 날 가능성이 높습니다.

👉 추가 학습 가능성

  • 지금 세팅으로 가능한 최대 규모로 학습 중입니다.
  • ✅ 더 큰 데이터, 더 큰 모델 학습은 현재 구조(예: DataParallel)로는 더 이상 확장 어렵습니다.

👉 왜?

  • DataParallel은 GPU마다 동일 모델의 복사본을 메모리에 올리고 forward/backward pass를 나눔
  • 그래서 GPU 개수 늘려도 메모리 이득은 거의 없고, 단지 연산 속도만 소폭 증가

🚀 그럼 더 학습하려면?

💡 현 구조 한계

  • DataParallel 방식은 메모리를 효율적으로 나누지 못합니다.

💡 대안
1️⃣ DistributedDataParallel (DDP)
→ GPU간 메모리 효율적 분산 학습 (각 GPU가 데이터 subset만 처리)
→ torchrun + DDP로 코드 변경 필요

2️⃣ Gradient Accumulation
→ 작은 batch 여러 번 forward-backward 후 한 번 update
→ batch_size 늘린 효과, 메모리 footprint 유지 가능

3️⃣ 모델 경량화
→ emb_dim, heads, layers 줄이고 학습 반복 횟수 (repeat) 늘려 통계적 의미 확보


📌 결론

✅ 현재 코드와 세팅으로는 GPU를 최대로 사용 중이고, 추가 학습 효과를 보기는 어렵습니다.
DDP + gradient accumulation 으로 넘어가야 더 큰 데이터, 더 큰 모델 학습이 가능해요.


💡 필요하면
제가:

  • DDP 학습용 코드
  • torchrun 실행 스크립트
  • gradient accumulation 적용법

을 작성해줄 수 있어요. 요청만 주세요! 🚀

 

필요하면 추가 인자:

--batch_size 16 --epochs 50 --lr 0.001

코드가 잘 작동된다!

 


* torchmetrics 버전 에러 발생시

더보기

 

💡 에러 원인

ImportError: cannot import name '_compare_version' from 'torchmetrics.utilities.imports'

👉 이 에러는 설치된 pytorch_lightning 버전과 torchmetrics 버전이 호환되지 않기 때문에 발생했습니다.

🔹 pytorch_lightning 1.7.7은 torchmetrics 0.x 또는 초기 1.x 버전과 주로 호환되는데
현재 설치된 torchmetrics는 너무 최신 버전(1.5.x 이상)이라 내부 구조가 달라져 _compare_version을 찾지 못하고 있습니다.


🚀 해결 방법

 1️⃣ torchmetrics 버전을 pytorch_lightning 1.7.7에 맞게 낮추기

pip install torchmetrics==0.10.3

👉 torchmetrics 0.10.3은 pytorch_lightning 1.7.7과 잘 맞습니다.


2️⃣ 설치 확인

설치되면 버전 체크:

pip show torchmetrics

출력에 Version: 0.10.3 나오면 OK.


3️⃣ 다시 실행

python main.py --dataset /data/project/kim89/cardio.h5ad --task custom

💡 참고

ScRAT은 pytorch_lightning 1.x 계열 기반으로 작성되어 torchmetrics 최신 (1.5+)은 호환이 깨집니다.
👉 torchmetrics==0.10.3, pytorch_lightning==1.7.7 조합이 가장 안전합니다.

* KeyError: 'X_pca' 

더보기

💡 에러 원인

KeyError: 'X_pca'

👉 ScRAT 코드 (dataloader.py 187줄)에서

origin = data.obsm['X_pca']

를 시도하는데, 당신의 cardio.h5ad 파일에는 obsm['X_pca']가 존재하지 않기 때문에 발생한 에러입니다.


🔍 왜 이런 일이 생겼을까요?

  • obsm['X_pca']는 보통 PCA 차원 축소 결과를 저장하는 필드
  • ScRAT은 이걸 전제하고 데이터에서 PCA로 축소된 feature를 로드하려고 함
  • 하지만 현재 h5ad는 PCA가 사전에 계산되어 저장되어 있지 않음

🚀 해결 방법

✅ 1️⃣ PCA를 사전에 계산하여 h5ad에 추가

아래와 같은 스크립트로 PCA를 추가하고 새 파일로 저장하세요:

import scanpy as sc

adata = sc.read_h5ad('/data/project/kim89/cardio.h5ad')

# PCA 실행
sc.pp.pca(adata, n_comps=50)  # 필요에 따라 n_comps 조절

# 결과 저장
adata.write('/data/project/kim89/cardio_pca.h5ad')

그 다음 실행 시:

python main.py --dataset /data/project/kim89/cardio_pca.h5ad --task custom

✅ 2️⃣ 또는 ScRAT 코드에서 PCA 수행 부분을 추가 (비추천)

ScRAT 내부에서 PCA를 수행하도록 dataloader.py를 수정할 수도 있습니다:

if 'X_pca' not in data.obsm:
    import scanpy as sc
    sc.pp.pca(data, n_comps=50)

👉 하지만 코드 일관성과 재사용성을 위해 PCA 포함된 h5ad 파일로 준비하는 쪽을 추천드립니다.


🌟 정리

👉 cardio.h5ad에 PCA가 없어 발생한 에러입니다.
👉 PCA 수행 후 h5ad를 다시 저장하고 그 파일을 ScRAT에 입력하세요.

* 현재 설치된 PyTorch 버전이 RTX 4090 GPU를 지원하지 않아서 발생한 에러

더보기

💡 이번 에러의 핵심 원인

RuntimeError: CUDA error: no kernel image is available for execution on the device

👉 현재 설치된 PyTorch 버전이 RTX 4090 GPU를 지원하지 않아서 발생한 에러입니다.
경고에서도 이미 말해주고 있죠:

NVIDIA GeForce RTX 4090 with CUDA capability sm_89 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.

즉, RTX 4090 (sm_89)에서 실행할 수 있는 CUDA 커널 코드가 PyTorch 빌드에 포함되어 있지 않다는 뜻입니다.


왜 발생했나?

  • 현재 설치된 PyTorch (1.10.0)는 RTX 4090이 출시되기 전 버전이라 RTX 4090 아키텍처 (Ada Lovelace, sm_89)를 지원하지 않습니다.
  • PyTorch 1.10은 sm_70 (Volta) 정도까지 지원합니다.

🚀 해결 방법

🔹 최신 PyTorch로 업그레이드

👉 RTX 4090을 지원하려면 PyTorch 1.13 이상, 권장 버전은 1.13.1, 2.x 이상입니다.

🔑 예시 (CUDA 11.8, RTX 4090 호환):

pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --extra-index-url https://download.pytorch.org/whl/cu118

# 디스크 공간 부족으로 다운로드/설치 실패시,
# 설치를 임시 공간(/data/project/kim89/pip_cache)에 지정
mkdir -p /data/project/kim89/pip_cache
mkdir -p /data/project/kim89/tmp

TMPDIR=/data/project/kim89/tmp \
pip install --cache-dir=/data/project/kim89/pip_cache \
torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 \
--extra-index-url https://download.pytorch.org/whl/cu118

# 설치 끝나면 /data/project/kim89/pip_cache와 /data/project/kim89/tmp를 지워서 공간을 확보하세요:
# rm -rf /data/project/kim89/pip_cache /data/project/kim89/tmp

또는 최신:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

🔹 주의

  • pytorch_lightning==1.7.7이 최신 torch와 호환되는지 확인해야 합니다. 보통 1.13+ torch는 lightning 1.8 이상과 잘 맞습니다.
  • 따라서 pytorch_lightning도 1.8 이상으로 함께 업그레이드하는 것이 안전합니다.

예:

pip install pytorch_lightning==1.9.0


# 설치를 임시 공간(/data/project/kim89/pip_cache)에 지정 시
TMPDIR=/data/project/kim89/tmp \
pip install --cache-dir=/data/project/kim89/pip_cache \
pytorch_lightning==1.9.0

🔹 임시 CPU 실행

만약 GPU를 쓰지 않고 급하게 CPU로만 돌리려면:

CUDA_VISIBLE_DEVICES="" python main.py --dataset ... --task custom

하지만 RTX 4090이 있으면 최신 torch 설치하는 것이 훨씬 낫습니다.


🌟 정리

✅ 에러 원인: PyTorch 1.10.0이 RTX 4090 (sm_89)을 지원하지 않음
✅ 해결책: PyTorch 최신 버전 + CUDA 11.8 이상 버전 설치

 

* torch.cuda.OutOfMemoryError

더보기

💡 에러 핵심

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 64.80 GiB
(GPU 0; 23.65 GiB total capacity; ...)

👉 현재 모델이 GPU 메모리 64.8 GB를 한 번에 할당하려다 실패했습니다.
하지만 RTX 4090의 실제 GPU 메모리는 약 24GB인데, 그보다 훨씬 큰 메모리를 요청했습니다.


왜 이런 일이 발생했을까?

👉 Transformer 구조에서 scaled_dot_product 연산이 메모리를 많이 소모했기 때문입니다:

attn_logits = torch.matmul(q, k.transpose(-2, -1))

여기서:

  • q, k의 시퀀스 길이가 너무 크면
  • q @ k^T 는 O(seq_len² * dim) 메모리를 필요로 합니다.

즉:
🔹 cell 수나 sequence length가 너무 큽니다
🔹 또는 batch size가 너무 큽니다


🚀 해결 방법

1️⃣ batch size 줄이기

현재 ScRAT 실행 인자에 --batch_size 옵션을 더 작게 주십시오:

python main.py --dataset /data/project/kim89/cardio_pca.h5ad --task custom --batch_size 4

또는 2까지 줄여 보세요.


2️⃣ train_sample_cells, test_sample_cells 줄이기

현재 args 값:

train_sample_cells=500
test_sample_cells=500

➡ 너무 많으면:

--train_sample_cells 200 --test_sample_cells 200

또는 더 줄여 100, 50 시도


3️⃣ heads, emb_dim 줄이기

지금:

heads=8
emb_dim=128

➡ 줄이기:

--heads 4 --emb_dim 64

또는 더 작게


4️⃣ model parallel / gradient checkpointing

ScRAT 코드에 적용은 복잡하지만, PyTorch gradient checkpointing을 쓰면 메모리를 많이 줄일 수 있습니다. 필요하면 방법 알려드릴 수 있어요.


🌟 추천 실행 예시

python main.py --dataset /data/project/kim89/cardio_pca.h5ad --task custom \
--batch_size 2 --train_sample_cells 100 --test_sample_cells 100 --heads 4 --emb_dim 64

Tip

이런 메모리 초과는 Transformer에서 sequence length² 메모리 때문에 생깁니다.

따라서 cell 수, head 수, embedding dim 줄이는 것이 최우선 대응입니다.


 


최종 코드

data_check.py
0.01MB
dataloader.py
0.01MB
main.py
0.02MB

# 실행 방법은 아래와 같습니다.
# 1. Setup
git clone https://github.com/yuzhenmao/ScRAT
cd ScRAT
python -m venv scrat
source scrat/bin/activate
pip install -r requirements.txt

# 2. PCA 추가
# 데이터에 PCA를 추가하는 Python 코드를 실행합니다.
# ※ 데이터 경로는 직접 수정해주셔야 합니다.
pip install scanpy
python data_check.py

# 3.실행 코드 교체
# dataloader.py와 main.py는 보내드린 파일로 교체해주세요.

# 4-1. cardio 데이터 실행
# ※ --dataset 경로는 수정해주셔야 합니다.
python main.py --dataset /data/project/kim89/cardio_pca.h5ad --task custom_cardio

# 4-2. covid 데이터 실행
# ※ --dataset 경로는 수정해주셔야 합니다.
python main.py --dataset /data/project/kim89/covid_pca.h5ad --task custom_covid



### 에러 발생 시 참고
# 'torchmetrics.utilities.imports' 에러 발생 시
pip install torchmetrics==0.10.3

# 'RuntimeError: CUDA error: no kernel image is available for execution on the device' 에러 발생 시
# 현재 설치된 PyTorch 버전이 RTX 4090 GPU를 지원하지 않아서 발생한 에러
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --extra-index-url https://download.pytorch.org/whl/cu118
# pytorch_lightning==1.7.7이 최신 torch와 호환되는지 확인해야 합니다. 보통 1.13+ torch는 lightning 1.8 이상과 잘 맞습니다.
pip install pytorch_lightning==1.9.0