| import torch | |
| from torch import nn | |
| class SelfAttention(nn.Module): | |
| def __init__(self, in_channels): | |
| super(SelfAttention, self).__init__() | |
| self.query = nn.Conv2d(in_channels, in_channels//8, 1) | |
| self.key = nn.Conv2d(in_channels, in_channels//8, 1) | |
| self.value = nn.Conv2d(in_channels, in_channels, 1) | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x): | |
| batch_size, C, H, W = x.size() | |
| q = self.query(x).view(batch_size, -1, H*W).permute(0, 2, 1) | |
| k = self.key(x).view(batch_size, -1, H*W) | |
| v = self.value(x).view(batch_size, -1, H*W) | |
| attention = torch.bmm(q, k) | |
| attention = torch.softmax(attention, dim=-1) | |
| out = torch.bmm(v, attention.permute(0, 2, 1)) | |
| out = out.view(batch_size, C, H, W) | |
| return self.gamma * out + x | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, channels): | |
| super(ResidualBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) | |
| self.bn1 = nn.BatchNorm2d(channels) | |
| self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) | |
| self.bn2 = nn.BatchNorm2d(channels) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| residual = x | |
| out = self.relu(self.bn1(self.conv1(x))) | |
| out = self.bn2(self.conv2(out)) | |
| out += residual | |
| out = self.relu(out) | |
| return out | |
| class aeModel(nn.Module): | |
| def __init__(self): | |
| super(aeModel, self).__init__() | |
| self.encoder = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(3, 32, 3, stride=2, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| ResidualBlock(32) | |
| ), | |
| nn.Sequential( | |
| nn.Conv2d(32, 64, 3, stride=2, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| ResidualBlock(64) | |
| ), | |
| nn.Sequential( | |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| ResidualBlock(128), | |
| SelfAttention(128) | |
| ), | |
| nn.Sequential( | |
| nn.Conv2d(128, 256, 3, stride=2, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| ResidualBlock(256), | |
| SelfAttention(256) | |
| ) | |
| ]) | |
| self.decoder = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| ResidualBlock(128), | |
| SelfAttention(128) | |
| ), | |
| nn.Sequential( | |
| nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| ResidualBlock(64) | |
| ), | |
| nn.Sequential( | |
| nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| ResidualBlock(32) | |
| ), | |
| nn.Sequential( | |
| nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1), | |
| nn.Sigmoid() | |
| ) | |
| ]) | |
| def forward(self, x): | |
| for encoder_block in self.encoder: | |
| x = encoder_block(x) | |
| for decoder_block in self.decoder: | |
| x = decoder_block(x) | |
| return x | |
| def encode(self, x): | |
| for encoder_block in self.encoder: | |
| x = encoder_block(x) | |
| return x | |
| def decode(self, x): | |
| for decoder_block in self.decoder: | |
| x = decoder_block(x) | |
| return x |