main.py에
for loop 전체 변경하자.
# 사전 생성된 split 파일을 기반으로 고정된 train/test 데이터셋으로 실험을 수행하기 위해 (2번코드) # => 주석처리
# 3. for loop 직접 구성 (repeat × fold)
for repeat in range(args.repeat):
fold_aucs, accuracy, cms, recalls, precisions = [], [], [], [], []
iter_count = 0
for fold in range(args.n_splits):
print(f"🔁 Repeat {repeat}, Fold {fold}")
train_path = f"/data/project/kim89/0804_data/repeat_{repeat}/fold_{fold}_train.h5ad"
test_path = f"/data/project/kim89/0804_data/repeat_{repeat}/fold_{fold}_test.h5ad"
train_data = scanpy.read_h5ad(train_path)
test_data = scanpy.read_h5ad(test_path)
train_p_index, train_labels, train_cell_type, patient_id, train_origin = Custom_data_from_loaded(train_data, args)
test_p_index, test_labels, test_cell_type, test_patient_id, test_origin = Custom_data_from_loaded(test_data, args)
labels_ = train_labels
print(f"🔍 Split #{iter_count + 1}")
print(f" → train_p_index 환자 수: {len(train_p_index)}")
print(f" → test_p_index 환자 수: {len(test_p_index)}")
# 실제 환자 ID로 보기
train_ids = [patient_id[idx[0]] for idx in train_p_index]
test_ids = [patient_id[idx[0]] for idx in test_p_index]
print(f" → train 환자 ID: {train_ids}")
print(f" → test 환자 ID: {test_ids}")
# 각 환자의 ID와 label 함께 출력
print(" → train 환자 ID 및 라벨:")
for idxs in train_p_index:
idx = idxs[0]
print(f" ID: {patient_id[idx]}, Label: {train_labels[idx]}")
print(" → test 환자 ID 및 라벨:")
for idxs in test_p_index:
idx = idxs[0]
print(f" ID: {patient_id[idx]}, Label: {test_labels[idx]}")
# if args.n_splits < 0:
# temp_idx = train_p_index
# train_p_index = test_p_index
# test_p_index = temp_idx
label_stat = [labels_[idx[0]] for idx in train_p_index] # train set에 포함된 환자들의 라벨 목록
unique, cts = np.unique(label_stat, return_counts=True)
# 훈련 데이터(train_p_index)에 클래스가 2개 이상 존재해야 학습을 진행한다.
if len(unique) < 2 or (1 in cts):
# 클래스가 하나밖에 없음 → 불균형 → 스킵
# or
# 등장한 클래스 중 한 클래스의 환자 수가 1명밖에 안 됨 → 학습이 불안정해질 가능성이 매우 높기 때문에 skip
continue
# print(dict(zip(unique, cts)))
# 원래 코드에는 test set의 클래스 불균형은 체크하지 않음
# ### ✅ test_p_index 클래스 확인 추가
# test_label_stat = [labels_[idx[0]] for idx in test_p_index]
# if len(set(test_label_stat)) < 2:
# print(f"⚠️ Skipping split: test set has only one class -> {set(test_label_stat)}")
# continue
# train_data에서 환자 단위로, train과 validation 나누기
kk = 0
while True:
train_p_index_, valid_p_index, ty, vy = train_test_split(train_p_index, label_stat, test_size=0.33,
random_state=args.seed + kk)
if len(set(ty)) == 2 and len(set(vy)) == 2:
break
kk += 1
print("train_p_index_",len(train_p_index_))
print("valid_p_index",len(valid_p_index))
print("test_p_index",len(test_p_index))
train_p_index = train_p_index_
len_valid = len(valid_p_index)
# _index = np.concatenate([valid_p_index, test_p_index])
_index = valid_p_index + test_p_index # ✅ 리스트끼리 결합
# train_ids = []
# for i in train_p_index:
# train_ids.append(patient_id.iloc[p_idx[i][0]])
# print(train_ids)
x_train = []
x_test = []
x_valid = []
y_train = []
y_valid = []
y_test = []
id_train = []
id_test = []
id_valid = []
if args.augment_num > 0:
print("data augment 실행 함")
data_augmented, train_p_index_aug, labels_aug, cell_type_aug = mixups(
args, train_origin, train_p_index_, train_labels, train_cell_type
)
if data_augmented is None:
print("⚠️ Skipping due to insufficient classes for mixup")
continue
else:
print("data augment 실행 안 함")
data_augmented = train_origin
train_p_index_aug = train_p_index_
labels_aug = train_labels
cell_type_aug = train_cell_type
# individual_train, individual_test = sampling(args, train_p_index, test_p_index, train_labels, labels_augmented, cell_type_augmented)
# individual_train, individual_test = sampling(args, train_p_index, test_p_index, train_labels, train_labels, train_cell_type)
# 평가용 인덱스는 valid + test로 합쳐서 sampling (scRAT 구조상 하나로 묶어서 sampling)
eval_p_index = valid_p_index + test_p_index
print("eval_p_index len: ",len(eval_p_index))
individual_train, individual_eval = sampling(
args,
train_p_index_aug,
eval_p_index,
train_labels,
labels_aug,
cell_type_aug
)
print("individual_train", len(individual_train))
print("individual_eval", len(individual_eval))
for sample_list in individual_train:
for sample in sample_list:
id, label = sample
x_train.append(id)
y_train.append(label)
id_train.append(id)
n_valid = len(valid_p_index)
for i in range(len(eval_p_index)):
ids, labels = [x[0] for x in individual_eval[i]], [x[1] for x in individual_eval[i]]
if i < n_valid:
x_valid.append(ids)
y_valid.append(labels[0])
id_valid.append(ids)
else:
x_test.append(ids)
y_test.append(labels[0])
id_test.append(ids)
x_train, x_valid, x_test, y_train, y_valid, y_test = x_train, x_valid, x_test, np.array(y_train).reshape([-1, 1]), \
np.array(y_valid).reshape([-1, 1]), np.array(y_test).reshape([-1, 1])
print("train data의 x, y, id 길이:", len(x_train), len(y_train), len(id_train))
print("valid data의 x, y, id 길이:", len(x_valid), len(y_valid), len(id_valid))
print("test data의 x, y, id 길이:", len(x_test), len(y_test), len(id_test))
auc, acc, cm, recall, precision = train(
x_train, x_valid, x_test,
y_train, y_valid, y_test,
id_train, id_test,
data_augmented=train_origin, # numpy array
data=train_origin # or full data if needed
)
fold_aucs.append(auc)
accuracy.append(acc)
cms.append(cm)
recalls.append(recall)
precisions.append(precision)
iter_count += 1
if iter_count == abs(args.n_splits) * args.repeat:
break
print(f"✅ Total valid splits used: {iter_count}")
del data_augmented
# 🔽 Repeat 단위 AUC 출력 추가
print(f"\n📌 Repeat {repeat}: 평균 AUC = {np.mean(fold_aucs):.4f}, 표준편차 = {np.std(fold_aucs):.4f}")
main.py의
train() 함수 내의 Confusion Matrix 분기 설정
* true, pred에 하나의 클래스만 있을 경우, AUC는 의미가 없으며 계산도 불가능하여, AUC=0.0으로 예외 처리해줌.
# Confusion Matrix 및 지표 분기
if output_class == 1:
cm = confusion_matrix(true, pred).ravel()
if len(cm) == 4:
recall = cm[3] / (cm[3] + cm[2]) if (cm[3] + cm[2]) > 0 else 0
precision = cm[3] / (cm[3] + cm[1]) if (cm[3] + cm[1]) > 0 else 0
else:
print("⚠️ Skipping evaluation due to insufficient class diversity")
recall = precision = 0
print("Confusion Matrix: " + str(cm))
dataloader.py에
Custom_data()를 h5ad가 아니라 이미 load된 AnnData 객체를 받아 처리하는 버전으로 하나 추가하면 됩니다.
# dataloader.py에 있는 Custom_data()를 h5ad가 아니라
# 이미 load된 AnnData 객체를 받아 처리하는 버전으로 하나 추가하면 됩니다.
def Custom_data_from_loaded(data, args):
# 1. 라벨 매핑 정의
id_dict = {
'normal': 0,
'COVID-19': 1
# 필요한 경우 클래스 추가 가능
}
# 2. 환자 ID, 라벨, 셀 타입 정보 추출
patient_id = data.obs['patient'] if 'patient' in data.obs else data.obs['donor_id']
labels = data.obs['disease__ontology_label']
cell_type = data.obs['manual_annotation']
# 3. expression 데이터 선택
if args.pca:
origin = data.obsm['X_pca']
else:
origin = data.X.toarray() if not isinstance(data.X, np.ndarray) else data.X
# 4. 라벨을 숫자로 변환
labels_ = np.array(labels.map(id_dict))
# 5. 환자별 인덱스를 구성
indices = np.arange(origin.shape[0])
p_ids = sorted(set(patient_id))
p_idx = []
for i in p_ids:
idx = indices[patient_id == i]
if len(set(labels_[idx])) > 1:
for ii in sorted(set(labels_[idx])):
if ii > -1:
iidx = idx[labels_[idx] == ii]
if len(iidx) < max(args.train_sample_cells, args.test_sample_cells):
continue
p_idx.append(iidx)
else:
if labels_[idx[0]] > -1:
if len(idx) < max(args.train_sample_cells, args.test_sample_cells):
continue
p_idx.append(idx)
# 6. numpy 기반으로 반환
return p_idx, labels_, np.array(cell_type), np.array(patient_id), origin
실행 코드 (sample cells 500일 때.)
log_file="logs_covid_split/sample_cells_500.txt"
echo "▶ Running with: lr=1e-4, heads=4, dropout=0.0, weight_decay=1e-4, emb_dim=8"
echo "▶ Output: ${log_file}"
CUDA_VISIBLE_DEVICES=0 PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \
python main.py \
--task custom_covid \
--learning_rate 1e-4 \
--epochs 100 \
--heads 4 \
--dropout 0.0 \
--weight_decay 1e-4 \
--emb_dim 8 \
--pca False \
--repeat 5 \
--train_sample_cells 500 \
--test_sample_cells 500 > "$log_file" 2>&1
결과
<Case 1 ; 19명 대상>
--train_sample_cells 500
--test_sample_cells 500
* true, pred에 하나의 클래스만 있을 경우, AUC는 의미가 없으며 계산도 불가능하여, AUC=0.0으로 예외 처리한 경우 3번 있음
Repeat 0: 평균 AUC = 0.5500
Repeat 1: 평균 AUC = 0.4333
Repeat 2: 평균 AUC = 0.7167
Repeat 3: 평균 AUC = 0.4133
Repeat 4: 평균 AUC = 0.5500
=> mean ± std = 0.5327 ± 0.1123
<Case 2 ; 32명 대상>
--train_sample_cells 300
--test_sample_cells 300
Repeat 0: 평균 AUC = 0.8600
Repeat 1: 평균 AUC = 0.7900
Repeat 2: 평균 AUC = 0.7372
Repeat 3: 평균 AUC = 0.4750
Repeat 4: 평균 AUC = 0.5867
=> mean ± std = 0.6898 ± 0.1415
<Case 3 ; 39명 대상>
--train_sample_cells 100
--test_sample_cells 100
Repeat 0: 평균 AUC = 0.9100
Repeat 1: 평균 AUC = 0.7900
Repeat 2: 평균 AUC = 0.8067
Repeat 3: 평균 AUC = 0.7345
Repeat 4: 평균 AUC = 0.5783
=> mean ± std = 0.7639 ± 0.1175
코드 파일
'AI & Data Analysis > Coding & Programming' 카테고리의 다른 글
| [ScRAT] utils.py _ sampling() (2) | 2025.08.14 |
|---|---|
| Data Size Optimization (0) | 2025.07.07 |
| [ScRAT] customized dataset (0) | 2025.07.01 |
| GSE62452 (microarray) Analysis Results (0) | 2025.05.07 |
| GSE86982 Analysis Summary (0) | 2025.04.30 |