Segmentación de instancias de dientes 3D. En la oscuridad, pero no solo

Segmentación de dientes en 3D desde la recuperación de datos hasta el resultado final. Por poco.





Descargo de responsabilidad

Este artículo no es educativo en ningún sentido de este término y es puramente informativo. El autor del artículo no es responsable del tiempo dedicado a leerlo.





Sobre el Autor

Amable - todo el mundo, el nombre es Andrey (27). Intentaré ser breve. ¿Por qué programar? Por educación - licenciado en mecánico eléctrico, conozco la profesión. Trabajé durante 2 años como ingeniero energético en una empresa de perforación con bastante éxito, en lugar de una promoción, escribí una declaración: me quemé, pero resultó que no para mí. Me gusta crear, encontrar soluciones a problemas complejos, con un PC en un abrazo desde años conscientes. La elección es obvia. Al principio (hace seis meses), pensé seriamente en inscribirme en cursos de mí o similares. Leí las reseñas, hablé con los participantes y me di cuenta de que no había problemas para obtener información. Entonces encontré el sitio, Obtuve una base de Python allí y comencé mi viaje con ella (ahora estoy estudiando gradualmente todo lo relacionado con ML allí). Inmediatamente interesado en el aprendizaje automático, CV en particular. Se me ocurrió un problema y aquí estoy (para mí, esta es una excelente manera de aprender).





1. Introducción

Como resultado de varios intentos fallidos, tomé la decisión de usar 2 modelos livianos para obtener el resultado deseado. El primer segmento todos los dientes como categoría [1, 0] y el segundo los divide en las categorías [0, 8]. Pero comencemos en orden.





2. Búsqueda y preparación de datos

Después de haber pasado más de una noche buscando datos para el trabajo, llegué a la conclusión de que una mandíbula libre de buena calidad y formato (* .stl, * .nrrd, etc.) no funcionaría. Lo mejor que encontré fue una muestra de prueba de la cabeza de un paciente después de una cirugía de mandíbula en 3D Slicer .





Obviamente, no necesito toda la cabeza, así que recorté la fuente en el mismo programa al tamaño de 163 * 112 * 120px (en esta publicación {x * y * z = wdh} y 1px - 0.5mm), dejando solo el dientes y partes maxilofaciales asociadas.





, - . . , - "autothreshold" , , , , ( ).





- ¿Píxeles (cortes a la izquierda)?  - Recordando el tamaño de la imagen.
- ( )? -

12~14. , 4 . , .





La versión final de la máscara.  Suave 0.5.  (el suavizado no se usó en el entrenamiento)
. Smooth 0.5. ( )

, ( ) , . , , N- , random-crop .





import nrrd
import torch
import torchvision.transforms as tf


class DataBuilder:
    def __init__(self,
                 data_path,
                 list_of_categories,
                 num_of_chunks: int = 0,
                 augmentation_coeff: int = 0,
                 num_of_classes: int = 0,
                 normalise: bool = False,
                 fit: bool = True,
                 data_format: int = 0,
                 save_data: bool = False
                 ):
        self.data_path = data_path
        self.number_of_chunks = num_of_chunks
        self.augmentation_coeff = augmentation_coeff
        self.list_of_cats = list_of_categories
        self.num_of_cls = num_of_classes
        self.normalise = normalise
        self.fit = fit
        self.data_format = data_format
        self.save_data = save_data

    def forward(self):
        data = self.get_data()
        data = self.fit_data(data) if self.fit else data
        data = self.pre_normalize(data) if self.normalise else data
        data = self.data_augmentation(data, self.augmentation_coeff) if self.augmentation_coeff != 0 else data
        data = self.new_chunks(data, self.number_of_chunks) if self.number_of_chunks != 0 else data
        data = self.category_splitter(data, self.num_of_cls, self.list_of_cats) if self.num_of_cls != 0 else data
        torch.save(data, self.data_path[-14:]+'.pt') if self.save_data else None

        return torch.unsqueeze(data, 1)

    def get_data(self):
        if self.data_format == 0:
            return torch.from_numpy(nrrd.read(self.data_path)[0])
        elif self.data_format == 1:
            return torch.load(self.data_path).cpu()
        elif self.data_format == 2:
            return torch.unsqueeze(self.data_path, 0).cpu()
        else:
            print('Available types are: "nrrd", "tensor" or "self.tensor(w/o load)"')

    @staticmethod
    def fit_data(some_data):
        data = torch.movedim(some_data, (1, 0), (0, -1))
        data_add_x = torch.nn.ZeroPad2d((5, 0, 0, 0))
        data = data_add_x(data)
        data = torch.movedim(data, -1, 0)
        data_add_z = torch.nn.ZeroPad2d((0, 0, 8, 0))

        return data_add_z(data)

    @staticmethod
    def pre_normalize(some_data):
        min_d, max_d = torch.min(some_data), torch.max(some_data)

        return (some_data - min_d) / (max_d - min_d)

    @staticmethod
    def data_augmentation(some_data, aug_n):
        torch.manual_seed(17)
        tr_data = []
        for e in range(aug_n):
            transform = tf.RandomRotation(degrees=(20*e, 20*e))
            for image in some_data:
                image = torch.unsqueeze(image, 0)
                image = transform(image)
                tr_data.append(image)

        return tr_data

    def new_chunks(self, some_data, n_ch):
        data = torch.stack(some_data, 0) if self.augmentation_coeff != 0 else some_data
        data = torch.squeeze(data, 1)
        chunks = torch.chunk(data, n_ch, 0)

        return torch.stack(chunks)

    @staticmethod
    def category_splitter(some_data, alpha, list_of_categories):
        data, _ = torch.squeeze(some_data, 1).to(torch.int64), alpha
        for i in list_of_categories:
            data = torch.where(data < i, _, data)
            _ += 1

        return data - alpha

      
      



3D U-net. :





  • ( ).





  • 0 168*120*120 ( 163*112*120). * .





  • 0...1 ( ~-2000...16000).





  • N- .





  • ( 1, 1, 72, 120, 120).





  • 28 (. ):





    • 1-;





    • 9 (8+) 2-.





Dataloader
import torch.utils.data as tud


class ToothDataset(tud.Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self): return len(self.images)

    def __getitem__(self, index):
        if self.masks is not None:
            return self.images[index, :, :, :, :],\
                    self.masks[index, :, :, :, :]
        else:
            return self.images[index, :, :, :, :]


def get_loaders(images, masks,
                batch_size: int = 1,
                num_workers: int = 1,
                pin_memory: bool = True):

    train_ds = ToothDataset(images=images,
                            masks=masks)

    data_loader = tud.DataLoader(train_ds,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 pin_memory=pin_memory)

    return data_loader

      
      



:









Semantic





Instance





Predictions





Data





(27*, 1, 56*, 120,120)[0...1]





(27*, 1, 56*, 120,120) [0, 1]





(1, 1, 168, 120, 120)[0...1]





Masks





(27*, 1, 56*, 120,120)[0, 1]





(27*, 1, 56*, 120,120)[0, 8]





-





* , , - .





3.

- . U-Net. , .





U-Net 2D
2D U-Net

, . - Adam, Dice-loss(implement), / 4, [64, 128, 256, 512] (, , - ). 60-80 epochs . Transfer learning .





model.summary()
model = UNet(dim=2, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device)
print(summary(model, (1, 168, 120)))

"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 168, 120]             640
              ReLU-2         [-1, 64, 168, 120]               0
       BatchNorm2d-3         [-1, 64, 168, 120]             128
            Conv2d-4         [-1, 64, 168, 120]          36,928
              ReLU-5         [-1, 64, 168, 120]               0
       BatchNorm2d-6         [-1, 64, 168, 120]             128
         MaxPool2d-7           [-1, 64, 84, 60]               0
         DownBlock-8  [[-1, 64, 84, 60], [-1, 64, 168, 120]]  0
            Conv2d-9          [-1, 128, 84, 60]          73,856
             ReLU-10          [-1, 128, 84, 60]               0
      BatchNorm2d-11          [-1, 128, 84, 60]             256
           Conv2d-12          [-1, 128, 84, 60]         147,584
             ReLU-13          [-1, 128, 84, 60]               0
      BatchNorm2d-14          [-1, 128, 84, 60]             256
        MaxPool2d-15          [-1, 128, 42, 30]               0
        DownBlock-16  [[-1, 128, 42, 30], [-1, 128, 84, 60]]  0
           Conv2d-17          [-1, 256, 42, 30]         295,168
             ReLU-18          [-1, 256, 42, 30]               0
      BatchNorm2d-19          [-1, 256, 42, 30]             512
           Conv2d-20          [-1, 256, 42, 30]         590,080
             ReLU-21          [-1, 256, 42, 30]               0
      BatchNorm2d-22          [-1, 256, 42, 30]             512
        MaxPool2d-23          [-1, 256, 21, 15]               0
        DownBlock-24  [[-1, 256, 21, 15], [-1, 256, 42, 30]]  0
           Conv2d-25          [-1, 512, 21, 15]       1,180,160
             ReLU-26          [-1, 512, 21, 15]               0
      BatchNorm2d-27          [-1, 512, 21, 15]           1,024
           Conv2d-28          [-1, 512, 21, 15]       2,359,808
             ReLU-29          [-1, 512, 21, 15]               0
      BatchNorm2d-30          [-1, 512, 21, 15]           1,024
        DownBlock-31  [[-1, 512, 21, 15], [-1, 512, 21, 15]]  0
  ConvTranspose2d-32          [-1, 256, 42, 30]         524,544
             ReLU-33          [-1, 256, 42, 30]               0
      BatchNorm2d-34          [-1, 256, 42, 30]             512
      Concatenate-35          [-1, 512, 42, 30]               0
           Conv2d-36          [-1, 256, 42, 30]       1,179,904
             ReLU-37          [-1, 256, 42, 30]               0
      BatchNorm2d-38          [-1, 256, 42, 30]             512
           Conv2d-39          [-1, 256, 42, 30]         590,080
             ReLU-40          [-1, 256, 42, 30]               0
      BatchNorm2d-41          [-1, 256, 42, 30]             512
          UpBlock-42          [-1, 256, 42, 30]               0
  ConvTranspose2d-43          [-1, 128, 84, 60]         131,200
             ReLU-44          [-1, 128, 84, 60]               0
      BatchNorm2d-45          [-1, 128, 84, 60]             256
      Concatenate-46          [-1, 256, 84, 60]               0
           Conv2d-47          [-1, 128, 84, 60]         295,040
             ReLU-48          [-1, 128, 84, 60]               0
      BatchNorm2d-49          [-1, 128, 84, 60]             256
           Conv2d-50          [-1, 128, 84, 60]         147,584
             ReLU-51          [-1, 128, 84, 60]               0
      BatchNorm2d-52          [-1, 128, 84, 60]             256
          UpBlock-53          [-1, 128, 84, 60]               0
  ConvTranspose2d-54         [-1, 64, 168, 120]          32,832
             ReLU-55         [-1, 64, 168, 120]               0
      BatchNorm2d-56         [-1, 64, 168, 120]             128
      Concatenate-57        [-1, 128, 168, 120]               0
           Conv2d-58         [-1, 64, 168, 120]          73,792
             ReLU-59         [-1, 64, 168, 120]               0
      BatchNorm2d-60         [-1, 64, 168, 120]             128
           Conv2d-61         [-1, 64, 168, 120]          36,928
             ReLU-62         [-1, 64, 168, 120]               0
      BatchNorm2d-63         [-1, 64, 168, 120]             128
          UpBlock-64         [-1, 64, 168, 120]               0
           Conv2d-65          [-1, 1, 168, 120]              65
================================================================
Total params: 7,702,721
Trainable params: 7,702,721
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.08
Forward/backward pass size (MB): 7434.08
Params size (MB): 29.38
Estimated Total Size (MB): 7463.54
"""
      
      



Exp. No. 1 2D U-Net, imágenes de alimentación cuadro por cuadro, plano [x, z]
.№1 2D U-Net, , [x, z]

, - . , . numpy - *.stl 6. , :





De izquierda a derecha: 1. No visible [x, y].  2. Ligeramente mejor que [x, z].  3. Aún mejor [y, z]
: 1. [x, y]. 2. [x, z]. 3. [y, z]

100% , ? , .





, , , , , .





Exp. No. 2 Cascada de 2 U-Net 2D, alimentación de imágenes cuadro por cuadro, plano [y, z]
.№2 2- 2D U-Net, , [y, z]

, , :





Exp. No. 3 Cascada de 2 U-Net 2D, plano de alimentación de imágenes cuadro por cuadro [y, z] con un aumento del 50% en el tiempo de entrenamiento
.№3 2- 2D U-Net, [y, z] 50%

3D . , (24*, 120, 120). ? - (~22. ). (1063gtx) .





24*

. :





  • (1512, 120, 120) - 63;





  • batch size (24, 120, 120) - , ;





  • (24) / ( 24/2/2/2=3 3*2*2*2=24, / 2 / 1);





  • , . .summary()





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device)
print(summary(model, (1, 24, 120, 120)))

"""
  ----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1     [-1, 64, 24, 120, 120]             1,792
              ReLU-2     [-1, 64, 24, 120, 120]                 0
       BatchNorm3d-3     [-1, 64, 24, 120, 120]               128
            Conv3d-4     [-1, 64, 24, 120, 120]           110,656
              ReLU-5     [-1, 64, 24, 120, 120]                 0
       BatchNorm3d-6     [-1, 64, 24, 120, 120]               128
         MaxPool3d-7        [-1, 64, 12, 60, 60]                0
         DownBlock-8  [[-1, 64, 12, 60, 60], [-1, 64, 24, 120, 120]]               0
            Conv3d-9       [-1, 128, 12, 60, 60]          221,312
             ReLU-10       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-11       [-1, 128, 12, 60, 60]              256
           Conv3d-12       [-1, 128, 12, 60, 60]          442,496
             ReLU-13       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-14       [-1, 128, 12, 60, 60]              256
        MaxPool3d-15       [-1, 128, 6, 30, 30]                 0
        DownBlock-16  [[-1, 128, 6, 30, 30], [-1, 128, 12, 60, 60]]               0
           Conv3d-17       [-1, 256, 6, 30, 30]           884,992
             ReLU-18       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-19       [-1, 256, 6, 30, 30]               512
           Conv3d-20       [-1, 256, 6, 30, 30]         1,769,728
             ReLU-21       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-22       [-1, 256, 6, 30, 30]               512
        MaxPool3d-23       [-1, 256, 3, 15, 15]                 0
        DownBlock-24  [[-1, 256, 3, 15, 15], [-1, 256, 6, 30, 30]]               0
           Conv3d-25       [-1, 512, 3, 15, 15]         3,539,456
             ReLU-26       [-1, 512, 3, 15, 15]                 0
      BatchNorm3d-27       [-1, 512, 3, 15, 15]             1,024
           Conv3d-28       [-1, 512, 3, 15, 15]         7,078,400
             ReLU-29       [-1, 512, 3, 15, 15]                 0
      BatchNorm3d-30       [-1, 512, 3, 15, 15]             1,024
        DownBlock-31  [[-1, 512, 3, 15, 15], [-1, 512, 3, 15, 15]]               0
  ConvTranspose3d-32       [-1, 256, 6, 30, 30]         1,048,832
             ReLU-33       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-34       [-1, 256, 6, 30, 30]               512
      Concatenate-35       [-1, 512, 6, 30, 30]                 0
           Conv3d-36       [-1, 256, 6, 30, 30]         3,539,200
             ReLU-37       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-38       [-1, 256, 6, 30, 30]               512
           Conv3d-39       [-1, 256, 6, 30, 30]         1,769,728
             ReLU-40       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-41       [-1, 256, 6, 30, 30]               512
          UpBlock-42       [-1, 256, 6, 30, 30]                 0
  ConvTranspose3d-43       [-1, 128, 12, 60, 60]          262,272
             ReLU-44       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-45       [-1, 128, 12, 60, 60]              256
      Concatenate-46       [-1, 256, 12, 60, 60]                0
           Conv3d-47       [-1, 128, 12, 60, 60]          884,864
             ReLU-48       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-49       [-1, 128, 12, 60, 60]              256
           Conv3d-50       [-1, 128, 12, 60, 60]          442,496
             ReLU-51       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-52       [-1, 128, 12, 60, 60]              256
          UpBlock-53       [-1, 128, 12, 60, 60]                0
  ConvTranspose3d-54       [-1, 64, 24, 120, 120]          65,600
             ReLU-55       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-56       [-1, 64, 24, 120, 120]             128
      Concatenate-57      [-1, 128, 24, 120, 120]               0
           Conv3d-58       [-1, 64, 24, 120, 120]         221,248
             ReLU-59       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-60       [-1, 64, 24, 120, 120]             128
           Conv3d-61       [-1, 64, 24, 120, 120]         110,656
             ReLU-62       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-63       [-1, 64, 24, 120, 120]             128
          UpBlock-64       [-1, 64, 24, 120, 120]               0
           Conv3d-65        [-1, 1, 24, 120, 120]              65
================================================================
Total params: 22,400,321
Trainable params: 22,400,321
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.61
Forward/backward pass size (MB): 15974.12
Params size (MB): 85.45
Estimated Total Size (MB): 16060.18
----------------------------------------------------------------
"""
      
      



Exp. No. 4 3D U-Net, volumen de alimentación, plano [y, z], tiempo * 0.38
.№4 3D U-Net, , [y, z], *0,38

~60% (25 epochs) , .





Exp. No. 5 3D U-Net, volumen de alimentación, plano [y, z], 65 épocas ~ 1,5 horas
.№5 3D U-Net, , [y, z], 65 epochs ~ 1,5

. , (.№3) - :





Exp. No. 6 3D U-Net, volumen de alimentación, plano [x, z], 105 épocas ~ 2,1 horas
.№6 3D U-Net, , [x, z], 105 epochs ~ 2,1

"" . ~400 ( ~22) [18, 32, 64, 128] / 3. RSMProp. (1, 1, 72*, 120, 120). ?





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=3, start_filters=18).to(device)
print(summary(model, (1, 1, 72, 120, 120)))

"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1     [-1, 18, 72, 120, 120]             504
              ReLU-2     [-1, 18, 72, 120, 120]               0
       BatchNorm3d-3     [-1, 18, 72, 120, 120]              36
            Conv3d-4     [-1, 18, 72, 120, 120]           8,766
              ReLU-5     [-1, 18, 72, 120, 120]               0
       BatchNorm3d-6     [-1, 18, 72, 120, 120]              36
         MaxPool3d-7       [-1, 18, 36, 60, 60]               0
         DownBlock-8  [[-1, 18, 36, 60, 60], [-1, 18, 24, 120, 120]]               0
            Conv3d-9       [-1, 36, 36, 60, 60]          17,532
             ReLU-10       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-11       [-1, 36, 36, 60, 60]              72
           Conv3d-12       [-1, 36, 36, 60, 60]          35,028
             ReLU-13       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-14       [-1, 36, 36, 60, 60]              72
        MaxPool3d-15        [-1, 36, 18, 30, 30]              0
        DownBlock-16  [[-1, 36, 18, 30, 30], [-1, 36, 36, 60, 60]]               0
           Conv3d-17        [-1, 72, 18, 30, 30]         70,056
             ReLU-18        [-1, 72, 18, 30, 30]              0
      BatchNorm3d-19        [-1, 72, 18, 30, 30]            144
           Conv3d-20        [-1, 72, 18, 30, 30]        140,040
             ReLU-21        [-1, 72, 18, 30, 30]              0
      BatchNorm3d-22        [-1, 72, 18, 30, 30]            144
        DownBlock-23  [[-1, 72, 18, 30, 30], [-1, 72, 18, 30, 30]]               0
  ConvTranspose3d-24       [-1, 36, 36, 60, 60]          20,772
             ReLU-25       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-26       [-1, 36, 36, 60, 60]              72
      Concatenate-27       [-1, 72, 36, 60, 60]               0
           Conv3d-28       [-1, 36, 36, 60, 60]          70,020
             ReLU-29       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-30       [-1, 36, 36, 60, 60]              72
           Conv3d-31       [-1, 36, 36, 60, 60]          35,028
             ReLU-32       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-33       [-1, 36, 36, 60, 60]              72
          UpBlock-34       [-1, 36, 36, 60, 60]               0
  ConvTranspose3d-35     [-1, 18, 72, 120, 120]           5,202
             ReLU-36     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-37     [-1, 18, 72, 120, 120]              36
      Concatenate-38     [-1, 36, 72, 120, 120]               0
           Conv3d-39     [-1, 18, 72, 120, 120]          17,514
             ReLU-40     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-41     [-1, 18, 72, 120, 120]              36
           Conv3d-42     [-1, 18, 72, 120, 120]           8,766
             ReLU-43     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-44     [-1, 18, 72, 120, 120]              36
          UpBlock-45     [-1, 18, 72, 120, 120]               0
           Conv3d-46      [-1, 1, 72, 120, 120]              19
================================================================
Total params: 430,075
Trainable params: 430,075
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.32
Forward/backward pass size (MB): 5744.38
Params size (MB): 1.64
Estimated Total Size (MB): 5747.34
----------------------------------------------------------------
"""
      
      



72*

, (168, 120, 120), (72, 120, 120). , . , 2 , . 9 (1512, 120, 120) .. 9 , 21(batch size) (72, 120, 120). 72 , 24*().





Exp. No. 7 3D U-Net, alimentación de volumen, plano [x, z], máscara (izquierda) y segmentación lista (derecha), parámetros de red optimizados, tiempo de entrenamiento (65 épocas) ~ 14 min.
.№7 3D U-Net, , [x, z], () (), , (65 epochs) ~ 14.

, ( "" ). , . semantic segmentation , .





3D ( ) (1512, 120, 120) --> 21*(1, 72, 120, 120), ~*(30, 30, 30) ( ). 2 : 3- , ( ); , .





, 1 epochs "" ~13, 2 (>80). 1 epochs. , .





. 8 + . loss function .





training loop
import torch
from tqdm import tqdm
from _loss_f import LossFunction


class TrainFunction:
    def __init__(self,
                 data_loader,
                 device_for_training,
                 model_name,
                 model_name_pretrained,
                 model,
                 optimizer,
                 scale,
                 learning_rate: int = 1e-2,
                 num_epochs: int = 1,
                 transfer_learning: bool = False,
                 binary_loss_f: bool = True
                 ):
        self.data_loader = data_loader
        self.device = device_for_training
        self.model_name_pretrained = model_name_pretrained
        self.semantic_binary = binary_loss_f
        self.num_epochs = num_epochs
        self.model_name = model_name
        self.transfer = transfer_learning
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.model = model
        self.scale = scale

    def forward(self):
        print('Running on the:', torch.cuda.get_device_name(self.device))
        self.model.load_state_dict(torch.load(self.model_name_pretrained)) if self.transfer else None
        optimizer = self.optimizer(self.model.parameters(), lr=self.learning_rate)
        for epoch in range(self.num_epochs):
            self.train_loop(self.data_loader, self.model, optimizer, self.scale, epoch)
            torch.save(self.model.state_dict(), 'models/' + self.model_name+str(epoch+1)
                       + '_epoch.pth') if (epoch + 1) % 10 == 0 else None

    def train_loop(self, loader, model, optimizer, scales, i):
        loop, epoch_loss = tqdm(loader), 0
        loop.set_description('Epoch %i' % (self.num_epochs - i))
        for batch_idx, (data, targets) in enumerate(loop):
            data, targets = data.to(device=self.device, dtype=torch.float), \
                            targets.to(device=self.device, dtype=torch.long)
            optimizer.zero_grad()
            * *
            with torch.cuda.amp.autocast():
                predictions = model(data)
                loss = LossFunction(predictions, targets,
                                    device_for_training=self.device,
                                    semantic_binary=self.semantic_binary
                                    ).forward()
            scales.scale(loss).backward()
            scales.step(optimizer)
            scales.update()
            epoch_loss += (1 - loss.item())*100
            loop.set_postfix(loss=loss.item())
        print('Epoch-acc', round(epoch_loss / (batch_idx+1), 2))

      
      



4.

Dice-loss , '' [0, 1]. , ( [0, 1]), ( "" "" ) Dice-loss , .





categorical_dice_loss
import torch


class LossFunction:
    def __init__(self,
                 prediction,
                 target,
                 device_for_training,
                 semantic_binary: bool = True,
                 ):
        self.prediction = prediction
        self.device = device_for_training
        self.target = target
        self.semantic_binary = semantic_binary

    def forward(self):
        if self.semantic_binary:
            return self.dice_loss(self.prediction, self.target)
        return self.categorical_dice_loss(self.prediction, self.target)

    @staticmethod
    def dice_loss(predictions, targets, alpha=1e-5):
        intersection = 2. * (predictions * targets).sum()
        denomination = (torch.square(predictions) + torch.square(targets)).sum()
        dice_loss = 1 - torch.mean((intersection + alpha) / (denomination + alpha))

        return dice_loss

    def categorical_dice_loss(self, prediction, target):
        pr, tr = self.prepare_for_multiclass_loss_f(prediction, target)
        target_categories, losses = torch.unique(tr).tolist(), 0
        for num_category in target_categories:
            categorical_target = torch.where(tr == num_category, 1, 0)
            categorical_prediction = pr[num_category][:][:][:]
            losses += self.dice_loss(categorical_prediction, categorical_target).to(self.device)

        return losses / len(target_categories)

    @staticmethod
    def prepare_for_multiclass_loss_f(prediction, target):
        prediction_prepared = torch.squeeze(prediction, 0)
        target_prepared = torch.squeeze(target, 0)
        target_prepared = torch.squeeze(target_prepared, 0)

        return prediction_prepared, target_prepared

      
      



, "categorical_dice_loss":





  • ( );





  • , batch ;





  • "" "" , [0, 1] Dice-loss;





  • , batct. .





, , one-hot , ( ), , . , , , . (5).





5.

".. ". *.nrrd .





import nrrd
#   numpy
read = nrrd.read(data_path) 
data, meta_data = read[0], read[1]

print(data.shape, np.max(data), np.min(data), meta_data, sep="\n")

(163, 112, 120)
14982
-2254 
 OrderedDict([('type', 'short'), ('dimension', 3), ('space', 'left-posterior-superior'), ('sizes', array([163, 112, 120])), ('space directions', array([[-0.5,  0. ,  0. ],
       [ 0. , -0.5,  0. ],
       [ 0. ,  0. ,  0.5]])), ('kinds', ['domain', 'domain', 'domain']), ('endian', 'little'), ('encoding', 'gzip'), ('space origin', array([131.57200623,  80.7661972 ,  32.29940033]))])
      
      



- , ? , , , .





, 8 12 . ( ) - ( 3- ) . , , "" -1 , ..





Parece tan loco como suena
,

- , . , . Skimage Stl.





from skimage.measure import marching_cubes
import nrrd
import numpy as np
from stl import mesh

path = 'some_path.nrrd'
data = nrrd.read(path)[0]


def three_d_creator(some_data):
    vertices, faces, volume, _ = marching_cubes(some_data)
    cube = mesh.Mesh(np.full(faces.shape[0], volume.shape[0], dtype=mesh.Mesh.dtype))
    for i, f in enumerate(faces):
        for j in range(3):
            cube.vectors[i][j] = vertices[f[j]]
    cube.save('name.stl')

    return cube


stl = three_d_creator(datas)
      
      



, "" . , , Win 10 3D Builder - . "" 3D . " " .





v3do. , , .





npy stl
from vedo import Volume, show, write

prediction = 'some_data_path.npy'

def show_save(data, save=False):
    data_multiclass = Volume(data, c='Set2', alpha=(0.1, 1), alphaUnit=0.87, mode=1)
    data_multiclass.addScalarBar3D(nlabels=9)
    show([(data_multiclass, "Multiclass teeth segmentation prediction")], bg='black', N=1, axes=1).close()
    write(data_multiclass.isosurface(), 'some_name_.stl') if save else None
    
show_save(prediction, save=True)
      
      



.





. :





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=9, n_blocks=3, start_filters=9).to(device)
print(summary(model, (1, 168*, 120, 120)))
    
"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1      [-1, 9, 168, 120, 120]            252
              ReLU-2      [-1, 9, 168, 120, 120]              0
       BatchNorm3d-3      [-1, 9, 168, 120, 120]             18
            Conv3d-4      [-1, 9, 168, 120, 120]          2,196
              ReLU-5      [-1, 9, 168, 120, 120]              0
       BatchNorm3d-6      [-1, 9, 168, 120, 120]             18
         MaxPool3d-7        [-1, 9, 84, 60, 60]               0
         DownBlock-8  [[-1, 9, 84, 60, 60], [-1, 9, 168, 120, 120]]               0
            Conv3d-9       [-1, 18, 84, 60, 60]           4,392
             ReLU-10       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-11       [-1, 18, 84, 60, 60]              36
           Conv3d-12       [-1, 18, 84, 60, 60]           8,766
             ReLU-13       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-14       [-1, 18, 84, 60, 60]              36
        MaxPool3d-15       [-1, 18, 42, 30, 30]               0
        DownBlock-16  [[-1, 18, 18, 42, 30], [-1, 18, 84, 60, 60]]               0
           Conv3d-17       [-1, 36, 42, 30, 30]          17,532
             ReLU-18       [-1, 36, 42, 30, 30]               0
      BatchNorm3d-19       [-1, 36, 42, 30, 30]              72
           Conv3d-20       [-1, 36, 42, 30, 30]          35,028
             ReLU-21       [-1, 36, 42, 30, 30]               0
      BatchNorm3d-22       [-1, 36, 42, 30, 30]              72
        DownBlock-23  [[-1, 36, 42, 30, 30], [-1, 36, 42, 30, 30]]               0
  ConvTranspose3d-24       [-1, 18, 84, 60, 60]           5,202
             ReLU-25       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-26       [-1, 18, 84, 60, 60]              36
      Concatenate-27       [-1, 36, 84, 60, 60]               0
           Conv3d-28       [-1, 18, 84, 60, 60]          17,514
             ReLU-29       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-30       [-1, 18, 84, 60, 60]              36
           Conv3d-31       [-1, 18, 84, 60, 60]           8,766
             ReLU-32       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-33       [-1, 18, 84, 60, 60]              36
          UpBlock-34       [-1, 18, 84, 60, 60]               0
  ConvTranspose3d-35      [-1, 9, 168, 120, 120]          1,305
             ReLU-36      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-37      [-1, 9, 168, 120, 120]             18
      Concatenate-38     [-1, 18, 168, 120, 120]              0
           Conv3d-39      [-1, 9, 168, 120, 120]          4,383
             ReLU-40      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-41      [-1, 9, 168, 120, 120]             18
           Conv3d-42      [-1, 9, 168, 120, 120]          2,196
             ReLU-43      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-44      [-1, 9, 168, 120, 120]             18
          UpBlock-45      [-1, 9, 168, 120, 120]              0
           Conv3d-46      [-1, 9, 168, 120, 120]             90
================================================================
Total params: 108,036
Trainable params: 108,036
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.96
Forward/backward pass size (MB): 12170.30
Params size (MB): 0.41
Estimated Total Size (MB): 12174.66
----------------------------------------------------------------
    """
      
      



* ([9, 18, 36, 72]), - 9*(168, 120, 120)





Exp. No. 8 Segmentación intermedia en 8 categorías
.№8 8

, , . ? - "" 8- , . , 12 (GPU) .





Exp. No. 9 Segmentación completa
.№9

6. After words

, , - . . , , 2 , . , ? , , 28 , , "" / ? U-net GCNN Pytorch - Pytorch3D? , , bounding box( 1 ). , , .





()
" "
Un ejemplo de un gráfico no dirigido para 28 categorías con "delimitadores"
28 ""

Un agradecimiento especial a mi esposa, Alena, por su apoyo especial durante esta "inmersión en la oscuridad".





Gracias a todos por su atención. Son bienvenidas las críticas constructivas y las sugerencias, tanto correcciones como nuevos proyectos.








All Articles