본문 바로가기
AI & Data Analysis/Coding & Programming

[ScRAT] customized dataset with CrossValidation

by doraemin_dev 2025. 8. 7.

main.py에

for loop 전체 변경하자.

# 사전 생성된 split 파일을 기반으로 고정된 train/test 데이터셋으로 실험을 수행하기 위해 (2번코드) # => 주석처리

# 3. for loop 직접 구성 (repeat × fold)

for repeat in range(args.repeat):
    fold_aucs, accuracy, cms, recalls, precisions = [], [], [], [], []
    iter_count = 0
    for fold in range(args.n_splits):
        print(f"🔁 Repeat {repeat}, Fold {fold}")
        train_path = f"/data/project/kim89/0804_data/repeat_{repeat}/fold_{fold}_train.h5ad"
        test_path = f"/data/project/kim89/0804_data/repeat_{repeat}/fold_{fold}_test.h5ad"

        train_data = scanpy.read_h5ad(train_path)
        test_data = scanpy.read_h5ad(test_path)

        train_p_index, train_labels, train_cell_type, patient_id, train_origin = Custom_data_from_loaded(train_data, args)
        test_p_index, test_labels, test_cell_type, test_patient_id, test_origin = Custom_data_from_loaded(test_data, args)

        labels_ = train_labels

        print(f"🔍 Split #{iter_count + 1}")
        print(f"  → train_p_index 환자 수: {len(train_p_index)}")
        print(f"  → test_p_index 환자 수: {len(test_p_index)}")

        # 실제 환자 ID로 보기
        train_ids = [patient_id[idx[0]] for idx in train_p_index]
        test_ids = [patient_id[idx[0]] for idx in test_p_index]
        print(f"  → train 환자 ID: {train_ids}")
        print(f"  → test  환자 ID: {test_ids}")

        # 각 환자의 ID와 label 함께 출력
        print("  → train 환자 ID 및 라벨:")
        for idxs in train_p_index:
            idx = idxs[0]
            print(f"    ID: {patient_id[idx]}, Label: {train_labels[idx]}")

        print("  → test 환자 ID 및 라벨:")
        for idxs in test_p_index:
            idx = idxs[0]
            print(f"    ID: {patient_id[idx]}, Label: {test_labels[idx]}")


        # if args.n_splits < 0:
        #     temp_idx = train_p_index
        #     train_p_index = test_p_index
        #     test_p_index = temp_idx


        label_stat = [labels_[idx[0]] for idx in train_p_index] #  train set에 포함된 환자들의 라벨 목록
        unique, cts = np.unique(label_stat, return_counts=True)

        # 훈련 데이터(train_p_index)에 클래스가 2개 이상 존재해야 학습을 진행한다.
        if len(unique) < 2 or (1 in cts): 
            # 클래스가 하나밖에 없음 → 불균형 → 스킵 
            # or 
            # 등장한 클래스 중 한 클래스의 환자 수가 1명밖에 안 됨 → 학습이 불안정해질 가능성이 매우 높기 때문에 skip
            continue
    #     print(dict(zip(unique, cts)))
        
        # 원래 코드에는 test set의 클래스 불균형은 체크하지 않음
        # ### ✅ test_p_index 클래스 확인 추가
        # test_label_stat = [labels_[idx[0]] for idx in test_p_index]
        # if len(set(test_label_stat)) < 2:
        #     print(f"⚠️  Skipping split: test set has only one class -> {set(test_label_stat)}")
        #     continue

        # train_data에서 환자 단위로, train과 validation 나누기
        kk = 0
        while True:
            train_p_index_, valid_p_index, ty, vy = train_test_split(train_p_index, label_stat, test_size=0.33,
                                                                random_state=args.seed + kk)
            if len(set(ty)) == 2 and len(set(vy)) == 2:
                break
            kk += 1

        print("train_p_index_",len(train_p_index_))
        print("valid_p_index",len(valid_p_index))
        print("test_p_index",len(test_p_index))


        train_p_index = train_p_index_
        len_valid = len(valid_p_index)
        # _index = np.concatenate([valid_p_index, test_p_index])
        _index = valid_p_index + test_p_index  # ✅ 리스트끼리 결합


        # train_ids = []
        # for i in train_p_index:
        #     train_ids.append(patient_id.iloc[p_idx[i][0]])

    #     print(train_ids)

        x_train = []
        x_test = []
        x_valid = []
        y_train = []
        y_valid = []
        y_test = []
        id_train = []
        id_test = []
        id_valid = []

        if args.augment_num > 0:
            print("data augment 실행 함")
            data_augmented, train_p_index_aug, labels_aug, cell_type_aug = mixups(
                args, train_origin, train_p_index_, train_labels, train_cell_type
            )
            if data_augmented is None:
                print("⚠️ Skipping due to insufficient classes for mixup")
                continue
        else:
            print("data augment 실행 안 함")
            data_augmented = train_origin
            train_p_index_aug = train_p_index_
            labels_aug = train_labels
            cell_type_aug = train_cell_type


        # individual_train, individual_test = sampling(args, train_p_index, test_p_index, train_labels, labels_augmented, cell_type_augmented)
        # individual_train, individual_test = sampling(args, train_p_index, test_p_index, train_labels, train_labels, train_cell_type)
        
        # 평가용 인덱스는 valid + test로 합쳐서 sampling (scRAT 구조상 하나로 묶어서 sampling)
        eval_p_index = valid_p_index + test_p_index
        print("eval_p_index len: ",len(eval_p_index))

        individual_train, individual_eval = sampling(
            args,
            train_p_index_aug,
            eval_p_index,
            train_labels,
            labels_aug,
            cell_type_aug
        )
        print("individual_train", len(individual_train))
        print("individual_eval", len(individual_eval))

        for sample_list in individual_train:
            for sample in sample_list:
                id, label = sample
                x_train.append(id)
                y_train.append(label)
                id_train.append(id)


        n_valid = len(valid_p_index)
        for i in range(len(eval_p_index)):
            ids, labels = [x[0] for x in individual_eval[i]], [x[1] for x in individual_eval[i]]
            if i < n_valid:
                x_valid.append(ids)
                y_valid.append(labels[0])
                id_valid.append(ids)
            else:
                x_test.append(ids)
                y_test.append(labels[0])
                id_test.append(ids)


        x_train, x_valid, x_test, y_train, y_valid, y_test = x_train, x_valid, x_test, np.array(y_train).reshape([-1, 1]), \
                                                            np.array(y_valid).reshape([-1, 1]), np.array(y_test).reshape([-1, 1])
        print("train data의 x, y, id 길이:", len(x_train), len(y_train), len(id_train))
        print("valid data의 x, y, id 길이:", len(x_valid), len(y_valid), len(id_valid))
        print("test data의 x, y, id 길이:", len(x_test), len(y_test), len(id_test))

        auc, acc, cm, recall, precision = train(
            x_train, x_valid, x_test,
            y_train, y_valid, y_test,
            id_train, id_test,
            data_augmented=train_origin,  # numpy array
            data=train_origin             # or full data if needed
        )

        fold_aucs.append(auc)
        accuracy.append(acc)
        cms.append(cm)
        recalls.append(recall)
        precisions.append(precision)
        iter_count += 1
        if iter_count == abs(args.n_splits) * args.repeat:
            break
        print(f"✅ Total valid splits used: {iter_count}")

        del data_augmented
        
    # 🔽 Repeat 단위 AUC 출력 추가
    print(f"\n📌 Repeat {repeat}: 평균 AUC = {np.mean(fold_aucs):.4f}, 표준편차 = {np.std(fold_aucs):.4f}")

main.py의

train() 함수 내의 Confusion Matrix 분기 설정

* true, pred에 하나의 클래스만 있을 경우, AUC는 의미가 없으며 계산도 불가능하여, AUC=0.0으로 예외 처리해줌.

    # Confusion Matrix 및 지표 분기
    if output_class == 1:
        cm = confusion_matrix(true, pred).ravel()

        if len(cm) == 4:
            recall = cm[3] / (cm[3] + cm[2]) if (cm[3] + cm[2]) > 0 else 0
            precision = cm[3] / (cm[3] + cm[1]) if (cm[3] + cm[1]) > 0 else 0
        else:
            print("⚠️ Skipping evaluation due to insufficient class diversity")
            recall = precision = 0
        print("Confusion Matrix: " + str(cm))

dataloader.py에

Custom_data()를 h5ad가 아니라 이미 load된 AnnData 객체를 받아 처리하는 버전으로 하나 추가하면 됩니다.

# dataloader.py에 있는 Custom_data()를 h5ad가 아니라 
# 이미 load된 AnnData 객체를 받아 처리하는 버전으로 하나 추가하면 됩니다.
def Custom_data_from_loaded(data, args):
    # 1. 라벨 매핑 정의
    id_dict = {
        'normal': 0,
        'COVID-19': 1
        # 필요한 경우 클래스 추가 가능
    }

    # 2. 환자 ID, 라벨, 셀 타입 정보 추출
    patient_id = data.obs['patient'] if 'patient' in data.obs else data.obs['donor_id']
    labels = data.obs['disease__ontology_label']
    cell_type = data.obs['manual_annotation']

    # 3. expression 데이터 선택
    if args.pca:
        origin = data.obsm['X_pca']
    else:
        origin = data.X.toarray() if not isinstance(data.X, np.ndarray) else data.X

    # 4. 라벨을 숫자로 변환
    labels_ = np.array(labels.map(id_dict))

    # 5. 환자별 인덱스를 구성
    indices = np.arange(origin.shape[0])
    p_ids = sorted(set(patient_id))
    p_idx = []

    for i in p_ids:
        idx = indices[patient_id == i]
        if len(set(labels_[idx])) > 1:
            for ii in sorted(set(labels_[idx])):
                if ii > -1:
                    iidx = idx[labels_[idx] == ii]
                    if len(iidx) < max(args.train_sample_cells, args.test_sample_cells):
                        continue
                    p_idx.append(iidx)
        else:
            if labels_[idx[0]] > -1:
                if len(idx) < max(args.train_sample_cells, args.test_sample_cells):
                    continue
                p_idx.append(idx)

    # 6. numpy 기반으로 반환
    return p_idx, labels_, np.array(cell_type), np.array(patient_id), origin

실행 코드 (sample cells 500일 때.) 

log_file="logs_covid_split/sample_cells_500.txt"
echo "▶ Running with: lr=1e-4, heads=4, dropout=0.0, weight_decay=1e-4, emb_dim=8"
echo "▶ Output: ${log_file}"
CUDA_VISIBLE_DEVICES=0 PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 \
python main.py \
  --task custom_covid \
  --learning_rate 1e-4 \
  --epochs 100 \
  --heads 4 \
  --dropout 0.0 \
  --weight_decay 1e-4 \
  --emb_dim 8 \
  --pca False \
  --repeat 5 \
  --train_sample_cells 500 \
  --test_sample_cells 500 > "$log_file" 2>&1

결과

<Case 1 ; 19명 대상> 
--train_sample_cells 500
--test_sample_cells 500
* true, pred에 하나의 클래스만 있을 경우, AUC는 의미가 없으며 계산도 불가능하여, AUC=0.0으로 예외 처리한 경우 3번 있음

Repeat 0: 평균 AUC = 0.5500
Repeat 1: 평균 AUC = 0.4333
Repeat 2: 평균 AUC = 0.7167
Repeat 3: 평균 AUC = 0.4133
Repeat 4: 평균 AUC = 0.5500
=> mean ± std = 0.5327 ± 0.1123


<Case 2 ; 32명 대상>
--train_sample_cells 300
--test_sample_cells 300

Repeat 0: 평균 AUC = 0.8600
Repeat 1: 평균 AUC = 0.7900
Repeat 2: 평균 AUC = 0.7372
Repeat 3: 평균 AUC = 0.4750
Repeat 4: 평균 AUC = 0.5867
=> mean ± std = 0.6898 ± 0.1415


<Case 3 ; 39명 대상>
--train_sample_cells 100
--test_sample_cells 100

Repeat 0: 평균 AUC = 0.9100
Repeat 1: 평균 AUC = 0.7900
Repeat 2: 평균 AUC = 0.8067
Repeat 3: 평균 AUC = 0.7345
Repeat 4: 평균 AUC = 0.5783
=> mean ± std = 0.7639 ± 0.1175

코드 파일

main.py
0.03MB

 

 

 

 

utils.py
0.01MB
dataloader.py
0.01MB

'AI & Data Analysis > Coding & Programming' 카테고리의 다른 글

[ScRAT] utils.py _ sampling()  (2) 2025.08.14
Data Size Optimization  (0) 2025.07.07
[ScRAT] customized dataset  (0) 2025.07.01
GSE62452 (microarray) Analysis Results  (0) 2025.05.07
GSE86982 Analysis Summary  (0) 2025.04.30