AI 및 Data Analysis/Code
[ScRAT] Code Demo
doraemin_dev
2025. 3. 20. 17:40
이전 글에서, ScRAT 실행을 해보았다.
2025.03.20 - [AI, 논문, 데이터 분석] - [ScRAT] scRNA Analysis
[ScRAT] scRNA Analysis
Paper : Phenotype prediction from single-cell RNA-seq data using attention-based neural networks https://academic.oup.com/bioinformatics/article/40/2/btae067/7613064 본 논문에서 언급된 ScRAT 방법으로 scRNA 분석하기 https://github.com/y
doraemin.tistory.com
어떤 코드가 실행되는지 살펴보자.
run.sh
python main.py \
--model Transformer \
--epochs 100 \
--norm_first False \
--seed 100 \
--task stage \
--all 0 \
--h_dim 128 \
--heads 8 \
--layers 1 \
--dropout 0.3 \
--min_size 10000 \
--inter_only True \
--batch_size 256 \
--same_pheno -1 \
--n_splits 2 \
--augment_num 300 \
--pca True \
--warmup False \
--learning_rate 0.01 \
--alpha 0.5 \
--mix_type 1 \
--repeat 1 \
--train_sample_cells=500 --test_sample_cells=500 --train_num_sample=20 --test_num_sample=50
main.py
from sklearn import metrics
from sklearn.metrics import accuracy_score
import scipy.stats as st
from torch.optim import Adam
from utils import *
import argparse
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import RepeatedKFold
from sklearn.model_selection import train_test_split
from model_baseline import *
from Transformer import TransformerPredictor
from dataloader import *
def _str2bool(v):
return v.lower() in ("yes", "y", "true", "t", "1")
def int_or_float(x):
try:
return int(x)
except ValueError:
return float(x)
parser = argparse.ArgumentParser(description='scRNA diagnosis')
parser.add_argument('--seed', type=int, default=240)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--learning_rate', type=float, default=3e-3)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument("--task", type=str, default="severity")
parser.add_argument('--emb_dim', type=int, default=128) # embedding dim
parser.add_argument('--h_dim', type=int, default=128) # hidden dim of the model
parser.add_argument('--dropout', type=float, default=0.3) # dropout
parser.add_argument('--layers', type=int, default=1)
parser.add_argument('--heads', type=int, default=8)
parser.add_argument("--train_sample_cells", type=int, default=500,
help="number of cells in one sample in train dataset")
parser.add_argument("--test_sample_cells", type=int, default=500,
help="number of cells in one sample in test dataset")
parser.add_argument("--train_num_sample", type=int, default=20,
help="number of sampled data points in train dataset")
parser.add_argument("--test_num_sample", type=int, default=100,
help="number of sampled data points in test dataset")
parser.add_argument('--model', type=str, default='Transformer')
parser.add_argument('--dataset', type=str, default=None)
parser.add_argument('--inter_only', type=_str2bool, default=False)
parser.add_argument('--same_pheno', type=int, default=0)
parser.add_argument('--augment_num', type=int, default=0)
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--repeat', type=int, default=3)
parser.add_argument('--all', type=int, default=1)
parser.add_argument('--min_size', type=int, default=6000)
parser.add_argument('--n_splits', type=int, default=5)
parser.add_argument('--pca', type=_str2bool, default=True)
parser.add_argument('--mix_type', type=int, default=1)
parser.add_argument('--norm_first', type=_str2bool, default=False)
parser.add_argument('--warmup', type=_str2bool, default=False)
parser.add_argument('--top_k', type=int, default=1)
args = parser.parse_args()
# print("# of GPUs is", torch.cuda.device_count())
print(args)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
patient_summary = {}
stats = {}
stats_id = {}
if args.task == 'haniffa' or args.task == 'combat':
label_dict = {0: 'Non Covid', 1: 'Covid'}
elif args.task == 'severity':
label_dict = {0: 'mild', 1: 'severe'}
elif args.task == 'stage':
label_dict = {0: 'convalescence', 1: 'progression'}
def train(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, data_augmented, data):
dataset_1 = MyDataset(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, fold='train')
dataset_2 = MyDataset(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, fold='test')
dataset_3 = MyDataset(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, fold='val')
train_loader = torch.utils.data.DataLoader(dataset_1, batch_size=args.batch_size, shuffle=True,
collate_fn=dataset_1.collate)
test_loader = torch.utils.data.DataLoader(dataset_2, batch_size=1, shuffle=False, collate_fn=dataset_2.collate)
valid_loader = torch.utils.data.DataLoader(dataset_3, batch_size=1, shuffle=False, collate_fn=dataset_3.collate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = data_augmented[0].shape[-1]
output_class = 1
if args.model == 'Transformer':
model = TransformerPredictor(input_dim=input_dim, model_dim=args.emb_dim, num_classes=output_class,
num_heads=args.heads, num_layers=args.layers, dropout=args.dropout,
input_dropout=0, pca=args.pca, norm_first=args.norm_first)
elif args.model == 'feedforward':
model = FeedForward(input_dim=input_dim, h_dim=args.emb_dim, cl=output_class, dropout=args.dropout)
elif args.model == 'linear':
model = Linear_Classfier(input_dim=input_dim, cl=output_class)
elif args.model == 'scfeed':
model = scFeedForward(input_dim=input_dim, cl=output_class, model_dim=args.emb_dim, dropout=args.dropout, pca=args.pca)
model = nn.DataParallel(model)
model.to(device)
best_model = model
print(device)
################################################################
# training and evaluation
################################################################
optimizer = Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
if args.warmup:
scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.epochs // 10,
num_training_steps=args.epochs)
else:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5, last_epoch=-1)
sigmoid = torch.nn.Sigmoid().to(device)
max_acc, max_epoch, max_auc, max_loss, max_valid_acc, max_valid_auc = 0, 0, 0, 0, 0, 0
test_accs, valid_aucs, train_losses, valid_losses, train_accs, test_aucs = [], [0.], [], [], [], []
best_valid_loss = float("inf")
wrongs = []
trigger_times = 0
patience = 2
for ep in (range(1, args.epochs + 1)):
model.train()
train_loss = []
for batch in (train_loader):
x_ = torch.from_numpy(data_augmented[batch[0]]).float().to(device)
y_ = batch[1].to(device)
mask_ = batch[3].to(device)
optimizer.zero_grad()
out = model(x_, mask_)
loss = nn.BCELoss()(sigmoid(out), y_)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
scheduler.step()
train_loss = sum(train_loss) / len(train_loss)
train_losses.append(train_loss)
if ep % 1 == 0:
valid_loss = []
model.eval()
pred = []
true = []
with torch.no_grad():
for batch in (valid_loader):
x_ = torch.from_numpy(data[batch[0]]).float().to(device).squeeze(0)
y_ = batch[1].int().to(device)
out = model(x_)
out = sigmoid(out)
loss = nn.BCELoss()(out, y_ * torch.ones(out.shape).to(device))
valid_loss.append(loss.item())
out = out.detach().cpu().numpy()
# majority voting
f = lambda x: 1 if x > 0.5 else 0
func = np.vectorize(f)
out = np.argmax(np.bincount(func(out).reshape(-1))).reshape(-1)
pred.append(out)
y_ = y_.detach().cpu().numpy()
true.append(y_)
# pred = np.concatenate(pred)
# true = np.concatenate(true)
valid_loss = sum(valid_loss) / len(valid_loss)
valid_losses.append(valid_loss)
if (valid_loss < best_valid_loss):
best_model = copy.deepcopy(model)
max_epoch = ep
best_valid_loss = valid_loss
max_loss = train_loss
print("Epoch %d, Train Loss %f, Valid_loss %f" % (ep, train_loss, valid_loss))
# Early stop
if (ep > args.epochs - 50) and ep > 1 and (valid_loss > valid_losses[-2]):
trigger_times += 1
if trigger_times >= patience:
break
else:
trigger_times = 0
best_model.eval()
pred = []
test_id = []
true = []
wrong = []
prob = []
with torch.no_grad():
for batch in (test_loader):
x_ = torch.from_numpy(data[batch[0]]).float().to(device).squeeze(0)
y_ = batch[1].int().numpy()
id_ = batch[2][0]
out = best_model(x_)
out = sigmoid(out)
out = out.detach().cpu().numpy().reshape(-1)
# For attention analysis:
# if args.model == 'Transformer':
# attens = best_model.module.get_attention_maps(x_)[-1]
# for iter in range(len(attens)):
# topK = np.bincount(attens[iter].argsort(-1)[:, :, -args.top_k:].
# cpu().detach().numpy().reshape(-1)).argsort()[-20:][::-1] # 20 is a
# for idd in id_[iter][topK]:
# stats[cell_type_large[idd]] = stats.get(cell_type_large[idd], 0) + 1
# stats_id[idd] = stats_id.get(idd, 0) + 1
y_ = y_[0][0]
true.append(y_)
if args.model != 'Transformer':
prob.append(out[0])
else:
prob.append(out.mean())
# majority voting
f = lambda x: 1 if x > 0.5 else 0
func = np.vectorize(f)
out = np.argmax(np.bincount(func(out).reshape(-1))).reshape(-1)[0]
pred.append(out)
test_id.append(patient_id[batch[2][0][0][0]])
if out != y_:
wrong.append(patient_id[batch[2][0][0][0]])
if len(wrongs) == 0:
wrongs = set(wrong)
else:
wrongs = wrongs.intersection(set(wrong))
test_auc = metrics.roc_auc_score(true, prob)
test_acc = accuracy_score(true, pred)
for idx in range(len(pred)):
print(f"{test_id[idx]} -- true: {label_dict[true[idx]]} -- pred: {label_dict[pred[idx]]}")
test_accs.append(test_acc)
cm = confusion_matrix(true, pred).ravel()
recall = cm[3] / (cm[3] + cm[2])
precision = cm[3] / (cm[3] + cm[1])
if (cm[3] + cm[1]) == 0:
precision = 0
print("Best performance: Epoch %d, Loss %f, Test ACC %f, Test AUC %f, Test Recall %f, Test Precision %f" % (
max_epoch, max_loss, test_acc, test_auc, recall, precision))
print("Confusion Matrix: " + str(cm))
for w in wrongs:
v = patient_summary.get(w, 0)
patient_summary[w] = v + 1
return test_auc, test_acc, cm, recall, precision
if args.model != 'Transformer':
args.repeat = 60
if args.task != 'custom':
p_idx, labels_, cell_type, patient_id, data, cell_type_large = Covid_data(args)
else:
p_idx, labels_, cell_type, patient_id, data, cell_type_large = Custom_data(args)
rkf = RepeatedKFold(n_splits=abs(args.n_splits), n_repeats=args.repeat * 100, random_state=args.seed)
num = np.arange(len(p_idx))
accuracy, aucs, cms, recalls, precisions = [], [], [], [], []
iter_count = 0
for train_index, test_index in rkf.split(num):
if args.n_splits < 0:
temp_idx = train_index
train_index = test_index
test_index = temp_idx
label_stat = []
for idx in train_index:
label_stat.append(labels_[p_idx[idx][0]])
unique, cts = np.unique(label_stat, return_counts=True)
if len(unique) < 2 or (1 in cts):
continue
# print(dict(zip(unique, cts)))
kk = 0
while True:
train_index_, valid_index, ty, vy = train_test_split(train_index, label_stat, test_size=0.33,
random_state=args.seed + kk)
if len(set(ty)) == 2 and len(set(vy)) == 2:
break
kk += 1
train_index = train_index_
len_valid = len(valid_index)
_index = np.concatenate([valid_index, test_index])
train_ids = []
for i in train_index:
train_ids.append(patient_id[p_idx[i][0]])
# print(train_ids)
x_train = []
x_test = []
x_valid = []
y_train = []
y_valid = []
y_test = []
id_train = []
id_test = []
id_valid = []
data_augmented, train_p_idx, labels_augmented, cell_type_augmented = mixups(args, data,
[p_idx[idx] for idx in train_index],
labels_,
cell_type)
individual_train, individual_test = sampling(args, train_p_idx, [p_idx[idx] for idx in _index], labels_,
labels_augmented, cell_type_augmented)
for t in individual_train:
id, label = [id_l[0] for id_l in t], [id_l[1] for id_l in t]
x_train += [ii for ii in id]
y_train += (label)
id_train += (id)
temp_idx = np.arange(len(_index))
for t_idx in temp_idx[len_valid:]:
id, label = [id_l[0] for id_l in individual_test[t_idx]], [id_l[1] for id_l in individual_test[t_idx]]
x_test.append([ii for ii in id])
y_test.append(label[0])
id_test.append(id)
for t_idx in temp_idx[:len_valid]:
id, label = [id_l[0] for id_l in individual_test[t_idx]], [id_l[1] for id_l in individual_test[t_idx]]
x_valid.append([ii for ii in id])
y_valid.append(label[0])
id_valid.append(id)
x_train, x_valid, x_test, y_train, y_valid, y_test = x_train, x_valid, x_test, np.array(y_train).reshape([-1, 1]), \
np.array(y_valid).reshape([-1, 1]), np.array(y_test).reshape(
[-1, 1])
auc, acc, cm, recall, precision = train(x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test,
data_augmented, data)
aucs.append(auc)
accuracy.append(acc)
cms.append(cm)
recalls.append(recall)
precisions.append(precision)
iter_count += 1
if iter_count == abs(args.n_splits) * args.repeat:
break
del data_augmented
print("="*33)
print("=== Final Evaluation (average across all splits) ===")
print("="*33)
print("Best performance: Test ACC %f, Test AUC %f, Test Recall %f, Test Precision %f" % (np.average(accuracy), np.average(aucs), np.average(recalls), np.average(precisions)))
####################################
######## Only for repeat > 1 #######
####################################
# accuracy = np.array(accuracy).reshape([-1, args.repeat]).mean(0)
# aucs = np.array(aucs).reshape([-1, args.repeat]).mean(0)
# recalls = np.array(recalls).reshape([-1, args.repeat]).mean(0)
# precisions = np.array(precisions).reshape([-1, args.repeat]).mean(0)
# ci_1 = st.t.interval(alpha=0.95, df=len(accuracy) - 1, loc=np.mean(accuracy), scale=st.sem(accuracy))[1] - np.mean(accuracy)
# ci_2 = st.t.interval(alpha=0.95, df=len(aucs) - 1, loc=np.mean(aucs), scale=st.sem(aucs))[1] - np.mean(aucs)
# ci_3 = st.t.interval(alpha=0.95, df=len(recalls) - 1, loc=np.mean(recalls), scale=st.sem(recalls))[1] - np.mean(recalls)
# ci_4 = st.t.interval(alpha=0.95, df=len(precisions) - 1, loc=np.mean(precisions), scale=st.sem(precisions))[1] - np.mean(precisions)
# print("ci: ACC ci %f, AUC ci %f, Recall ci %f, Precision ci %f" % (ci_1, ci_2, ci_3, ci_4))
# print(np.average(cms, 0))
# print(patient_summary)
# print(stats)
# print(stats_id)
utils.py
2025.03.20 - [AI 및 Data Analysis/논문] - [ScRAT] utils.py _ mixup()
[ScRAT] utils.py _ mixup()
🔍 mixup() 함수 상세 분석📌 1. mixup() 개요mixup() 함수는 두 개의 샘플(x, x_p)을 주어진 가중치(lam)에 따라 선형 조합하여 새로운 데이터를 생성하는 핵심 함수입니다.즉, mixups() 함수는 mixup을 수행
doraemin.tistory.com
import torch
import numpy as np
import copy
from torch.utils.data import Dataset
from tqdm import tqdm
class MyDataset(Dataset):
def __init__(self, x_train, x_valid, x_test, y_train, y_valid, y_test, id_train, id_test, fold='train'):
fold = fold.lower()
self.train = False
self.test = False
self.val = False
if fold == "train":
self.train = True
elif fold == "test":
self.test = True
elif fold == "val":
self.val = True
else:
raise RuntimeError("Not train-val-test")
self.x_train = x_train
self.x_valid = x_valid
self.x_test = x_test
self.y_train = y_train
self.y_valid = y_valid
self.y_test = y_test
self.id_train = id_train
self.id_test = id_test
def __len__(self):
if self.train:
return len(self.x_train)
elif self.test:
return len(self.x_test)
elif self.val:
return len(self.x_valid)
def __getitem__(self, index):
if self.train:
x, y, cell_id = torch.from_numpy(np.array(self.x_train[index])), \
torch.from_numpy(self.y_train[index]).float(), self.id_train[index]
elif self.test:
x, y, cell_id = torch.from_numpy(np.array(self.x_test[index])), \
torch.from_numpy(self.y_test[index]).float(), self.id_test[index]
elif self.val:
x, y, cell_id = torch.from_numpy(np.array(self.x_valid[index])), \
torch.from_numpy(self.y_valid[index]).float(), []
return x, y, cell_id
def collate(self, batches):
xs = torch.stack([batch[0] for batch in batches if len(batch) > 0])
mask = torch.stack([batch[0] == -1 for batch in batches if len(batch) > 0])
ys = torch.stack([batch[1] for batch in batches if len(batch) > 0])
ids = [batch[2] for batch in batches if len(batch) > 0]
return xs, torch.FloatTensor(ys), ids, mask
def add_noise(cells):
mean = 0
var = 1e-5
sigma = var ** 0.5
gauss = np.random.normal(mean, sigma, cells.shape)
noisy = cells + gauss
return noisy
def mixup(x, x_p, alpha=1.0, size=1, lam=None):
batch_size = min(x.shape[0], x_p.shape[0])
if lam == None:
lam = np.random.beta(alpha, alpha)
if size > 1:
lam = np.random.beta(alpha, alpha, size=size).reshape([-1, 1])
# x = np.random.permutation(x)
# x_p = np.random.permutation(x_p)
x_mix = lam * x[:batch_size] + (1 - lam) * x_p[:batch_size]
return x_mix, lam
def mixups(args, data, p_idx, labels_, cell_type):
max_num_cells = data.shape[0]
###################
# check the dataset
for i, pp in enumerate(p_idx):
if len(set(labels_[pp])) > 1:
print(i)
###################
all_ct = {}
for i, ct in enumerate(sorted(set(cell_type))):
all_ct[ct] = i
cell_type_ = np.array(cell_type.map(all_ct))
###################
for idx, i in enumerate(p_idx):
max_num_cells += (max(args.min_size - len(i), 0) + 100)
data_augmented = np.zeros([max_num_cells + (args.min_size + 100) * args.augment_num, data.shape[1]])
data_augmented[:data.shape[0]] = data
last = data.shape[0]
labels_augmented = copy.deepcopy(labels_)
cell_type_augmented = cell_type_
if args.same_pheno != 0:
p_idx_per_pheno = {}
for pp in p_idx:
y = labels_augmented[pp[0]]
if p_idx_per_pheno.get(y, -2) == -2:
p_idx_per_pheno[y] = [pp]
else:
p_idx_per_pheno[y].append(pp)
if args.inter_only and (args.augment_num > 0):
p_idx_augmented = []
else:
p_idx_augmented = copy.deepcopy(p_idx)
if args.augment_num > 0:
print("======= sample mixup ... ============")
for i in tqdm(range(args.augment_num)):
lam = np.random.beta(args.alpha, args.alpha)
if args.same_pheno == 1:
temp_label = np.random.randint(len(p_idx_per_pheno))
id_1, id_2 = np.random.randint(len(p_idx_per_pheno[temp_label]), size=2)
idx_1, idx_2 = p_idx_per_pheno[temp_label][id_1], p_idx_per_pheno[temp_label][id_2]
elif args.same_pheno == -1:
i_1, i_2 = np.random.choice(len(p_idx_per_pheno), 2, replace=False)
id_1 = np.random.randint(len(p_idx_per_pheno[i_1]))
id_2 = np.random.randint(len(p_idx_per_pheno[i_2]))
idx_1, idx_2 = p_idx_per_pheno[i_1][id_1], p_idx_per_pheno[i_2][id_2]
else:
id_1, id_2 = np.random.randint(len(p_idx), size=2)
idx_1, idx_2 = p_idx[id_1], p_idx[id_2]
diff = 0
set_union = sorted(set(cell_type_augmented[idx_1]).union(set(cell_type_augmented[idx_2])))
while diff < (args.min_size // 2):
for ct in set_union:
i_sub_1 = idx_1[cell_type_augmented[idx_1] == ct]
i_sub_2 = idx_2[cell_type_augmented[idx_2] == ct]
diff_sub = max(
int(args.min_size * (lam * len(i_sub_1) / len(idx_1) + (1 - lam) * len(i_sub_2) / len(idx_2))),
1)
diff += diff_sub
if len(i_sub_1) == 0:
sampled_idx_1 = [-1] * diff_sub
sampled_idx_2 = np.random.choice(i_sub_2, diff_sub)
x_mix, _ = mixup(data_augmented[sampled_idx_1], data_augmented[sampled_idx_2], alpha=args.alpha,
lam=lam)
x_mix = add_noise(x_mix)
elif len(i_sub_2) == 0:
sampled_idx_1 = np.random.choice(i_sub_1, diff_sub)
sampled_idx_2 = [-1] * diff_sub
x_mix, _ = mixup(data_augmented[sampled_idx_1], data_augmented[sampled_idx_2], alpha=args.alpha,
lam=lam)
x_mix = add_noise(x_mix)
else:
sampled_idx_1 = np.random.choice(i_sub_1, diff_sub)
sampled_idx_2 = np.random.choice(i_sub_2, diff_sub)
x_mix, _ = mixup(data_augmented[sampled_idx_1], data_augmented[sampled_idx_2], alpha=args.alpha,
lam=lam)
data_augmented[last:(last + x_mix.shape[0])] = x_mix
last += x_mix.shape[0]
cell_type_augmented = np.concatenate([cell_type_augmented, [ct] * diff_sub])
labels_augmented = np.concatenate(
[labels_augmented, [lam * labels_augmented[idx_1[0]] + (1 - lam) * labels_augmented[idx_2[0]]] * diff])
p_idx_augmented.append(np.arange(labels_augmented.shape[0] - diff, labels_augmented.shape[0]))
return data_augmented[:last+1], p_idx_augmented, labels_augmented, cell_type_augmented
def sampling(args, train_p_idx, test_p_idx, labels_, labels_augmented, cell_type_augmented):
if args.all == 0:
individual_train = []
individual_test = []
for idx in train_p_idx:
y = labels_augmented[idx[0]]
temp = []
if idx.shape[0] < args.train_sample_cells:
for _ in range(args.train_num_sample):
sample = np.zeros(args.train_sample_cells, dtype=int) - 1
sample[:idx.shape[0]] = idx
temp.append((sample, y))
else:
for _ in range(args.train_num_sample):
sample = np.random.choice(idx, args.train_sample_cells, replace=False)
temp.append((sample, y))
individual_train.append(temp)
for idx in test_p_idx:
y = labels_[idx[0]]
if idx.shape[0] < args.test_sample_cells:
sample_cells = idx.shape[0]
else:
sample_cells = args.test_sample_cells
temp = []
for _ in range(args.test_num_sample):
sample = np.random.choice(idx, sample_cells, replace=False)
temp.append((sample, y))
individual_test.append(temp)
else:
max_length = max([len(tt) for tt in train_p_idx])
individual_train = []
individual_test = []
for idx in train_p_idx:
y = labels_augmented[idx[0]]
temp = []
sample = np.zeros(max_length, dtype=int) - 1
sample[:idx.shape[0]] = idx
temp.append((sample, y))
individual_train.append(temp)
for idx in test_p_idx:
y = labels_[idx[0]]
temp = []
sample = idx
temp.append((sample, y))
individual_test.append(temp)
return individual_train, individual_test
def stratify(out, split=2):
f = lambda x: int(x * split)
func = np.vectorize(f)
majority = np.argmax(np.bincount(func(out))).reshape(-1)[0]
return out[func(out) == majority].mean()