-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval_us.py
More file actions
128 lines (97 loc) · 5.07 KB
/
eval_us.py
File metadata and controls
128 lines (97 loc) · 5.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from dataset.FetalAbdominal import FetalAbdominal_loading
import torch
from unet import UNet3D, UNet2D
import tqdm
import einops
import losses.dice as dice
from network.MedNCA import MedNCA
import nvflare.client as flare
from quantization.dequantization import dequantize_model
from utils.root_path import get_root_path, results_workingdir_path, results_single_path, results_tempdir_path
from network.model import get_model
from dataset.loader import get_data_loader
import tenseal as ts
import os
import numpy as np
class FetalAbdominalValidator:
def __init__(self, val_client_name):
data_type = "XRayMimic" #"fetalAbdominal", "XRay", XRayMimic
model_type = "mednca" #"unet", "mednca", "transunet_b16"
data_split_seed = 42
num_clients = 5
num_test = 0.5
path = get_root_path()
ldr = get_data_loader(data_type, val_client_name, num_clients, num_test, data_split_seed, batch_size=1, shuffle=False)
self.ldr = ldr
self.device = 'cuda:0'
self.quantize_mode = "none"
self.sparsification_mode = "none"
self.sparsification_parameter = 0.01
self.supervision_scenario = "full"
self.apply_homomorphic_encryption = True
if self.sparsification_mode == "none":
self.sparsification_parameter = 0.0
#self.model_type = "unet"
self.model = get_model(data_type, model_type, self.device)
self.setup_name = f"{data_type}_{num_clients}_{data_split_seed}_{num_test}_{self.supervision_scenario}"
self.exp_name = f"exp_{self.quantize_mode}_{self.sparsification_mode}_{self.sparsification_parameter}_{self.apply_homomorphic_encryption}_{model_type}"
self.workingdir_path = os.path.join(results_workingdir_path, self.setup_name, self.exp_name)
self.path_single = os.path.join(results_single_path, self.setup_name, f"exp_{model_type}")
@torch.no_grad()
def eval(self):
dices = []
#self.load_model(os.path.join(self.path_single, "client-4.pt"))
self.load_global_model(os.path.join(self.workingdir_path, "server/simulate_job/app_server/FL_global_model.pt"))
self.model.to(self.device)
self.model.eval()
for img, label in tqdm.tqdm(self.ldr):
img, label = img.to(self.device, torch.float32), label.to(self.device, torch.float32)
outputs = self.model(img)
if isinstance(self.model, MedNCA):
outputs = einops.rearrange(outputs, 'b h w c -> b c h w')
outputs = (outputs > 0).float()
dices.append(dice.DiceLoss.compute_dice(outputs, label))
dices = torch.tensor(dices)
return torch.mean(dices), torch.std(dices)
def load_model(self, model_path):
self.model.load_state_dict(torch.load(model_path, weights_only=True))
def load_global_model(self, model_path):
_dict = torch.load(model_path)
if self.apply_homomorphic_encryption:
# we need the private key to decrypt the model parameters
with open(results_tempdir_path, self.setup_name, self.exp_name, "private_context.seal", 'rb') as f:
public_seal_context = ts.context_from(f.read())
params = _dict['meta_props']['encrypted_weights']
decrypted_params = {}
for param_name, encrypted_values in params.items():
decrypted_values = ts.ckks_vector_from(public_seal_context, bytes(encrypted_values))
decrypted_values = decrypted_values.decrypt()
decrypted_params[param_name] = torch.from_numpy(np.array(decrypted_values)).reshape(self.model.state_dict()[param_name].shape)
self.model.load_state_dict(decrypted_params)
elif self.quantize_mode != 'none':
m = flare.FLModel(
params = _dict['model'],
meta = _dict['meta_props'])
m = dequantize_model(m, self.quantize_mode)
self.model.load_state_dict(m.params)
else:
self.model.load_state_dict(_dict['model'])
if __name__ == '__main__':
#out_file = None
#for i in range(1, 5):
# val_client_name = f'val-{i}'
# validator = FetalAbdominalValidator(val_client_name)
# dice_mean, dice_std = validator.eval()
# if out_file is None:
# out_file = os.path.join(validator.workingdir_path, "dice.txt")
# assert not os.path.exists(out_file), f"Output file {out_file} already exists"
# with open(out_file, 'a') as f:
# f.write(f'{val_client_name}\t {dice_mean}\t {dice_std}\n')
# print(f'{val_client_name} Dice score: {dice_mean} \u00b1 {dice_std}')
validator = FetalAbdominalValidator('val')
dice_mean, dice_std = validator.eval()
out_file = os.path.join(validator.workingdir_path, "dice.txt")
assert not os.path.exists(out_file), f"Output file {out_file} already exists"
with open(out_file, 'a') as f:
f.write(f'val\t {dice_mean}\t {dice_std}\n')
print(f'Dice score: {dice_mean} \u00b1 {dice_std}')