fasttransform:让可逆管道变得简单

隆重推出 fasttransform,这是一个 Python 库,它通过多重分派(multiple dispatch)的力量,使数据转换可逆且可扩展。
作者

Rens Dimmendaal, Hamel Husain, & Jeremy Howard

发布日期

2025 年 2 月 20 日

fasttransform:让可逆管道变得简单

隆重推出 fasttransform,这是一个 Python 库,它通过多重分派(multiple dispatch)的力量,使数据转换可逆且可扩展。

“这张图片是怎样被错误分类的?”

如果你曾训练过机器学习模型,你就知道接下来会发生什么:试图理解你的模型实际看到了什么,这是一个令人沮丧的过程。你需要深入挖掘层层转换——归一化、缩放、数据增强——结果却发现需要编写逆函数才能再次看到你的数据。这太痛苦了,以至于我们很多人干脆跳过这一步,基于抽象数字而不是实际数据来调试模型。

或者引用 OpenAI 的 Greg Brockman 的话来说

Greg Brockman tweet: “Manual inspection of data has probably the highest value-to-prestige ratio of any activity in machine learning.”

Greg Brockman 的推文:“在机器学习的各种活动中,人工检查数据大概是价值-声望比最高的。”

让我们看看你可能会错过什么。下面是一个使用 fastai 的简单例子

from fastai.vision.all import *
dls = ImageDataLoaders.from_folder(
    Path("./huskies_vs_wolves/"), 
    item_tfms=RandomResizedCrop(128, min_scale=0.35), 
    batch_tfms=Normalize.from_stats(*imagenet_stats)
)
dls.show_batch()  # One line to see our data

show_batch makes it easy to take a look at your data after has been transformed

show_batch 使查看数据在转换后的样子变得很容易
learn = Learner(dls, xresnet34(n_out=2), metrics=accuracy)
learn.fit_one_cycle(5, 0.0015)
learn.show_results()  # One line to see predictions

show_results lets you inspect your predictions immediately after training your model. The prediction labels (0/1) are also automatically transformed back to their string representation.

show_results 允许你在模型训练完成后立即检查预测结果。预测标签 (0/1) 也会自动转换回其字符串表示形式。
# Two lines to see the model's biggest mistakes
interp = Interpretation.from_learner(learn)
interp.plot_top_losses(9)

plot_top_losses visualizes where the model is “most confidently wrong” which teaches us about the most glaring issues.

plot_top_losses 可视化了模型“最自信地犯错”的地方,这向我们揭示了最明显的问题。

仅用这四行代码,我们就发现了一个有趣的事实:我们的“狼检测器”根本不是在检测狼——它是在检测雪!看看训练数据:雪中的狼,森林中的哈士奇。再看看预测结果:模型在背景翻转时就失败了。如果不能轻松地可视化我们的数据,我们可能永远无法发现这个明显的缺陷。

The LIME technique visualizes how the model focuses on snowy backgrounds to make its predictions

LIME 技术可视化了模型如何聚焦于雪景背景进行预测

虽然像 LIME1 这样复杂的解释性技术可以很好地可视化你的模型正在聚焦于图像的哪些部分(如上所示),但通常最有价值的洞察来自于简单地用自己的眼睛查看数据。在这个例子中,快速的视觉检查同样揭示了一个明显的数据集偏差。

fastai 是如何做到这一点的?它使用了 Transform ——这是一个看似简单却功能强大的想法,一直隐藏在 fastcore 的代码库中。今天,我们很高兴地宣布,我们已将其移至自己的库:fasttransform,因为我们相信它的应用可能超越机器学习。

无论你处理的是图像、文本、时间序列还是其他需要处理的数据,fasttransform 都提供一个简单的承诺:如果你能以某种方式转换数据,那么就应该能够同样轻松地将其转换回来。不再需要编写逆函数,不再丢失对数据的可见性。

让我们看看它是如何工作的。

问题 #1:单向转换

你是否曾试图通过查看数据来调试机器学习管道?通常情况是这样的:

  1. 加载数据
  2. 应用一些转换
  3. 试图找出哪里出错了
  4. 意识到你实际上看不到模型看到的东西
  5. 花接下来的一个小时编写逆函数
  6. 放弃并转而使用 print 语句调试

让我们用一个简单的例子来具体说明:使用 PyTorch 归一化一张图片

from torchvision import transforms as T
transforms_pt = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(*imagenet_stats)
])

# Load and transform an image
img = Image.open("./huskies_vs_wolves/train/husky/husky_0.jpeg")
img_transformed = transforms_pt(img)

# Try to look at what we did...
show_image(img_transformed);
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-2.1007793..2.2489083].

归一化是一个关键的预处理步骤,它将像素值缩放到相似的范围(通常是均值=0、标准差=1),这有助于神经网络更有效地训练。

然而,归一化后的图片并不适合人类眼睛进行检查。要解决这个问题,我们需要手动编写一个逆转换

def decode_pt(tensor, mean, std):
    """Decode a normalized PyTorch tensor back to RGB range"""
    out = tensor.clone()  # Clone to avoid modifying original
    for t, m, s in zip(out, mean, std): t.mul_(s).add_(m)  # Denormalize
    out = out.mul(255).clamp(0, 255).byte() # Scale back to RGB
    return out

img_decoded = decode_pt(img_transformed, *imagenet_stats)
show_image(img_decoded);

这并非一个冷门问题。多年来,这一直是许多机器学习实践者的痛点。

Question on Pytorch Discourse with discussion from 2017 to 2022 Google results overview of people trying to undo a Normalize in Pytorch

这还只是简单的归一化。在实际项目中,你可能需要处理:- 需要与图像同步转换的分割掩码 - 需要分词、填充和特殊 token 的文本数据 - 需要滑动窗口、归一化和编码的时间序列数据

每一次转换都增加了一层需要解开的复杂性。最糟糕的是:因为查看转换后的数据如此痛苦,我们很多人干脆……不看了。我们最终基于抽象数字而不是实际数据来调试模型,只是希望我们的转换正在按我们预期的那样工作。

还记得在我们的 fastai 示例中查看模型究竟看到什么有多容易吗?那不是魔术——那是可逆转换的力量。让我们看看 fasttransform 是如何实现这一点的。

更好的方法:可逆管道

以下是 fastai 如何处理与上一节 PyTorch 示例相同的管道

from fastai.vision.all import *

transforms_ft = Pipeline([
   PILImage.create,
   Resize(256,method="squish"),
   Resize(224,method="crop"),
   ToTensor(),
   IntToFloatTensor(),
   Normalize.from_stats(*imagenet_stats)
])

# Transform our image
fpath = Path("./huskies_vs_wolves/train/husky/husky_0.jpeg")
img_transformed = transforms_ft(fpath)
show_image(img_transformed[0]);  # Still looks wrong...
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-2.0836544..2.2317834].

# But now the magic:
img_decoded = transforms_ft.decode(img_transformed)
show_image(img_decoded[0]);  # That's better!

就这样。无需手动编写逆函数。无需记住均值和标准差。只需调用 .decode(),我们就回到了可以实际查看的状态。

fasttransform 将这一相同的功能带到你自己的代码中。关键的洞察在于,对于你想应用的任何转换,你可能已经知道如何撤销它。让我们看看它内部是如何工作的。

工作原理:.decode()

fasttransform 背后的核心思想很简单:将一个转换与其逆操作配对。

以下是如何编写一个可逆的归一化转换

class Normalize(Transform):
    def __init__(self, mean=None, std=None):
        self.mean = mean
        self.std = std
        
    def encodes(self, x): return (x-self.mean) / self.std  # forward transform
    def decodes(self, x): return x*self.std + self.mean    # inverse transform

仅此而已。

通过定义 encodesdecodes,fasttransform 自动知道如何反转你的转换。将此与我们之前的 PyTorch 示例进行比较——我们不是编写单独的正向和逆向函数,而是将它们放在一起,各得其所。

你可能注意到了这个特殊的命名——带有 ‘s’ 的 encodesdecodes。我们稍后会解释原因,但这与 fasttransform 如何自动处理不同类型的数据密切相关。

当你调用 decode() 时,fasttransform 会智能地判断哪些转换需要反转。有些转换,比如加载图片或缩放图片,不需要撤销,你实际上想看到模型看到的样子!其他转换,比如归一化,则需要反转才能人类可读。

如何做到这一点?很简单,只有当转换需要反转时,才定义一个 .decodes 方法!

介绍部分的可视化函数正是使用了这一功能,将转换后的输入恢复到人类可理解的状态。

问题 #2:处理多种类型

我们已经看到了如何通过使转换可逆来更容易地查看数据。但在使用转换时还有另一个挑战:不同类型的数据需要不同的转换。

最常见的情况是输入和标签需要不同的转换。这里也适用同样的原则。我们希望将所有这些转换保存在一起,因为我们希望能够撤销它们。例如,我们将分类标签从字符串转换为整数,然后为了人类可读性再转换回字符串。但我们不希望为输入和输出维护单独的转换管道。

为了理解为什么这是一个问题,让我们看看 PyTorch(最流行的深度学习框架之一)是如何处理这种情况的。以下是教程中的一个示例,展示了典型的自定义数据集

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)  
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform  # <- separate target transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:  # <- separate target transform
            label = self.target_transform(label)
        return image, label

图像和标签的转换是单独定义并提供给数据集类的。这种分离乍一看似乎合理,但它产生了两个问题

  1. 当我们想要反转转换时,必须记住反转两个管道
  2. 当某些转换需要同时应用于输入和目标时,我们必须在两个地方维护它们(例如,图像分割中的图像和掩码缩放)

让我们看看 fasttransform 如何使这一切变得更容易。

更好的方法:输入和输出共用一个管道

这正是 fasttransform 的方法闪耀之处:它不是同时处理单独的管道,而是在一个转换中处理你的图像及其标签。当你向转换传递一个元组时,它只会应用相关的转换。这听起来可能是一件小事,但对于实际的机器学习工作来说,它是一个颠覆性的改变。

让我们看看实际操作。

首先,我们将创建一个同时加载图像和标签的函数

def load_img_and_label(fp): return PILImage.create(fp), parent_label(fp)
load_img_and_label(fpath)
(PILImage mode=RGB size=375x500, 'husky')

现在到了酷炫的部分——我们可以通过一个微小的改动,在我们的转换管道中使用这个函数。看看这有多简洁

transforms_ft = Pipeline([
   load_img_and_label,  # <-- Load img and label as a tuple
   Resize(256,method="squish"),
   Resize(224,method="crop"),
   ToTensor(),
   IntToFloatTensor(),
   Normalize.from_stats(*imagenet_stats)
])

out = transforms_ft(fpath)
print((out[0][0,:2,:2,:2], out[1])) 
(TensorImage([[[-0.2856, -0.2856],
              [-0.2856, -0.2856]],

             [[ 0.5553,  0.5553],
              [ 0.5553,  0.5553]]], device='mps:0'), 'husky')

但我们还没完成!这些字符串标签(“husky”,“wolf”)需要转换为数字,以便我们的模型处理。在 PyTorch 中,我们需要为此建立一个单独的转换管道。使用 fasttransform,我们只需添加另一个只应用于字符串的转换

class StrCategorize(Transform):
    def __init__(self, vocab):
        self.vocab = vocab
        self.s2i = {s:i for i,s in enumerate(vocab)}
        self.i2s = {i:s for i,s in enumerate(vocab)}
    def encodes(self, s:str): return self.s2i[s]
    def decodes(self, i:int): return self.i2s[i]
    
transforms_ft = Pipeline([
   load_img_and_label,
   Resize(256,method="squish"),
   Resize(224,method="crop"),
   ToTensor(),
   IntToFloatTensor(),
   Normalize.from_stats(*imagenet_stats),
   StrCategorize(vocab=['husky','wolf']), # <-- Transform is just for the target label
])

out = transforms_ft(fpath)
print((out[0][0,:2,:2,:2], out[1])) 
(TensorImage([[[-0.2856, -0.2856],
              [-0.2856, -0.2856]],

             [[ 0.5553,  0.5553],
              [ 0.5553,  0.5553]]], device='mps:0'), 0)

你可能会想:“好吧,把转换保存在一个管道里是很好,但这真的那么重要吗?”

嗯,一个好处是现在你可以一次性反转这两个转换

rev = transforms_ft.decode(out)
print((rev[0][0,:2,:2,:2], rev[1])) 
(TensorImage([[[107, 107],
              [107, 107]],

             [[148, 148],
              [148, 148]]], device='mps:0'), 'husky')

接下来我们将展示另一个例子,它说明了为什么将这些转换保存在同一个地方至关重要:图像分割。

在分割任务中,你试图识别图像中的特定区域——比如在照片中找到一只哈士奇。棘手的部分在于:你的输入图像和目标掩码都需要以完全相同的方式进行转换。当你使用随机转换作为数据增强形式时,这就会变得棘手。举例来说,如果你对图像应用随机裁剪,那么你最好也以完全相同的方式裁剪那个掩码!如果它们不对齐,你的整个训练数据就会变成无意义的。

让我们看看这在实践中是什么样子。首先,我们定义一个新函数,它同时加载图像及其对应的掩码

fnames = list(Path("./segment_huskies/img/").glob("*"))
fn = fnames[0]

def load_img_msk(fn): 
    return PILImage.create(fn), PILMask.create(fn.parent.parent / "msk" / fn.name)

img, msk = load_img_msk(fn)

show_images([img,msk])

现在,如果我们想同时随机裁剪图像和掩码(一种常见的增强技术),它们需要以完全相同的方式裁剪。如果它们不对齐,那么你的整个训练数据就会变成无意义的。

fasttransform 是如何处理这个问题的

transforms_ft = Pipeline([
   load_img_msk,  # <-- New load func for img and mask
   RandomResizedCrop(200),  # Applied to both img and mask
   ToTensor(),              # Applied to both img and mask
   IntToFloatTensor(),                   # Only applied to img
  Normalize.from_stats(*imagenet_stats)  # Only applied to img
])

out = transforms_ft(fn)
out
show_images((out[0][0], out[1]))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.8096584..2.64].

瞧,源图像和目标掩码都以完全相同的方式进行了转换。

如果这些转换存储在不同的管道中,那么要使它们保持同步将困难得多。特别是因为转换中存在随机性。

另外,请注意,反转仍然同样容易

rev = transforms_ft.decode(out)
show_images((rev[0][0], rev[1]))

此时你可能会想:“这太棒了——一个管道处理不同类型的数据,只在需要时应用相关的转换。但这究竟是如何工作的呢?”

好吧,接下来我们就深入探讨一下!

工作原理:多重分派

让转换仅应用于相关数据类型的秘密武器叫做多重分派(multiple dispatch)。如果你之前没听过,别担心——这是一个强大的编程概念,在像 Julia2 这样的语言中很流行,但在 Python 中相对不为人知。

你可以将多重分派想象成同一个函数拥有不同的版本,每个版本都设计用于处理特定类型的数据。当你调用函数时,Python 会根据你提供的数据类型自动选择正确的版本。

Python 开箱即用地提供了一个仅限于单参数函数的实现

from functools import singledispatch

@singledispatch 
def greet(x): return "Hello stranger!"

@greet.register
def _(x:str): return f"Hello {x}!"

@greet.register
def _(x:int): return f"Hello number {x}!"

greet(None), greet("Alice"), greet(42)
('Hello stranger!', 'Hello Alice!', 'Hello number 42!')

多重分派将这个想法扩展到具有多个参数的函数。虽然 Python 的内置工具仅处理单参数分派,但 plum 库为任意数量的参数提供了真正的多重分派。下面是一个简单的例子来说明这个概念

from plum import dispatch

class Dog: pass
class Cat: pass

@dispatch
def greet(a: Cat, b: Dog):
    return "Hiss!"

@dispatch
def greet(a: Dog, b: Cat):
    return "Grrrr..."


# Let's try it out
cat, dog = Cat(), Dog()


print(greet(cat, dog))  # "Hiss!"
print(greet(dog, cat))  # Grrrr...
Hiss!
Grrrr...

Transform 在内部使用了 plum 的多重分派功能,但核心思想是一样的:基于它接收的运行时数据类型来调用正确的函数。这就是为什么一个单个管道可以处理图像、标签、掩码和其他类型的数据的原因。

有三种不同的方式可以在转换中定义类型特定的行为,每种方式都适用于不同的情况。让我们逐一看一下。

创建转换最简单的方法是直接传递函数。这非常适合快速实验或一次性转换

# Method 1: Direct functions
def enc_str(x:str): return f"encoded str: {x=}"
def enc_int(x:int): return f"encoded int: {x=}"
my_transform = Transform(enc=(enc_str,enc_int))

my_transform(("hello", 42))
("encoded str: x='hello'", 'encoded int: x=42')

你在原型开发或不需要在代码的其他地方重用转换时可能会使用这种方法。但对于结构更清晰的代码,你可能希望创建一个合适的类……

子类化 Transform 提供了一种更组织化的方式来处理不同类型

# Method 2: Create a Transform subclass
class MyTransform(Transform):
    def encodes(self, x:str): return f"encoded str: {x=}"
    def encodes(self, x:int): return f"encoded int: {x=}"
    
my_transform = MyTransform()
my_transform(("my str", 42))
("encoded str: x='my str'", 'encoded int: x=42')

这里有一件有趣的事情:在普通的 Python 类中,你不能多次定义同一个方法。但当你从 Transform 子类化时,你可以!

encodes 方法自动设置为支持多重分派,因此 Python 知道根据输入类型调用哪个版本。

但还有一种定义转换的方法,当你想要扩展现有转换时,这种方法特别有用……

# Method 3: Extend with decorators
@MyTransform
def encodes(self, x: float): return f"encoded float: {x=}"

# Now our transform handles three types!
my_transform(("hello", 42, 6.28))  
("encoded str: x='hello'", 'encoded int: x=42', 'encoded float: x=6.28')

这种装饰器语法在实际应用中非常有用。

例如,在 fastai 中,Normalize 转换在核心库中定义以处理图像,但其他模块可以扩展它以处理新类型

# In fastai.data.transforms:
class Normalize(Transform): ...  # handles image normalization

# In fastai.tabular.core:
@Normalize
def encodes(self, x: pd.DataFrame): ...  # adds DataFrame support

这种插件式的架构意味着任何人都可以扩展现有转换以处理新类型的数据,而无需修改原始代码。这就是多重分派的力量在发挥作用!

真正的力量体现在 fastai 周边的生态系统中代码被重用和扩展时。像 fastxtend 这样的库可以在不修改原始代码的情况下添加对新数据类型的支持。如果没有多重分派,它们将面临经典的继承问题。相反,使用 fasttransform,它们只需为现有转换注册新的行为。

结论

我们已经看到了 fasttransform 如何解决数据处理中的两个基本问题

  1. 通过配对的 encodes/decodes 方法使转换可逆
  2. 通过多重分派处理不同数据类型

虽然这些想法源于 fastai 的深度学习需求,但它们的应用远不止于此。无论你处理的是图像、文本、时间序列,还是量子态,fasttransform 都提供一个简单的承诺:如果你能以某种方式转换数据,那么就应该能够同样轻松地将其转换回来。

准备好亲自尝试了吗?使用以下命令安装 fasttransform

pip install fasttransform

请查看我们的文档以获取更多示例和详细的 API 参考。如果你之前已在使用 fastcore 的 dispatch 和 transform 模块,那么你可能需要看看我们的迁移指南

我们很想听听你在自己的项目中是如何使用 fasttransform 的!

脚注

  1. 数据集改编自介绍 LIME 技术的学术论文。该数据集经过量身定制,旨在展示其突出雪景背景在识别哈士奇中最重要性的技术。来源:Ribeiro, Marco Tulio, Sameer Singh, and Carlos Guestrin. “‘为什么我应该相信你?’ 解释任何分类器的预测结果。” 第 22 届 ACM SIGKDD 国际知识发现与数据挖掘会议论文集。2016 年。↩︎

  2. 如果你想深入探索多重分派这个“兔子洞”,我们推荐这场由该语言(Julia)共同创建者之一 Stefan Karpinski 带来的演讲,标题为“多重分派的不可思议的有效性”(The Unreasonable Effectiveness of Multiple Dispatch)↩︎