Неинтересная цель этой статьи — показать, как можно смержить две свертки пайторча в одну. Если интересна лишь реализация — прошу в конец статьи.

А интересная цель — потыкать непосредственно в веса моделей на примере объединения свёрток. Узнать, как они хранятся и используются конкретно в pytorch, не вдаваясь в хардкорные интересности по типу im2col. Но перед тем, как показывать реализацию, давайте немного вспомним, с чем работаем.

Начнём с терминологии. Нейронная сеть, или граф вычислений — это набор операций (слоёв), которые применяются над входным объектом или над выходами других операций. Нас интересует зрение и свёртки, поэтому входным объектом будет картинка: тензор размера [кол-во каналов X высота X ширина], где количество каналов чаще всего 3 (RGB). Промежуточные выходы будут называться фичемапой или картой признаков. Фичемапа — это в каком-то смысле та же картинка, только количество каналов произвольное, а каждая ячейка тензора называется признаком, а не пикселем. Ну а чтобы уметь за один проход прогонять сразу несколько картинок, у всех этих тензоров появляется дополнительное измерение размером в кол-во обьектов в батче (одном прогоне).

Теперь перейдем к кросс-корреляции. Часто ее путают со сверткой, но для наших целей эти две операции надо различать. Что же такое кросс-корреляция в контексте нейронных сетей? Перво-наперво надо сказать, что кросс-корреляция имеет ядро. Ядро — это тензор с числами размера [кол-во каналов X высота ядра X ширина ядра]. Ну а операция кросс-корреляции — это скольжение ядром по входной фичемапе как на гифке. Каждый раз, когда мы прикладываем ядро к кусочку входа, мы перемножаем соответствующие веса ядра с признаками фичемапы и складываем, получая новый признак в выходном канале.

my alt text
wikipedia.org

Самое время рассказать, что такое свёртка. Свёртка — это почти кросс-корреляция, только ядро отражено относительно горизонтальной и вертикальной оси. Эту разницу часто опускают, потому что для понимания работы сеток это не важно, но для нас эта разница важна, так как мы будем преобразовывать непосредственно веса.

Вы, наверное, заметили, что описанная выше свёртка из множества каналов создает 1. В нейронках же свертки создают из множества каналов другое множество. Достигается это тем, что в слое свертки хранятся сразу несколько описанных выше ядер, каждое из которых порождает один канал, а затем они конкатенируются. Итоговая размерность ядра такой свёртки получается: [количество каналов на выход X количество каналов на вход X высота ядра X ширина ядра]

Следующая особенность свертки в pytorch и во многих других фреймворках — это биас, или смещение. Биас это буквально набор весов по количеству каналов, которые прибавляются к выходу свертки. То есть к каждому признаку канала прибавляется константа.

Самое время вспомнить нашу задачу. Задача — объединить две свёртки в одну. Почему это вообще возможно? Краткий ответ такой — что свёртка, что прибавление биаса — это линейные операции. А комбинация линейных операций — линейная операция. Более подробно я предлагаю разбираться ниже, смотря непосредственно в код.

Начнем с самого простого случая. Одна свёртка произвольного размера, вторая 1х1. Ни у одной нет биаса. Обе преобразуют один канал в один канал. Простым этот случай делает свертка 1х1. Фактически, она означает домножение фичемапы на константу. А значит, можно просто домножить веса первой свёртки на эту самую константу.

import torch
import numpy as np
conv1 = torch.nn.Conv2d(1, 1, (3, 3), bias=False) # [каналы на вход X на выход X размер ядра]
conv2 = torch.nn.Conv2d(1, 1, (1, 1), bias=False) # свёртка 1х1. Если посмотреть на ее веса, там будет ровно одно число.

new_conv = torch.nn.Conv2d(1, 1, 3, bias=False) # Вот эта свёртка объеденит две.
new_conv.weight.data = conv1.weight.data * conv2.weight.data 

# проверим что все сработало
x = torch.randn([1, 1, 6, 6]) # [Размер батча X кол-во каналов X размер по вертикали X размер по горизонтали]
out = conv2(conv1(x))
new_out = new_conv(x)
assert (torch.abs(out - new_out) < 1e-6).min()

Усложним. Пусть теперь первая свертка преобразует произвольное количество каналов в другое произвольное количество. В этом случае наша свёртка 1х1 будет взвешенной суммой каналов промежуточной фичемапы. А это значит, что можно просто взвешенно сложить веса, порождающие эти самые каналы.

conv1 = torch.nn.Conv2d(2, 3, (3, 3), bias=False) # [каналы на вход X на выход X размер ядра]
conv2 = torch.nn.Conv2d(3, 1, (1, 1), bias=False) # свёртка 1х1. Если посмотреть на ее веса, то уже будет 3 числа.

new_conv = torch.nn.Conv2d(2, 1, 3, bias=False) # В итоге мы хотим из двух каналов получить 1.

# Веса свёртки — тензор размера: [каналы на выход X каналы на вход X размер ядра1 X размер ядра2]
# Чтобы домножить по первому измерению веса, отвечающие за создание каждого промежуточного канала с их весом,
# нам придется пермьютнуть размерности второй второй свертки.
# Затем суммируем взвешенные веса и дорисовываем измерение, дабы успешно подменить веса.
new_conv.weight.data = (conv1.weight.data * conv2.weight.data.permute(1, 0, 2, 3)).sum(0)[None, ]

x = torch.randn([1, 2, 6, 6]) 
out = conv2(conv1(x))
new_out = new_conv(x)
assert (torch.abs(out - new_out) < 1e-6).min()

Еще усложним. Пусть теперь обе наши свертки преобразуют произвольное количество каналов в другое произвольное количество. В этом случае наша свёртка 1х1 будет набором взвешенных сумм каналов промежуточной фичемапы. Логика тут та же. Нужно взвесить веса, порождающие промежуточные фичемапы, с весами из второй свертки.

conv1 = torch.nn.Conv2d(2, 3, 3, bias=False)
conv2 = torch.nn.Conv2d(3, 5, 1, bias=False)

new_conv = torch.nn.Conv2d(1, 5, 3, bias=False) # Забавный факт:
                                                # Не важно какие размеры передавать при инициализации.
                                                # Подмена весов все исправит.

# Магия бродкаста в действии. Можно было сделать так и в предыдущем примере, но что-то же должно улучшаться.
new_conv.weight.data = (conv2.weight.data[..., None] * conv1.weight.data[None]).sum(1)

x = torch.randn([1, 2, 6, 6])
out = conv2(conv1(x))
new_out = new_conv(x)
assert (torch.abs(out - new_out) < 1e-6).min()

Пришло время отказаться от ограничения на размер второй свертки. Давайте разбираться, что происходит в этом случае. Для наглядности посмотрим на одномерную свертку. Как будет выглядеть операция k с ядром 2? $$k_1 * x_1 + k_2 * x_2$$ Добавим сюда вторую свертку v, снова с ядром 2: $$v_1 * (k_1 * x_i + k_2 * x_{i+1}) + v_2 * (k_1 * x_{i+1} + k_2 * x_{i+2})$$ Раскроем скобки, объединив вокруг x: $$x_i * v_1 * k_1 + x_{i+1} * (v_1 * k_2 + v_2 * k_1) + x_{i+1} * (k_2 * v_2)$$ Добавим лишних нулей $$x_i * (v_1 * k_1 + 0 * v_2)+ x_{i+1} * (v_1 * k_2 + v_2 * k_1) + x_{i+1} * (0 * v_1 + k_2 * v_2)$$ Что у нас получилось? Получилась свертка с ядром 3. А коэффициенты у этой свертки — результат кросс-корреляции весов первой свертки после паддинга весами второй свертки. Собственно, для других размеров ядер, двумерного случая, многоканального случая работает та же логика. Но останавливаться на выводе для более сложных случаев я не буду, вместо этого покажу как все это кодить.

kernel_size_1 = np.array([3, 3])
kernel_size_2 = np.array([3, 5])
kernel_size_merged = kernel_size_1 + kernel_size_2 - 1

conv1 = torch.nn.Conv2d(2, 3, kernel_size_1, bias=False)
conv2 = torch.nn.Conv2d(3, 5, kernel_size_2, bias=False)


new_conv = torch.nn.Conv2d(2, 5, kernel_size_merged, bias=False)

# Считаем сколько нулей надо дорисовать вокруг первой свёртки
# Паддинг это сколько нулей надо нарисовать вокруг входной фичимапы по вертикали и горизонтали
padding = [kernel_size_2[0]-1, kernel_size_2[1]-1]

new_conv.weight.data = torch.conv2d(conv1.weight.data.permute(1, 0, 2, 3), # Уже знакомый нам трюк. Мотивация та же.
                                    conv2.weight.data.flip(-1, -2), # А вот это нужно чтобы свертку преваритить в кросс-корреляцию.
                                    padding=padding).permute(1, 0, 2, 3)


x = torch.randn([1, 2, 9, 9])
out = conv2(conv1(x))
new_out = new_conv(x)
assert (torch.abs(out - new_out) < 1e-6).min()

Теперь добавим биасы. Начнем с наличия смещения у второй свертки. Напомню, что биас — это вектор размером с кол-во выходных каналов, который потом прибавляется к выходу свертки. Ну а значит, нам надо просто приравнять биас нашей новой свертки к биасу второй.

kernel_size_1 = np.array([3, 3])
kernel_size_2 = np.array([3, 5])
kernel_size_merged = kernel_size_1 + kernel_size_2 - 1

conv1 = torch.nn.Conv2d(2, 3, kernel_size_1, bias=False)
conv2 = torch.nn.Conv2d(3, 5, kernel_size_2, bias=True)


x = torch.randn([1, 2, 9, 9])
out = conv2(conv1(x))

new_conv = torch.nn.Conv2d(2, 5, kernel_size_merged, bias=True)
padding = [kernel_size_2[0]-1, kernel_size_2[1]-1]

new_conv.weight.data = torch.conv2d(conv1.weight.data.permute(1, 0, 2, 3),
                                    conv2.weight.data.flip(-1, -2),
                                    padding=padding).permute(1, 0, 2, 3)

new_conv.bias.data = conv2.bias.data # вот тут интересность

new_out = new_conv(x)
assert (torch.abs(out - new_out) < 1e-6).min()

С биасом первой свёртки чуть сложнее. Поэтому мы пойдем в два этапа. Для начала заметим, что применение биаса у свертки эквивалентно созданию дополнительной фичимапы, в которой признаки у каждого канала будут константы и равны параметрам смещения. Затем прибавим эту фичимапу к выходу свертки.

kernel_size_1 = np.array([3, 3])
kernel_size_2 = np.array([3, 5])
kernel_size_merged = kernel_size_1 + kernel_size_2 - 1

conv1 = torch.nn.Conv2d(2, 3, kernel_size_1, bias=True)
conv2 = torch.nn.Conv2d(3, 5, kernel_size_2, bias=False)


x = torch.randn([1, 2, 9, 9])
out = conv2(conv1(x))



new_conv = torch.nn.Conv2d(2, 5, kernel_size_merged, bias=False)
padding = [kernel_size_2[0]-1, kernel_size_2[1]-1]

new_conv.weight.data = torch.conv2d(conv1.weight.data.permute(1, 0, 2, 3),
                                    conv2.weight.data.flip(-1, -2),
                                    padding=padding).permute(1, 0, 2, 3)


new_out = new_conv(x)

add_x = torch.ones(1, 3, 7, 7) * conv1.bias.data[None, :, None, None] # Тут делается доп фичимапа
new_out += conv2(add_x) 

assert (torch.abs(out - new_out) < 1e-6).min()

Но нам ведь не хочется каждый раз создавать эту лишнию фичепаму, нам хочется как-то изменить параметры свертки. И мы можем это сделать. Достаточно лишь заметить, что после применения свёртки над такой константной фичемапой получится еще одна константная фичемапа. Мотивация простая — куда бы мы не приложили окошко, срез получится одинаковый. Значит, нам достаточно свернуть такую фичемапу всего один раз, посмотреть, какие признаки получились на каждом канале, и записать их в биас.

kernel_size_1 = np.array([3, 3])
kernel_size_2 = np.array([3, 5])
kernel_size_merged = kernel_size_1 + kernel_size_2 - 1

conv1 = torch.nn.Conv2d(2, 3, kernel_size_1, bias=True)
conv2 = torch.nn.Conv2d(3, 5, kernel_size_2, bias=True)


x = torch.randn([1, 2, 9, 9])
out = conv2(conv1(x))



new_conv = torch.nn.Conv2d(2, 5, kernel_size_merged)
padding = [kernel_size_2[0]-1, kernel_size_2[1]-1]

new_conv.weight.data = torch.conv2d(conv1.weight.data.permute(1, 0, 2, 3),
                                    conv2.weight.data.flip(-1, -2),
                                    padding=padding).permute(1, 0, 2, 3)


add_x = torch.ones(1, 3, *kernel_size_2) * conv1.bias.data[None, :, None, None]

# Эта операция одновременно переносит биас из первой свертки и добавляет биас второй.
new_conv.bias.data = conv2(add_x).flatten()
new_out = new_conv(x)

assert (torch.abs(out - new_out) < 1e-6).min()

Осталось просто обернуть это все в функцию, убрав волшебные константы для размерностей.

import torch
import numpy as np

def merge_two_conv(conv1, conv2):
    kernel_size_1 = np.array(conv1.weight.size()[-2:])
    kernel_size_2 = np.array(conv2.weight.size()[-2:])
    kernel_size_merged = kernel_size_1 + kernel_size_2 - 1
    
    
    in_channels = conv1.weight.size()[1]
    out_channels = conv2.weight.size()[0]
    inner_channels = conv1.weight.size()[0]

    new_conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size_merged)
    padding = [kernel_size_2[0]-1, kernel_size_2[1]-1]

    new_conv.weight.data = torch.conv2d(conv1.weight.data.permute(1, 0, 2, 3),
                                        conv2.weight.data.flip(-1, -2),
                                        padding=padding).permute(1, 0, 2, 3)


    add_x = torch.ones(1, inner_channels, *kernel_size_2)
    add_x *= conv1.bias.data[None, :, None, None]
    
    new_conv.bias.data = torch.conv2d(add_x,
                                      conv2.weight.data).flatten()
        
    new_conv.bias.data += conv2.bias.data
    return new_conv

conv1 = torch.nn.Conv2d(2, 3, 3)
conv2 = torch.nn.Conv2d(3, 5, (4, 5))
new_conv = merge_two_conv(conv1, conv2)

x = torch.randn([1, 2, 9, 9])

assert (torch.abs(conv2(conv1(x)) - new_conv(x)) < 1e-6).min()

В конце статьи хочу сказать, что наша функция будет справляться не с каждой сверткой. В этой реализации мы не учли паддинга, диляции, страйдов, групп. Но тем не менее, я показал все, что хотел. Разбирать дальше, наверное, мало смысла, так как задачи по сворачиванию нескольких сверток в одну я видел лишь раз, (тут)., и то там вторая свертка была 1x1.