AI 및 Data Analysis/Code

[ScRAT] Code Demo

doraemin_dev 2025. 3. 20. 17:40

이전 글에서, ScRAT 실행을 해보았다.

2025.03.20 - [AI, 논문, 데이터 분석] - [ScRAT] scRNA Analysis

 

[ScRAT] scRNA Analysis

Paper : Phenotype prediction from single-cell RNA-seq data using attention-based neural networks https://academic.oup.com/bioinformatics/article/40/2/btae067/7613064  본 논문에서 언급된 ScRAT 방법으로 scRNA 분석하기 https://github.com/y

doraemin.tistory.com

 

어떤 코드가 실행되는지 살펴보자.


run.sh

run.sh
0.00MB

 

python main.py  \
    --model Transformer \
    --epochs 100 \
    --norm_first False \
    --seed 100 \
    --task stage \
    --all 0 \
    --h_dim 128 \
    --heads 8 \
    --layers 1 \
    --dropout 0.3 \
    --min_size 10000 \
    --inter_only True \
    --batch_size 256 \
    --same_pheno -1 \
    --n_splits 2 \
    --augment_num 300 \
    --pca True \
    --warmup False \
    --learning_rate 0.01 \
    --alpha 0.5 \
    --mix_type 1 \
    --repeat 1 \
    --train_sample_cells=500  --test_sample_cells=500  --train_num_sample=20 --test_num_sample=50

 

 


main.py

main.py
0.02MB

 

from sklearn import metrics
from sklearn.metrics import accuracy_score
import scipy.stats as st
from torch.optim import Adam
from utils import *
import argparse
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import RepeatedKFold
from sklearn.model_selection import train_test_split

from model_baseline import *
from Transformer import TransformerPredictor

from dataloader import *


def _str2bool(v):
    return v.lower() in ("yes", "y", "true", "t", "1")


def int_or_float(x):
    try:
        return int(x)
    except ValueError:
        return float(x)


parser = argparse.ArgumentParser(description='scRNA diagnosis')

parser.add_argument('--seed', type=int, default=240)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--learning_rate', type=float, default=3e-3)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument("--task", type=str, default="severity")
parser.add_argument('--emb_dim', type=int, default=128)  # embedding dim
parser.add_argument('--h_dim', type=int, default=128)  # hidden dim of the model
parser.add_argument('--dropout', type=float, default=0.3)  # dropout
parser.add_argument('--layers', type=int, default=1)
parser.add_argument('--heads', type=int, default=8)
parser.add_argument("--train_sample_cells", type=int, default=500,
                    help="number of cells in one sample in train dataset")
parser.add_argument("--test_sample_cells", type=int, default=500,
                    help="number of cells in one sample in test dataset")
parser.add_argument("--train_num_sample", type=int, default=20,
                    help="number of sampled data points in train dataset")
parser.add_argument("--test_num_sample", type=int, default=100,
                    help="number of sampled data points in test dataset")
parser.add_argument('--model', type=str, default='Transformer')
parser.add_argument('--dataset', type=str, default=None)
parser.add_argument('--inter_only', type=_str2bool, default=False)
parser.add_argument('--same_pheno', type=int, default=0)
parser.add_argument('--augment_num', type=int, default=0)
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--repeat', type=int, default=3)
parser.add_argument('--all', type=int, default=1)
parser.add_argument('--min_size', type=int, default=6000)
parser.add_argument('--n_splits', type=int, default=5)
parser.add_argument('--pca', type=_str2bool, default=True)
parser.add_argument('--mix_type', type=int, default=1)
parser.add_argument('--norm_first', type=_str2bool, default=False)
parser.add_argument('--warmup', type=_str2bool, default=False)
parser.add_argument('--top_k', type=int, default=1)

args = parser.parse_args()

# print("# of GPUs is", torch.cuda.device_count())
print(args)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

patient_summary = {}
stats = {}
stats_id = {}

if args.task == 'haniffa' or args.task == 'combat':
    label_dict = {0: 'Non Covid', 1: 'Covid'}
elif args.task == 'severity':
    label_dict = {0: 'mild', 1: 'severe'}
elif args.task == 'stage':
    label_dict = {0: 'convalescence', 1: 'progression'}


def train(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, data_augmented, data):
    dataset_1 = MyDataset(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, fold='train')
    dataset_2 = MyDataset(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, fold='test')
    dataset_3 = MyDataset(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, fold='val')
    train_loader = torch.utils.data.DataLoader(dataset_1, batch_size=args.batch_size, shuffle=True,
                                               collate_fn=dataset_1.collate)
    test_loader = torch.utils.data.DataLoader(dataset_2, batch_size=1, shuffle=False, collate_fn=dataset_2.collate)
    valid_loader = torch.utils.data.DataLoader(dataset_3, batch_size=1, shuffle=False, collate_fn=dataset_3.collate)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    input_dim = data_augmented[0].shape[-1]
    output_class = 1

    if args.model == 'Transformer':
        model = TransformerPredictor(input_dim=input_dim, model_dim=args.emb_dim, num_classes=output_class,
                                     num_heads=args.heads, num_layers=args.layers, dropout=args.dropout,
                                     input_dropout=0, pca=args.pca, norm_first=args.norm_first)
    elif args.model == 'feedforward':
        model = FeedForward(input_dim=input_dim, h_dim=args.emb_dim, cl=output_class, dropout=args.dropout)
    elif args.model == 'linear':
        model = Linear_Classfier(input_dim=input_dim, cl=output_class)
    elif args.model == 'scfeed':
        model = scFeedForward(input_dim=input_dim, cl=output_class, model_dim=args.emb_dim, dropout=args.dropout, pca=args.pca)

    model = nn.DataParallel(model)
    model.to(device)
    best_model = model

    print(device)

    ################################################################
    # training and evaluation
    ################################################################
    optimizer = Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    if args.warmup:
        scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.epochs // 10,
                                                                 num_training_steps=args.epochs)
    else:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5, last_epoch=-1)
    sigmoid = torch.nn.Sigmoid().to(device)

    max_acc, max_epoch, max_auc, max_loss, max_valid_acc, max_valid_auc = 0, 0, 0, 0, 0, 0
    test_accs, valid_aucs, train_losses, valid_losses, train_accs, test_aucs = [], [0.], [], [], [], []
    best_valid_loss = float("inf")
    wrongs = []
    trigger_times = 0
    patience = 2
    for ep in (range(1, args.epochs + 1)):
        model.train()
        train_loss = []


        for batch in (train_loader):
            x_ = torch.from_numpy(data_augmented[batch[0]]).float().to(device)
            y_ = batch[1].to(device)
            mask_ = batch[3].to(device)

            optimizer.zero_grad()

            out = model(x_, mask_)

            loss = nn.BCELoss()(sigmoid(out), y_)
            loss.backward()

            optimizer.step()
            train_loss.append(loss.item())

        scheduler.step()

        train_loss = sum(train_loss) / len(train_loss)
        train_losses.append(train_loss)

        if ep % 1 == 0:
            valid_loss = []
            model.eval()
            pred = []
            true = []
            with torch.no_grad():
                for batch in (valid_loader):
                    x_ = torch.from_numpy(data[batch[0]]).float().to(device).squeeze(0)
                    y_ = batch[1].int().to(device)

                    out = model(x_)
                    out = sigmoid(out)

                    loss = nn.BCELoss()(out, y_ * torch.ones(out.shape).to(device))
                    valid_loss.append(loss.item())

                    out = out.detach().cpu().numpy()

                    # majority voting
                    f = lambda x: 1 if x > 0.5 else 0
                    func = np.vectorize(f)
                    out = np.argmax(np.bincount(func(out).reshape(-1))).reshape(-1)
                    pred.append(out)
                    y_ = y_.detach().cpu().numpy()
                    true.append(y_)
            # pred = np.concatenate(pred)
            # true = np.concatenate(true)

            valid_loss = sum(valid_loss) / len(valid_loss)
            valid_losses.append(valid_loss)

            if (valid_loss < best_valid_loss):
                best_model = copy.deepcopy(model)
                max_epoch = ep
                best_valid_loss = valid_loss
                max_loss = train_loss

            print("Epoch %d, Train Loss %f, Valid_loss %f" % (ep, train_loss, valid_loss))

            # Early stop
            if (ep > args.epochs - 50) and ep > 1 and (valid_loss > valid_losses[-2]):
                trigger_times += 1
                if trigger_times >= patience:
                    break
            else:
                trigger_times = 0

    best_model.eval()
    pred = []
    test_id = []
    true = []
    wrong = []
    prob = []
    with torch.no_grad():
        for batch in (test_loader):
            x_ = torch.from_numpy(data[batch[0]]).float().to(device).squeeze(0)
            y_ = batch[1].int().numpy()
            id_ = batch[2][0]

            out = best_model(x_)
            out = sigmoid(out)
            out = out.detach().cpu().numpy().reshape(-1)

            # For attention analysis:

            # if args.model == 'Transformer':
            #     attens = best_model.module.get_attention_maps(x_)[-1]
            #     for iter in range(len(attens)):
            #         topK = np.bincount(attens[iter].argsort(-1)[:, :, -args.top_k:].
            #                            cpu().detach().numpy().reshape(-1)).argsort()[-20:][::-1]   # 20 is a 
            #         for idd in id_[iter][topK]:
            #             stats[cell_type_large[idd]] = stats.get(cell_type_large[idd], 0) + 1
            #             stats_id[idd] = stats_id.get(idd, 0) + 1

            y_ = y_[0][0]
            true.append(y_)

            if args.model != 'Transformer':
                prob.append(out[0])
            else:
                prob.append(out.mean())

            # majority voting
            f = lambda x: 1 if x > 0.5 else 0
            func = np.vectorize(f)
            out = np.argmax(np.bincount(func(out).reshape(-1))).reshape(-1)[0]
            pred.append(out)
            test_id.append(patient_id[batch[2][0][0][0]])
            if out != y_:
                wrong.append(patient_id[batch[2][0][0][0]])

    if len(wrongs) == 0:
        wrongs = set(wrong)
    else:
        wrongs = wrongs.intersection(set(wrong))

    test_auc = metrics.roc_auc_score(true, prob)

    test_acc = accuracy_score(true, pred)
    for idx in range(len(pred)):
        print(f"{test_id[idx]} -- true: {label_dict[true[idx]]} -- pred: {label_dict[pred[idx]]}")
    test_accs.append(test_acc)

    cm = confusion_matrix(true, pred).ravel()
    recall = cm[3] / (cm[3] + cm[2])
    precision = cm[3] / (cm[3] + cm[1])
    if (cm[3] + cm[1]) == 0:
        precision = 0

    print("Best performance: Epoch %d, Loss %f, Test ACC %f, Test AUC %f, Test Recall %f, Test Precision %f" % (
    max_epoch, max_loss, test_acc, test_auc, recall, precision))
    print("Confusion Matrix: " + str(cm))
    for w in wrongs:
        v = patient_summary.get(w, 0)
        patient_summary[w] = v + 1

    return test_auc, test_acc, cm, recall, precision


if args.model != 'Transformer':
    args.repeat = 60

if args.task != 'custom':
    p_idx, labels_, cell_type, patient_id, data, cell_type_large = Covid_data(args)
else:
    p_idx, labels_, cell_type, patient_id, data, cell_type_large = Custom_data(args)
rkf = RepeatedKFold(n_splits=abs(args.n_splits), n_repeats=args.repeat * 100, random_state=args.seed)
num = np.arange(len(p_idx))
accuracy, aucs, cms, recalls, precisions = [], [], [], [], []
iter_count = 0

for train_index, test_index in rkf.split(num):
    if args.n_splits < 0:
        temp_idx = train_index
        train_index = test_index
        test_index = temp_idx

    label_stat = []
    for idx in train_index:
        label_stat.append(labels_[p_idx[idx][0]])
    unique, cts = np.unique(label_stat, return_counts=True)
    if len(unique) < 2 or (1 in cts):
        continue
#     print(dict(zip(unique, cts)))

    kk = 0
    while True:
        train_index_, valid_index, ty, vy = train_test_split(train_index, label_stat, test_size=0.33,
                                                             random_state=args.seed + kk)
        if len(set(ty)) == 2 and len(set(vy)) == 2:
            break
        kk += 1

    train_index = train_index_
    len_valid = len(valid_index)
    _index = np.concatenate([valid_index, test_index])

    train_ids = []
    for i in train_index:
        train_ids.append(patient_id[p_idx[i][0]])
#     print(train_ids)

    x_train = []
    x_test = []
    x_valid = []
    y_train = []
    y_valid = []
    y_test = []
    id_train = []
    id_test = []
    id_valid = []
    data_augmented, train_p_idx, labels_augmented, cell_type_augmented = mixups(args, data,
                                                                                [p_idx[idx] for idx in train_index],
                                                                                labels_,
                                                                                cell_type)
    individual_train, individual_test = sampling(args, train_p_idx, [p_idx[idx] for idx in _index], labels_,
                                                 labels_augmented, cell_type_augmented)
    for t in individual_train:
        id, label = [id_l[0] for id_l in t], [id_l[1] for id_l in t]
        x_train += [ii for ii in id]
        y_train += (label)
        id_train += (id)

    temp_idx = np.arange(len(_index))
    for t_idx in temp_idx[len_valid:]:
        id, label = [id_l[0] for id_l in individual_test[t_idx]], [id_l[1] for id_l in individual_test[t_idx]]
        x_test.append([ii for ii in id])
        y_test.append(label[0])
        id_test.append(id)
    for t_idx in temp_idx[:len_valid]:
        id, label = [id_l[0] for id_l in individual_test[t_idx]], [id_l[1] for id_l in individual_test[t_idx]]
        x_valid.append([ii for ii in id])
        y_valid.append(label[0])
        id_valid.append(id)
    x_train, x_valid, x_test, y_train, y_valid, y_test = x_train, x_valid, x_test, np.array(y_train).reshape([-1, 1]), \
                                                         np.array(y_valid).reshape([-1, 1]), np.array(y_test).reshape(
        [-1, 1])
    auc, acc, cm, recall, precision = train(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test,
                                            data_augmented, data)
    aucs.append(auc)
    accuracy.append(acc)
    cms.append(cm)
    recalls.append(recall)
    precisions.append(precision)
    iter_count += 1
    if iter_count == abs(args.n_splits) * args.repeat:
        break

    del data_augmented

print("="*33)
print("=== Final Evaluation (average across all splits) ===")
print("="*33)

print("Best performance: Test ACC %f,   Test AUC %f,   Test Recall %f,   Test Precision %f" % (np.average(accuracy), np.average(aucs), np.average(recalls), np.average(precisions)))

####################################
######## Only for repeat > 1 #######
####################################
# accuracy = np.array(accuracy).reshape([-1, args.repeat]).mean(0)
# aucs = np.array(aucs).reshape([-1, args.repeat]).mean(0)
# recalls = np.array(recalls).reshape([-1, args.repeat]).mean(0)
# precisions = np.array(precisions).reshape([-1, args.repeat]).mean(0)
# ci_1 = st.t.interval(alpha=0.95, df=len(accuracy) - 1, loc=np.mean(accuracy), scale=st.sem(accuracy))[1] - np.mean(accuracy)
# ci_2 = st.t.interval(alpha=0.95, df=len(aucs) - 1, loc=np.mean(aucs), scale=st.sem(aucs))[1] - np.mean(aucs)
# ci_3 = st.t.interval(alpha=0.95, df=len(recalls) - 1, loc=np.mean(recalls), scale=st.sem(recalls))[1] - np.mean(recalls)
# ci_4 = st.t.interval(alpha=0.95, df=len(precisions) - 1, loc=np.mean(precisions), scale=st.sem(precisions))[1] - np.mean(precisions)
# print("ci: ACC ci %f,   AUC ci %f,   Recall ci %f,   Precision ci %f" % (ci_1, ci_2, ci_3, ci_4))

# print(np.average(cms, 0))
# print(patient_summary)
# print(stats)
# print(stats_id)

utils.py

utils.py
0.01MB

 

 

2025.03.20 - [AI 및 Data Analysis/논문] - [ScRAT] utils.py _ mixup()

 

[ScRAT] utils.py _ mixup()

🔍 mixup() 함수 상세 분석📌 1. mixup() 개요mixup() 함수는 두 개의 샘플(x, x_p)을 주어진 가중치(lam)에 따라 선형 조합하여 새로운 데이터를 생성하는 핵심 함수입니다.즉, mixups() 함수는 mixup을 수행

doraemin.tistory.com

 

 

import torch
import numpy as np
import copy
from torch.utils.data import Dataset
from tqdm import tqdm


class MyDataset(Dataset):
    def __init__(self, x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, fold='train'):
        fold = fold.lower()

        self.train = False
        self.test = False
        self.val = False

        if fold == "train":
            self.train = True
        elif fold == "test":
            self.test = True
        elif fold == "val":
            self.val = True
        else:
            raise RuntimeError("Not train-val-test")

        self.x_train = x_train
        self.x_valid = x_valid
        self.x_test = x_test
        self.y_train = y_train
        self.y_valid = y_valid
        self.y_test = y_test
        self.id_train = id_train
        self.id_test = id_test

    def __len__(self):
        if self.train:
            return len(self.x_train)
        elif self.test:
            return len(self.x_test)
        elif self.val:
            return len(self.x_valid)

    def __getitem__(self, index):
        if self.train:
            x, y, cell_id = torch.from_numpy(np.array(self.x_train[index])), \
                            torch.from_numpy(self.y_train[index]).float(), self.id_train[index]
        elif self.test:
            x, y, cell_id = torch.from_numpy(np.array(self.x_test[index])), \
                            torch.from_numpy(self.y_test[index]).float(), self.id_test[index]
        elif self.val:
            x, y, cell_id = torch.from_numpy(np.array(self.x_valid[index])), \
                            torch.from_numpy(self.y_valid[index]).float(), []

        return x, y, cell_id

    def collate(self, batches):
        xs = torch.stack([batch[0] for batch in batches if len(batch) > 0])
        mask = torch.stack([batch[0] == -1 for batch in batches if len(batch) > 0])
        ys = torch.stack([batch[1] for batch in batches if len(batch) > 0])
        ids = [batch[2] for batch in batches if len(batch) > 0]
        return xs, torch.FloatTensor(ys), ids, mask


def add_noise(cells):
    mean = 0
    var = 1e-5
    sigma = var ** 0.5
    gauss = np.random.normal(mean, sigma, cells.shape)
    noisy = cells + gauss
    return noisy


def mixup(x, x_p, alpha=1.0, size=1, lam=None):
    batch_size = min(x.shape[0], x_p.shape[0])
    if lam == None:
        lam = np.random.beta(alpha, alpha)
        if size > 1:
            lam = np.random.beta(alpha, alpha, size=size).reshape([-1, 1])
    # x = np.random.permutation(x)
    # x_p = np.random.permutation(x_p)
    x_mix = lam * x[:batch_size] + (1 - lam) * x_p[:batch_size]
    return x_mix, lam


def mixups(args, data, p_idx, labels_, cell_type):
    max_num_cells = data.shape[0]
    ###################
    # check the dataset
    for i, pp in enumerate(p_idx):
        if len(set(labels_[pp])) > 1:
            print(i)
    ###################
    all_ct = {}
    for i, ct in enumerate(sorted(set(cell_type))):
        all_ct[ct] = i
    cell_type_ = np.array(cell_type.map(all_ct))
    ###################
    for idx, i in enumerate(p_idx):
        max_num_cells += (max(args.min_size - len(i), 0) + 100)
    data_augmented = np.zeros([max_num_cells + (args.min_size + 100) * args.augment_num, data.shape[1]])
    data_augmented[:data.shape[0]] = data
    last = data.shape[0]
    labels_augmented = copy.deepcopy(labels_)
    cell_type_augmented = cell_type_

    if args.same_pheno != 0:
        p_idx_per_pheno = {}
        for pp in p_idx:
            y = labels_augmented[pp[0]]
            if p_idx_per_pheno.get(y, -2) == -2:
                p_idx_per_pheno[y] = [pp]
            else:
                p_idx_per_pheno[y].append(pp)

    if args.inter_only and (args.augment_num > 0):
        p_idx_augmented = []
    else:
        p_idx_augmented = copy.deepcopy(p_idx)
    
    if args.augment_num > 0:
        print("======= sample mixup ... ============")
        for i in tqdm(range(args.augment_num)):
            lam = np.random.beta(args.alpha, args.alpha)
            if args.same_pheno == 1:
                temp_label = np.random.randint(len(p_idx_per_pheno))
                id_1, id_2 = np.random.randint(len(p_idx_per_pheno[temp_label]), size=2)
                idx_1, idx_2 = p_idx_per_pheno[temp_label][id_1], p_idx_per_pheno[temp_label][id_2]
            elif args.same_pheno == -1:
                i_1, i_2 = np.random.choice(len(p_idx_per_pheno), 2, replace=False)
                id_1 = np.random.randint(len(p_idx_per_pheno[i_1]))
                id_2 = np.random.randint(len(p_idx_per_pheno[i_2]))
                idx_1, idx_2 = p_idx_per_pheno[i_1][id_1], p_idx_per_pheno[i_2][id_2]
            else:
                id_1, id_2 = np.random.randint(len(p_idx), size=2)
                idx_1, idx_2 = p_idx[id_1], p_idx[id_2]
            diff = 0
            set_union = sorted(set(cell_type_augmented[idx_1]).union(set(cell_type_augmented[idx_2])))
            while diff < (args.min_size // 2):
                for ct in set_union:
                    i_sub_1 = idx_1[cell_type_augmented[idx_1] == ct]
                    i_sub_2 = idx_2[cell_type_augmented[idx_2] == ct]
                    diff_sub = max(
                        int(args.min_size * (lam * len(i_sub_1) / len(idx_1) + (1 - lam) * len(i_sub_2) / len(idx_2))),
                        1)
                    diff += diff_sub
                    if len(i_sub_1) == 0:
                        sampled_idx_1 = [-1] * diff_sub
                        sampled_idx_2 = np.random.choice(i_sub_2, diff_sub)
                        x_mix, _ = mixup(data_augmented[sampled_idx_1], data_augmented[sampled_idx_2], alpha=args.alpha,
                                         lam=lam)
                        x_mix = add_noise(x_mix)
                    elif len(i_sub_2) == 0:
                        sampled_idx_1 = np.random.choice(i_sub_1, diff_sub)
                        sampled_idx_2 = [-1] * diff_sub
                        x_mix, _ = mixup(data_augmented[sampled_idx_1], data_augmented[sampled_idx_2], alpha=args.alpha,
                                         lam=lam)
                        x_mix = add_noise(x_mix)
                    else:
                        sampled_idx_1 = np.random.choice(i_sub_1, diff_sub)
                        sampled_idx_2 = np.random.choice(i_sub_2, diff_sub)
                        x_mix, _ = mixup(data_augmented[sampled_idx_1], data_augmented[sampled_idx_2], alpha=args.alpha,
                                         lam=lam)
                    data_augmented[last:(last + x_mix.shape[0])] = x_mix
                    last += x_mix.shape[0]
                    cell_type_augmented = np.concatenate([cell_type_augmented, [ct] * diff_sub])
            labels_augmented = np.concatenate(
                [labels_augmented, [lam * labels_augmented[idx_1[0]] + (1 - lam) * labels_augmented[idx_2[0]]] * diff])
            p_idx_augmented.append(np.arange(labels_augmented.shape[0] - diff, labels_augmented.shape[0]))

    return data_augmented[:last+1], p_idx_augmented, labels_augmented, cell_type_augmented


def sampling(args, train_p_idx, test_p_idx, labels_, labels_augmented, cell_type_augmented):
    if args.all == 0:
        individual_train = []
        individual_test = []
        for idx in train_p_idx:
            y = labels_augmented[idx[0]]
            temp = []
            if idx.shape[0] < args.train_sample_cells:
                for _ in range(args.train_num_sample):
                    sample = np.zeros(args.train_sample_cells, dtype=int) - 1
                    sample[:idx.shape[0]] = idx
                    temp.append((sample, y))
            else:
                for _ in range(args.train_num_sample):
                    sample = np.random.choice(idx, args.train_sample_cells, replace=False)
                    temp.append((sample, y))
            individual_train.append(temp)
        for idx in test_p_idx:
            y = labels_[idx[0]]
            if idx.shape[0] < args.test_sample_cells:
                sample_cells = idx.shape[0]
            else:
                sample_cells = args.test_sample_cells
            temp = []
            for _ in range(args.test_num_sample):
                sample = np.random.choice(idx, sample_cells, replace=False)
                temp.append((sample, y))
            individual_test.append(temp)
    else:
        max_length = max([len(tt) for tt in train_p_idx])
        individual_train = []
        individual_test = []
        for idx in train_p_idx:
            y = labels_augmented[idx[0]]
            temp = []
            sample = np.zeros(max_length, dtype=int) - 1
            sample[:idx.shape[0]] = idx
            temp.append((sample, y))
            individual_train.append(temp)
        for idx in test_p_idx:
            y = labels_[idx[0]]
            temp = []
            sample = idx
            temp.append((sample, y))
            individual_test.append(temp)

    return individual_train, individual_test


def stratify(out, split=2):
    f = lambda x: int(x * split)
    func = np.vectorize(f)
    majority = np.argmax(np.bincount(func(out))).reshape(-1)[0]
    return out[func(out) == majority].mean()