关于 HMR 的实现


上期 谈到,我在 Python 实现了响应式原语,然后 又实现了 HMR,但是没细讲,这次写一下从响应式原语到 HMR 是怎么实现的。

我也新开了个 仓库,未来放 hmr 生态的东西(如果社区有反响的话),不过由于这算是临时起意中的临时起意(歪楼项目中的歪楼项目),所以代码还是在原仓库中没有拿出来。这个仓库目前就是一个 README 哈哈哈

如果想要直观理解 hmr 对 Python 到底意味着什么,可以看看 这个仓库 里的两个视频,另外我还分别给 FastAPI 和 Flask 开了 Discussion,分别也附上了精致的录屏,欢迎来下面讨论:

FastAPI | Flask (我还准备给 gradio、streamlit、pytest、litestar 也提一下)

不过没时间了,明天我要回家了,所以先不急着搞了

怎么实现的

正如我仓库中所说,HMR 实现可以拆成三部分:

  1. 实现那几个响应式原语
  2. 让加载进来的包,能自动跟踪依赖
    1. 实现一个 Reactive 的 context(按 item 追踪依赖)
    2. 实现一个自定义的 ModuleType,使用这个 Reactive 的 context
    3. 实现一个 Loader 以及一个 MetaPathFinder,并添加到sys.meta_path中,这样新加载进来的模块就会用我们的机制来初始化
  3. 通过watchfiles监听 FS 变化,并 invalidate 对应的 module

其中第 3 部分正如其字面意思,很简单,就是 watchfiles 的基本用法,没什么创意。本文重点讲第二部分

Reactive namespace

我们需要实现一个 MutableMapping,类似 Vue 的 reactive,每个__getitem__会触发对应 item 的 track,而__setitem____delitem__会触发对应 item 的 notify。实现如下:

class Reactive[K, V](Subscribable, MutableMapping[K, V]):
    UNSET: V = object()  # type: ignore

    def __hash__(self):
        return id(self)

    def _null(self):
        return Signal(self.UNSET, self._check_equality)

    def __init__(self, initial: Mapping[K, V] | None = None, check_equality=True):
        super().__init__()
        self._signals = defaultdict[K, Signal[V]](self._null) if initial is None else defaultdict(self._null, {k: Signal(v, check_equality) for k, v in initial.items()})
        self._check_equality = check_equality

    def __getitem__(self, key: K):
        value = self._signals[key].get()
        if value is self.UNSET:
            raise KeyError(key)
        return value

    def __setitem__(self, key: K, value: V):
        with Batch():
            self._signals[key].set(value)
            self.notify()

    def __delitem__(self, key: K):
        state = self._signals[key]
        if state.get(track=False) is self.UNSET:
            raise KeyError(key)
        with Batch():
            state.set(self.UNSET)
            self.notify()

    def __iter__(self):
        self.track()
        return iter(self._signals)

    def __len__(self):
        self.track()
        return len(self._signals)

    def __repr__(self):
        self.track()
        return repr({k: v.get() for k, v in self._signals.items()})

    def items(self):
        self.track()
        return ({k: v.get() for k, v in self._signals.items()}).items()

可以看到我还多做了一些处理:比如这个 mapping 本身也是一个 Subscribable,像是对.items的使用就会订阅它的全部改动(可以这么理解:如果我直接str()了某个 module 的.__dict__,我就依赖于它的所有元素)

一些细节

exec 的时候,globals 必须是 dict 的子类,但是我们这个 Reactive 不能是 dict 的子类,因为一个 Subscribable 必须是能添加到一个 set 里的,而 set 里不能是可变元素,好像这与 dict 冲突了(所以我必须得加一个__hash__函数)。另外后来文档说 globals 甚至不能是 dict 子类,必须得是 dict 严格本身(否则我发现会有一些奇怪的问题,关于在类中的定义域啥的),详见我给 CPython 的唯一一个 Issue:121306(说起来 3.14 有望 revert 这个,不知道他们后面争论的怎么样了)

说回来,globals 必须是严格 dict,所以我们只能把这个 Reactive 作为locals传过去,但是 locals 不会继承到函数里,所以还得把对我们这个Reactive的写入,写回到这个initial里去。

class NamespaceProxy(Reactive[str, Any]):
    def __init__(self, initial: MutableMapping, check_equality=True):
        super().__init__(initial, check_equality)
        self._original = initial

    def __setitem__(self, key, value):
        self._original[key] = value
        return super().__setitem__(key, value)

    def __delitem__(self, key):
        del self._original[key]
        return super().__delitem__(key)

Reactive module

这个就 debug 麻烦。print 的时候还会有各种循环访问的报错,总之就是很痛苦,跑通了就再也不想碰了:

class ReactiveModule(ModuleType):
    def __init__(self, file: Path, namespace: dict, name: str, doc: str | None = None):
        super().__init__(name, doc)
        self.__is_initialized = False
        self.__dict__.update(namespace)
        self.__is_initialized = True

        self.__namespace = namespace
        self.__namespace_proxy = NamespaceProxy(namespace)
        self.__file = file

    @property
    def file(self):
        if is_called_in_this_file():
            return self.__file
        raise AttributeError("file")

    @memoized_method
    def __load(self):
        code = compile(self.__file.read_text("utf-8"), str(self.__file), "exec", dont_inherit=True)
        exec(code, self.__namespace, self.__namespace_proxy)

    @property
    def load(self):
        if is_called_in_this_file():
            return self.__load
        raise AttributeError("load")

    def __dir__(self):
        return iter(self.__namespace_proxy)

    def __getattribute__(self, name: str):
        if name == "__dict__" and self.__is_initialized:
            return self.__namespace
        return super().__getattribute__(name)

    def __getattr__(self, name: str):
        try:
            return self.__namespace_proxy[name]
        except KeyError as e:
            raise AttributeError(*e.args) from e

    def __setattr__(self, name: str, value):
        if is_called_in_this_file():
            return super().__setattr__(name, value)
        self.__namespace_proxy[name] = value

这里用到一个_is_called_in_this_file,其实就是判断调用者的 frame 是不是在这个文件中:

def is_called_in_this_file() -> bool:
    frame = currentframe()  # this function
    assert frame is not None

    frame = frame.f_back  # the function calling this function
    assert frame is not None

    frame = frame.f_back  # the function calling the function calling this function
    assert frame is not None

    return frame.f_globals.get("__file__") == __file__

为什么要这样,因为我这个load必须得在后面的ModuleFinder里调用,所以不能是双下划线大头的__load形式,但是如果是_load或者load都容易与模块的真实的属性冲突,所以我就做了这么一个 hack,实现了类似“package-private 级变量”的效果,我对这个还蛮洋洋得意的,哈哈哈

Module finder & loader

class ReactiveModuleLoader(Loader):
    def __init__(self, file: Path, is_package=False):
        super().__init__()
        self._file = file
        self._is_package = is_package

    def create_module(self, spec: ModuleSpec):
        namespace = {"__file__": str(self._file), "__spec__": spec, "__loader__": self, "__name__": spec.name}
        if self._is_package:
            assert self._file.name == "__init__.py"
            namespace["__path__"] = [str(self._file.parent)]
        return ReactiveModule(self._file, namespace, spec.name)

    def exec_module(self, module: ModuleType):
        assert isinstance(module, ReactiveModule)
        module.load()


class ReactiveModuleFinder(MetaPathFinder):
    def __init__(self, includes: Iterable[str] = ".", excludes: Iterable[str] = ()):
        super().__init__()
        self.includes = [Path(i).resolve() for i in includes]
        self.excludes = [Path(e).resolve() for e in excludes]

    def _accept(self, path: Path):
        return path.is_file() and not is_relative_to_any(path, self.excludes) and is_relative_to_any(path, self.includes)

    def find_spec(self, fullname: str, paths: Sequence[str] | None, _=None):
        if fullname in sys.modules:
            return None

        for p in sys.path:
            directory = Path(p).resolve()
            if directory.is_file():
                continue

            file = directory / f"{fullname.replace('.', '/')}.py"
            if self._accept(file) and (paths is None or is_relative_to_any(file, paths)):
                return spec_from_loader(fullname, ReactiveModuleLoader(file), origin=str(file))
            file = directory / f"{fullname.replace('.', '/')}/__init__.py"
            if self._accept(file) and (paths is None or is_relative_to_any(file, paths)):
                return spec_from_loader(fullname, ReactiveModuleLoader(file, is_package=True), origin=str(file), is_package=True)

很常规,很无聊

Reload

核心是我这个on_events,因为 watchfiles 的 watch 和 awatch yield 出来的 events 都长这样,就只用写一份了。

目前只考虑了 modified 事件。其他的情况是未定义的 哈哈哈

def on_events(self, events: Iterable[tuple[int, str]]):
    from watchfiles import Change

    if not events:
        return

    path2module = get_path_module_map()

    with batch():
        for type, file in events:
            if type is Change.modified:
                path = Path(file).resolve()
                if path.samefile(self.entry):
                    self.run_entry_file.invalidate()
                elif module := path2module.get(path):
                    try:
                        module.load.invalidate()
                    except Exception as e:
                        sys.excepthook(e.__class__, e, e.__traceback__)

        for module in path2module.values():
            try:
                module.load()
            except Exception as e:
                sys.excepthook(e.__class__, e, e.__traceback__)
        self.run_entry_file()

这里用到一个.invalidate(),包括前面也用到一个memoized_method,这个确实还没讲,但其实就是createMemo一样的。

我这里其实是把 memoized 当“惰性的 effect”用了,可以看到,我invalidate 修改的模块,然后每个模块都 call 一下 load,保证它是最新的,其中的缓存啥的就交给机制了,测试了半天最后终于没问题了

Memoized

而其中createMemo其实就是惰性的一个derived这样的意思。我觉得 memoized 更强调了这种惰性感,所以没用 Derived 这个词:

class Memoized[T](Subscribable, BaseComputation[T]):
    def __init__(self, fn: Callable[[], T]):
        super().__init__()
        self.fn = fn
        self.is_stale = True
        self.cached_value: T
        self._recompute = False

    def trigger(self):
        self.track()
        if self._recompute:
            self._recompute = False
            self._before()
            try:
                self.cached_value = self.fn()
                self.is_stale = False
            finally:
                self._after()
        else:
            self.invalidate()

    def __call__(self):
        if self.is_stale:
            self._recompute = True
            self.trigger()
            assert not self._recompute
        return self.cached_value

    def invalidate(self):
        if not self.is_stale:
            del self.cached_value
            self.is_stale = True

memoized_method 其实就是在这外面包裹了一层 descriptor 而已:

class MemoizedMethod[T, I]:
    def __init__(self, method: Callable[[I], T]):
        super().__init__()
        self.method = method
        self.map = WeakKeyDictionary[I, Memoized[T]]()

    @overload
    def __get__(self, instance: None, owner: type[I]) -> Self: ...
    @overload
    def __get__(self, instance: I, owner: type[I]) -> Memoized[T]: ...

    def __get__(self, instance: I | None, owner):
        if instance is None:
            return self
        if memo := self.map.get(instance):
            return memo
        self.map[instance] = memo = Memoized(partial(self.method, instance))
        return memo

唉,实在没什么时间,要回去睡觉了,所以写的很匆忙,简直什么都没写,不管了


结语

正如我在那个 hmr 仓库里写的,生态最重要,我还没找 pytest(因为我不太会做 pytest plugin,所以没法给他们一个很吸引人的 demo),目前找了 FastAPI 和 Flask,其中 FastAPI 底下有个 uvicorn 的维护者评论了我几句,提醒了我我没考虑静态文件的读取:

https://github.com/fastapi/fastapi/discussions/13192#discussioncomment-11796729