Hi, I’m currently reproducing a project which uses torch lightning trainer to manage the DDP training, while its home-made dataset implementation blow up my CPU memory. Its implementation is doubly bad (the raw data is ~300G, while during training I have to apply more than 720G), so I just wonder if HF dataset can perform better than that.
I attach its dataset implementation below, I think the most suspicious part is the CachedDataset wrapper:
from typing import Any, TypeVar
from multiprocessing import Manager
import torch
from torch.utils.data import Dataset
__all__ = ["CachedDataset"]
class NumpiedTensor:
def __init__(self, tensor: torch.Tensor) -> None:
self.array = tensor.numpy()
def to_tensor(self) -> torch.Tensor:
return torch.tensor(self.array)
def numpize_sample(sample: Any) -> Any:
if isinstance(sample, torch.Tensor):
return NumpiedTensor(sample)
elif isinstance(sample, tuple):
return tuple(numpize_sample(s) for s in sample)
elif isinstance(sample, list):
return [numpize_sample(s) for s in sample]
elif isinstance(sample, dict):
return {k: numpize_sample(v) for k, v in sample.items()}
else:
return sample
def tensorize_sample(sample: Any) -> Any:
if isinstance(sample, NumpiedTensor):
return sample.to_tensor()
elif isinstance(sample, tuple):
return tuple(tensorize_sample(s) for s in sample)
elif isinstance(sample, list):
return [tensorize_sample(s) for s in sample]
elif isinstance(sample, dict):
return {k: tensorize_sample(v) for k, v in sample.items()}
else:
return sample
T_co = TypeVar("T_co", covariant=True)
class CachedDataset(Dataset[T_co]):
def __init__(self, dataset: Dataset[T_co]) -> None:
self.dataset = dataset
self.manager = Manager()
self.cache = self.manager.dict()
def __len__(self) -> int:
return len(self.dataset) # type: ignore[arg-type]
def __getitem__(self, index: int) -> Any:
if index not in self.cache:
self.cache[index] = numpize_sample(self.dataset[index])
return tensorize_sample(self.cache[index])
Where it wraps a HDF5Dataset
from __future__ import annotations
from typing import Any
from pathlib import Path
import pickle as pkl
import torch
from torch.utils.data import Dataset
import h5py as h5
__all__ = [
"RawHDF5Dataset",
"HDF5Dataset",
]
class RawHDF5Dataset(Dataset[int]):
def __init__(self, dataset_path: Path | str, grp_list: Path | str | list[str] | None = None) -> None:
self.dataset_path = dataset_path
if grp_list is None:
with h5.File(self.dataset_path, "r") as f:
self.grp_list = list(f.keys())
elif isinstance(grp_list, (str, Path)):
with open(grp_list, "rb") as f:
self.grp_list = pkl.load(f)
elif isinstance(grp_list, list):
self.grp_list = grp_list
else:
raise NotImplementedError()
self.grp_list.sort()
self.f: h5.File | None = None
def __len__(self) -> int:
return len(self.grp_list)
def __getitem__(self, index: int) -> dict[str, Any]:
if self.f is None:
self.f = h5.File(self.dataset_path, "r")
return {k: v[:] for k, v in self.f[self.grp_list[index]].items()}
def __del__(self) -> None:
if self.f is not None:
self.f.close()
class HDF5Dataset(RawHDF5Dataset):
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
return {k: torch.as_tensor(v) for k, v in super().__getitem__(index).items()}
I ask chatgpt and claude, both of them told me every GPU will create their own CachedDataset instance, but I’m not sure. Will HFDataset handle this better?