From 5f4b0a9ef004d3c8b41876af3ce0f9700f79b8f0 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Thu, 4 May 2023 18:58:17 +0000 Subject: [PATCH 01/96] update --- examples/inference/hf_generate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/inference/hf_generate.py b/examples/inference/hf_generate.py index d068aca92..24b044a03 100644 --- a/examples/inference/hf_generate.py +++ b/examples/inference/hf_generate.py @@ -85,7 +85,7 @@ def run_all(pp_ranks, args): prompt = "Hey, are you conscious? Can you talk to me?" input = tokenizer(prompt, return_tensors="pt") input_ids = input["input_ids"].to(args.device) - outputs = model.generate(input_ids, max_length=30) + outputs = model.generate(input_ids, max_new_tokens=30) response = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] print(response) @@ -109,7 +109,7 @@ def run_all(pp_ranks, args): assert args.world_size % args.pp_group_size == 0 - supported_model_categories = ["opt", "gpt", "bloom", "codegen"] + supported_model_categories = ["opt", "gpt", "bloom", "codegen", "llama"] # For example: # "facebook/opt-350m" # "gpt2" From 3edf3ab46c69ecd1ef9cdbbd10a969571ff87051 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 26 May 2023 18:30:18 -0700 Subject: [PATCH 02/96] Add c10d backend for PiPPy (#800) ## Description - Allow PiPPy to use c10d as communication backend instead of RPC - User program would be in GSPMD style - Make PiPPy and Tensor Parallel more composable --- examples/tp+pp/pippy_tp.py | 147 ++++++++++-------- pippy/PipelineStage.py | 300 +++++++++++++++++++++++++++++++++++++ pippy/__init__.py | 8 +- pippy/compile.py | 95 +++++++++++- test/local_test_c10d.py | 111 +++++--------- test/min_gpt_pp_tp.py | 201 +++++++++++++++++++++++++ 6 files changed, 725 insertions(+), 137 deletions(-) create mode 100644 pippy/PipelineStage.py create mode 100644 test/min_gpt_pp_tp.py diff --git a/examples/tp+pp/pippy_tp.py b/examples/tp+pp/pippy_tp.py index ae507a633..90cc80757 100644 --- a/examples/tp+pp/pippy_tp.py +++ b/examples/tp+pp/pippy_tp.py @@ -7,9 +7,10 @@ import pippy import pippy.fx -from pippy import run_pippy from pippy.IR import pipe_split +from pippy.compile import compile_stage +import torch.distributed as dist from torch.distributed._tensor import ( DeviceMesh, ) @@ -55,75 +56,78 @@ def forward(self, x): return x -def run_all(pp_ranks, args): - chunks = args.pp_group_size - device_type = "cuda" if args.cuda else "cpu" +d_hid = 256 +batch_size_per_chunk = 8 - # Figure out my PP rank - pp_rank = args.rank // args.tp_group_size - print(f"Global rank {args.rank}, pipeline: {pp_ranks}, my rank in pipe: {pp_rank}") - - d_hid = 256 - batch_size_per_chunk = 8 - inp_size = [chunks * batch_size_per_chunk, d_hid] - # Ensure all tp ranks have same input. +def run_all(args): + # The seed here has two purposes: + # - Ensure all TP ranks have same input + # - Ensure the model (ec) created are the same, as if it comes from a + # single, big model before partitioning torch.manual_seed(0) - inp = torch.rand(*inp_size, device=device_type) - """ - # Reference run - # `torchrun --nproc_per_node=2 pippy_tp_mlp.py --world_size=2 --pp_group_size=1` - ec_tp = ExampleCode(d_hid) - ec_tp.to(args.device) + # Create original model + ec = ExampleCode(d_hid) + ec.to(args.device) - start_idx = 0 - device_mesh = DeviceMesh( + # Create input + inp_size = [args.chunks * batch_size_per_chunk, d_hid] + device_type = args.device.type + inp = torch.rand(*inp_size, device=args.device) + + # Create global DeviceMesh + ranks = torch.arange(args.world_size) + rank_mesh = ranks.reshape(args.pp_group_size, args.tp_group_size) + pp_dim = 0 + tp_dim = 1 + dm = DeviceMesh( device_type, - list(range(start_idx, start_idx + args.tp_group_size)), + rank_mesh, ) - print(f"Rank {args.rank} calling parallelize_module with {device_mesh}") - parallelize_module(ec_tp, device_mesh, PairwiseParallel()) - print(f"Rank {args.rank} sharding complete") - ref_out = ec_tp(inp) - print(f"Ref out: {ref_out.size()}") - """ + # Figure out my PP and TP rank + pp_rank = args.rank // args.tp_group_size + tp_rank = args.rank % args.tp_group_size + print(f"Global rank {args.rank}, pp rank: {pp_rank}, tp rank: {tp_rank}") - # PiPPy run - ec = ExampleCode(d_hid) - ec.to(args.device) + # Get pp group + # `tp_rank` can serve as pipeline id + print(f"Rank {args.rank} Instantiating pipeline with ranks {dm.mesh[:, tp_rank]}") + pp_group = dm.get_dim_groups()[pp_dim] - # Get: - # - pipeline driver (for pipeline head rank) - # - stage submodule (for all ranks) - pipe_driver, submod = pippy.all_compile( + # Get stage module (on all pp ranks) + stage = compile_stage( ec, + pp_rank, args.pp_group_size, - chunks, - ranks=pp_ranks, + args.chunks, + args.device, + pp_group, + example_inputs=[inp], ) - # Create TP device mesh - my_device_mesh = None - for stage in range(args.pp_group_size): - start_rank = stage * args.tp_group_size - tp_ranks = list(range(start_rank, start_rank + args.tp_group_size)) - tp_device_mesh = DeviceMesh( - device_type, - tp_ranks, - ) - if stage == pp_rank: - my_device_mesh = tp_device_mesh - # Tensor parallelize submodules - print(f"Rank {args.rank} calling parallelize_module with {my_device_mesh}") - parallelize_module(submod, my_device_mesh, PairwiseParallel()) + print(f"Rank {args.rank} TP-lize submodule with {dm.mesh[pp_rank]}") + parallelize_module(stage.submod, dm, PairwiseParallel(), tp_mesh_dim = tp_dim) if pp_rank == 0: - print(f"Rank {args.rank} Instantiated pipeline with ranks {pp_ranks}") - out = pipe_driver(inp) - print(f"Pipeline {args.rank} output: {out.size()}") + out = stage(inp) + elif pp_rank == args.pp_group_size - 1: + out = stage() + else: + stage() + + dist.barrier() + print(f"Rank {args.rank} completes") + + # Last rank checks result + if pp_rank == args.pp_group_size - 1: + ref_out = ec(inp) + torch.testing.assert_close(out, ref_out) + print( + f"Pipeline {tp_rank} equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" + ) def main(args=None): @@ -131,7 +135,7 @@ def main(args=None): parser.add_argument( "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 8)) ) - # ExampleCode has two stages + # ExampleCode has 4 stages parser.add_argument( "--pp_group_size", type=int, default=4, ) @@ -155,24 +159,41 @@ def main(args=None): "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") ) parser.add_argument( - "-s", - "--schedule", - type=str, - default="FillDrain", + "--cuda", type=int, default=int(torch.cuda.is_available()) ) parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) + "--chunks", type=int, default=4, ) args = parser.parse_args(args) # Use world size to determine TP group size assert args.world_size % args.pp_group_size == 0 args.tp_group_size = args.world_size // args.pp_group_size - print(f"Using tensor parallel group size: {args.tp_group_size}") + if args.rank == 0: + print( + f"Pipeline parallel size: {args.pp_group_size}\n" + f"Tensor parallel size: {args.tp_group_size}" + ) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + # HACK: we need to pin device here because `DeviceMesh` currently does + # an all_gather with device_type only, without device id + # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/device_mesh.py#L191-L192 + torch.cuda.set_device(args.device) + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) - # All ranks participate - args.gspmd = 1 - run_pippy(run_all, args) + run_all(args) if __name__ == "__main__": diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py new file mode 100644 index 000000000..d0f2d816a --- /dev/null +++ b/pippy/PipelineStage.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist +import pippy + +from pippy.fx.passes import shape_prop +from pippy.IR import Pipe +from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks + + +def _make_tensor_from_meta( + tensor_meta: shape_prop.TensorMetadata, + device: torch.device, +): + return torch.empty( + tensor_meta.shape, dtype=tensor_meta.dtype, device=device + ) + + +class RecvInfo: + def __init__( + self, + input_name: str, + source: int, + buffer: torch.Tensor, + ): + self.input_name = input_name + self.source = source + self.buffer = buffer + + +class PipelineStage(torch.nn.Module): + def __init__( + self, + pipe: Pipe, + rank: int, + nstages: int, + chunks: int, + device: torch.device, + group: dist.ProcessGroup = None, + return_to_0: bool = False, + args_chunk_spec=None, + kwargs_chunk_spec=None, + output_chunk_spec=None, + ): + super().__init__() + self.pipe = pipe + self.rank = rank + self.nstages = nstages + self.chunks = chunks + self.device = device + self.group = group + self.return_to_0 = return_to_0 + self.args_chunk_spec = args_chunk_spec + self.kwargs_chunk_spec = kwargs_chunk_spec + self.output_chunk_spec = output_chunk_spec + + # Find my submodule + self.split_gm = self.pipe.split_gm + named_children = list(self.split_gm.named_children()) + self.name, self.submod = named_children[rank] + logging.info( + f"[{self.rank}][{self.name}] " + f"Creating PipelineStage:\n" + f"{self.submod}" + ) + + # Find my node in graph + found_node = False + for node in self.split_gm.graph.nodes: + if node.name == self.name: + self.node = node + found_node = True + break + if not found_node: + raise AssertionError(f"Cannot find {self.name} in graph") + + # Create submod name to rank mapping + self.submod_to_rank: Dict[str, int] = {} + for i, (name, _) in enumerate(self.split_gm.named_children()): + self.submod_to_rank.setdefault(name, i) + + # Prepare send/recv infrastructure + self.args_recv_info = [] + self.kwargs_recv_info = [] + for _ in range(chunks): + args_recv, kwargs_recv = self._create_recv_buffers() + self.args_recv_info.append(args_recv) + self.kwargs_recv_info.append(kwargs_recv) + + self._create_send_info() + + def get_rank_of_submod( + self, + submod_name: str, + ): + if submod_name not in self.submod_to_rank: + raise AssertionError(f"Rank of {submod_name} not found") + + return self.submod_to_rank[submod_name] + + def _create_recv_buffers( + self, + ): + def create_recv_tensor( + input_node, + output_idx: Optional[int] = None, + ): + """ + Create a tensor for receiving the `output_idx`-th value from + `input_node` + """ + if input_node.op == "placeholder": + # Do not create buffer for placeholder + return None + + # In case the input is a `getitem` node, we recursively find the + # real source e.g. getitem1 = submod0[1] + # Here `submod0` is args[0], 1 is args[1] + if input_node.target is operator.getitem: + if "tensor_meta" in input_node.meta: + real_input_node = input_node.args[0] + out_idx = input_node.args[1] + return create_recv_tensor(real_input_node, out_idx) + else: + return None + + if output_idx is not None: + # If a node has multiple output values, "tensor_meta" is a list + # of tensor meta + tensor_meta = input_node.meta["tensor_meta"][output_idx] + else: + tensor_meta = input_node.meta["tensor_meta"] + + logging.info( + f"[{self.rank}][{self.name}] " + f"Creating recv buffer for input '{input_node.name}' " + f"value index {output_idx}: {tensor_meta.shape}" + ) + + src_rank = self.get_rank_of_submod(input_node.name) + return RecvInfo( + input_node.name, + src_rank, + _make_tensor_from_meta(tensor_meta, self.device), + ) + + # `args` is a Tuple, hence we will have: + # Tuple[RecvInfo] + args_recv_info = pippy.fx.node.map_arg( + self.node.args, create_recv_tensor + ) + + # `kwargs` is a Dict, hence we will have: + # Dict[keyword, RecvInfo] + kwargs_recv_info = pippy.fx.node.map_arg( + self.node.kwargs, create_recv_tensor + ) + + return args_recv_info, kwargs_recv_info + + def _create_send_info(self): + # Find send destinations + def find_dst_rank(user) -> Optional[int]: + if user.op == "output": + if self.return_to_0: + # Send result back to pp rank 0 + return 0 + else: + return None + else: + # User is a stage (`call_module`) + return self.get_rank_of_submod(user.name) + + # Output index: List of receivers + self.dst_infos: Dict[int, List] = {} + out_idx = 0 + + for user in self.node.users: + if user.target is operator.getitem: + # Recursively find the real destination + gi_dsts = self.dst_infos.setdefault(out_idx, []) + for gi_user in user.users: + gi_dsts.append(find_dst_rank(gi_user)) + # Next `getitem` will point to the next output index + out_idx += 1 + else: + # In case of single output value, `out_idx` will not increase + dsts = self.dst_infos.setdefault(out_idx, []) + dsts.append(find_dst_rank(user)) + + logging.info( + f"[{self.rank}][{self.name}] " f"Send info: {self.dst_infos}" + ) + + def forward(self, *args, **kwargs): + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self.chunks, + self.args_chunk_spec, + self.kwargs_chunk_spec, + ) + + # Receive requests of a chunk + recv_reqs: List[dist.Work] = [] + # Send requests of a chunk + send_reqs: List[dist.Work] = [] + + def recv_tensor(info): + if isinstance(info, RecvInfo): + logging.info( + f"[{self.rank}][{self.name}] " + f"Receiving tensor '{info.input_name}' from Rank {info.source}: " + f"{info.buffer.size()}" + ) + # Use async to parallelize recv of tensors + work = dist.irecv( + info.buffer, + info.source + if self.group is None + else dist.get_global_rank(self.group, info.source), + group=self.group, + ) + recv_reqs.append(work) + return info.buffer + else: + return info + + output_chunks = [] + + for chunk in range(self.chunks): + recv_reqs.clear() + if args: + chunk_args = args_split[chunk] + else: + chunk_args = pippy.fx.node.map_aggregate( + self.args_recv_info[chunk], + recv_tensor, + ) + + if kwargs: + chunk_kwargs = kwargs_split[chunk] + else: + chunk_kwargs = pippy.fx.node.map_aggregate( + self.kwargs_recv_info[chunk], + recv_tensor, + ) + + # Wait for all recvs to finish + for work in recv_reqs: + work.wait() + + # Compute + output = self.submod(*chunk_args, **chunk_kwargs) + + # Unify output form to tuple for easy correspondance with + # `dst_infos` + output_tuple = output if type(output) is tuple else (output,) + + for idx, out in enumerate(output_tuple): + dst_ranks = self.dst_infos[idx] + for dst in dst_ranks: + if dst is None: + # If dst is a `output` node, we don't need to send + # (unless `return_to_0` is required) + continue + logging.info( + f"[{self.rank}][{self.name}] " + f"Sending tensor to Rank {dst}: {out.size()}" + ) + work = dist.isend( + out, + dst + if self.group is None + else dist.get_global_rank(self.group, dst), # TODO + group=self.group, + ) + send_reqs.append(work) + + output_chunks.append(output) + + # Wait for all sends to finish + # TODO: okay to delay the sync till completion of all chunks? + for work in send_reqs: + work.wait() + + # Last rank return merged results per original format + if self.rank == self.nstages - 1: + return merge_chunks( + output_chunks, + self.output_chunk_spec, + ) + else: + return None diff --git a/pippy/__init__.py b/pippy/__init__.py index 96e209aa5..d8d6564ed 100644 --- a/pippy/__init__.py +++ b/pippy/__init__.py @@ -11,7 +11,12 @@ from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B from pippy.ModelSplit import split_on_size_threshold, split_into_equal_size from pippy.utils import run_pippy -from pippy.compile import compile, all_compile, create_default_args +from pippy.compile import ( + compile, + all_compile, + create_default_args, + compile_stage, +) __all__ = [ @@ -30,4 +35,5 @@ "compile", "all_compile", "create_default_args", + "compile_stage", ] diff --git a/pippy/compile.py b/pippy/compile.py index 3a8a54f2d..62a988aa6 100644 --- a/pippy/compile.py +++ b/pippy/compile.py @@ -7,12 +7,20 @@ PipelineDriverFillDrain, PipelineDriverInterleaved1F1B, ) +from pippy.PipelineStage import PipelineStage import pippy.fx as fx from pippy.IR import MultiUseParameterConfig, Pipe -from pippy.microbatch import LossReducer, sum_reducer -from pippy.utils import get_device, get_pp_rank, get_rank +from pippy.fx.passes import shape_prop +from pippy.microbatch import ( + LossReducer, + split_args_kwargs_into_chunks, + sum_reducer, +) +from pippy.utils import get_device, get_pp_rank, get_rank, PIPPY_VERBOSITY import torch +import torch.distributed as dist +from torch._subclasses.fake_tensor import FakeTensorMode PIPELINE_SCHEDULE_DRIVERS = { @@ -204,3 +212,86 @@ def all_compile( _debug_mask_minibatches=_debug_mask_minibatches, **kwargs, ) + + +def compile_stage( + mod: torch.nn.Module, + rank: int, + num_ranks: int, + num_chunks: int, + device: torch.device, + group: dist.ProcessGroup, + example_inputs: List[torch.Tensor], + split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, + return_to_0: bool = False, + tracer=None, + args_chunk_spec=None, + kwargs_chunk_spec=None, + output_chunk_spec=None, + **kwargs, +) -> PipelineStage: + # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across + # stages instead of TRANSMIT'ting it + multi_use_param_spec = MultiUseParameterConfig.REPLICATE + + # Figure out which output is loss from output_chunk_spec + output_loss_value_spec: Any = None + if isinstance(output_chunk_spec, dict): + output_loss_value_spec = { + k: isinstance(v, LossReducer) for k, v in output_chunk_spec.items() + } + + logging.info("[PiPPy] Tracing model ...") + pipe = Pipe.from_tracing( + mod, + multi_use_param_spec=multi_use_param_spec, + tracer=tracer, + output_loss_value_spec=output_loss_value_spec, + split_policy=split_policy, + **kwargs, + ) + + gm = pipe.split_gm + if rank == 0: + logging.info(gm) + if PIPPY_VERBOSITY == "INFO": + gm.graph.print_tabular() + + # Get shape of chunked arguments + args_split, _ = split_args_kwargs_into_chunks( + example_inputs, + {}, # kwargs included in `example_inputs` + num_chunks, + args_chunk_spec, + kwargs_chunk_spec, # TODO: merge into args_chunk_spec + ) + + # Use fake tensor for shape propagation + # Since model itself may have been materialized, we need to use + # `allow_non_fake_inputs` + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + # In reality, the fake input should be created from shape info (potentially + # broadcast from Rank 0) + fake_args_split = fx.node.map_aggregate( + args_split, lambda a: fake_mode.from_tensor(a) + ) + + # Use 1st chunk of args for shape propagation + chunk0 = fake_args_split[0] + + sp = shape_prop.ShapeProp(gm) + sp.propagate(*chunk0) + + # Create pipeline stage + return PipelineStage( + pipe, + rank, + num_ranks, + num_chunks, + device, + group, + return_to_0, + args_chunk_spec, + kwargs_chunk_spec, + output_chunk_spec, + ) diff --git a/test/local_test_c10d.py b/test/local_test_c10d.py index 9fbba1fc1..5044c50d3 100644 --- a/test/local_test_c10d.py +++ b/test/local_test_c10d.py @@ -4,15 +4,14 @@ import unittest import torch -from torch._subclasses.fake_tensor import FakeTensorMode import torch.distributed as dist -from pippy.fx.passes import shape_prop -from pippy.IR import MultiUseParameterConfig, Pipe, pipe_split +from pippy.compile import compile_stage +from pippy.IR import pipe_split d_hid = 512 -bs = 256 +chunk_size = 256 torch.manual_seed(0) @@ -24,16 +23,17 @@ def __init__(self): self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin = torch.nn.Linear(d_hid, d_hid) - def forward(self, x): + def forward(self, x, y): x = torch.mm(x, self.mm_param) - # skip_connection = x + skip_connection = x + x = x + y x = torch.relu(x) pipe_split() x = torch.mm(x, self.mm_param) x = self.lin(x) pipe_split() x = torch.relu(x) - # x = x + skip_connection + x = x + skip_connection x = torch.mm(x, self.mm_param2) pipe_split() x = self.lin(x) @@ -41,76 +41,40 @@ def forward(self, x): return x -def make_tensor_from_meta( - tensor_meta: shape_prop.TensorMetadata, - device: torch.device, -): - return torch.empty( - tensor_meta.shape, dtype=tensor_meta.dtype, device=device - ) - - def run_worker(args): ec = ExampleCode() ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - # Trace and cut - ec_pipe = Pipe.from_tracing(ec, MultiUseParameterConfig.REPLICATE) - gm = ec_pipe.split_gm - if args.rank == 0: - print(gm) - gm.graph.print_tabular() - - # Use fake tensor for shape propagation - # Since model itself may have been materialized, we need to use - # `allow_non_fake_inputs` - fake_mode = FakeTensorMode(allow_non_fake_inputs=True) - # In reality, the fake input should be created from shape info (potentially - # broadcast from Rank 0) - fake_input = fake_mode.from_tensor(ec_input) - sp = shape_prop.ShapeProp(gm) - sp.propagate(fake_input) - - # Find my submodule - named_children = list(gm.named_children()) - name, submod = named_children[args.rank] - print(f"{name}: {submod}") - - # Find my input size - for node in gm.graph.nodes: - if node.name == name: - input = node.args[0] # args is a tuple - break - tensor_meta = input.meta["tensor_meta"] - print(tensor_meta) - - # Get input + ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + ec_y = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + + stage = compile_stage( + ec, + args.rank, + args.world_size, + args.chunks, + args.device, + None, + [ec_x, ec_y], + ) + + # Run if args.rank == 0: - x = ec_input + out = stage(ec_x, ec_y) + elif args.rank == args.world_size - 1: + out = stage() else: - x = make_tensor_from_meta(tensor_meta, args.device) - dist.recv(x, args.rank - 1) - - # Compute - y = submod(x) + stage() - # Send to next stage - dist.send(y, (args.rank + 1) % args.world_size) + dist.barrier() + print(f"Rank {args.rank} completes") - # Rank 0 checks result - if args.rank == 0: - # Get final output shape - for node in gm.graph.nodes: - if node.target == "output": - break - tensor_meta = node.meta["tensor_meta"] - z = make_tensor_from_meta(tensor_meta, args.device) - dist.recv(z, args.world_size - 1) - ref_out = ec_pipe(ec_input) - torch.testing.assert_close(z, ref_out) + # Last rank checks result + if args.rank == args.world_size - 1: + ref_out = ec(ec_x, ec_y) + torch.testing.assert_close(out, ref_out) print( - f"equivalence test passed {torch.sum(z)} ref {torch.sum(ref_out)}" + f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" ) @@ -129,17 +93,22 @@ def main(args=None): parser.add_argument( "--cuda", type=int, default=int(torch.cuda.is_available()) ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) args = parser.parse_args(args) if args.cuda: dev_id = args.rank % torch.cuda.device_count() - args.device = f"cuda:{dev_id}" + args.device = torch.device(f"cuda:{dev_id}") else: - args.device = "cpu" + args.device = torch.device("cpu") # Init process group backend = "nccl" if args.cuda else "gloo" - torch.distributed.init_process_group( + dist.init_process_group( backend=backend, rank=args.rank, world_size=args.world_size, diff --git a/test/min_gpt_pp_tp.py b/test/min_gpt_pp_tp.py new file mode 100644 index 000000000..fc7c917a9 --- /dev/null +++ b/test/min_gpt_pp_tp.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os + +import torch +import torch.distributed.tensor.parallel as tp + +import pippy +import pippy.fx +from pippy.IR import PipeSplitWrapper, annotate_split_points + +import torch.distributed as dist +from torch.distributed._tensor import ( + DeviceMesh, +) + +from minGPT.mingpt.model import GPT, GPTConfig +from min_gpt_tracing import AdditionDataset # type: ignore + +pippy.fx.Tracer.proxy_buffer_attributes = True + +batch_size_per_chunk = 8 + +# The seed here has two purposes: +# - Ensure all TP ranks have same input +# - Ensure the model (ec) created are the same, as if it comes from a +# single, big model before partitioning +torch.manual_seed(0) + +ndigit = 2 +train_dataset = AdditionDataset(ndigit=ndigit, split="train") +test_dataset = AdditionDataset(ndigit=ndigit, split="test") + +mconf = GPTConfig( + train_dataset.vocab_size, + train_dataset.block_size, + n_layer=4, + n_head=4, + n_embd=128, +) + +d_hid = 4 + + +def run_all(args): + # initialize a baby GPT model + model = GPT(mconf) + model.eval() + model.to(args.device) + + # Specify split points + sp_spec = { + "blocks.0.mlp.3": PipeSplitWrapper.SplitPoint.END, + "blocks.1.mlp.3": PipeSplitWrapper.SplitPoint.END, + "blocks.2.mlp.3": PipeSplitWrapper.SplitPoint.END, + } + annotate_split_points(model, sp_spec) + + # Create input + x = torch.tensor([1, 2, 3, 4], dtype=torch.long) + batch_size = args.chunks * batch_size_per_chunk + device_type = args.device.type + inp = x.repeat(batch_size, 1).to(args.device) + + # Create global DeviceMesh + ranks = torch.arange(args.world_size) + rank_mesh = ranks.reshape(args.pp_group_size, args.tp_group_size) + pp_dim = 0 + tp_dim = 1 + dev_mesh = DeviceMesh( + device_type, + rank_mesh, + ) + + # Figure out my PP and TP rank + pp_rank = args.rank // args.tp_group_size + tp_rank = args.rank % args.tp_group_size + print( + f"Global rank {args.rank}, pp rank: {pp_rank}, tp rank: {tp_rank}, device: {args.device}" + ) + + # Get pp group + # `tp_rank` can serve as pipeline id + print( + f"Rank {args.rank} Instantiating pipeline with ranks {dev_mesh.mesh[:, tp_rank]}" + ) + pp_group = dev_mesh.get_dim_groups()[pp_dim] + + # Get stage module (on all pp ranks) + stage = pippy.compile_stage( + model, + pp_rank, + args.pp_group_size, + args.chunks, + args.device, + pp_group, + example_inputs=[inp], + concrete_args={"targets": None}, + ) + + # Tensor parallelize submodules + print(f"Rank {args.rank} TP-lize submodule with {dev_mesh.mesh[pp_rank]}") + tp.parallelize_module( + stage.submod, + dev_mesh, + parallelize_plan={ + f"blocks_{pp_rank}_mlp_0": tp.ColwiseParallel(), + f"blocks_{pp_rank}_mlp_2": tp.RowwiseParallel(), + }, + tp_mesh_dim=tp_dim, + ) + + if pp_rank == 0: + out = stage(None, inp) + else: + out = stage() + + dist.barrier() + print(f"Rank {args.rank} completes") + + # Last rank checks result + if pp_rank == args.pp_group_size - 1: + ref_out = model(inp)[0] # [0] is logits, [1] is loss (None) + torch.testing.assert_close(out, ref_out) + print( + f"Pipeline {tp_rank} equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" + ) + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 8)) + ) + # ExampleCode has 4 stages + parser.add_argument( + "--pp_group_size", + type=int, + default=4, + ) + # in row-major + # TP ranks are contiguous rows of size `args.tp_group_size` + # PP ranks are non-contiguous columns of size `args.pp_group_size` + # + # if tp_group_size = 4 and pp_group_size = 3 + # + # 0 1 2 3 + # 4 5 6 7 + # 8 9 10 11 + # + # TP ranks are [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11] + # PP ranks are [0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11] + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + # Use world size to determine TP group size + assert args.world_size % args.pp_group_size == 0 + args.tp_group_size = args.world_size // args.pp_group_size + if args.rank == 0: + print( + f"Pipeline parallel size: {args.pp_group_size}\n" + f"Tensor parallel size: {args.tp_group_size}" + ) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + # HACK: we need to pin device here because `DeviceMesh` currently does + # an all_gather with device_type only, without device id + # https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/device_mesh.py#L191-L192 + torch.cuda.set_device(args.device) + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_all(args) + + +if __name__ == "__main__": + main() From e6efebbdd8e8ea32dc7ebf497c2f45bca9670d6c Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Fri, 2 Jun 2023 17:36:55 +0000 Subject: [PATCH 03/96] simple tp tests --- examples/inference/pp_tp.py | 175 ++++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 examples/inference/pp_tp.py diff --git a/examples/inference/pp_tp.py b/examples/inference/pp_tp.py new file mode 100644 index 000000000..1c0792bb1 --- /dev/null +++ b/examples/inference/pp_tp.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os + +import torch +import pippy +import pippy.fx +from pippy import run_pippy +from pippy.hf import PiPPyHFTracer, inject_pipeline_forward +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch.distributed._tensor import ( + DeviceMesh, +) +from torch.distributed.tensor.parallel import ( + PairwiseParallel, + parallelize_module, +) + +pippy.fx.Tracer.proxy_buffer_attributes = True + +gigabyte_size = 1024 ** 3 + + +def format_to_gb(item, precision=4): + """quick function to format numbers to gigabyte and round to (default) 4 digit precision""" + metric_num = item / gigabyte_size + metric_num = round(metric_num, ndigits=precision) + return metric_num + + +def print_mem_usage(): + memory_reserved = format_to_gb(torch.cuda.memory_reserved()) + memory_allocated = format_to_gb(torch.cuda.memory_allocated()) + print( + f"memory_reserved: {memory_reserved} GB, " + f"memory_allocated: {memory_allocated} GB" + ) + + +def get_number_of_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def run_all(pp_ranks, args): + device_type = "cuda" if args.cuda else "cpu" + pp_rank = args.rank // args.tp_group_size + model = args.model + model.eval() + model.config.use_cache = False # don't output `past_key_values` + num_ranks = len(pp_ranks) + + if args.rank == 0: + print(model.config) + print(f"model total number of params = {get_number_of_params(model) // 10 ** 6}M") + + split_policy = pippy.split_into_equal_size(num_ranks) + + # Use default value for kwargs other than `input_ids` + concrete_args = pippy.create_default_args( + model, + except_keys="input_ids", + ) + if 'bloom' in args.model_name: + # Used to avoid a control flow and tracing `len` call in BloomForCausalLM that looks like this: + # `if len(deprecated_arguments) > 0:` + concrete_args.setdefault("deprecated_arguments", {}) + + pipe_driver, stage_mod = pippy.all_compile( + model, + num_ranks, + args.chunks, + split_policy=split_policy, + tracer=PiPPyHFTracer(), + concrete_args=concrete_args, + index_filename=args.index_filename, + checkpoint_prefix=args.checkpoint_prefix, + ) + my_device_mesh = None + for stage in range(args.pp_group_size): + start_rank = stage * args.tp_group_size + tp_ranks = list(range(start_rank, start_rank + args.tp_group_size)) + tp_device_mesh = DeviceMesh( + device_type, + tp_ranks, + ) + if stage == pp_rank: + my_device_mesh = tp_device_mesh + + # Tensor parallelize submodules + print(f"Rank {args.rank} calling parallelize_module with {my_device_mesh}") + parallelize_module(stage_mod, my_device_mesh, PairwiseParallel()) + + params = get_number_of_params(stage_mod) + print(f"submod_{args.rank} {params // 10 ** 6}M params") + + if args.rank != 0: + return + + # Master continues + print_mem_usage() + + # Inject pipeline driver's forward function back to original model to support HF's `generate()` method + inject_pipeline_forward(model, pipe_driver) + + # Generate text based on prompt + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + prompt = "Hey, are you conscious? Can you talk to me?" + input = tokenizer(prompt, return_tensors="pt") + input_ids = input["input_ids"].to(args.device) + outputs = model.generate(input_ids, max_new_tokens=30) + response = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + print(response) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--model_name', type=str, default='facebook/opt-350m') + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--chunks', type=int, default=1) + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument('--tp_group_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--pp_group_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--dtype', type=str, default="fp32", choices=["fp32", "bf16", "fp16"]) + parser.add_argument('--index_filename', type=str, default=None, help="The director of model's index.json file") + parser.add_argument('--checkpoint_prefix', type=str, default=None, help="Prefix to add to the weight names in checkpoint map back to model structure") + + args = parser.parse_args() + + assert args.world_size % args.pp_group_size == 0 + + supported_model_categories = ["opt", "gpt", "bloom", "codegen", "llama"] + # For example: + # "facebook/opt-350m" + # "gpt2" + # "bigscience/bloom-3b" + # "EleutherAI/gpt-neo-2.7B" + # "Salesforce/codegen-2B-multi" + # "cerebras/Cerebras-GPT-13B" + + if args.dtype == "fp32": + dtype = torch.float32 + elif args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + else: + # Using float32 as default dtype to correspond to the default "fp32" + # value for "--dtype" + print( + f"Unsupported data type {args.dtype}, " + "please submit a PR to support it. Falling back to fp32 now." + ) + dtype = torch.float32 + + # Main process loads model + if any([m in args.model_name + # Some model names use upper case + or m.upper() in args.model_name + for m in supported_model_categories]): + print(f"Loading model {args.model_name}") + if args.index_filename is not None: + with torch.device("meta"): + model = AutoModelForCausalLM.from_pretrained(args.model_name, use_cache=False, torch_dtype=dtype) + else: + model = AutoModelForCausalLM.from_pretrained(args.model_name, use_cache=False, torch_dtype=dtype) + else: + raise ValueError(f"Unsupported model: {args.model_name}") + + args.model = model + args.gspmd = 1 + run_pippy(run_all, args) From 433b702922dbc60fb003d16d479ed0090b2e562c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 6 Jun 2023 16:19:48 -0700 Subject: [PATCH 04/96] Pause some HF model tests for cleaner CI signal (#805) Should re-enable after cleaning out tests. For example, we should fix the test to certain models rather than tracking all. --- .github/workflows/pippy_tests.yaml | 182 ++++++++++++++--------------- 1 file changed, 91 insertions(+), 91 deletions(-) diff --git a/.github/workflows/pippy_tests.yaml b/.github/workflows/pippy_tests.yaml index caaea040c..69d91ad77 100644 --- a/.github/workflows/pippy_tests.yaml +++ b/.github/workflows/pippy_tests.yaml @@ -44,33 +44,33 @@ jobs: run: | pytest --cov=pippy --ignore=test/hf_test.py --ignore=test/test_fx.py --ignore=test/test_fx_experimental.py --ignore=test/fx test/ - hf_model_tests: - runs-on: linux.12xlarge - strategy: - matrix: - python-version: ["3.9"] - shard: ["0", "1", "2", "3", "4", "5", "6", "7"] - container: - image: python:${{ matrix.python-version }} + # hf_model_tests: + # runs-on: linux.12xlarge + # strategy: + # matrix: + # python-version: ["3.9"] + # shard: ["0", "1", "2", "3", "4", "5", "6", "7"] + # container: + # image: python:${{ matrix.python-version }} - steps: - - uses: actions/checkout@v2 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest pytest-cov pytest-xdist pytest-shard numpy - if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pavel's huggingface fork - run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses - - name: Install pippy - run: "python setup.py install" - # Single thread to avoid OOM - - name: Test forward only - run: | - pytest --shard-id=${{ matrix.shard }} --num-shards=8 -k 'not HFModelsForwardBackwardTest' -sv --cov=pippy test/hf_test.py - - name: Test forward and backward - run: | - pytest --shard-id=${{ matrix.shard }} --num-shards=8 -k 'HFModelsForwardBackwardTest' -sv --cov=pippy test/hf_test.py + # steps: + # - uses: actions/checkout@v2 + # - name: Install dependencies + # run: | + # python -m pip install --upgrade pip + # pip install flake8 pytest pytest-cov pytest-xdist pytest-shard numpy + # if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi + # - name: Install pavel's huggingface fork + # run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses + # - name: Install pippy + # run: "python setup.py install" + # # Single thread to avoid OOM + # - name: Test forward only + # run: | + # pytest --shard-id=${{ matrix.shard }} --num-shards=8 -k 'not HFModelsForwardBackwardTest' -sv --cov=pippy test/hf_test.py + # - name: Test forward and backward + # run: | + # pytest --shard-id=${{ matrix.shard }} --num-shards=8 -k 'HFModelsForwardBackwardTest' -sv --cov=pippy test/hf_test.py integration_test_cpu: runs-on: linux.4xlarge @@ -117,74 +117,74 @@ jobs: - name: Run compile test run: python test/local_test_compile.py -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - hf_examples_set1: - runs-on: linux.12xlarge - strategy: - matrix: - python-version: ["3.9"] - schedule: ["FillDrain", "1F1B"] - env: - OMP_NUM_THREADS: "1" - container: - image: python:${{ matrix.python-version }} + # hf_examples_set1: + # runs-on: linux.12xlarge + # strategy: + # matrix: + # python-version: ["3.9"] + # schedule: ["FillDrain", "1F1B"] + # env: + # OMP_NUM_THREADS: "1" + # container: + # image: python:${{ matrix.python-version }} - steps: - - uses: actions/checkout@v2 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest pytest-cov numpy datasets evaluate scikit-learn sacrebleu - if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pavel's huggingface fork - run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses - - name: Install pippy - run: "python setup.py install" - - name: Test min-GPT - run: | - git config --global --add safe.directory /__w/tau/tau - git submodule update --init test/minGPT - python test/min_gpt_tracing.py - - name: Run GPT2 example - run: python examples/hf/gpt2/pippy_gpt2.py -s ${{ matrix.schedule }} - - name: Run BERT example - run: python examples/hf/bert/pippy_bert.py -s ${{ matrix.schedule }} - - name: Run T5 example - run: python examples/hf/t5/pippy_t5.py -s ${{ matrix.schedule }} - - name: "HF Translation: fine-tune T5 model translation English to Romanian" - run: > - python examples/hf/translation/run_translation.py --model_name_or_path t5-small --do_train --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --dataset_name wmt16 --dataset_config_name ro-en --output_dir /tmp/tst-translation --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --overwrite_output_dir --predict_with_generate --max_steps=10 --dp_group_size=1 --pp_group_size=8 - - name: "HF Translation: fine-tune BART model translation English to Romanian" - run: > - python examples/hf/translation/run_translation.py --model_name_or_path facebook/bart-base --do_train --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --dataset_name wmt16 --dataset_config_name ro-en --output_dir /tmp/tst-translation --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --overwrite_output_dir --predict_with_generate --max_steps=10 --dp_group_size=2 --pp_group_size=8 + # steps: + # - uses: actions/checkout@v2 + # - name: Install dependencies + # run: | + # python -m pip install --upgrade pip + # pip install flake8 pytest pytest-cov numpy datasets evaluate scikit-learn sacrebleu + # if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi + # - name: Install pavel's huggingface fork + # run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses + # - name: Install pippy + # run: "python setup.py install" + # - name: Test min-GPT + # run: | + # git config --global --add safe.directory /__w/tau/tau + # git submodule update --init test/minGPT + # python test/min_gpt_tracing.py + # - name: Run GPT2 example + # run: python examples/hf/gpt2/pippy_gpt2.py -s ${{ matrix.schedule }} + # - name: Run BERT example + # run: python examples/hf/bert/pippy_bert.py -s ${{ matrix.schedule }} + # - name: Run T5 example + # run: python examples/hf/t5/pippy_t5.py -s ${{ matrix.schedule }} + # - name: "HF Translation: fine-tune T5 model translation English to Romanian" + # run: > + # python examples/hf/translation/run_translation.py --model_name_or_path t5-small --do_train --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --dataset_name wmt16 --dataset_config_name ro-en --output_dir /tmp/tst-translation --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --overwrite_output_dir --predict_with_generate --max_steps=10 --dp_group_size=1 --pp_group_size=8 + # - name: "HF Translation: fine-tune BART model translation English to Romanian" + # run: > + # python examples/hf/translation/run_translation.py --model_name_or_path facebook/bart-base --do_train --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --dataset_name wmt16 --dataset_config_name ro-en --output_dir /tmp/tst-translation --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --overwrite_output_dir --predict_with_generate --max_steps=10 --dp_group_size=2 --pp_group_size=8 - hf_examples_set2: - runs-on: linux.12xlarge - strategy: - matrix: - python-version: ["3.9"] - schedule: ["FillDrain", "1F1B"] - env: - OMP_NUM_THREADS: "1" - container: - image: python:${{ matrix.python-version }} + # hf_examples_set2: + # runs-on: linux.12xlarge + # strategy: + # matrix: + # python-version: ["3.9"] + # schedule: ["FillDrain", "1F1B"] + # env: + # OMP_NUM_THREADS: "1" + # container: + # image: python:${{ matrix.python-version }} - steps: - - uses: actions/checkout@v2 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest pytest-cov numpy datasets evaluate scikit-learn sacrebleu - if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pavel's huggingface fork - run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses - - name: Install pippy - run: "python setup.py install" - - name: "HF Causal Language Modeling: fine-tune GPT-2 on WikiText-2" - run: python examples/hf/language-modeling/run_clm.py --dp_group_size=2 --pp_group_size=8 --model_name_or_path gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --output_dir /tmp/test-clm --max_steps=3 --overwrite_output_dir - - name: "HF Masked Language Modeling: fine-tune RoBERTa on WikiText-2" - run: python examples/hf/language-modeling/run_mlm.py --dp_group_size=2 --pp_group_size=8 --model_name_or_path roberta-base --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --output_dir /tmp/test-mlm --max_steps=3 --overwrite_output_dir - - name: "HF Text classification: fine-tune BERT on the GLUE benchmark" - run: python examples/hf/text-classification/run_glue.py --dp_group_size=2 --pp_group_size=8 --model_name_or_path bert-base-cased --task_name mrpc --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --output_dir /tmp/mrpc/ --max_steps=3 --overwrite_output_dir + # steps: + # - uses: actions/checkout@v2 + # - name: Install dependencies + # run: | + # python -m pip install --upgrade pip + # pip install flake8 pytest pytest-cov numpy datasets evaluate scikit-learn sacrebleu + # if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi + # - name: Install pavel's huggingface fork + # run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses + # - name: Install pippy + # run: "python setup.py install" + # - name: "HF Causal Language Modeling: fine-tune GPT-2 on WikiText-2" + # run: python examples/hf/language-modeling/run_clm.py --dp_group_size=2 --pp_group_size=8 --model_name_or_path gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --output_dir /tmp/test-clm --max_steps=3 --overwrite_output_dir + # - name: "HF Masked Language Modeling: fine-tune RoBERTa on WikiText-2" + # run: python examples/hf/language-modeling/run_mlm.py --dp_group_size=2 --pp_group_size=8 --model_name_or_path roberta-base --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --do_train --do_eval --output_dir /tmp/test-mlm --max_steps=3 --overwrite_output_dir + # - name: "HF Text classification: fine-tune BERT on the GLUE benchmark" + # run: python examples/hf/text-classification/run_glue.py --dp_group_size=2 --pp_group_size=8 --model_name_or_path bert-base-cased --task_name mrpc --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --output_dir /tmp/mrpc/ --max_steps=3 --overwrite_output_dir integration_test_gpu: runs-on: linux.16xlarge.nvidia.gpu From fc99431c0d62977fd70f02f2e51cdc0885be5b9d Mon Sep 17 00:00:00 2001 From: Yeonju Ro Date: Tue, 6 Jun 2023 23:44:49 -0700 Subject: [PATCH 05/96] print module and code together in min_gpt_tracing (#804) Print module and code together in min_gpt_tracing test code ## Type of change Minor ## Feature/Issue validation/testing Modified debug message in test file --- test/min_gpt_tracing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/min_gpt_tracing.py b/test/min_gpt_tracing.py index 636298cc3..16106d167 100644 --- a/test/min_gpt_tracing.py +++ b/test/min_gpt_tracing.py @@ -132,7 +132,7 @@ def __getitem__(self, idx): "submod_2", ] -print(traced_pipe.split_gm.code) +print(traced_pipe.split_gm) """ def forward(self, idx, targets_1 = None): submod_0 = self.submod_0(idx); idx = None From 8f549f37f96e8c89e4c6dd3e1caf358273ce8032 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Mon, 12 Jun 2023 17:47:09 -0700 Subject: [PATCH 06/96] Add docstring for load_checkpoint (#807) ## Description Added docstrings to load_checkpoint module and _get_file_to_weight_map function. ## Type of change Please delete options that are not relevant. - [ ] Bug fix (non-breaking change which fixes an issue) --------- Co-authored-by: Eddy --- pippy/LoadModule.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/pippy/LoadModule.py b/pippy/LoadModule.py index 33bc1e8be..b0a620d8a 100644 --- a/pippy/LoadModule.py +++ b/pippy/LoadModule.py @@ -20,7 +20,22 @@ def load_checkpoint( device: torch.device = None, dtype: torch.dtype = None, checkpoint_prefix: str = None, -): +) -> nn.Module: + """ + Load a checkpoint from a model file. + Args: + model (`torch.nn.Module`): the model to load the checkpoint into + index_filename (`Union[str, os.PathLike]`): path to the checkpoint's index (metadata file) + device (`torch.device`): the device on which to load the checkpoint + dtype (`torch.dtype`): the dtype on which to load the checkpoint + checkpoint_prefix (`str`): the prefix of the checkpoint to load + Returns: + The loaded checkpoint model + Example: + ``` + checkpoint = load_checkpoint(model, index_filename, device, dtype) + ``` + """ checkpoint_folder = os.path.split(index_filename)[0] with open(index_filename, "r") as f: index = json.loads(f.read()) @@ -68,9 +83,20 @@ def load_checkpoint( def _get_file_to_weight_map( model: nn.Module, - index, + index: Dict[str, str], prefix_to_test: List[str], ) -> Dict[str, List[Tuple]]: + """ + A helper function to create a mapping from binary checkpoint filename to parameter names + Args: + model (`torch.nn.Module`): The model to load weights into + index (`Dict[str, str]`): The checkpoint index mapping parameter name to binary checkpoint filename + prefix_to_test (`List[str]`): prefix to try if direct match is not found + Returns: + `Dict[str, List[Tuple]]`: A mapping from binary checkpoint filename to list of tuples of parameter names + Raises: + RuntimeError: if a parameter name is not found in the checkpoint index + """ file_to_weights: Dict[str, List[Tuple]] = {} for iterator in [ From e1dee3b878519a4940b516942a6c108b299cac28 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Tue, 20 Jun 2023 11:03:10 -0700 Subject: [PATCH 07/96] save checkpoint index file (#812) ## Description As part of progressing towards implementing checkpointing for pippy, `_save_index` writes the checkpoint metadata in a json file. Fixes #(issue) ## Type of change Please delete options that are not relevant. - [x] New feature (non-breaking change which adds functionality) - [x] This change requires a documentation update ## Feature/Issue validation/testing ## Checklist: - [ ] Have you added tests that prove your fix is effective or that this feature works? - [ ] Has code been commented, particularly in hard-to-understand areas? - [ ] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy --- pippy/hf/_SaveModule.py | 61 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 pippy/hf/_SaveModule.py diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py new file mode 100644 index 000000000..51bb67529 --- /dev/null +++ b/pippy/hf/_SaveModule.py @@ -0,0 +1,61 @@ +import torch.distributed as dist +import pippy + +import json +import os + +CKPT_INDEX_JSON_FILENAME = "pytorch_model.bin.index.json" + + +def _save_index( + pipe: pippy.fx.GraphModule, + ckpt_index_filename: str = CKPT_INDEX_JSON_FILENAME, + checkpoint_dir: str = "checkpoints", +) -> None: + """ + Saves index file describing location of weights in checkpoint. + + Args: + pipe (pippy.fx.GraphModule): pipeline graph module with weights to save + ckpt_index_filename (str, optional): name of index file. Defaults to "pytorch_model.bin.index.json". + checkpoint_dir (str, optional): directory to save checkpoint to. Defaults to "checkpoints". + """ + index_dict = {} + total_size = 0 + index_dict["metadata"] = {"total_size": total_size} + + weight_map = {} + for idx, (submod_name, submod) in enumerate(pipe.split_gm.named_children()): + for param_name, _ in submod.named_parameters(): + old_name = submod.remap_qualname(param_name) + + binary_filename = _create_binary_filename(idx) + weight_map[old_name] = binary_filename + index_dict["weight_map"] = weight_map + + json_str = json.dumps(index_dict, indent=4) + + filepath = os.path.join(checkpoint_dir, ckpt_index_filename) + + # create checkpoint directory if it doesn't exist + if not os.path.exists(checkpoint_dir): + os.mkdir(checkpoint_dir) + + with open(filepath, "w") as f: + f.write(json_str) + + +def _create_binary_filename(cur_idx: int) -> str: + """ + Gets filename for pytorch checkpoint binary based on current index and world size. + + Args: + cur_idx (int): current device index + + Returns: + str: checkpoint filename + """ + cur_idx = str(cur_idx + 1).zfill(5) + world_size = str(dist.get_world_size()).zfill(5) + + return f"pytorch_model-{cur_idx}-of-{world_size}.bin" From 595d3c53c11d115ec8519be0621e0aa238b5ab7d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 21 Jun 2023 11:11:02 -0700 Subject: [PATCH 08/96] Add c10d implementation for backward pass too (#811) ## Description This is a follow-up of PR #800. It adds backward pass implementation in c10d-based PipelineStage. New test: test/local_test_c10d_bwd.py --- pippy/IR.py | 40 ++- pippy/PipelineDriver.py | 28 +-- pippy/PipelineStage.py | 483 ++++++++++++++++++++++++++++-------- pippy/backward.py | 28 ++- pippy/compile.py | 19 +- pippy/debug.py | 28 +++ pippy/utils.py | 51 +++- test/local_test_c10d_bwd.py | 135 ++++++++++ 8 files changed, 660 insertions(+), 152 deletions(-) create mode 100644 pippy/debug.py create mode 100644 test/local_test_c10d_bwd.py diff --git a/pippy/IR.py b/pippy/IR.py index 32190d905..ec84e2b93 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -11,6 +11,7 @@ import torch.fx as torch_fx import pippy.fx from packaging import version +from pippy.fx.passes import shape_prop from pippy.fx.passes.split_module import split_module from pippy.backward import ( stage_backward, @@ -125,7 +126,10 @@ def _find_loss_output( def _insert_stage_symbolic_backward( - g: pippy.fx.Graph, loss_node: pippy.fx.Node, output_node: pippy.fx.Node + g: pippy.fx.Graph, + loss_node: pippy.fx.Node, + output_node: pippy.fx.Node, + return_to_0: bool, ): # Collect metadata about tuple output values. TODO: move this to split_module or FX IR tuples: Dict[pippy.fx.Node, Tuple] = {} @@ -250,10 +254,11 @@ def add_to_live_nodes(n): # guaranteed that all backwards jobs for that micro-batch have been executed. # When all micro-batch pipeline outputs are ready, gradients have been fully # computed and synchronized and the optimizer step can be applied. - barrier_call = g.call_function( - sync_barrier, (output_node.args[0], barrier_tokens, last_grads) - ) - output_node.args = (barrier_call,) + if return_to_0: + barrier_call = g.call_function( + sync_barrier, (output_node.args[0], barrier_tokens, last_grads) + ) + output_node.args = (barrier_call,) return g @@ -689,6 +694,7 @@ def _from_traced( traced: pippy.fx.GraphModule, multi_use_param_spec: Optional[MultiUseParamSpec] = None, output_loss_value_spec=None, + return_to_0: bool = True, ): """ Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate @@ -991,7 +997,10 @@ def move_param_to_callee( ) if loss_node is not None: _insert_stage_symbolic_backward( - split.graph, loss_node, output_node + split.graph, + loss_node, + output_node, + return_to_0, ) split.recompile() has_loss_and_backward = True @@ -1027,6 +1036,7 @@ def from_tracing( split_policy: Optional[ Callable[[pippy.fx.GraphModule], pippy.fx.GraphModule] ] = None, + return_to_0: bool = True, **kwargs, ): # TODO: abstract partitioning policy @@ -1070,6 +1080,7 @@ def remap_vals(n): traced, multi_use_param_spec, output_loss_value_spec=output_loss_value_spec, + return_to_0=return_to_0, ) def __str__(self): @@ -1184,3 +1195,20 @@ def annotate_split_points( mod_to_wrap = getattr(predecessor_module, atoms[-1]) wrapped_mod = PipeSplitWrapper(mod_to_wrap, split_type) setattr(predecessor_module, atoms[-1], wrapped_mod) + + +class PiPPyShapeProp(shape_prop.ShapeProp): + def __init__( + self, module: pippy.fx.GraphModule, garbage_collect_values: bool = True + ): + super().__init__(module, garbage_collect_values) + self.stop_prop = False + + def run_node(self, n: pippy.fx.Node) -> Any: + if (n.op, n.target) == ("call_function", stage_backward): + self.stop_prop = True + + if self.stop_prop: + return None + + return super().run_node(n) diff --git a/pippy/PipelineDriver.py b/pippy/PipelineDriver.py index 104a6aff1..c442f3f94 100644 --- a/pippy/PipelineDriver.py +++ b/pippy/PipelineDriver.py @@ -27,6 +27,7 @@ split_args_kwargs_into_chunks, merge_chunks, ) +from pippy.utils import flatten_args_detach # TODO: Define the strategy for replicating the computation. In particular, we will likely make the assumption # that the operations in the program are batch-wise commutative (my term), i.e. we can guarantee equivalence @@ -396,27 +397,11 @@ def retrieve_value_ref_args_by_idx(a): ) def forward(args, kwargs, no_grad): - flat_args = [] - - def extract_tensor_args(a): - nonlocal flat_args - if isinstance(a, torch.Tensor): - val = a.detach().requires_grad_(a.requires_grad) - flat_args.append(val) - return val - else: - flat_args.append(a) - return a + args, flat_args = flatten_args_detach(args) + kwargs, flat_kwargs = flatten_args_detach(kwargs) + # Contains all tensors from args and kwargs, in flattened form + flat_args += flat_kwargs - def dont_traverse_size(a): - return type(a) != torch.Size - - args = pippy.fx.node.map_aggregate( - args, extract_tensor_args, dont_traverse_size - ) - kwargs = pippy.fx.node.map_aggregate( - kwargs, extract_tensor_args, dont_traverse_size - ) logging.info( f"[{self.rank}] Running forward module for microbatch {work_item.microbatch_id}" # type: ignore[union-attr] ) @@ -437,6 +422,9 @@ def set_requires_grad(a): a.requires_grad_(True) return a + def dont_traverse_size(a): + return type(a) != torch.Size + if no_grad: with torch.no_grad(): out_val = forward_maybe_with_ddp(args, kwargs) diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index d0f2d816a..f589bb2a2 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -1,21 +1,24 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging import operator -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed as dist import pippy +from pippy.backward import stage_backward, sync_barrier +from pippy.debug import map_debug_info from pippy.fx.passes import shape_prop from pippy.IR import Pipe from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks +from pippy.utils import flatten_args def _make_tensor_from_meta( tensor_meta: shape_prop.TensorMetadata, device: torch.device, -): +) -> torch.Tensor: return torch.empty( tensor_meta.shape, dtype=tensor_meta.dtype, device=device ) @@ -32,6 +35,13 @@ def __init__( self.source = source self.buffer = buffer + def __repr__(self): + return f"RecvInfo(input={self.input_name}, source={self.source}, buffer={self.buffer.size()})" + + +class StageArgPlaceholder: + pass + class PipelineStage(torch.nn.Module): def __init__( @@ -42,7 +52,6 @@ def __init__( chunks: int, device: torch.device, group: dist.ProcessGroup = None, - return_to_0: bool = False, args_chunk_spec=None, kwargs_chunk_spec=None, output_chunk_spec=None, @@ -54,7 +63,6 @@ def __init__( self.chunks = chunks self.device = device self.group = group - self.return_to_0 = return_to_0 self.args_chunk_spec = args_chunk_spec self.kwargs_chunk_spec = kwargs_chunk_spec self.output_chunk_spec = output_chunk_spec @@ -69,7 +77,7 @@ def __init__( f"{self.submod}" ) - # Find my node in graph + # Find my forward node in graph found_node = False for node in self.split_gm.graph.nodes: if node.name == self.name: @@ -79,20 +87,65 @@ def __init__( if not found_node: raise AssertionError(f"Cannot find {self.name} in graph") - # Create submod name to rank mapping + # Find my backward node in graph + if self.pipe.has_loss_and_backwards: + found_bwd = False + seen_bwd = -1 + for node in reversed(self.split_gm.graph.nodes): + if (node.op, node.target) == ("call_function", stage_backward): + seen_bwd += 1 + if seen_bwd == self.rank: + found_bwd = True + self.bwd_node = node + break + if not found_bwd: + raise AssertionError( + f"Cannot find backward for {self.name} in graph" + ) + + # Create submod to rank mapping self.submod_to_rank: Dict[str, int] = {} for i, (name, _) in enumerate(self.split_gm.named_children()): self.submod_to_rank.setdefault(name, i) # Prepare send/recv infrastructure - self.args_recv_info = [] - self.kwargs_recv_info = [] - for _ in range(chunks): - args_recv, kwargs_recv = self._create_recv_buffers() - self.args_recv_info.append(args_recv) - self.kwargs_recv_info.append(kwargs_recv) + self._prepare_send_recv_infra() + + def _prepare_send_recv_infra(self): + """ + Create send/recv infrastructures for activations (during forward) and + gradients (during backward) + """ + # chunk : Tuple of arg buffers + self.args_recv_info: Dict[int, Tuple] = {} + # chunk : Dict of kwarg buffers + self.kwargs_recv_info: Dict[int, Dict] = {} + for chunk in range(self.chunks): + ( + self.args_recv_info[chunk], + self.kwargs_recv_info[chunk], + ) = self._create_act_recv_buffers() + + # Send info during forward for each activation + self.act_send_info = self._create_act_send_info() + + if self.pipe.has_loss_and_backwards: + # chunk : List of output grad buffers + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info: Dict = {} + for chunk in range(self.chunks): + self.grad_recv_info[chunk] = self._create_grad_recv_info( + self.act_send_info + ) - self._create_send_info() + # Send info for input grads during backward + # List of destinations corresponding to input grads + # Can be None if an input has no grad + # `grad_send_info` is a mirror of `args_recv_info` + `kwargs_recv_info` + self.grad_send_info = self._create_grad_send_info( + self.args_recv_info[0], + self.kwargs_recv_info[0], + ) def get_rank_of_submod( self, @@ -103,7 +156,7 @@ def get_rank_of_submod( return self.submod_to_rank[submod_name] - def _create_recv_buffers( + def _create_act_recv_buffers( self, ): def create_recv_tensor( @@ -116,7 +169,7 @@ def create_recv_tensor( """ if input_node.op == "placeholder": # Do not create buffer for placeholder - return None + return StageArgPlaceholder() # In case the input is a `getitem` node, we recursively find the # real source e.g. getitem1 = submod0[1] @@ -127,7 +180,10 @@ def create_recv_tensor( out_idx = input_node.args[1] return create_recv_tensor(real_input_node, out_idx) else: - return None + raise NotImplementedError( + f"getitem gets a non-Tensor value, this is not yet supported. " + f"Node: {input_node.format_node()}" + ) if output_idx is not None: # If a node has multiple output values, "tensor_meta" is a list @@ -143,10 +199,14 @@ def create_recv_tensor( ) src_rank = self.get_rank_of_submod(input_node.name) + buffer = _make_tensor_from_meta(tensor_meta, self.device) + # Enable gradient in training mode + if self.pipe.has_loss_and_backwards: + buffer.requires_grad_(True) return RecvInfo( input_node.name, src_rank, - _make_tensor_from_meta(tensor_meta, self.device), + buffer, ) # `args` is a Tuple, hence we will have: @@ -161,135 +221,360 @@ def create_recv_tensor( self.node.kwargs, create_recv_tensor ) + logging.info( + f"[{self.rank}][{self.name}] " + f"Activation recv info: {args_recv_info}" + ) return args_recv_info, kwargs_recv_info - def _create_send_info(self): - # Find send destinations - def find_dst_rank(user) -> Optional[int]: - if user.op == "output": - if self.return_to_0: - # Send result back to pp rank 0 - return 0 - else: - return None - else: - # User is a stage (`call_module`) - return self.get_rank_of_submod(user.name) + def find_dst_rank( + self, + user: pippy.fx.node, + ) -> Optional[int]: + """ + Find the destination rank of a `user` node. + If the `user` is not a submod, `None` may be returned. + """ + if user.op == "call_module": + # User is a stage (`call_module`) + return self.get_rank_of_submod(user.name) + elif user.target is sync_barrier: + # Send result back to pp rank 0 + return 0 + else: + # - If user.op == "output": + # No need to send back to rank 0 + # - If user.target is stage_backward: + # No need to send assuming submod output is stored locally or + # should be re-calucated in case of activation checkpointing + return None - # Output index: List of receivers - self.dst_infos: Dict[int, List] = {} + def _create_act_send_info(self): + # Output index: List of receiver ranks + act_send_info: Dict[int, List] = {} out_idx = 0 for user in self.node.users: if user.target is operator.getitem: # Recursively find the real destination - gi_dsts = self.dst_infos.setdefault(out_idx, []) + gi_dsts = act_send_info.setdefault(out_idx, []) for gi_user in user.users: - gi_dsts.append(find_dst_rank(gi_user)) + dst_rank = self.find_dst_rank(gi_user) + if dst_rank is not None: + gi_dsts.append(dst_rank) # Next `getitem` will point to the next output index out_idx += 1 else: # In case of single output value, `out_idx` will not increase - dsts = self.dst_infos.setdefault(out_idx, []) - dsts.append(find_dst_rank(user)) + dsts = act_send_info.setdefault(out_idx, []) + dst_rank = self.find_dst_rank(user) + if dst_rank is not None: + dsts.append(dst_rank) logging.info( - f"[{self.rank}][{self.name}] " f"Send info: {self.dst_infos}" + f"[{self.rank}][{self.name}] " f"Send info: {act_send_info}" ) + return act_send_info - def forward(self, *args, **kwargs): - if args or kwargs: - args_split, kwargs_split = split_args_kwargs_into_chunks( - args, - kwargs, - self.chunks, - self.args_chunk_spec, - self.kwargs_chunk_spec, + def _create_grad_recv_info( + self, + act_send_info: Dict, + ) -> Dict[int, RecvInfo]: + # Dict[output_index, RecvInfo] + grad_recv_info: Dict = {} + my_tensor_meta = self.node.meta["tensor_meta"] + + for out_idx, dst_list in act_send_info.items(): + if not dst_list: + # No actual receiver for activation so no grad coming back + continue + + # TODO: clean way + if len(act_send_info) > 1: + tensor_meta = my_tensor_meta[out_idx] + else: + tensor_meta = my_tensor_meta + + # TODO: otherwise needs grad accumulation + assert len(dst_list) == 1 + grad_src = dst_list[0] + grad_recv_info[out_idx] = RecvInfo( + f"{grad_src}", + grad_src, + _make_tensor_from_meta(tensor_meta, self.device), ) + logging.info( + f"[{self.rank}][{self.name}] " f"Grad recv info: {grad_recv_info}" + ) + return grad_recv_info + + def _create_grad_send_info( + self, + args_recv_info: Tuple, + kwargs_recv_info: Dict, + ) -> List[Optional[int]]: + grad_send_info = [] + + def map_recv_to_send(a): + if isinstance(a, RecvInfo): + grad_send_info.append(a.source) + return a.source + else: + grad_send_info.append(None) + return None + + pippy.fx.node.map_aggregate(args_recv_info, map_recv_to_send) + + pippy.fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send) + + logging.info( + f"[{self.rank}][{self.name}] " f"Grad send info: {grad_send_info}" + ) + return grad_send_info + + def _recv_tensor(self, info, recv_reqs): + logging.debug( + f"[{self.rank}][{self.name}] " + f"Receiving tensor '{info.input_name}' from Rank {info.source}: " + f"{info.buffer.size()}" + ) + # Use async to parallelize recv of tensors + work = dist.irecv( + info.buffer, + info.source + if self.group is None + else dist.get_global_rank(self.group, info.source), + group=self.group, + ) + recv_reqs.append(work) + return info.buffer + + def recv_tensor_fn( + self, + reqs, + ): + return lambda info: self._recv_tensor(info, reqs) + + def _recv_and_fill_inputs( + self, + chunk: int, + args_split, + kwargs_split, + ): # Receive requests of a chunk recv_reqs: List[dist.Work] = [] + + act_recv = self.recv_tensor_fn(recv_reqs) + + if args_split: + chunk_args = args_split[chunk] + chunk_args_list = list(chunk_args) + + def recv_args(info): + if isinstance(info, RecvInfo): + return act_recv(info) + else: + return chunk_args_list.pop(0) + + composite_args = pippy.fx.node.map_aggregate( + self.args_recv_info[chunk], + recv_args, + ) + + if kwargs_split: + chunk_kwargs = kwargs_split[chunk] + + def recv_kwargs(info): + if isinstance(info, RecvInfo): + return act_recv(info) + else: + k = next(iter(chunk_kwargs)) + return chunk_kwargs.pop(k) + + composite_kwargs = pippy.fx.node.map_aggregate( + self.kwargs_recv_info[chunk], + recv_kwargs, + ) + + # Wait for all recvs to finish + for work in recv_reqs: + work.wait() + + return composite_args, composite_kwargs + + def _send_activations( + self, + output_tuple, + ) -> List[dist.Work]: # Send requests of a chunk send_reqs: List[dist.Work] = [] - def recv_tensor(info): - if isinstance(info, RecvInfo): - logging.info( + for idx, out in enumerate(output_tuple): + dst_ranks = self.act_send_info[idx] + for dst in dst_ranks: + if dst is None: + continue + logging.debug( + f"[{self.rank}][{self.name}] " + f"Sending tensor to Rank {dst}: {out.size()}" + ) + work = dist.isend( + out, + dst + if self.group is None + else dist.get_global_rank(self.group, dst), # TODO + group=self.group, + ) + send_reqs.append(work) + + return send_reqs + + def _recv_grads( + self, + bwd_chunk, + ): + # Receive requests of a chunk + grad_recv_reqs: List[dist.Work] = [] + + recv_grad = self.recv_tensor_fn(grad_recv_reqs) + + # Receive gradients + grads = pippy.fx.node.map_aggregate( + self.grad_recv_info[bwd_chunk], + recv_grad, + ) + # Wait for all recvs to finish + for work in grad_recv_reqs: + work.wait() + + logging.debug( + f"[{self.rank}][{self.name}] " + f"Received output grads of chunk {bwd_chunk}: {map_debug_info(grads)}" + ) + return grads + + def _send_grads( + self, + grads_input, + ) -> List[dist.Work]: + # Send requests of a chunk + grad_send_reqs: List[dist.Work] = [] + + for grad, grad_receiver in zip(grads_input, self.grad_send_info): + if isinstance(grad, torch.Tensor) and grad_receiver is not None: + logging.debug( f"[{self.rank}][{self.name}] " - f"Receiving tensor '{info.input_name}' from Rank {info.source}: " - f"{info.buffer.size()}" + f"Sending gradient to Rank {grad_receiver}: {grad.size()}" ) - # Use async to parallelize recv of tensors - work = dist.irecv( - info.buffer, - info.source + work = dist.isend( + grad, + grad_receiver if self.group is None - else dist.get_global_rank(self.group, info.source), + else dist.get_global_rank( + self.group, grad_receiver + ), # TODO group=self.group, ) - recv_reqs.append(work) - return info.buffer + grad_send_reqs.append(work) else: - return info + assert grad is None and grad_receiver is None + + return grad_send_reqs + + def forward(self, *args, **kwargs): + # map microbatch ID to list of forward tensor args + fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} + + args_split = None + kwargs_split = None + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self.chunks, + self.args_chunk_spec, + self.kwargs_chunk_spec, + ) + + # Activation send requests of all chunk + all_send_reqs: List[dist.Work] = [] output_chunks = [] for chunk in range(self.chunks): - recv_reqs.clear() - if args: - chunk_args = args_split[chunk] - else: - chunk_args = pippy.fx.node.map_aggregate( - self.args_recv_info[chunk], - recv_tensor, - ) - - if kwargs: - chunk_kwargs = kwargs_split[chunk] - else: - chunk_kwargs = pippy.fx.node.map_aggregate( - self.kwargs_recv_info[chunk], - recv_tensor, - ) + composite_args, composite_kwargs = self._recv_and_fill_inputs( + chunk, + args_split, + kwargs_split, + ) - # Wait for all recvs to finish - for work in recv_reqs: - work.wait() + # Compute forward + try: + output = self.submod(*composite_args, **composite_kwargs) - # Compute - output = self.submod(*chunk_args, **chunk_kwargs) + except Exception as e: + exc_msg = f""" + Rank {self.rank} failed to run forward stage: {self.name} + args: {map_debug_info(composite_args)} + kwargs: {map_debug_info(composite_kwargs)} + """ + raise RuntimeError(exc_msg) from e # Unify output form to tuple for easy correspondance with - # `dst_infos` + # `act_send_info` output_tuple = output if type(output) is tuple else (output,) - for idx, out in enumerate(output_tuple): - dst_ranks = self.dst_infos[idx] - for dst in dst_ranks: - if dst is None: - # If dst is a `output` node, we don't need to send - # (unless `return_to_0` is required) - continue - logging.info( - f"[{self.rank}][{self.name}] " - f"Sending tensor to Rank {dst}: {out.size()}" - ) - work = dist.isend( - out, - dst - if self.group is None - else dist.get_global_rank(self.group, dst), # TODO - group=self.group, - ) - send_reqs.append(work) + send_reqs = self._send_activations(output_tuple) + all_send_reqs += send_reqs + # Prepare for final output merge or reduction output_chunks.append(output) + # Save activations and inputs for backward + flat_args = flatten_args(composite_args) + flat_kwargs = flatten_args(composite_kwargs) + flatten_input_tensors = flat_args + flat_kwargs + fwd_cache[chunk] = ( + output_tuple, # stage_output + flatten_input_tensors, # input_values + ) + # Wait for all sends to finish # TODO: okay to delay the sync till completion of all chunks? - for work in send_reqs: + for work in all_send_reqs: work.wait() + if self.pipe.has_loss_and_backwards: + # Backward starts here + # Grad send requests of all chunk + all_grad_send_reqs: List[dist.Work] = [] + + for bwd_chunk in range(self.chunks): + grads = self._recv_grads(bwd_chunk) + + # Pack args for `stage_backward`` + bwd_kwargs = dict(self.bwd_node.kwargs) + ( + bwd_kwargs["stage_output"], + bwd_kwargs["input_values"], + ) = fwd_cache.pop(bwd_chunk) + # (None,) is for `stage_backward` signature + bwd_kwargs["output_grads"] = ( + grads if len(grads) > 0 else (None,) + ) + + # `stage_backward` node does not have `args`, only `kwargs` + grads_input, _ = stage_backward(**bwd_kwargs) + + grad_send_reqs = self._send_grads(grads_input) + all_grad_send_reqs += grad_send_reqs + + # Wait for all sends to finish + # TODO: okay to delay the sync till completion of all chunks? + for work in all_grad_send_reqs: + work.wait() + # Last rank return merged results per original format if self.rank == self.nstages - 1: return merge_chunks( diff --git a/pippy/backward.py b/pippy/backward.py index 1f3765d38..d32c57a3f 100644 --- a/pippy/backward.py +++ b/pippy/backward.py @@ -2,7 +2,7 @@ from typing import List import torch -import pippy.fx +from pippy.debug import map_debug_info def stage_backward( @@ -13,18 +13,12 @@ def stage_backward( outputs_with_grads_idxs: List[int], ): """ - Given the input value(s) and the corresponding gradient for those/that input + Given the input value(s) and the corresponding gradient for the output value(s), compute and accumulate gradients for all parameter values (leaves in the autograd trace) as well as return a list of the gradients for the input values """ - def friendly_debug_info(v): - if isinstance(v, torch.Tensor): - return f"Tensor(size={v.shape})" - else: - return str(v) - try: stage_output_with_grads = [ stage_output[i] for i in outputs_with_grads_idxs @@ -82,12 +76,24 @@ def extract_tensors_with_grads(output_val, grad_val): else: grad_inputs.append(None) + # TODO: use `torch.autograd.grad` + """ + inputs_with_grad = [] + for val in input_values: + if isinstance(val, torch.Tensor) and val.requires_grad: + inputs_with_grad.append(val) + + grad_inputs = torch.autograd.grad( + stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] + ) + """ + except Exception as e: exc_msg = f""" Failed to run backward stage {stage_info} - Stage output value: {pippy.fx.node.map_aggregate(stage_output, friendly_debug_info)} - Output gradient values: {pippy.fx.node.map_aggregate(output_grads, friendly_debug_info)} - Input values: {pippy.fx.node.map_aggregate(input_values, friendly_debug_info)} + Stage output: {map_debug_info(stage_output)} + Output gradient: {map_debug_info(output_grads)} + Input: {map_debug_info(input_values)} """ raise RuntimeError(exc_msg) from e diff --git a/pippy/compile.py b/pippy/compile.py index 62a988aa6..c59f3a3ae 100644 --- a/pippy/compile.py +++ b/pippy/compile.py @@ -9,14 +9,15 @@ ) from pippy.PipelineStage import PipelineStage import pippy.fx as fx -from pippy.IR import MultiUseParameterConfig, Pipe -from pippy.fx.passes import shape_prop +from pippy.IR import MultiUseParameterConfig, Pipe, PiPPyShapeProp from pippy.microbatch import ( LossReducer, + gen_output_chunk_spec, split_args_kwargs_into_chunks, sum_reducer, ) -from pippy.utils import get_device, get_pp_rank, get_rank, PIPPY_VERBOSITY +from pippy.utils import get_device, get_pp_rank, get_rank +from pippy.debug import PIPPY_VERBOSITY import torch import torch.distributed as dist @@ -225,6 +226,7 @@ def compile_stage( split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, return_to_0: bool = False, tracer=None, + loss_reducer: LossReducer = sum_reducer, args_chunk_spec=None, kwargs_chunk_spec=None, output_chunk_spec=None, @@ -248,6 +250,7 @@ def compile_stage( tracer=tracer, output_loss_value_spec=output_loss_value_spec, split_policy=split_policy, + return_to_0=return_to_0, **kwargs, ) @@ -279,9 +282,16 @@ def compile_stage( # Use 1st chunk of args for shape propagation chunk0 = fake_args_split[0] - sp = shape_prop.ShapeProp(gm) + sp = PiPPyShapeProp(gm) sp.propagate(*chunk0) + # Prepare output chunk/reduce spec for merging/reducing final outputs + output_chunk_spec = ( + output_chunk_spec + if output_chunk_spec + else gen_output_chunk_spec(pipe.loss_spec, loss_reducer) + ) + # Create pipeline stage return PipelineStage( pipe, @@ -290,7 +300,6 @@ def compile_stage( num_chunks, device, group, - return_to_0, args_chunk_spec, kwargs_chunk_spec, output_chunk_spec, diff --git a/pippy/debug.py b/pippy/debug.py new file mode 100644 index 000000000..b06751ad2 --- /dev/null +++ b/pippy/debug.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import os +import torch +import pippy.fx + + +PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "OFF") + +if PIPPY_VERBOSITY == "DEBUG": + logging.getLogger().setLevel(logging.DEBUG) +elif PIPPY_VERBOSITY == "INFO": + logging.getLogger().setLevel(logging.INFO) +elif PIPPY_VERBOSITY == "OFF": + pass +else: + print(f"Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") + + +def friendly_debug_info(v): + if isinstance(v, torch.Tensor): + return f"Tensor({v.shape}, grad={v.requires_grad})" + else: + return str(v) + + +def map_debug_info(a): + return pippy.fx.node.map_aggregate(a, friendly_debug_info) diff --git a/pippy/utils.py b/pippy/utils.py index c9db5ad14..c2b64a9c3 100644 --- a/pippy/utils.py +++ b/pippy/utils.py @@ -33,17 +33,7 @@ import torch.multiprocessing as mp import torch.distributed.rpc as rpc - -PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "OFF") - -if PIPPY_VERBOSITY == "DEBUG": - logging.getLogger().setLevel(logging.DEBUG) -elif PIPPY_VERBOSITY == "INFO": - logging.getLogger().setLevel(logging.INFO) -elif PIPPY_VERBOSITY == "OFF": - pass -else: - print(f"Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") +import pippy.fx def get_rank() -> int: @@ -280,3 +270,42 @@ def _pp_group_barrier(): run_func(my_pp_ranks, args, *extra_args) rpc.shutdown() + + +def flatten_args_detach(args): + flat_detached_args = [] + + def extract_tensor_args(a): + nonlocal flat_detached_args + if isinstance(a, torch.Tensor): + val = a.detach().requires_grad_(a.requires_grad) + flat_detached_args.append(val) + return val + else: + flat_detached_args.append(a) + return a + + def dont_traverse_size(a): + return type(a) != torch.Size + + new_args = pippy.fx.node.map_aggregate( + args, extract_tensor_args, dont_traverse_size + ) + + return new_args, flat_detached_args + + +def flatten_args(args): + flat_args = [] + + def extract_tensor_args(a): + nonlocal flat_args + flat_args.append(a) + return a + + def dont_traverse_size(a): + return type(a) != torch.Size + + pippy.fx.node.map_aggregate(args, extract_tensor_args, dont_traverse_size) + + return flat_args diff --git a/test/local_test_c10d_bwd.py b/test/local_test_c10d_bwd.py new file mode 100644 index 000000000..5cdd897e9 --- /dev/null +++ b/test/local_test_c10d_bwd.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os +import unittest + +import torch +import torch.distributed as dist + +from pippy.compile import compile_stage +from pippy.IR import pipe_split + + +d_hid = 512 +chunk_size = 256 + +torch.manual_seed(0) + + +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin = torch.nn.Linear(d_hid, d_hid) + self.mse_loss = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, target): + x = torch.mm(x, self.mm_param) + skip_connection = x + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param) + x = self.lin(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin(x) + x = torch.relu(x) + loss = self.mse_loss(x, target) + return {"loss": loss} + + +def run_worker(args): + ec = ExampleCode() + ec.to(args.device) + ec.train() + + ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + + stage = compile_stage( + ec, + args.rank, + args.world_size, + args.chunks, + args.device, + None, + [ec_x, target], + ) + + # Run + if args.rank == 0: + out = stage(ec_x) + elif args.rank == args.world_size - 1: + out = stage(target) + else: + stage() + + dist.barrier() + print(f"Rank {args.rank} completes") + + # Last rank checks result + if args.rank == args.world_size - 1: + ref_out = ec(ec_x, target) + torch.testing.assert_close(out["loss"], ref_out["loss"]) + print( + f"equivalence test passed, loss = {out['loss']}, ref loss = {ref_out['loss']}" + ) + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) + + +if __name__ == "__main__": + main() + + +class LocalTestC10DBwdTest(unittest.TestCase): + def test_c10d_bwd(self): + import random + + port = random.randint(29500, 30000) + args = [ + "--master_port", + str(port), + ] + main(args) From 09be9beafe1eb8a5ce2b5b6285031165061892bf Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 21 Jun 2023 13:45:25 -0700 Subject: [PATCH 09/96] Use stable pytorch version for code quality check (#816) ## Description mypy fails with internal error when it comes to checks involving dynamo tensor. Use pytorch stable instead of nightly to see if we can avoid this error. --- .github/workflows/code-quality.yml | 4 ++-- pippy/PipelineStage.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index adc21257e..a8f34082e 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -23,8 +23,8 @@ jobs: pip install --upgrade pip pip install -r docs/requirements.txt pip install types-docutils types-setuptools tqdm types-tabulate - if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - pip install --pre torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + if [ -f requirements.txt ]; then pip install -r requirements.txt --index-url https://download.pytorch.org/whl/cpu; fi + pip install torchvision --index-url https://download.pytorch.org/whl/cpu pip install git+https://github.com/pbelevich/transformers.git@compatible_with_pt_master pip install "black<23" pylint==v3.0.0a5 mypy==v0.960 flake8==3.8.2 pyre-check==0.9.15 - name: Static Analysis Checks diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index f589bb2a2..58db3c4f8 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist import pippy +import pippy.fx from pippy.backward import stage_backward, sync_barrier from pippy.debug import map_debug_info @@ -229,7 +230,7 @@ def create_recv_tensor( def find_dst_rank( self, - user: pippy.fx.node, + user: pippy.fx.Node, ) -> Optional[int]: """ Find the destination rank of a `user` node. @@ -314,7 +315,7 @@ def _create_grad_send_info( args_recv_info: Tuple, kwargs_recv_info: Dict, ) -> List[Optional[int]]: - grad_send_info = [] + grad_send_info: List[Optional[int]] = [] def map_recv_to_send(a): if isinstance(a, RecvInfo): From c2c5a715b31e240bacfc0a13a93ad0e6bfe33596 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 22 Jun 2023 11:42:07 -0700 Subject: [PATCH 10/96] Add test for optimizer (#815) Attach optimizer as follows: ``` # Create an optimizer for stage submodule's parameters optimizer = optim.SGD(stage.submod.parameters(), lr=1e-3, momentum=0.9) ``` --- test/local_test_optim.py | 139 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 test/local_test_optim.py diff --git a/test/local_test_optim.py b/test/local_test_optim.py new file mode 100644 index 000000000..3f61db9d3 --- /dev/null +++ b/test/local_test_optim.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os +import unittest + +import torch +import torch.distributed as dist +import torch.optim as optim + +from pippy.compile import compile_stage +from pippy.IR import pipe_split, TrivialLossWrapper + + +d_hid = 512 +chunk_size = 256 + +torch.manual_seed(0) + + +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin0 = torch.nn.Linear(d_hid, d_hid) + self.lin1 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + x = torch.mm(x, self.mm_param0) + skip_connection = x + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param1) + x = self.lin0(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin1(x) + x = torch.relu(x) + return x + + +def run_worker(args): + ec = ExampleCode() + loss_fn = torch.nn.MSELoss(reduction="sum") + ec_with_loss = TrivialLossWrapper(ec, loss_fn) + ec_with_loss.to(args.device) + + ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + + stage = compile_stage( + ec_with_loss, + args.rank, + args.world_size, + args.chunks, + args.device, + None, + [ec_x, target], + ) + + # Create an optimizer for stage submodule's parameters + optimizer = optim.SGD(stage.submod.parameters(), lr=1e-3, momentum=0.9) + + for _ in range(2): + # Zero gradients + optimizer.zero_grad() + + # Run + if args.rank == 0: + stage(ec_x) + elif args.rank == args.world_size - 1: + stage(target) + else: + stage() + + # Take an optimization step + optimizer.step() + + dist.barrier() + print(f"Rank {args.rank} completes") + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) + + +if __name__ == "__main__": + main() + + +class LocalTestOptimTest(unittest.TestCase): + def test_optim(self): + import random + + port = random.randint(29500, 30000) + args = [ + "--master_port", + str(port), + ] + main(args) From c7d89746e6a756ed079a243c53135d1f97c1b23f Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Thu, 22 Jun 2023 14:07:20 -0700 Subject: [PATCH 11/96] Save module (#814) ## Description Replaced the initial mechanism of writing to file from a direct file.write to a more atomic write, using tempfile, to prevent data corruption if the data to write is large. Fixes #(issue) ## Type of change Please delete options that are not relevant. - [x] New feature (non-breaking change which adds functionality) - [x] This change requires a documentation update ## Feature/Issue validation/testing Working on a test on another PR ## Checklist: - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy --- pippy/hf/_SaveModule.py | 75 ++++++++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 16 deletions(-) diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index 51bb67529..cdd9d4ed5 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -1,5 +1,9 @@ import torch.distributed as dist -import pippy +from pippy.IR import Pipe + +from itertools import chain +import tempfile +import logging import json import os @@ -7,8 +11,36 @@ CKPT_INDEX_JSON_FILENAME = "pytorch_model.bin.index.json" +def _atomic_write(file_contents: str, target_file_path: str, mode="w") -> None: + """ + Atomically writes `file_contents` into `target_file_path`. + Args: + file_contents (str): contents to write to file + target_file_path (str): path to write to + mode (str, optional): mode to write file with. Defaults to "w". Only "w" and "a" are supported. + """ + # create tempfile as `move` ops aren't guaranteed to be atomic when between different file systems + temp_file = tempfile.NamedTemporaryFile( + delete=False, + dir=os.path.dirname(target_file_path), + ) + try: + with open(temp_file.name, mode) as f: + f.write(file_contents) + # sync in-memory state with storage device + f.flush() + os.fsync(f.fileno()) + os.replace(temp_file.name, target_file_path) + finally: + if os.path.exists(temp_file.name): + try: + os.unlink(temp_file.name) + except Exception: + raise RuntimeError(f"Failed to delete {temp_file.name}") + + def _save_index( - pipe: pippy.fx.GraphModule, + pipe: Pipe, ckpt_index_filename: str = CKPT_INDEX_JSON_FILENAME, checkpoint_dir: str = "checkpoints", ) -> None: @@ -16,23 +48,32 @@ def _save_index( Saves index file describing location of weights in checkpoint. Args: - pipe (pippy.fx.GraphModule): pipeline graph module with weights to save + pipe (Pipe): pipeline graph module with weights to save ckpt_index_filename (str, optional): name of index file. Defaults to "pytorch_model.bin.index.json". checkpoint_dir (str, optional): directory to save checkpoint to. Defaults to "checkpoints". """ - index_dict = {} - total_size = 0 - index_dict["metadata"] = {"total_size": total_size} + index_dict = { + "metadata": { + "total_size": 0, + }, + "weight_map": {}, + } weight_map = {} - for idx, (submod_name, submod) in enumerate(pipe.split_gm.named_children()): - for param_name, _ in submod.named_parameters(): - old_name = submod.remap_qualname(param_name) - - binary_filename = _create_binary_filename(idx) + for idx, (_, submod) in enumerate(pipe.split_gm.named_children()): # type: ignore + # chain both params and buffers generators together + params_buffers = chain( + submod.named_parameters(), submod.named_buffers() + ) + for param_name, _ in params_buffers: + old_name = submod.remap_qualname(param_name) # type: ignore + + binary_filename = _get_binary_filename(idx) weight_map[old_name] = binary_filename + index_dict["weight_map"] = weight_map + # serialize json json_str = json.dumps(index_dict, indent=4) filepath = os.path.join(checkpoint_dir, ckpt_index_filename) @@ -41,11 +82,13 @@ def _save_index( if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) - with open(filepath, "w") as f: - f.write(json_str) + # write index file atomically to avoid partial/corrupted writes + _atomic_write(json_str, filepath) + + logging.info(f"Saved index file to {filepath}") -def _create_binary_filename(cur_idx: int) -> str: +def _get_binary_filename(cur_idx: int) -> str: # type: ignore[valid-type] """ Gets filename for pytorch checkpoint binary based on current index and world size. @@ -55,7 +98,7 @@ def _create_binary_filename(cur_idx: int) -> str: Returns: str: checkpoint filename """ - cur_idx = str(cur_idx + 1).zfill(5) + idx = str(cur_idx + 1).zfill(5) world_size = str(dist.get_world_size()).zfill(5) - return f"pytorch_model-{cur_idx}-of-{world_size}.bin" + return f"pytorch_model-{idx}-of-{world_size}.bin" From 3b09bdf5775cefb7211c3de7ad9e844227bd68c7 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Thu, 22 Jun 2023 16:29:50 -0700 Subject: [PATCH 12/96] Save checkpoint size to metadata index file (#817) ## Description Adding the params total size to the index file. ## Type of change - [x] New feature (non-breaking change which adds functionality) - [x] This change requires a documentation update ## Checklist: - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy --- pippy/hf/_SaveModule.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index cdd9d4ed5..f50112568 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -1,5 +1,6 @@ import torch.distributed as dist from pippy.IR import Pipe +import torch from itertools import chain import tempfile @@ -10,6 +11,16 @@ CKPT_INDEX_JSON_FILENAME = "pytorch_model.bin.index.json" +DTYPE_SIZES = { + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, +} + + +def _get_param_size(param: torch.Tensor) -> int: + return param.numel() * DTYPE_SIZES[param.dtype] + def _atomic_write(file_contents: str, target_file_path: str, mode="w") -> None: """ @@ -65,12 +76,14 @@ def _save_index( params_buffers = chain( submod.named_parameters(), submod.named_buffers() ) - for param_name, _ in params_buffers: + for param_name, param in params_buffers: old_name = submod.remap_qualname(param_name) # type: ignore binary_filename = _get_binary_filename(idx) weight_map[old_name] = binary_filename + index_dict["metadata"]["total_size"] += _get_param_size(param) # type: ignore + index_dict["weight_map"] = weight_map # serialize json From e427095f7c8f9bd281a499e5e796ca125b179ea4 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Fri, 23 Jun 2023 15:53:50 -0700 Subject: [PATCH 13/96] update metadata total size at once (#819) ## Description continually increment variable in loop to update the index file's `total_size` field at once after the loop. ## Checklist: - [x] Have you added tests that prove your fix is effective or that this feature works? - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy --- pippy/hf/_SaveModule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index f50112568..a1963303d 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -71,6 +71,7 @@ def _save_index( } weight_map = {} + total_size = 0 for idx, (_, submod) in enumerate(pipe.split_gm.named_children()): # type: ignore # chain both params and buffers generators together params_buffers = chain( @@ -82,9 +83,10 @@ def _save_index( binary_filename = _get_binary_filename(idx) weight_map[old_name] = binary_filename - index_dict["metadata"]["total_size"] += _get_param_size(param) # type: ignore + total_size += _get_param_size(param) index_dict["weight_map"] = weight_map + index_dict["metadata"]["total_size"] = total_size # type: ignore # serialize json json_str = json.dumps(index_dict, indent=4) From f9b84c0093f00192269cc70a822f7d510b7fbf8a Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Tue, 27 Jun 2023 10:55:49 -0700 Subject: [PATCH 14/96] add test to check index file output (#818) ## Description Add single test to check if the contests of the ckpt metadata(index) file are accurate. ## Feature/Issue validation/testing - [x] Test LocalIndexMetadataTest ## Checklist: - [x] Have you added tests that prove your fix is effective or that this feature works? - [x] Has code been commented, particularly in hard-to-understand areas? --------- Co-authored-by: Eddy --- test/local_test_ckpt_index_file.py | 175 +++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 test/local_test_ckpt_index_file.py diff --git a/test/local_test_ckpt_index_file.py b/test/local_test_ckpt_index_file.py new file mode 100644 index 000000000..e69d2b1ad --- /dev/null +++ b/test/local_test_ckpt_index_file.py @@ -0,0 +1,175 @@ +from typing import List +import unittest +import argparse +import shutil +import json +import os + +from pippy.IR import pipe_split, TrivialLossWrapper +from pippy.hf._SaveModule import _save_index +from pippy.compile import compile_stage +import torch.distributed as dist +import torch.optim as optim +import torch + +DEFAULT_FILENAME = "pytorch_model.bin.index.json" +CKPT_DIR = "test_ckpts" +WEIGHT_MAP = set( + [ + "module.mm_param0", + "module.mm_param1", + "module.mm_param2", + "module.lin0.weight", + "module.lin0.bias", + "module.lin1.weight", + "module.lin1.bias", + ] +) +D_HID = 512 +CHUNK_SIZE = 256 + + +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param0 = torch.nn.Parameter(torch.randn(D_HID, D_HID)) + self.mm_param1 = torch.nn.Parameter(torch.randn(D_HID, D_HID)) + self.mm_param2 = torch.nn.Parameter(torch.randn(D_HID, D_HID)) + self.lin0 = torch.nn.Linear(D_HID, D_HID) + self.lin1 = torch.nn.Linear(D_HID, D_HID) + + def forward(self, x): + x = torch.mm(x, self.mm_param0) + skip_connection = x + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param1) + x = self.lin0(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin1(x) + x = torch.relu(x) + return x + + +def run_worker(args: List[str | int]) -> None: + ec = ExampleCode() + loss_fn = torch.nn.MSELoss(reduction="sum") + ec_with_loss = TrivialLossWrapper(ec, loss_fn) + ec_with_loss.to(args.device) + + ec_x = torch.randn(args.chunks * CHUNK_SIZE, D_HID, device=args.device) + target = torch.randn(args.chunks * CHUNK_SIZE, D_HID, device=args.device) + + stage = compile_stage( + ec_with_loss, + args.rank, + args.world_size, + args.chunks, + args.device, + None, + [ec_x, target], + ) + + # Create an optimizer for stage submodule's parameters + optimizer = optim.SGD(stage.submod.parameters(), lr=1e-3, momentum=0.9) + + for _ in range(2): + # Zero gradients + optimizer.zero_grad() + + # Run + if args.rank == 0: + stage(ec_x) + elif args.rank == args.world_size - 1: + stage(target) + else: + stage() + + # Take an optimization step + optimizer.step() + + dist.barrier() + print(f"Rank {args.rank} completes") + + # save index file in rank 0 + if args.rank == 0: + _save_index(stage, checkpoint_dir=CKPT_DIR) + + filepath = os.path.join(CKPT_DIR, DEFAULT_FILENAME) + with open(filepath) as f: + content = f.read() + data = json.loads(content) + + # check file written on disk to given location + assert os.path.exists(filepath) + + # check total_size correct + size_calc = sum(param.numel() for param in ec.parameters()) * 4 + assert size_calc == data["metadata"]["total_size"] + + # check all params present + assert len(data["weight_map"]) == 7 + for param in WEIGHT_MAP: + assert param in data["weight_map"] + + # remove test directory + shutil.rmtree(CKPT_DIR) + + +def main(args: List[str | int] = None) -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) + + +if __name__ == "__main__": + main() + + +class LocalIndexMetadataTest(unittest.TestCase): + def test_index_file(self): + import random + + port = random.randint(29500, 30000) + args = [ + "--master_port", + str(port), + ] + main(args) From d61e12567cc8c75d1914a9a8730caddbe97ebc78 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Tue, 27 Jun 2023 19:15:51 -0700 Subject: [PATCH 15/96] write actual weights to files on disk (#820) ## Description save model parameters(params+buffers) to disk, following the huggingface format, in multiple binary files depending on world size. Fixes #(issue) ## Type of change - [x] New feature (non-breaking change which adds functionality) - [x] This change requires a documentation update ## Checklist: - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy Co-authored-by: Iris <31293777+wz337@users.noreply.github.com> --- .gitignore | 12 ----- pippy/hf/_SaveModule.py | 45 ++++++++++++++----- test/local_test_ckpt_index_file.py | 70 ++++++++++++++++++++++-------- 3 files changed, 84 insertions(+), 43 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 2bbda701f..000000000 --- a/.gitignore +++ /dev/null @@ -1,12 +0,0 @@ -docs/build -__pycache__ -build -pippy.egg-info -pippy/version.py -dist -.idea/ -.pyre/ -.watchmanconfig -**/*.json -**/*.out -**/.DS_STORE diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index a1963303d..f35630c6e 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -2,6 +2,7 @@ from pippy.IR import Pipe import torch +from typing import Dict from itertools import chain import tempfile import logging @@ -63,15 +64,10 @@ def _save_index( ckpt_index_filename (str, optional): name of index file. Defaults to "pytorch_model.bin.index.json". checkpoint_dir (str, optional): directory to save checkpoint to. Defaults to "checkpoints". """ - index_dict = { - "metadata": { - "total_size": 0, - }, - "weight_map": {}, - } - - weight_map = {} + index_dict = {} total_size = 0 + + weight_map: Dict[str, str] = {} for idx, (_, submod) in enumerate(pipe.split_gm.named_children()): # type: ignore # chain both params and buffers generators together params_buffers = chain( @@ -81,12 +77,15 @@ def _save_index( old_name = submod.remap_qualname(param_name) # type: ignore binary_filename = _get_binary_filename(idx) - weight_map[old_name] = binary_filename - total_size += _get_param_size(param) + # add ckpt size once + if old_name not in weight_map: + total_size += _get_param_size(param) # type: ignore - index_dict["weight_map"] = weight_map - index_dict["metadata"]["total_size"] = total_size # type: ignore + weight_map[old_name] = binary_filename + + index_dict["metadata"] = {"total_size": total_size} # type: ignore + index_dict["weight_map"] = weight_map # type: ignore # serialize json json_str = json.dumps(index_dict, indent=4) @@ -117,3 +116,25 @@ def _get_binary_filename(cur_idx: int) -> str: # type: ignore[valid-type] world_size = str(dist.get_world_size()).zfill(5) return f"pytorch_model-{idx}-of-{world_size}.bin" + + +def _save_checkpoint(submod: Pipe, checkpoint_dir: str) -> None: + """ + writes `module`'s parameters and buffers to disk. + + Args: + submod(`Pipe`): a submodule of the model's graph + checkpoint_dir(`str`): where to keep the checkpoint binaries + """ + if not os.path.exists(checkpoint_dir): + os.mkdir(checkpoint_dir) + filepath = os.path.join( + checkpoint_dir, _get_binary_filename(dist.get_rank()) + ) + torch.save( + { + submod.remap_qualname(param_name): param + for param_name, param in submod.state_dict().items() + }, + filepath, + ) diff --git a/test/local_test_ckpt_index_file.py b/test/local_test_ckpt_index_file.py index e69d2b1ad..6a9de4b88 100644 --- a/test/local_test_ckpt_index_file.py +++ b/test/local_test_ckpt_index_file.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import List import unittest import argparse @@ -5,8 +6,12 @@ import json import os +from pippy.hf._SaveModule import ( + _save_index, + _save_checkpoint, +) from pippy.IR import pipe_split, TrivialLossWrapper -from pippy.hf._SaveModule import _save_index +from pippy.LoadModule import load_checkpoint from pippy.compile import compile_stage import torch.distributed as dist import torch.optim as optim @@ -77,23 +82,17 @@ def run_worker(args: List[str | int]) -> None: # Create an optimizer for stage submodule's parameters optimizer = optim.SGD(stage.submod.parameters(), lr=1e-3, momentum=0.9) - for _ in range(2): - # Zero gradients - optimizer.zero_grad() - - # Run - if args.rank == 0: - stage(ec_x) - elif args.rank == args.world_size - 1: - stage(target) - else: - stage() - - # Take an optimization step - optimizer.step() + # first run + # Zero gradients + optimizer.zero_grad() - dist.barrier() - print(f"Rank {args.rank} completes") + # Run + if args.rank == 0: + stage(ec_x) + elif args.rank == args.world_size - 1: + stage(target) + else: + stage() # save index file in rank 0 if args.rank == 0: @@ -107,7 +106,7 @@ def run_worker(args: List[str | int]) -> None: # check file written on disk to given location assert os.path.exists(filepath) - # check total_size correct + # check total_size is correct size_calc = sum(param.numel() for param in ec.parameters()) * 4 assert size_calc == data["metadata"]["total_size"] @@ -116,7 +115,40 @@ def run_worker(args: List[str | int]) -> None: for param in WEIGHT_MAP: assert param in data["weight_map"] - # remove test directory + # Take an optimization step + optimizer.step() + ref = deepcopy(stage.submod.state_dict()) + _save_checkpoint(stage.submod, CKPT_DIR) + + # second run + # Zero gradients + optimizer.zero_grad() + + # Run + if args.rank == 0: + stage(ec_x) + elif args.rank == args.world_size - 1: + stage(target) + else: + stage() + + # Take an optimization step + optimizer.step() + + # load ckpt + mod = load_checkpoint( + stage.submod, + os.path.join(CKPT_DIR, "pytorch_model.bin.index.json"), + args.device, + ) + + torch.testing.assert_close(mod.state_dict(), ref) + + dist.barrier() + print(f"Rank {args.rank} completes") + + # remove test ckpt directory in last rank + if args.rank == args.world_size - 1: shutil.rmtree(CKPT_DIR) From b5d4c751e15314cc72a4e97e06617fdf955a04ee Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Tue, 27 Jun 2023 19:22:36 -0700 Subject: [PATCH 16/96] restore gitignore (#822) ## Description the previous merged PR got rid of .gitignore because of its '.watchmanconfig' addition. Here, I restore the .gitignore to its initial state Co-authored-by: Eddy --- .gitignore | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..7dca738f4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +docs/build +__pycache__ +build +pippy.egg-info +pippy/version.py +dist +.idea/ +.pyre/ +**/*.json +**/*.out +**/.DS_STORE From f73076154ddffbb5476ca2ce6ae576c9efce6b13 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 28 Jun 2023 15:44:27 -0700 Subject: [PATCH 17/96] [BE] Apply ufmt to all Python files except for pippy/fx/*.py and change the check.sh to use ufmt (#824) Apply ufmt to all Python files except for pippy/fx/*.py and change the check.sh to use ufmt. This PR also fixes some failing tests. --------- Co-authored-by: Ke Wen --- .github/workflows/code-quality.yml | 2 +- .github/workflows/pippy_tests.yaml | 5 +- check.sh | 2 +- pippy/IR.py | 15 +- pippy/LoadModule.py | 6 +- pippy/ModelSplit.py | 1 + pippy/PipelineDriver.py | 19 +- pippy/PipelineStage.py | 1 + pippy/__init__.py | 24 +- pippy/auto_parallelization.py | 3 +- pippy/backward.py | 1 + pippy/compile.py | 25 +- pippy/debug.py | 2 + pippy/events.py | 2 +- pippy/hf/_SaveModule.py | 17 +- pippy/hf/__init__.py | 6 +- pippy/hf/utils.py | 9 +- pippy/microbatch.py | 7 +- pippy/utils.py | 4 +- pippy/visualizer.py | 4 +- test/hf_test.py | 2 +- test/local_test_autosplit.py | 9 +- test/local_test_ckpt_index_file.py | 22 +- test/local_test_compile.py | 5 +- test/local_test_ddp.py | 8 +- test/local_test_forward.py | 6 +- test/local_test_forward_auto_parallel.py | 8 +- test/local_test_forward_backward.py | 12 +- test/local_test_forward_hf_bert.py | 12 +- test/local_test_forward_hf_gpt2.py | 12 +- test/local_test_null_coalesce_accumulate.py | 10 +- test/local_test_visualizer.py | 17 +- test/min_gpt_pp_tp.py | 14 +- test/min_gpt_tracing.py | 6 +- test/test_fx.py | 1536 ++++++++++++------- test/test_fx_experimental.py | 222 ++- test/test_ir.py | 20 +- 37 files changed, 1299 insertions(+), 777 deletions(-) diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index a8f34082e..3008e2e18 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -26,7 +26,7 @@ jobs: if [ -f requirements.txt ]; then pip install -r requirements.txt --index-url https://download.pytorch.org/whl/cpu; fi pip install torchvision --index-url https://download.pytorch.org/whl/cpu pip install git+https://github.com/pbelevich/transformers.git@compatible_with_pt_master - pip install "black<23" pylint==v3.0.0a5 mypy==v0.960 flake8==3.8.2 pyre-check==0.9.15 + pip install "black<23" pylint==v3.0.0a5 mypy==v0.960 flake8==3.8.2 pyre-check==0.9.15 ufmt==2.1.0 - name: Static Analysis Checks if: always() run: ./check.sh --keep-going diff --git a/.github/workflows/pippy_tests.yaml b/.github/workflows/pippy_tests.yaml index 69d91ad77..26c555fe3 100644 --- a/.github/workflows/pippy_tests.yaml +++ b/.github/workflows/pippy_tests.yaml @@ -25,7 +25,7 @@ jobs: runs-on: linux.4xlarge strategy: matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.8", "3.9"] container: image: python:${{ matrix.python-version }} @@ -76,7 +76,7 @@ jobs: runs-on: linux.4xlarge strategy: matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.8", "3.9"] replicate: ["0", "1"] schedule: ["FillDrain", "1F1B"] checkpoint: [ "0", "1" ] @@ -190,6 +190,7 @@ jobs: runs-on: linux.16xlarge.nvidia.gpu strategy: matrix: + python-version: ["3.8"] replicate: ["0", "1"] schedule: ["FillDrain", "1F1B"] env: diff --git a/check.sh b/check.sh index 00c75a6dc..6f74be023 100755 --- a/check.sh +++ b/check.sh @@ -49,7 +49,7 @@ RETVAL=0 if (( SKIP_FORMAT == 0 )); then echo; echo "Running format check ..." - ./format.sh --check + ufmt diff pippy/*.py pippy/hf/*.py test/*.py (( RETVAL |= $? )) fi diff --git a/pippy/IR.py b/pippy/IR.py index ec84e2b93..bd64cf37f 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -2,22 +2,23 @@ import copy import logging import operator -from enum import Enum import os import threading +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.fx as torch_fx -import pippy.fx from packaging import version -from pippy.fx.passes import shape_prop -from pippy.fx.passes.split_module import split_module + +import pippy.fx from pippy.backward import ( + _null_coalesce_accumulate, stage_backward, sync_barrier, - _null_coalesce_accumulate, ) +from pippy.fx.passes import shape_prop +from pippy.fx.passes.split_module import split_module from pippy.LoadModule import load_checkpoint # Because split_module with 4 arguments is available only in PT 1.12+ @@ -638,7 +639,7 @@ def throw(self, *args, **kwargs): def forward(self, *args, **kwargs): executor_args = args if len(kwargs) > 0: - from inspect import Signature, Parameter + from inspect import Parameter, Signature parameters = [] for node in self.split_gm.graph.nodes: @@ -840,7 +841,7 @@ def move_param_to_callee( to_delete.append((mod_itr, atoms)) # deferral deletion - for (mod_itr, atoms) in to_delete: + for mod_itr, atoms in to_delete: delattr(mod_itr, atoms[-1]) split.graph.lint() diff --git a/pippy/LoadModule.py b/pippy/LoadModule.py index b0a620d8a..e3f04f986 100644 --- a/pippy/LoadModule.py +++ b/pippy/LoadModule.py @@ -1,8 +1,8 @@ -import os -import json import gc -from typing import Dict, List, Optional, Tuple, Union +import json import logging +import os +from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn diff --git a/pippy/ModelSplit.py b/pippy/ModelSplit.py index f2896d85d..74490b86c 100644 --- a/pippy/ModelSplit.py +++ b/pippy/ModelSplit.py @@ -3,6 +3,7 @@ from typing import Callable, Dict, List, Tuple import torch + import pippy.fx from pippy.IR import pipe_split diff --git a/pippy/PipelineDriver.py b/pippy/PipelineDriver.py index c442f3f94..26dcda518 100644 --- a/pippy/PipelineDriver.py +++ b/pippy/PipelineDriver.py @@ -6,26 +6,27 @@ import warnings from enum import Enum from inspect import Parameter, Signature -from typing import Any, Callable, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed.rpc as rpc -import pippy.fx -from pippy.fx.passes import shape_prop -from pippy.IR import Pipe +import pippy.fx from pippy.backward import ( + _null_coalesce_accumulate, stage_backward, sync_barrier, - _null_coalesce_accumulate, ) -from pippy.events import EventRecorder, EventsContext, Event, Allocator +from pippy.events import Allocator, Event, EventRecorder, EventsContext +from pippy.fx.passes import shape_prop + +from pippy.IR import Pipe from pippy.microbatch import ( - LossReducer, - sum_reducer, gen_output_chunk_spec, - split_args_kwargs_into_chunks, + LossReducer, merge_chunks, + split_args_kwargs_into_chunks, + sum_reducer, ) from pippy.utils import flatten_args_detach diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 58db3c4f8..96fa2166a 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist + import pippy import pippy.fx from pippy.backward import stage_backward, sync_barrier diff --git a/pippy/__init__.py b/pippy/__init__.py index d8d6564ed..49dc0493a 100644 --- a/pippy/__init__.py +++ b/pippy/__init__.py @@ -1,22 +1,22 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +from pippy.compile import ( + all_compile, + compile, + compile_stage, + create_default_args, +) from pippy.IR import ( - PipeSequential, + annotate_split_points, LossWrapper, - TrivialLossWrapper, - pipe_split, Pipe, + pipe_split, + PipeSequential, PipeSplitWrapper, - annotate_split_points, + TrivialLossWrapper, ) -from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B -from pippy.ModelSplit import split_on_size_threshold, split_into_equal_size +from pippy.ModelSplit import split_into_equal_size, split_on_size_threshold +from pippy.PipelineDriver import PipelineDriver1F1B, PipelineDriverFillDrain from pippy.utils import run_pippy -from pippy.compile import ( - compile, - all_compile, - create_default_args, - compile_stage, -) __all__ = [ diff --git a/pippy/auto_parallelization.py b/pippy/auto_parallelization.py index b3d4ca973..94e026807 100644 --- a/pippy/auto_parallelization.py +++ b/pippy/auto_parallelization.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass +from enum import Enum from typing import List, Tuple import numpy as np + import pippy.fx -from enum import Enum from pippy import pipe_split diff --git a/pippy/backward.py b/pippy/backward.py index d32c57a3f..6dfc6309b 100644 --- a/pippy/backward.py +++ b/pippy/backward.py @@ -2,6 +2,7 @@ from typing import List import torch + from pippy.debug import map_debug_info diff --git a/pippy/compile.py b/pippy/compile.py index c59f3a3ae..7b76513cb 100644 --- a/pippy/compile.py +++ b/pippy/compile.py @@ -2,26 +2,27 @@ import inspect import logging from typing import Any, Callable, List, Optional -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) -from pippy.PipelineStage import PipelineStage + +import torch +import torch.distributed as dist +from torch._subclasses.fake_tensor import FakeTensorMode + import pippy.fx as fx +from pippy.debug import PIPPY_VERBOSITY from pippy.IR import MultiUseParameterConfig, Pipe, PiPPyShapeProp from pippy.microbatch import ( - LossReducer, gen_output_chunk_spec, + LossReducer, split_args_kwargs_into_chunks, sum_reducer, ) +from pippy.PipelineDriver import ( + PipelineDriver1F1B, + PipelineDriverFillDrain, + PipelineDriverInterleaved1F1B, +) +from pippy.PipelineStage import PipelineStage from pippy.utils import get_device, get_pp_rank, get_rank -from pippy.debug import PIPPY_VERBOSITY - -import torch -import torch.distributed as dist -from torch._subclasses.fake_tensor import FakeTensorMode PIPELINE_SCHEDULE_DRIVERS = { diff --git a/pippy/debug.py b/pippy/debug.py index b06751ad2..c393581a9 100644 --- a/pippy/debug.py +++ b/pippy/debug.py @@ -1,7 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging import os + import torch + import pippy.fx diff --git a/pippy/events.py b/pippy/events.py index 7d054fd14..96b2669a3 100644 --- a/pippy/events.py +++ b/pippy/events.py @@ -4,7 +4,7 @@ from collections import defaultdict from copy import deepcopy from dataclasses import dataclass, field -from typing import Optional, Any, List, Dict +from typing import Any, Dict, List, Optional @dataclass diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index f35630c6e..594a1a3f5 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -1,14 +1,15 @@ -import torch.distributed as dist -from pippy.IR import Pipe -import torch +import json +import logging +import os +import tempfile +from itertools import chain from typing import Dict -from itertools import chain -import tempfile -import logging -import json -import os +import torch +import torch.distributed as dist + +from pippy.IR import Pipe CKPT_INDEX_JSON_FILENAME = "pytorch_model.bin.index.json" diff --git a/pippy/hf/__init__.py b/pippy/hf/__init__.py index 39e613001..9c93e2faf 100644 --- a/pippy/hf/__init__.py +++ b/pippy/hf/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from pippy.hf.utils import ( + inject_pipeline_forward, PiPPyHFTracer, - PiPPyTrainingArguments, + PiPPySeq2SeqTrainer, PiPPySeq2SeqTrainingArguments, PiPPyTrainer, - PiPPySeq2SeqTrainer, - inject_pipeline_forward, + PiPPyTrainingArguments, ) __all__ = [ diff --git a/pippy/hf/utils.py b/pippy/hf/utils.py index f05b7abb9..aa5072d79 100644 --- a/pippy/hf/utils.py +++ b/pippy/hf/utils.py @@ -11,16 +11,13 @@ import transformers import transformers.utils.fx as fx from transformers import ( - TrainingArguments, + Seq2SeqTrainer, Seq2SeqTrainingArguments, Trainer, - Seq2SeqTrainer, + TrainingArguments, ) from transformers.modeling_utils import ModuleUtilsMixin -from transformers.utils import ( - is_torch_available, -) -from transformers.utils import cached_property +from transformers.utils import cached_property, is_torch_available from pippy.PipelineDriver import PipelineDriverBase diff --git a/pippy/microbatch.py b/pippy/microbatch.py index 6f22c1c5b..eb2cace9a 100644 --- a/pippy/microbatch.py +++ b/pippy/microbatch.py @@ -1,12 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging -from typing import Any import warnings -from pippy.IR import TrivialLossWrapper +from typing import Any + import torch from torch.utils._pytree import tree_flatten, tree_unflatten +from pippy.IR import TrivialLossWrapper + class CustomReducer: def __init__(self, init_value, reduce_fn): @@ -22,6 +24,7 @@ class LossReducer(CustomReducer): DEFAULT_CHUNK_DIM = 0 + # Class used to specify chunking of inputs class TensorChunkSpec: def __init__(self, split_dim): diff --git a/pippy/utils.py b/pippy/utils.py index c2b64a9c3..ac76372f1 100644 --- a/pippy/utils.py +++ b/pippy/utils.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +import logging import os import socket -import logging from typing import List @@ -30,8 +30,8 @@ import torch -import torch.multiprocessing as mp import torch.distributed.rpc as rpc +import torch.multiprocessing as mp import pippy.fx diff --git a/pippy/visualizer.py b/pippy/visualizer.py index ac9ac83bf..6444ac2ad 100644 --- a/pippy/visualizer.py +++ b/pippy/visualizer.py @@ -1,12 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict, List, Any +from typing import Any, Dict, List from pippy.events import ( + Allocator, Event, EventDependency, EventsContext, MemDumpEvent, - Allocator, ) diff --git a/test/hf_test.py b/test/hf_test.py index f707b8d81..753780d8c 100644 --- a/test/hf_test.py +++ b/test/hf_test.py @@ -10,9 +10,9 @@ import transformers.utils.fx as fx from pippy.IR import ( + annotate_split_points, MultiUseParameterConfig, Pipe, - annotate_split_points, PipeSplitWrapper, stage_backward, ) diff --git a/test/local_test_autosplit.py b/test/local_test_autosplit.py index 6ceaa0105..db1c6a28b 100644 --- a/test/local_test_autosplit.py +++ b/test/local_test_autosplit.py @@ -3,17 +3,17 @@ import os import unittest -import torch -import torch.autograd.profiler_legacy - import pippy.fx import pippy.ModelSplit + +import torch +import torch.autograd.profiler_legacy from pippy import run_pippy from pippy.IR import MultiUseParameterConfig, Pipe from pippy.PipelineDriver import ( + PipelineDriver1F1B, PipelineDriverBase, PipelineDriverFillDrain, - PipelineDriver1F1B, PipelineDriverInterleaved1F1B, ) @@ -80,7 +80,6 @@ def inspect_split_module( # Common function to run pipeline with input and check equivalence def run_pipe_driver(ec_pipe, args): - nstages = len(list(ec_pipe.split_gm.children())) pipe_driver: PipelineDriverBase = schedules[args.schedule]( diff --git a/test/local_test_ckpt_index_file.py b/test/local_test_ckpt_index_file.py index 6a9de4b88..95539f4c0 100644 --- a/test/local_test_ckpt_index_file.py +++ b/test/local_test_ckpt_index_file.py @@ -1,21 +1,19 @@ -from copy import deepcopy -from typing import List -import unittest import argparse -import shutil import json import os +import shutil +import unittest +from copy import deepcopy +from typing import List -from pippy.hf._SaveModule import ( - _save_index, - _save_checkpoint, -) -from pippy.IR import pipe_split, TrivialLossWrapper -from pippy.LoadModule import load_checkpoint -from pippy.compile import compile_stage +import torch import torch.distributed as dist import torch.optim as optim -import torch +from pippy.compile import compile_stage + +from pippy.hf._SaveModule import _save_checkpoint, _save_index +from pippy.IR import pipe_split, TrivialLossWrapper +from pippy.LoadModule import load_checkpoint DEFAULT_FILENAME = "pytorch_model.bin.index.json" CKPT_DIR = "test_ckpts" diff --git a/test/local_test_compile.py b/test/local_test_compile.py index 7bc0045ef..6ee079912 100644 --- a/test/local_test_compile.py +++ b/test/local_test_compile.py @@ -2,11 +2,12 @@ import argparse import os import unittest + import pippy -from pippy import run_pippy -from pippy.IR import pipe_split import torch +from pippy import run_pippy +from pippy.IR import pipe_split d_hid = 512 bs = 256 diff --git a/test/local_test_ddp.py b/test/local_test_ddp.py index 5f5fb3e1d..1b8e60b08 100644 --- a/test/local_test_ddp.py +++ b/test/local_test_ddp.py @@ -4,21 +4,21 @@ import os import unittest +import pippy.fx + import torch import torch.distributed.rpc as rpc - -import pippy.fx from pippy import run_pippy from pippy.IR import ( MultiUseParameterConfig, Pipe, - TrivialLossWrapper, pipe_split, + TrivialLossWrapper, ) from pippy.PipelineDriver import ( - PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverBase, + PipelineDriverFillDrain, PipelineDriverInterleaved1F1B, ) diff --git a/test/local_test_forward.py b/test/local_test_forward.py index 509899fd7..924759cba 100644 --- a/test/local_test_forward.py +++ b/test/local_test_forward.py @@ -3,16 +3,16 @@ import os import unittest +import pippy.fx + import torch import torch.autograd.profiler_legacy - -import pippy.fx from pippy import run_pippy from pippy.IR import MultiUseParameterConfig, Pipe, pipe_split from pippy.PipelineDriver import ( + PipelineDriver1F1B, PipelineDriverBase, PipelineDriverFillDrain, - PipelineDriver1F1B, PipelineDriverInterleaved1F1B, ) diff --git a/test/local_test_forward_auto_parallel.py b/test/local_test_forward_auto_parallel.py index cc28902b9..19f0edd01 100644 --- a/test/local_test_forward_auto_parallel.py +++ b/test/local_test_forward_auto_parallel.py @@ -3,19 +3,19 @@ import os import unittest +import pippy.fx + import torch import torch.autograd.profiler_legacy - -import pippy.fx from pippy import run_pippy +from pippy.auto_parallelization import AutoParallelConfig, dp_auto_parallel from pippy.IR import MultiUseParameterConfig, Pipe from pippy.PipelineDriver import ( + PipelineDriver1F1B, PipelineDriverBase, PipelineDriverFillDrain, - PipelineDriver1F1B, PipelineDriverInterleaved1F1B, ) -from pippy.auto_parallelization import AutoParallelConfig, dp_auto_parallel PROFILING_ENABLED = True CHECK_NUMERIC_EQUIVALENCE = True diff --git a/test/local_test_forward_backward.py b/test/local_test_forward_backward.py index 0e63d18f3..43c177a9b 100644 --- a/test/local_test_forward_backward.py +++ b/test/local_test_forward_backward.py @@ -4,26 +4,24 @@ import os import unittest +import pippy.fx + import torch import torch.distributed.rpc as rpc - -import pippy.fx from pippy import run_pippy from pippy.IR import ( MultiUseParameterConfig, Pipe, - TrivialLossWrapper, pipe_split, + TrivialLossWrapper, ) +from pippy.microbatch import split_args_kwargs_into_chunks from pippy.PipelineDriver import ( + PipelineDriver1F1B, PipelineDriverBase, PipelineDriverFillDrain, - PipelineDriver1F1B, PipelineDriverInterleaved1F1B, ) -from pippy.microbatch import ( - split_args_kwargs_into_chunks, -) # TODOs for implementing forward/backward/loss with schedules: # * ability to switch between full-batch loss vs. per-microbatch loss. shen mentioned diff --git a/test/local_test_forward_hf_bert.py b/test/local_test_forward_hf_bert.py index d34977828..3b5d99563 100644 --- a/test/local_test_forward_hf_bert.py +++ b/test/local_test_forward_hf_bert.py @@ -3,24 +3,24 @@ import inspect import os +import pippy.fx + import torch import torch.autograd.profiler_legacy -from transformers import BertModel, BertConfig - -import pippy.fx from pippy import run_pippy +from pippy.hf import PiPPyHFTracer from pippy.IR import ( + annotate_split_points, MultiUseParameterConfig, Pipe, PipeSplitWrapper, - annotate_split_points, ) from pippy.PipelineDriver import ( - PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverBase, + PipelineDriverFillDrain, ) -from pippy.hf import PiPPyHFTracer +from transformers import BertConfig, BertModel PROFILING_ENABLED = True CHECK_NUMERIC_EQUIVALENCE = True diff --git a/test/local_test_forward_hf_gpt2.py b/test/local_test_forward_hf_gpt2.py index a6d612ebf..a8beb677e 100644 --- a/test/local_test_forward_hf_gpt2.py +++ b/test/local_test_forward_hf_gpt2.py @@ -3,24 +3,24 @@ import inspect import os +import pippy.fx + import torch import torch.autograd.profiler_legacy -from transformers import GPT2Model, GPT2Config - -import pippy.fx from pippy import run_pippy +from pippy.hf import PiPPyHFTracer from pippy.IR import ( + annotate_split_points, MultiUseParameterConfig, Pipe, PipeSplitWrapper, - annotate_split_points, ) from pippy.PipelineDriver import ( - PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverBase, + PipelineDriverFillDrain, ) -from pippy.hf import PiPPyHFTracer +from transformers import GPT2Config, GPT2Model PROFILING_ENABLED = True CHECK_NUMERIC_EQUIVALENCE = True diff --git a/test/local_test_null_coalesce_accumulate.py b/test/local_test_null_coalesce_accumulate.py index e4797550f..33bed4895 100644 --- a/test/local_test_null_coalesce_accumulate.py +++ b/test/local_test_null_coalesce_accumulate.py @@ -3,20 +3,20 @@ import os import unittest -import torch - import pippy.fx + +import torch from pippy import run_pippy from pippy.IR import ( + _null_coalesce_accumulate, Pipe, - TrivialLossWrapper, pipe_split, - _null_coalesce_accumulate, + TrivialLossWrapper, ) from pippy.PipelineDriver import ( + PipelineDriver1F1B, PipelineDriverBase, PipelineDriverFillDrain, - PipelineDriver1F1B, ) PROFILING_ENABLED = True diff --git a/test/local_test_visualizer.py b/test/local_test_visualizer.py index 0d63a35b4..95eacbe20 100644 --- a/test/local_test_visualizer.py +++ b/test/local_test_visualizer.py @@ -5,14 +5,14 @@ import unittest from collections import defaultdict from functools import reduce -from typing import List, Dict, Any +from typing import Any, Dict, List + +import pippy.fx import torch import torch.nn as nn -from torch.autograd import Function - -import pippy.fx from pippy import run_pippy +from pippy.events import Event from pippy.IR import ( MultiUseParameterConfig, Pipe, @@ -20,15 +20,15 @@ TrivialLossWrapper, ) from pippy.PipelineDriver import ( - PipelineDriverFillDrain, - PipelineDriver1F1B, + EventsContext, Phase, + PipelineDriver1F1B, PipelineDriverBase, - EventsContext, + PipelineDriverFillDrain, PipelineDriverInterleaved1F1B, ) -from pippy.events import Event from pippy.visualizer import events_to_json +from torch.autograd import Function PROFILING_ENABLED = True CHECK_NUMERIC_EQUIVALENCE = True @@ -55,7 +55,6 @@ def forward(self, input, target): # Inherit from Function class MyLinearFunction(Function): - # Note that both forward and backward are @staticmethods @staticmethod # bias is an optional argument diff --git a/test/min_gpt_pp_tp.py b/test/min_gpt_pp_tp.py index fc7c917a9..6491d3254 100644 --- a/test/min_gpt_pp_tp.py +++ b/test/min_gpt_pp_tp.py @@ -2,20 +2,18 @@ import argparse import os -import torch -import torch.distributed.tensor.parallel as tp - import pippy import pippy.fx -from pippy.IR import PipeSplitWrapper, annotate_split_points + +import torch import torch.distributed as dist -from torch.distributed._tensor import ( - DeviceMesh, -) +import torch.distributed.tensor.parallel as tp +from min_gpt_tracing import AdditionDataset # type: ignore from minGPT.mingpt.model import GPT, GPTConfig -from min_gpt_tracing import AdditionDataset # type: ignore +from pippy.IR import annotate_split_points, PipeSplitWrapper +from torch.distributed._tensor import DeviceMesh pippy.fx.Tracer.proxy_buffer_attributes = True diff --git a/test/min_gpt_tracing.py b/test/min_gpt_tracing.py index 16106d167..90b0d9a96 100644 --- a/test/min_gpt_tracing.py +++ b/test/min_gpt_tracing.py @@ -1,9 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.IR import Pipe, annotate_split_points, PipeSplitWrapper -import pippy.fx - import logging +import pippy.fx +from pippy.IR import annotate_split_points, Pipe, PipeSplitWrapper + logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", diff --git a/test/test_fx.py b/test/test_fx.py index 2aa82f0f1..b366395ec 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -22,145 +22,190 @@ from copy import deepcopy from math import sqrt +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union + +import pippy +import pippy.fx._pytree as fx_pytree + import torch import torch.nn.utils._stateless as _stateless import torch.utils._pytree as pytree -from torch.multiprocessing import Process -from torch.testing import FileCheck -from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests -from torch.testing._internal.common_methods_invocations import op_db -import pippy -import pippy.fx._pytree as fx_pytree -from pippy.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, \ - CodeGen -from pippy.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMATIBLITY +from fx.named_tup import MyNamedTup +from pippy.fx import ( + CodeGen, + Graph, + GraphModule, + Interpreter, + Node, + PH, + Proxy, + symbolic_trace, + Tracer, + Transformer, + wrap, +) +from pippy.fx._compatibility import ( + _BACK_COMPAT_OBJECTS, + _MARKED_WITH_COMATIBLITY, +) from pippy.fx.experimental.rewriter import RewritingTracer from pippy.fx.immutable_collections import immutable_dict, immutable_list -from pippy.fx.node import Target, Argument, _format_arg +from pippy.fx.node import _format_arg, Argument, Target from pippy.fx.operator_schemas import get_signature_for_torch_op from pippy.fx.passes import shape_prop from pippy.fx.proxy import TraceError - -from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union +from torch.multiprocessing import Process +from torch.testing import FileCheck +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCPU, + ops, +) +from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( + find_library_location, IS_FBCODE, IS_MACOS, IS_WINDOWS, - find_library_location, run_tests, skipIfSlowGradcheckEnv, ) from torch.testing._internal.jit_utils import JitTestCase -from fx.named_tup import MyNamedTup - try: from torchvision import models as torchvision_models + HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + class SimpleTest(torch.nn.Module): def forward(self, x): return torch.relu(x + 3.0) + def a_non_torch_leaf(a, b): return a + b + # Used for test_autowrap_function. Autowrapped functions need to be global def fx_int(x: float) -> int: return int(x) + def fx_int_x2(x: float) -> int: return int(x) * 2 + # used in test_pytree. It's all the way out here because pickling a GraphModule # that uses Point errors out if Point is local to the function -Point = namedtuple('Point', ['x', 'y']) +Point = namedtuple("Point", ["x", "y"]) + # Test wrap() passing both a function name as well as a function # directly def a_lifted_leaf(a, b): return a[0] + a[1] + b -wrap('a_lifted_leaf') + +wrap("a_lifted_leaf") # Test wrapping twice doesn't break anything -wrap('a_lifted_leaf') +wrap("a_lifted_leaf") + def a_lifted_leaf2(a, b): return a[0] + a[1] + b + wrap(a_lifted_leaf2) -wrap('len') +wrap("len") + +wrap("getattr") -wrap('getattr') def wrapped_named_tup(p1, *, p2): return p1.x + p2.y + wrap(wrapped_named_tup) + @wrap def wrapped_via_decorator(a): return a + 1 -wrap('wrapped_with_submodule') + +wrap("wrapped_with_submodule") + def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d): return batchnorm1d(x) + def my_decorator(f): @functools.wraps(f) def wrapper_inside_decorator(*args, **kwargs): return f(*args, **kwargs) + return wrapper_inside_decorator + @wrap @my_decorator def wrapped_decorated_fn(x): return x + real_wrapped_via_decorator = wrapped_via_decorator real_a_lifed_leaf = a_lifted_leaf real_a_lifed_leaf2 = a_lifted_leaf2 _sqrt = sqrt -wrap('wrapper_fn') +wrap("wrapper_fn") + def wrapper_fn(x): return torch.foo(x) + class Pair(NamedTuple): - x : torch.Tensor - y : torch.Tensor + x: torch.Tensor + y: torch.Tensor def _custom_fx_repr_fn(self) -> str: return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})" + # for testing pytrees class Foo(object): # noqa: B209 def __init__(self, a, b): self.a = a self.b = b + class TestFX(JitTestCase): def setUp(self): super().setUp() # Checking for mutable operations whil tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = pippy.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + pippy.fx.proxy.TracerBase.check_mutable_operations + ) pippy.fx.proxy.TracerBase.check_mutable_operations = True if not (IS_FBCODE or IS_WINDOWS or IS_MACOS): - lib_file_path = find_library_location('libtorchbind_test.so') + lib_file_path = find_library_location("libtorchbind_test.so") torch.ops.load_library(str(lib_file_path)) def tearDown(self): super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + pippy.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): """Check that an nn.Module's results match the GraphModule version @@ -191,7 +236,16 @@ def __init__(self): def forward(self, A, B, c): t = torch.sigmoid(A) + self.lin(c) - return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) + return self.sub_mod( + t.data + + self.w + + t + + 1 + - A + + B // A + + -A + + A.add(B, alpha=3) + ) m = MyModule() gm = symbolic_trace(m) @@ -207,9 +261,8 @@ def forward(self, A): gm2 = symbolic_trace(m2) class T(torch.nn.Module): - def forward(self, A, b=4, *args, c=5, **kwargs): - x = A + 1 + args[0] + kwargs['3'] + x = A + 1 + args[0] + kwargs["3"] return x t = T() @@ -230,8 +283,8 @@ def forward(self, x): def test_custom_import(self): graph = pippy.fx.Graph() - a = graph.placeholder('x') - b = graph.placeholder('y') + a = graph.placeholder("x") + b = graph.placeholder("y") c = graph.call_function(a_non_torch_leaf, (a, b)) d = graph.call_function(torch.sin, (c,)) graph.output(d) @@ -242,11 +295,13 @@ def test_custom_import(self): def test_args_kwargs(self): class T(torch.nn.Module): def forward(self, *args, **kwargs): - x = args[0] + kwargs['foo'] + x = args[0] + kwargs["foo"] return x t = T() - self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) + self.checkGraphModule( + t, (torch.rand(1), torch.rand(1)), {"foo": torch.rand(1)} + ) def test_args_kwargs_no_self(self): class T(torch.nn.Module): @@ -255,8 +310,12 @@ def forward(*args, **kwargs): # noqa: B902 return torch.relu(args[1]) t = T() - with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'): - self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) + with self.assertRaisesRegex( + RuntimeError, r"cannot be part of \*args expansion" + ): + self.checkGraphModule( + t, (torch.rand(1), torch.rand(1)), {"foo": torch.rand(1)} + ) def test_fx_shifts(self): class MyModule(torch.nn.Module): @@ -281,9 +340,9 @@ def forward(self, x): def test_dict(self): class MyDictMod(torch.nn.Module): def forward(self, d): - return d['3'].relu(), {'4' : d['3'].neg()} + return d["3"].relu(), {"4": d["3"].neg()} - input_dict = {'3': torch.rand(3, 4)} + input_dict = {"3": torch.rand(3, 4)} m = MyDictMod() self.checkGraphModule(m, (input_dict,)) @@ -305,16 +364,25 @@ def rmatmul_f(x): inp = torch.randn(3) self.assertEqual(mod(inp), rmatmul_f(inp)) - def test_disallow_override(self): # Custom delegate to disallow in-place tensor operations class NoMutableCallTracer(Tracer): - def create_node(self, kind : str, target : Union[str, Callable], - args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: - name = target if isinstance(target, str) else torch.typename(target) - if name[-1] == '_': - raise RuntimeError('In-place operations are not supported') + def create_node( + self, + kind: str, + target: Union[str, Callable], + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: + name = ( + target + if isinstance(target, str) + else torch.typename(target) + ) + if name[-1] == "_": + raise RuntimeError("In-place operations are not supported") return super().create_node(kind, target, args, kwargs, name) # Test method @@ -325,7 +393,7 @@ def forward(self, x): m = MyInplaceMod() - with self.assertRaisesRegex(RuntimeError, 'In-place operations'): + with self.assertRaisesRegex(RuntimeError, "In-place operations"): NoMutableCallTracer().trace(m) # Test free function @@ -333,8 +401,9 @@ class MyInplaceMod2(torch.nn.Module): def forward(self, x): torch.log_(x) return x + m2 = MyInplaceMod2() - with self.assertRaisesRegex(RuntimeError, 'In-place operations'): + with self.assertRaisesRegex(RuntimeError, "In-place operations"): NoMutableCallTracer().trace(m2) # Test symbolic node as an arg @@ -343,8 +412,9 @@ def forward(self, x): y = torch.ones(3, 4) y.add_(x) return x + m3 = MyInplaceMod3() - with self.assertRaisesRegex(RuntimeError, 'In-place operations'): + with self.assertRaisesRegex(RuntimeError, "In-place operations"): NoMutableCallTracer().trace(m3) def test_leaf_module(self): @@ -365,17 +435,21 @@ def forward(self, x): mrm = MyReluMod() sym = NoLeafModulesTracer().trace(mrm) for node in sym.nodes: - self.assertNotEqual(node.op, 'call_module') + self.assertNotEqual(node.op, "call_module") sym.lint() def test_wrap(self): self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) def to_trace(y): - return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y) + return ( + a_lifted_leaf((4, y), 3) + + a_lifted_leaf((3, 4), 5) + + a_lifted_leaf((y, y), y) + ) m = symbolic_trace(to_trace) - self.assertIn('a_lifted_leaf', m.code) + self.assertIn("a_lifted_leaf", m.code) self.assertEqual(27, m(2)) self.assertIs(a_lifted_leaf, real_a_lifed_leaf) @@ -383,10 +457,14 @@ def test_wrap_fn_directly(self): self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) def to_trace(y): - return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y) + return ( + a_lifted_leaf2((4, y), 3) + + a_lifted_leaf2((3, 4), 5) + + a_lifted_leaf2((y, y), y) + ) m = symbolic_trace(to_trace) - self.assertIn('a_lifted_leaf2', m.code) + self.assertIn("a_lifted_leaf2", m.code) self.assertEqual(27, m(2)) self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) @@ -397,7 +475,7 @@ def to_trace(y): return wrapped_via_decorator(y) m = symbolic_trace(to_trace) - self.assertIn('wrapped_via_decorator', m.code) + self.assertIn("wrapped_via_decorator", m.code) self.assertEqual(m(0), 1) self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) @@ -409,19 +487,18 @@ def to_trace(y): return wrapped_via_decorator(y) m = symbolic_trace(to_trace) - self.assertIn('wrapped_via_decorator', m.code) + self.assertIn("wrapped_via_decorator", m.code) self.assertEqual(m(0), 1) self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) transformed = pippy.fx.Transformer(m).transform() - self.assertIn('wrapped_via_decorator', transformed.code) + self.assertIn("wrapped_via_decorator", transformed.code) self.assertEqual(transformed(0), 1) self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) def test_wrap_with_submodule(self): - class M(torch.nn.Module): def __init__(self): super(M, self).__init__() @@ -443,11 +520,11 @@ def to_trace(y): return wrapped_via_decorator(y) m = symbolic_trace(to_trace) - self.assertIn('wrapped_via_decorator', m.code) + self.assertIn("wrapped_via_decorator", m.code) self.assertEqual(m(0), 1) retraced = symbolic_trace(m) - self.assertIn('wrapped_via_decorator', retraced.code) + self.assertIn("wrapped_via_decorator", retraced.code) self.assertEqual(retraced(0), 1) def test_wrap_decorated_function(self): @@ -455,17 +532,18 @@ def to_trace(y): return wrapped_decorated_fn(y) m = symbolic_trace(to_trace) - self.assertIn('wrapped_decorated_fn', m.code) + self.assertIn("wrapped_decorated_fn", m.code) self.assertEqual(m(1), 1) def test_graph_edit_with_proxy(self): class M(torch.nn.Module): def forward(self, a, b): return a + b + m = M() g = symbolic_trace(m).graph new_g = pippy.fx.Graph() - val_map : Dict[Node, Node] = {} + val_map: Dict[Node, Node] = {} output_val = new_g.graph_copy(g, val_map) t = Proxy(output_val) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. @@ -480,8 +558,10 @@ def forward(self, x, val=None): return x if val is None else x + val f = Foo() - traced = pippy.fx.symbolic_trace(f, concrete_args={'val' : None}) - with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'): + traced = pippy.fx.symbolic_trace(f, concrete_args={"val": None}) + with self.assertRaisesRegex( + AssertionError, "val has been specialized to have value None" + ): traced(torch.randn(5), torch.randn(5)) x = torch.randn(5) @@ -504,7 +584,9 @@ def multiply_forward(self, x, y): print(torch.__version__) tracer = Tracer() - torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y)) + torch.testing.assert_close( + GraphModule(f, tracer.trace(f))(x, y), f(x, y) + ) tracer.traced_func_name = "minus_forward" torch.testing.assert_close( @@ -522,21 +604,21 @@ def multiply_forward(self, x, y): with self.assertRaisesRegex(AssertionError, "doesn't exist in"): tracer.trace(f) - def test_graph_unique_names(self): class M(torch.nn.Module): def forward(self, a, b): return a + b + m = M() g = symbolic_trace(m).graph new_g = pippy.fx.Graph() - val_map : Dict[Node, Node] = {} + val_map: Dict[Node, Node] = {} output_val = new_g.graph_copy(g, val_map) t = Proxy(output_val) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) - seen_names : Set[str] = set() + seen_names: Set[str] = set() for node in gm.graph.nodes: assert node.name not in seen_names seen_names.add(node.name) @@ -553,15 +635,15 @@ def forward(self, a, b): # saving the original list because we will insert new nodes as a part of a test orig_graph_nodes = list(graph.nodes) for node in orig_graph_nodes: - if node.op == 'output': + if node.op == "output": continue self.assertTrue(node.stack_trace is not None) - assert 'test_fx.py' in node.stack_trace + assert "test_fx.py" in node.stack_trace # verify that copying the node does not lose the stack trace new_node = graph.node_copy(node) self.assertTrue(new_node.stack_trace is not None) - assert 'test_fx.py' in new_node.stack_trace + assert "test_fx.py" in new_node.stack_trace def test_stack_traces_with_transformer(self): class M(torch.nn.Module): @@ -577,22 +659,26 @@ def forward(self, a, b): # nodes after Transformer should still preserve the original node's stack trace for node in new_gm.graph.nodes: - if node.op in {'placeholder', 'output'}: + if node.op in {"placeholder", "output"}: continue self.assertTrue(node.stack_trace is not None) - assert 'test_fx.py' in node.stack_trace + assert "test_fx.py" in node.stack_trace def test_graph_unique_names_manual(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - a : pippy.fx.Node = graph.create_node('placeholder', 'x') - b : pippy.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1') - c : pippy.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1') - d : pippy.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) + graph: pippy.fx.Graph = pippy.fx.Graph() + a: pippy.fx.Node = graph.create_node("placeholder", "x") + b: pippy.fx.Node = graph.create_node( + "call_module", "linear_mod", args=(a,), name="foo_1_1" + ) + c: pippy.fx.Node = graph.create_node("get_attr", "y_attr", name="foo_1") + d: pippy.fx.Node = graph.create_node( + "call_function", operator.add, args=(b, c) + ) graph.output(d) graph2 = pippy.fx.Graph() - val_map : Dict[Node, Node] = {} + val_map: Dict[Node, Node] = {} graph2.graph_copy(graph, val_map) - seen_names : Set[str] = set() + seen_names: Set[str] = set() for node in graph2.nodes: assert node.name not in seen_names seen_names.add(node.name) @@ -610,7 +696,9 @@ def forward(self, a, b): def test_native_callable(self): if IS_FBCODE or IS_WINDOWS or IS_MACOS: - raise unittest.SkipTest("non-portable load_library call used in test") + raise unittest.SkipTest( + "non-portable load_library call used in test" + ) # This test exercises the case where we use FX to translate from Python # code to some native callable object # @@ -639,7 +727,9 @@ def forward(self, x): # a valid nn.Module, symbolically traces it, lowers the Module to some # representation, and wraps that representation up into another # nn.Module instance that handles dispatch to the compiled/lowered code. - def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module: + def lower_to_elementwise_interpreter( + orig_mod: torch.nn.Module, + ) -> torch.nn.Module: # ===== Stage 1: Symbolic trace the module ===== mod = symbolic_trace(orig_mod) @@ -650,12 +740,9 @@ def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Mod constants = {} fn_input_names = [] - target_to_name = { - operator.add : "add", - operator.mul : "mul" - } + target_to_name = {operator.add: "add", operator.mul: "mul"} - output_node : Optional[Node] = None + output_node: Optional[Node] = None # For each instruction, create a triple # (instruction_name : str, inputs : List[str], output : str) # to feed into the C++ interpreter @@ -663,33 +750,42 @@ def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Mod target, args, out_name = n.target, n.args, n.name assert len(n.kwargs) == 0, "kwargs currently not supported" - if n.op == 'placeholder': + if n.op == "placeholder": # Placeholders specify function argument names. Save these # for later when we generate the wrapper GraphModule fn_input_names.append(target) - elif n.op == 'call_function': - assert target in target_to_name, "Unsupported call target " + target + elif n.op == "call_function": + assert target in target_to_name, ( + "Unsupported call target " + target + ) arg_names = [] for arg in args: if not isinstance(arg, Node): # Pull out constants. These constants will later be # fed to the interpreter C++ object via add_constant() - arg_name = f'constant_{constant_idx}' + arg_name = f"constant_{constant_idx}" constants[arg_name] = torch.tensor( - [arg] if isinstance(arg, numbers.Number) else arg) + [arg] + if isinstance(arg, numbers.Number) + else arg + ) arg_names.append(arg_name) constant_idx += 1 else: arg_names.append(arg.name) - instructions.append((target_to_name[target], arg_names, out_name)) - elif n.op == 'output': + instructions.append( + (target_to_name[target], arg_names, out_name) + ) + elif n.op == "output": if output_node is not None: - raise RuntimeError('Multiple output nodes!') + raise RuntimeError("Multiple output nodes!") output_node = n else: - raise RuntimeError('Unsupported opcode ' + n.op) + raise RuntimeError("Unsupported opcode " + n.op) - interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter() + interpreter = ( + torch.classes._TorchScriptTesting._ElementwiseInterpreter() + ) # Load constants for k, v in constants.items(): interpreter.add_constant(k, v) @@ -722,14 +818,17 @@ def __init__(self, interpreter): # Add placeholders for fn inputs placeholder_nodes = [] for name in fn_input_names: - placeholder_nodes.append(graph.create_node('placeholder', name)) + placeholder_nodes.append(graph.create_node("placeholder", name)) # Get the interpreter object - interpreter_node = graph.create_node('get_attr', 'interpreter') + interpreter_node = graph.create_node("get_attr", "interpreter") # Add a node to call the interpreter instance output_node = graph.create_node( - op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) + op="call_method", + target="__call__", + args=(interpreter_node, placeholder_nodes), + ) # Register output graph.output(output_node) @@ -739,7 +838,6 @@ def __init__(self, interpreter): # Return final GraphModule!!! return GraphModule(wrapper, graph) - # Lower GraphModule to C++ interpreter lowered = lower_to_elementwise_interpreter(msm) @@ -761,6 +859,7 @@ def __init__(self, interpreter): def test_reserved_getattr(self): """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" + class M(torch.nn.Module): def forward(self, a): return a.foo.bar.baz @@ -781,13 +880,13 @@ def __init__(self): self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer('buffer', torch.randn(bs + 100, d_hid)) + self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) def forward(self, x): x = torch.mm(x, self.mm_param) skip_connection = x x = torch.relu(x) - x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]] + x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] x = self.lin(x) x = torch.relu(x) x = x + skip_connection @@ -795,7 +894,6 @@ def forward(self, x): x = self.lin(x) return x - ec = ExampleCode() traced = pippy.fx.symbolic_trace(ec) @@ -803,14 +901,19 @@ def forward(self, x): x = torch.randn(bs, d_hid) torch.testing.assert_allclose(ec(x), traced(x)) - def test_node_tagging(self): class TaggingTracer(Tracer): - def create_node(self, kind : str, target : Union[str, Callable], - args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: + def create_node( + self, + kind: str, + target: Union[str, Callable], + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: n = super().create_node(kind, target, args, kwargs, name) - n.tag = 'foo' + n.tag = "foo" return n class M(torch.nn.Module): @@ -821,8 +924,8 @@ def forward(self, a, b): g = TaggingTracer().trace(m) g.lint() for n in g.nodes: - self.assertTrue(hasattr(n, 'tag')) - self.assertEqual(n.tag, 'foo') + self.assertTrue(hasattr(n, "tag")) + self.assertEqual(n.tag, "foo") def test_tensor_attribute(self): class TensorAttribute(torch.nn.Module): @@ -851,11 +954,10 @@ def forward(self, x): traced2(torch.rand(4, 4)) def test_tensor_attribute_coalseced(self): - def count_attrs(fx_module): targets = set() for node in traced.graph.nodes: - if node.op == 'get_attr': + if node.op == "get_attr": targets.add(node.target) return len(targets) @@ -863,6 +965,7 @@ def count_attrs(fx_module): def f(x): return x + val + val + traced = symbolic_trace(f) traced.graph.lint() self.assertEqual(count_attrs(traced), 1) @@ -877,17 +980,12 @@ def f(x): traced.graph.lint() self.assertEqual(count_attrs(traced), 2) - def test_symbolic_trace_sequential(self): class Simple(torch.nn.Module): def forward(self, x): return torch.neg(x) - seq = torch.nn.Sequential( - Simple(), - Simple(), - Simple() - ) + seq = torch.nn.Sequential(Simple(), Simple(), Simple()) traced = symbolic_trace(seq) traced.graph.lint() x = torch.rand(3, 4) @@ -923,8 +1021,8 @@ def forward(self, x): def test_pickle_custom_import(self): graph = pippy.fx.Graph() - a = graph.placeholder('x') - b = graph.placeholder('y') + a = graph.placeholder("x") + b = graph.placeholder("y") c = graph.call_function(a_non_torch_leaf, (a, b)) d = graph.call_function(torch.sin, (c,)) graph.output(d) @@ -936,12 +1034,12 @@ def test_pickle_custom_import(self): self.assertEqual(loaded(x, y), gm(x, y)) def test_all_input_nodes(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - a : pippy.fx.Node = graph.placeholder('x') - b : pippy.fx.Node = graph.call_module('linear_mod', args=(a,)) - c : pippy.fx.Node = graph.get_attr('y_attr') - d : pippy.fx.Node = graph.call_function(operator.add, args=(b, c)) - e : pippy.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) + graph: pippy.fx.Graph = pippy.fx.Graph() + a: pippy.fx.Node = graph.placeholder("x") + b: pippy.fx.Node = graph.call_module("linear_mod", args=(a,)) + c: pippy.fx.Node = graph.get_attr("y_attr") + d: pippy.fx.Node = graph.call_function(operator.add, args=(b, c)) + e: pippy.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) graph.output(e) graph.lint() @@ -957,12 +1055,14 @@ def test_deepcopy_graphmodule_with_transform(self): def transform(traced): new_graph = pippy.fx.Graph() - val_map : Dict[Node, Node] = {} + val_map: Dict[Node, Node] = {} output_value = new_graph.graph_copy(traced.graph, val_map) relu_out = new_graph.create_node( - op='call_method', target='neg', args=(output_value,), kwargs={}) + op="call_method", target="neg", args=(output_value,), kwargs={} + ) new_graph.output(relu_out) return GraphModule(traced, new_graph) + transformed = transform(traced) transformed.graph.lint() copied = copy.deepcopy(transformed) @@ -1017,11 +1117,13 @@ def __init__(self): super().__init__() self.sa = SomeArgs() - def forward(self, x : list): + def forward(self, x: list): return self.sa(*x) ul = UnpacksList() - with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): + with self.assertRaisesRegex( + TraceError, "Proxy object cannot be iterated." + ): symbolic_trace(ul) def test_unpack_dict_better_error(self): @@ -1034,11 +1136,13 @@ def __init__(self): super().__init__() self.sk = SomeKwargs() - def forward(self, x : dict): + def forward(self, x: dict): return self.sk(**x) ud = UnpacksDict() - with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): + with self.assertRaisesRegex( + TraceError, "Proxy object cannot be iterated." + ): symbolic_trace(ud) def test_pretty_print_targets(self): @@ -1051,16 +1155,17 @@ def forward(self, x): traced = symbolic_trace(SomeMod()) graph_str = str(traced.graph) - self.assertIn('builtins.getattr', graph_str) - self.assertIn('operator.add', graph_str) - self.assertIn('torch.add', graph_str) + self.assertIn("builtins.getattr", graph_str) + self.assertIn("operator.add", graph_str) + self.assertIn("torch.add", graph_str) def test_pretty_print_node(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.param: torch.nn.Parameter = torch.nn.Parameter( - torch.rand(3, 4)) + torch.rand(3, 4) + ) self.linear = torch.nn.Linear(4, 5) def forward(self, x: torch.Tensor, y: int = 2): @@ -1070,14 +1175,25 @@ def forward(self, x: torch.Tensor, y: int = 2): all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes]) - FileCheck().check("x").check("placeholder") \ - .check("y").check("placeholder") \ - .check("getitem").check("call_function") \ - .check("param").check("get_attr") \ - .check("add").check("call_function") \ - .check("linear").check("call_module") \ - .check("clamp").check("call_method") \ - .run(all_formatted) + FileCheck().check("x").check("placeholder").check("y").check( + "placeholder" + ).check("getitem").check("call_function").check("param").check( + "get_attr" + ).check( + "add" + ).check( + "call_function" + ).check( + "linear" + ).check( + "call_module" + ).check( + "clamp" + ).check( + "call_method" + ).run( + all_formatted + ) def test_script_tensor_constant(self): # TorchScript seems to ignore attributes that start with `__`. @@ -1104,7 +1220,7 @@ def forward(self, x): # `int` would normally throw a TypeError as argument can't be `Proxy` tracer = Tracer(autowrap_functions=(fx_int,)) graph = tracer.trace(AutowrapFnTest()) - traced = GraphModule(tracer.root, graph, 'test') + traced = GraphModule(tracer.root, graph, "test") tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2)) tracer_2.trace(AutowrapFnTest2()) @@ -1113,7 +1229,7 @@ def forward(self, x): self.assertEqual(traced_scripted(torch.rand(4)), 2) def test_tuple_no_subscript(self): - def foo(x : Tuple): + def foo(x: Tuple): return x[0] traced = pippy.fx.symbolic_trace(foo) @@ -1162,7 +1278,7 @@ def forward(self, x): def test_torch_fx_getattr(self): class FXGetattrTest(torch.nn.Module): def forward(self, x): - return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3])) + return getattr(x, "nonexistent_attr", torch.Tensor([2, 3])) traced = symbolic_trace(FXGetattrTest()) self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3])) @@ -1192,6 +1308,7 @@ def forward(self, a): b = torch.ops.aten.sigmoid(a) c = torch.ops.aten.cat([a, b]) return torch.ops.aten.cat((c, c)) + m = M() input = torch.randn(3) ref_out = m(input) @@ -1205,6 +1322,7 @@ class M(torch.nn.Module): def forward(self, a): b = torch.ops.aten.add.Tensor(a, a) return b + m = M() input = torch.randn(3) ref_out = m(input) @@ -1214,9 +1332,9 @@ def forward(self, a): self.assertEqual(out, ref_out) for node in gm.graph.nodes: - if node.op == 'call_function': + if node.op == "call_function": assert isinstance(node.target, torch._ops.OpOverload) - assert node.target.__name__ == 'add.Tensor' + assert node.target.__name__ == "add.Tensor" def test_pickle_torch_custom_ops(self): class M(torch.nn.Module): @@ -1224,6 +1342,7 @@ def forward(self, a): b = torch.ops.aten.sigmoid(a) c = torch.ops.aten.cat([a, b]) return torch.ops.aten.cat((c, c)) + m = M() input = torch.randn(3) ref_out = m(input) @@ -1238,18 +1357,19 @@ def test_pretty_print(self): traced = symbolic_trace(st) traced.graph.lint() printed = str(traced) - assert 'SimpleTest()' in printed - assert 'torch.relu' in printed + assert "SimpleTest()" in printed + assert "torch.relu" in printed def test_pretty_print_graph(self): class KwargPrintTest(torch.nn.Module): def forward(self, x): return torch.squeeze(x + 3.0, dim=2) + st = KwargPrintTest() traced = symbolic_trace(st) traced.graph.lint() stringed = str(traced.graph) - for s in ['args', 'kwargs', '#users']: + for s in ["args", "kwargs", "#users"]: assert s in stringed def test_custom_proxy_type(self): @@ -1267,7 +1387,7 @@ def mul(self, other): r = self.right * other.right return TensorPair(l, r) - def use_tensor_pair(x : TensorPair, y : TensorPair): + def use_tensor_pair(x: TensorPair, y: TensorPair): s = x.add(y) return s.mul(x) @@ -1297,7 +1417,7 @@ def mul(self, other): r = self.right * other.right return TensorPair(l, r) - def use_tensor_pair_literal(x : TensorPair): + def use_tensor_pair_literal(x: TensorPair): s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3))) return s.mul(x) @@ -1326,7 +1446,7 @@ def mul(self, other): r = self.right * other.right return TensorPair(l, r) - def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): + def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor): s = x.add(TensorPair(y, y)) return s.mul(x) @@ -1356,7 +1476,7 @@ def add(self, other): elif other.is_zero: return self - def use_zero_tensor(x : torch.Tensor, y : torch.Tensor): + def use_zero_tensor(x: torch.Tensor, y: torch.Tensor): return ZeroTensor(x + y) x, y = torch.randn(5, 3), torch.randn(5, 3) @@ -1372,10 +1492,10 @@ def use_zero_tensor(x : torch.Tensor, y : torch.Tensor): def test_graph_fns(self): g = Graph() - a = g.placeholder('a') - b = g.call_module('linear', (a,)) - c = g.get_attr('bias') - d = g.call_method('add', (b, c)) + a = g.placeholder("a") + b = g.call_module("linear", (a,)) + c = g.get_attr("bias") + d = g.call_method("add", (b, c)) e = g.call_function(torch.sin, (d,)) g.output(e) mod = torch.nn.Module() @@ -1389,10 +1509,10 @@ def test_graph_fns(self): self.assertEqual(r, ref) def test_remove_uses(self): - g : pippy.fx.Graph = Graph() - x : pippy.fx.Node = g.placeholder('x') - relu : pippy.fx.Node = g.call_function(torch.relu, (x,)) - neg : pippy.fx.Node = g.call_function(torch.neg, (relu,)) + g: pippy.fx.Graph = Graph() + x: pippy.fx.Node = g.placeholder("x") + relu: pippy.fx.Node = g.call_function(torch.relu, (x,)) + neg: pippy.fx.Node = g.call_function(torch.neg, (relu,)) g.output(neg) neg.replace_all_uses_with(relu) @@ -1401,23 +1521,22 @@ def test_remove_uses(self): self.assertTrue(neg not in relu.users) def test_remove_uses_with_custom_filter(self): - g : pippy.fx.Graph = Graph() - x : pippy.fx.Node = g.placeholder('x') - relu : pippy.fx.Node = g.call_function(torch.relu, (x,)) - neg : pippy.fx.Node = g.call_function(torch.neg, (relu,)) + g: pippy.fx.Graph = Graph() + x: pippy.fx.Node = g.placeholder("x") + relu: pippy.fx.Node = g.call_function(torch.relu, (x,)) + neg: pippy.fx.Node = g.call_function(torch.neg, (relu,)) g.output(neg) neg.replace_all_uses_with(relu, lambda x: x != neg) self.assertTrue(neg in relu.users) - def test_nonetype_annotation(self): eb = torch.nn.EmbeddingBag(3, 4) symbolic_trace(eb) def test_pickle_nonetype_annotation(self): - eb = torch.nn.EmbeddingBag(10, 3, mode='sum') + eb = torch.nn.EmbeddingBag(10, 3, mode="sum") traced = symbolic_trace(eb) pickled = pickle.dumps(traced) loaded = pickle.loads(pickled) @@ -1428,37 +1547,42 @@ def test_pickle_nonetype_annotation(self): def test_return_tuple(self): class M(torch.nn.Module): - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: return (x, x + x) - original = M() traced = symbolic_trace(original) self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1))) def test_construct_root_dict(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - a : pippy.fx.Node = graph.create_node('placeholder', 'x') - b : pippy.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) - c : pippy.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') - d : pippy.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) + graph: pippy.fx.Graph = pippy.fx.Graph() + a: pippy.fx.Node = graph.create_node("placeholder", "x") + b: pippy.fx.Node = graph.create_node( + "call_module", "foo.bar.baz", args=(a,) + ) + c: pippy.fx.Node = graph.create_node("get_attr", "zip.zap.zam") + d: pippy.fx.Node = graph.create_node( + "call_function", operator.add, args=(b, c) + ) graph.output(d) - linear_mod : torch.nn.Module = torch.nn.Linear(3, 4) - add_param : torch.Tensor = torch.rand(3, 4) - gm : pippy.fx.GraphModule = pippy.fx.GraphModule( - {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph) + linear_mod: torch.nn.Module = torch.nn.Linear(3, 4) + add_param: torch.Tensor = torch.rand(3, 4) + gm: pippy.fx.GraphModule = pippy.fx.GraphModule( + {"foo.bar.baz": linear_mod, "zip.zap.zam": add_param}, graph + ) gm.graph.lint() - assert 'self.foo.bar.baz' in gm.code + assert "self.foo.bar.baz" in gm.code - x : torch.Tensor = torch.rand(3, 3) - out : torch.Tensor = gm(x) - ref_out : torch.Tensor = linear_mod(x) + add_param + x: torch.Tensor = torch.rand(3, 3) + out: torch.Tensor = gm(x) + ref_out: torch.Tensor = linear_mod(x) + add_param self.assertEqual(out, ref_out) def test_symbolic_trace_assert(self): - class AssertsTensorShape(torch.nn.Module): def forward(self, x): torch._assert(x.shape[1] > 4, "assert_foobar") @@ -1536,26 +1660,40 @@ def test_copy_no_remap(self): copied = pippy.fx.Graph() for node in g.nodes: copied.node_copy(node) - with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): + with self.assertRaisesRegex( + RuntimeError, "does not belong to this Graph" + ): copied.lint() def test_wrong_topo(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - a : pippy.fx.Node = graph.create_node('placeholder', 'x') - b : pippy.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) - c : pippy.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') - d : pippy.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) + graph: pippy.fx.Graph = pippy.fx.Graph() + a: pippy.fx.Node = graph.create_node("placeholder", "x") + b: pippy.fx.Node = graph.create_node( + "call_module", "foo.bar.baz", args=(a,) + ) + c: pippy.fx.Node = graph.create_node("get_attr", "zip.zap.zam") + d: pippy.fx.Node = graph.create_node( + "call_function", operator.add, args=(b, c) + ) graph.output(d) nodes = list(graph.nodes) nodes[3].append(nodes[2]) - with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): + with self.assertRaisesRegex( + RuntimeError, "was used before it has been defined" + ): graph.lint() def test_wrong_target_type(self): - graph : pippy.fx.Graph = pippy.fx.Graph() + graph: pippy.fx.Graph = pippy.fx.Graph() with self.assertRaises(ValueError): - n = pippy.fx.Node(graph=graph, name='foo', op='call_function', target='foo', - args=(), kwargs={}) + n = pippy.fx.Node( + graph=graph, + name="foo", + op="call_function", + target="foo", + args=(), + kwargs={}, + ) def test_example_shape_prop(self): class TestCase(torch.nn.Module): @@ -1566,6 +1704,7 @@ def __init__(self): def forward(self, x): return torch.neg(self.submod(x.relu() + self.attr)) + tc = TestCase() tc_traced = symbolic_trace(tc) ref_out = tc_traced(torch.rand(3, 4)) @@ -1573,15 +1712,26 @@ def forward(self, x): # Make sure we're testing all opcodes opcodes = set() - output_shape : Optional[torch.Shape] = None - output_stride : Optional[Tuple[int]] = None + output_shape: Optional[torch.Shape] = None + output_stride: Optional[Tuple[int]] = None for node in tc_traced.graph.nodes: opcodes.add(node.op) - if node.op == 'output': - output_shape = node.args[0].meta['tensor_meta'].shape - output_stride = node.args[0].meta['tensor_meta'].stride - self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', - 'call_module', 'output'])) + if node.op == "output": + output_shape = node.args[0].meta["tensor_meta"].shape + output_stride = node.args[0].meta["tensor_meta"].stride + self.assertEqual( + opcodes, + set( + [ + "placeholder", + "get_attr", + "call_function", + "call_method", + "call_module", + "output", + ] + ), + ) # Test shape propagation and make sure results match actual self.assertEqual(output_shape, ref_out.shape) @@ -1602,8 +1752,10 @@ def forward(self, x): x = torch.randn(5, 5, 224, 224) shape_prop.ShapeProp(traced).propagate(x) - assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format - for node in traced.graph.nodes)) + assert all( + node.meta["tensor_meta"].memory_format is torch.contiguous_format + for node in traced.graph.nodes + ) x_channels_last = x.contiguous(memory_format=torch.channels_last) traced.to(memory_format=torch.channels_last) @@ -1612,8 +1764,10 @@ def forward(self, x): # NB: the implementation of conv may not preserve the memory format, # unfortunately. The best we can do is just check that the placeholder # node is channels-last - if node.op in {'placeholder'}: - self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last) + if node.op in {"placeholder"}: + self.assertEqual( + node.meta["tensor_meta"].memory_format, torch.channels_last + ) def test_shape_prop_aggregate(self): class ReturnTwo(torch.nn.Module): @@ -1640,9 +1794,9 @@ def is_leaf_module(self, m, module_qualified_name): shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4)) for node in mod.graph.nodes: - if node.op == 'call_module': - assert 'tensor_meta' in node.meta - tensor_meta = node.meta['tensor_meta'] + if node.op == "call_module": + assert "tensor_meta" in node.meta + tensor_meta = node.meta["tensor_meta"] assert tensor_meta[0] == 3 assert tensor_meta[1].shape == torch.Size([]) @@ -1659,18 +1813,25 @@ def forward(self, x): traced_3d = symbolic_trace(test_mod_3d) x_3d = torch.randn(5, 5, 224, 224, 15) shape_prop.ShapeProp(traced_3d).propagate(x_3d) - assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format - for node in traced_3d.graph.nodes)) + assert all( + node.meta["tensor_meta"].memory_format is torch.contiguous_format + for node in traced_3d.graph.nodes + ) - x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d) + x_channels_last_3d = x_3d.contiguous( + memory_format=torch.channels_last_3d + ) traced_3d.to(memory_format=torch.channels_last_3d) shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d) for node in traced_3d.graph.nodes: # NB: the implementation of conv may not preserve the memory format, # unfortunately. The best we can do is just check that the placeholder # node is channels-last - if node.op in {'placeholder'}: - self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d) + if node.op in {"placeholder"}: + self.assertEqual( + node.meta["tensor_meta"].memory_format, + torch.channels_last_3d, + ) def test_interpreter(self): class MyModule(torch.nn.Module): @@ -1707,7 +1868,7 @@ class RunNodeInterpreter(Interpreter): def __init__(self, module): super().__init__(module) - def run_node(self, n : Node) -> Any: + def run_node(self, n: Node) -> Any: result = super().run_node(n) n.cached_value = result return result @@ -1715,23 +1876,26 @@ def run_node(self, n : Node) -> Any: input = torch.randn(3, 4) RunNodeInterpreter(gm).run(input) for node in gm.graph.nodes: - assert hasattr(node, 'cached_value') + assert hasattr(node, "cached_value") def test_interpreter_onthefly_swap(self): - def fn(x): return torch.sigmoid(x).neg() gm = pippy.fx.symbolic_trace(fn) class NegSigmSwapInterpreter(Interpreter): - def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: + def call_function( + self, target: Target, args: Tuple, kwargs: Dict + ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) - def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: - if target == 'neg': + def call_method( + self, target: Target, args: Tuple, kwargs: Dict + ) -> Any: + if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) @@ -1754,13 +1918,15 @@ def forward(self, x): interp = Interpreter(gm) env = {} for node in gm.graph.nodes: - if node.op == 'call_module' and node.target == 'linear': + if node.op == "call_module" and node.target == "linear": env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 break assert len(env) == 1 x = torch.randn(3, 4) result = interp.run(x, initial_env=env) - self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0)) + self.assertEqual( + result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0) + ) def test_interpreter_star_args(self): def with_star_args(x, *args): @@ -1768,7 +1934,9 @@ def with_star_args(x, *args): gm = pippy.fx.symbolic_trace(with_star_args) interp = Interpreter(gm) - result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4)) + result = interp.run( + torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4) + ) self.assertEqual(result, torch.ones(3, 4) * 2.0) @skipIfNoTorchVision @@ -1785,7 +1953,7 @@ def test_interpreter_gc_values(self): inp = torch.rand(5, 3, 224, 224) out = interp.run(inp) env_key_names = set(n.name for n in interp.env.keys()) - self.assertEqual(env_key_names, set(['output'])) + self.assertEqual(env_key_names, set(["output"])) def test_interpreter_default_args(self): class Model(torch.nn.Module): @@ -1810,8 +1978,10 @@ def forward(self, x, y): interp = Interpreter(gm) x = torch.randn(5, 3) - with self.assertRaisesRegex(RuntimeError, - 'Expected positional argument for parameter y, but one was not passed in'): + with self.assertRaisesRegex( + RuntimeError, + "Expected positional argument for parameter y, but one was not passed in", + ): out = interp.run(x) def test_transformer_noop(self): @@ -1833,20 +2003,23 @@ def forward(self, x): self.assertEqual(new_gm(input), gm(input)) def test_transformer_op_swap(self): - def fn(x): return torch.sigmoid(x).neg() gm = pippy.fx.symbolic_trace(fn) class NegSigmSwapXformer(Transformer): - def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: + def call_function( + self, target: Target, args: Tuple, kwargs: Dict + ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) - def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: - if target == 'neg': + def call_method( + self, target: Target, args: Tuple, kwargs: Dict + ) -> Any: + if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) @@ -1877,8 +2050,10 @@ def forward(self, x): def test_fn_type_annotations(self): class Foo(torch.nn.Module): - def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: - return {'a': p.x + p.y + z + i} + def forward( + self, p: Pair, z: torch.Tensor, i: int + ) -> Dict[str, torch.Tensor]: + return {"a": p.x + p.y + z + i} foo_scripted = torch.jit.script(Foo()) foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) @@ -1888,8 +2063,9 @@ def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) def test_fn_type_annotation_empty(self): - def forward(a : List[torch.Tensor]): + def forward(a: List[torch.Tensor]): return a[0] + torch.jit.script(symbolic_trace(forward)) def test_wrapped_method(self): @@ -1897,6 +2073,7 @@ def wrap_with_relu(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): return torch.relu(fn(*args, **kwargs)) + return wrapper class Foo(torch.nn.Module): @@ -1936,13 +2113,14 @@ def forward(self, x): self.checkGraphModule(m, (torch.rand(3, 4),)) def test_typename_print(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - x : pippy.fx.Node = graph.create_node('placeholder', 'x') - b : pippy.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), - type_expr=List[float]) - output : pippy.fx.Node = graph.output(b) + graph: pippy.fx.Graph = pippy.fx.Graph() + x: pippy.fx.Node = graph.create_node("placeholder", "x") + b: pippy.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,), type_expr=List[float] + ) + output: pippy.fx.Node = graph.output(b) - self.assertTrue('typing.List[float]' in str(graph)) + self.assertTrue("typing.List[float]" in str(graph)) def test_layout(self): class M(torch.nn.Module): @@ -1950,7 +2128,9 @@ def __init__(self): super().__init__() def forward(self, x): - return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0) + return torch.empty_like( + x, layout=torch.strided, pin_memory=False + ).fill_(0) traced = symbolic_trace(M()) x = torch.rand(5, 9, 3, 4) @@ -1971,27 +2151,31 @@ def forward(self, x, y): def test_inf_nan(self): class FooMod(torch.nn.Module): def forward(self, x): - return x + float('inf'), x + float('-inf'), x + float('nan') + return x + float("inf"), x + float("-inf"), x + float("nan") fm = FooMod() self.checkGraphModule(fm, (torch.rand(3, 4),)) def test_inf_nan_kwds(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - x : pippy.fx.Node = graph.create_node('placeholder', 'x') - b : pippy.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') - c : pippy.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') + graph: pippy.fx.Graph = pippy.fx.Graph() + x: pippy.fx.Node = graph.create_node("placeholder", "x") + b: pippy.fx.Node = graph.create_node( + "call_function", operator.add, (x, float("inf")), {}, name="inf" + ) + c: pippy.fx.Node = graph.create_node( + "call_function", operator.add, (x, float("nan")), {}, name="nan" + ) graph.output((b, c)) gm = pippy.fx.GraphModule(torch.nn.Module(), graph) x = torch.rand(3, 4) - self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) + self.assertEqual(gm(x), (x + float("inf"), x + float("nan"))) def test_deepcopy_recursion_depth(self): depth = sys.getrecursionlimit() + 20 g = pippy.fx.Graph() - x = g.placeholder('x') + x = g.placeholder("x") for i in range(depth): x = g.call_function(torch.relu, (x,)) g.output(x) @@ -2013,7 +2197,7 @@ def test_replace_uses(self): rn18 = torchvision_models.resnet18() class LowerReluTracer(pippy.fx.Tracer): - def is_leaf_module(self, m : torch.nn.Module, qualname : str): + def is_leaf_module(self, m: torch.nn.Module, qualname: str): if isinstance(m, torch.nn.ReLU): return False return super().is_leaf_module(m, qualname) @@ -2022,26 +2206,33 @@ def is_leaf_module(self, m : torch.nn.Module, qualname : str): to_erase = [] for node in rn18_traced.graph.nodes: - if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]: + if node.op == "call_function" and node.target in [ + torch.relu, + torch.nn.functional.relu, + ]: kwargs = node.kwargs.copy() # Neg doesn't have in-place - kwargs.pop('inplace') + kwargs.pop("inplace") with rn18_traced.graph.inserting_before(node): new_node = rn18_traced.graph.call_function( - the_function=torch.neg, args=node.args, kwargs=node.kwargs) + the_function=torch.neg, + args=node.args, + kwargs=node.kwargs, + ) node.replace_all_uses_with(replace_with=new_node) to_erase.append(node) for node in to_erase: rn18_traced.graph.erase_node(node) - def test_replace_input(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - x : pippy.fx.Node = graph.create_node('placeholder', 'x') - y : pippy.fx.Node = graph.create_node('placeholder', 'y') - b : pippy.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : pippy.fx.Node = graph.output(b) + graph: pippy.fx.Graph = pippy.fx.Graph() + x: pippy.fx.Node = graph.create_node("placeholder", "x") + y: pippy.fx.Node = graph.create_node("placeholder", "y") + b: pippy.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: pippy.fx.Node = graph.output(b) b.replace_input_with(x, y) @@ -2052,13 +2243,17 @@ def test_replace_input(self): self.assertEqual(gm(input_x, input_y), torch.relu(input_y)) def test_insertion_point(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - x : pippy.fx.Node = graph.create_node('placeholder', 'x') - b : pippy.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : pippy.fx.Node = graph.output(b) + graph: pippy.fx.Graph = pippy.fx.Graph() + x: pippy.fx.Node = graph.create_node("placeholder", "x") + b: pippy.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: pippy.fx.Node = graph.output(b) with graph.inserting_before(b): - neg : pippy.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) + neg: pippy.fx.Node = graph.call_function( + the_function=torch.neg, args=(x,) + ) _, *relu_args = b.args b.args = (neg, *relu_args) @@ -2068,34 +2263,36 @@ def test_insertion_point(self): self.assertEqual(gm(input), torch.relu(torch.neg(input))) def test_update_args_api(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - x : pippy.fx.Node = graph.create_node('placeholder', 'x') - y : pippy.fx.Node = graph.create_node('placeholder', 'y') - b : pippy.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : pippy.fx.Node = graph.output(b) + graph: pippy.fx.Graph = pippy.fx.Graph() + x: pippy.fx.Node = graph.create_node("placeholder", "x") + y: pippy.fx.Node = graph.create_node("placeholder", "y") + b: pippy.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: pippy.fx.Node = graph.output(b) orig_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) - b.update_arg(0, y) new_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) def test_update_kwargs_api(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - x : pippy.fx.Node = graph.create_node('placeholder', 'x') - y : pippy.fx.Node = graph.create_node('placeholder', 'y') - b : pippy.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x}) - output : pippy.fx.Node = graph.output(b) + graph: pippy.fx.Graph = pippy.fx.Graph() + x: pippy.fx.Node = graph.create_node("placeholder", "x") + y: pippy.fx.Node = graph.create_node("placeholder", "y") + b: pippy.fx.Node = graph.create_node( + "call_function", target=torch.relu, kwargs={"input": x} + ) + output: pippy.fx.Node = graph.output(b) orig_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) - - b.update_kwarg('input', y) + b.update_kwarg("input", y) new_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) @@ -2112,7 +2309,7 @@ def test_immutable_list_pytree_ops(self): def test_immutable_dict_pytree_ops(self): rand_tensor = torch.randn(5, 3) - d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]}) + d = immutable_dict({"a": 3, "b": [rand_tensor, 42]}) flattened, spec = pytree.tree_flatten(d) assert flattened == [3, rand_tensor, 42] @@ -2122,12 +2319,16 @@ def test_immutable_dict_pytree_ops(self): assert isinstance(unflattened, immutable_dict) def test_move_before(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - x : pippy.fx.Node = graph.create_node('placeholder', 'x') - b : pippy.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : pippy.fx.Node = graph.output(b) + graph: pippy.fx.Graph = pippy.fx.Graph() + x: pippy.fx.Node = graph.create_node("placeholder", "x") + b: pippy.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: pippy.fx.Node = graph.output(b) - neg : pippy.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) + neg: pippy.fx.Node = graph.call_function( + the_function=torch.neg, args=(x,) + ) _, *relu_args = b.args b.args = (neg, *relu_args) b.prepend(neg) @@ -2138,10 +2339,12 @@ def test_move_before(self): self.assertEqual(gm(input), torch.relu(torch.neg(input))) def test_prepend_self(self): - graph : pippy.fx.Graph = pippy.fx.Graph() - x : pippy.fx.Node = graph.create_node('placeholder', 'x') - b : pippy.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) - output : pippy.fx.Node = graph.output(b) + graph: pippy.fx.Graph = pippy.fx.Graph() + x: pippy.fx.Node = graph.create_node("placeholder", "x") + b: pippy.fx.Node = graph.create_node( + "call_function", target=torch.relu, args=(x,) + ) + output: pippy.fx.Node = graph.output(b) b.prepend(b) x.append(b) @@ -2154,7 +2357,9 @@ def test_erase_node_error(self): for node in traced.graph.nodes: # Test deleting with uses both in another Node and at the output if node.target in [operator.add, torch.relu]: - with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): + with self.assertRaisesRegex( + RuntimeError, "but it still had .* users in the graph" + ): traced.graph.erase_node(node) def test_copy_it(self): @@ -2172,7 +2377,7 @@ def test_get_torch_func_signature(self): def test_find_uses(self): graph = pippy.fx.Graph() - x = pippy.fx.Proxy(graph.placeholder('x')) + x = pippy.fx.Proxy(graph.placeholder("x")) y = torch.relu(x) z = x + x @@ -2182,7 +2387,7 @@ def test_find_uses(self): users_of_x = x.node.users self.assertEqual(len(users_of_x), 3) - expected_ops = set(['relu', 'add', 'neg']) + expected_ops = set(["relu", "add", "neg"]) for use in users_of_x: assert any(use.name.startswith(prefix) for prefix in expected_ops) @@ -2202,20 +2407,22 @@ def forward(self, x): output_node = combined_graph.graph_copy(inline_into.graph, {}) input_node = list(to_inline.graph.nodes)[0] - assert input_node and input_node.op == 'placeholder' + assert input_node and input_node.op == "placeholder" - val_map = {input_node : output_node} + val_map = {input_node: output_node} output = combined_graph.graph_copy(to_inline.graph, val_map) combined_graph.output(output) - combined_module = pippy.fx.GraphModule(torch.nn.Module(), combined_graph) + combined_module = pippy.fx.GraphModule( + torch.nn.Module(), combined_graph + ) input = torch.rand(3, 4) self.assertEqual(combined_module(input), input.relu().neg()) def test_multi_insert_point(self): graph = pippy.fx.Graph() - x = pippy.fx.Proxy(graph.placeholder('x')) + x = pippy.fx.Proxy(graph.placeholder("x")) relu = torch.relu(x) with graph.inserting_before(relu.node): @@ -2225,13 +2432,13 @@ def test_multi_insert_point(self): graph.output((relu.node, z.node)) graph.lint() - expected_ops = ['x', 'neg', 'tanh', 'relu'] + expected_ops = ["x", "neg", "tanh", "relu"] for node, expected in zip(graph.nodes, expected_ops): assert expected in node.name def test_reassign_args_kwargs_uses(self): graph = pippy.fx.Graph() - x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) + x, y = Proxy(graph.placeholder("x")), Proxy(graph.placeholder("y")) z = x + y zed = z + z + z graph.output(zed.node) @@ -2254,7 +2461,7 @@ def foo(x, y): def test_trace_dict_int_keys(self): class ModWithDictArg(torch.nn.Module): - def forward(self, d : Dict[int, torch.Tensor]): + def forward(self, d: Dict[int, torch.Tensor]): return d[42] class CallsModWithDict(torch.nn.Module): @@ -2266,14 +2473,16 @@ def forward(self, x): return self.m({42: x}) class MyTracer(pippy.fx.Tracer): - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + def is_leaf_module( + self, m: torch.nn.Module, module_qualified_name: str + ) -> bool: return isinstance(m, ModWithDictArg) traced_graph = MyTracer().trace(CallsModWithDict()) def test_trace_dict_proxy_keys(self): class ModWithDictArg(torch.nn.Module): - def forward(self, d : Dict[torch.Tensor, torch.Tensor]): + def forward(self, d: Dict[torch.Tensor, torch.Tensor]): return d[42] class CallsModWithDict(torch.nn.Module): @@ -2285,10 +2494,12 @@ def forward(self, x): return self.m({x: x}) class MyTracer(pippy.fx.Tracer): - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + def is_leaf_module( + self, m: torch.nn.Module, module_qualified_name: str + ) -> bool: return isinstance(m, ModWithDictArg) - with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'): + with self.assertRaisesRegex(RuntimeError, "cannot contain a Node"): traced_graph = MyTracer().trace(CallsModWithDict()) def test_module_deepcopy_edit_nodes(self): @@ -2328,7 +2539,7 @@ def forward(self, x): return self.a.b, self.a.b.t(), self.a.b.view(12) traced = pippy.fx.symbolic_trace(Foo()) - assert(all('constant' not in node.target for node in traced.graph.nodes)) + assert all("constant" not in node.target for node in traced.graph.nodes) def test_single_default_arg(self): class M(torch.nn.Module): @@ -2390,17 +2601,21 @@ def forward(self, x): def test_update_args_kwargs_yells_at_you(self): symtraced = symbolic_trace(SimpleTest()) node = next(iter(symtraced.graph.nodes)) - with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'): + with self.assertRaisesRegex(AttributeError, "__update_args_kwargs"): node.__update_args_kwargs((), {}) def test_torchbind_class_attribute_in_fx(self): if IS_FBCODE or IS_WINDOWS or IS_MACOS: - self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping") + self.skipTest( + "torch.classes._TorchScriptTesting._StackString is registered, skipping" + ) class FooBar1234(torch.nn.Module): def __init__(self): super(FooBar1234, self).__init__() - self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"]) + self.f = torch.classes._TorchScriptTesting._StackString( + ["3", "4"] + ) def forward(self): return self.f.top() @@ -2410,7 +2625,9 @@ def forward(self): def test_torchbind_class_attribute_in_fx_tensor_arg(self): if IS_FBCODE or IS_WINDOWS or IS_MACOS: - self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping") + self.skipTest( + "torch.classes._TorchScriptTesting._ReLUClass is registered, skipping" + ) class FooBar2341(torch.nn.Module): def __init__(self): @@ -2426,7 +2643,7 @@ def forward(self, x): input = torch.randn(3, 4) self.assertEqual(traced(input), m(input)) - self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) + self.assertTrue(any(n.op == "call_method" for n in traced.graph.nodes)) def test_script_method_trace(self): class Scripted(torch.nn.Module): @@ -2446,7 +2663,7 @@ def forward(self, x): input = torch.randn(3, 4) self.assertEqual(traced(input), h(input)) - self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) + self.assertTrue(any(n.op == "call_method" for n in traced.graph.nodes)) def test_namedtuple_return_trace(self): class NamedTupReturn(torch.nn.Module): @@ -2475,7 +2692,9 @@ def forward(self, inp): for node in traced.graph.nodes: if node.op == "placeholder": ph = node - elif node.op == "call_function" and node.target == wrapped_named_tup: + elif ( + node.op == "call_function" and node.target == wrapped_named_tup + ): node.update_arg(0, Pair(ph, 1.2)) node.update_kwarg("p2", Pair(3.4, ph)) call_func = node @@ -2484,7 +2703,9 @@ def forward(self, inp): self.assertTrue(isinstance(call_func.args[0], Pair)) self.assertTrue(isinstance(call_func.kwargs["p2"], Pair)) self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)") - self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)") + self.assertEqual( + _format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)" + ) traced.graph.eliminate_dead_code() traced.recompile() @@ -2508,11 +2729,11 @@ def getitem_inner(self): class GetItemBase(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('pe', torch.randn(8, 8)) + self.register_buffer("pe", torch.randn(8, 8)) class GetItem1(GetItemBase): def forward(self, x): - return self.pe[:, :x.size(0)] + return self.pe[:, : x.size(0)] class GetItem2(GetItemBase): def forward(self, x): @@ -2526,8 +2747,10 @@ def forward(self, x): self.checkGraphModule(GetItem2(), [torch.zeros(4)]) self.checkGraphModule(GetItem3(), [torch.zeros(4)]) - @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1", - "Will be checked in test_getitem_subproc") + @unittest.skipUnless( + os.environ.get("FX_PATCH_GETITEM") == "1", + "Will be checked in test_getitem_subproc", + ) def test_getitem(self): self.getitem_inner() @@ -2539,16 +2762,18 @@ def test_getitem_subproc(self): proc.join() self.assertEqual(proc.exitcode, 0) - def test_user_friendly_call_provenance_with_function(self): def fn(x): return wrapper_fn(x) traced = pippy.fx.symbolic_trace(fn) - with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " - "being compiled since it was called" - " from 'fn.forward'"): + with self.assertRaisesRegex( + RuntimeError, + "'wrapper_fn' is " + "being compiled since it was called" + " from 'fn.forward'", + ): scripted = torch.jit.script(traced) def test_user_friendly_call_provenance_with_module(self): @@ -2558,20 +2783,25 @@ def forward(self, x): traced = pippy.fx.symbolic_trace(M()) - with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " - "being compiled since it was called" - " from 'M.forward'"): + with self.assertRaisesRegex( + RuntimeError, + "'wrapper_fn' is " + "being compiled since it was called" + " from 'M.forward'", + ): scripted = torch.jit.script(traced) def test_snake_case(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() - self.activations = torch.nn.ModuleDict([ - ["snake_case", torch.nn.ReLU()], - ["PascalCase", torch.nn.LeakyReLU()], - ["ALL_CAPS", torch.nn.PReLU()] - ]) + self.activations = torch.nn.ModuleDict( + [ + ["snake_case", torch.nn.ReLU()], + ["PascalCase", torch.nn.LeakyReLU()], + ["ALL_CAPS", torch.nn.PReLU()], + ] + ) def forward(self, x): a = self.activations["snake_case"](x) @@ -2584,7 +2814,7 @@ def forward(self, x): check = [ ("activations_snake_case", "activations.snake_case"), ("activations_pascal_case", "activations.PascalCase"), - ("activations_all_caps", "activations.ALL_CAPS") + ("activations_all_caps", "activations.ALL_CAPS"), ] i = 0 @@ -2600,6 +2830,7 @@ def forward(self, x): def test_no_mutation(self): from pippy.fx.immutable_collections import immutable_list + x = immutable_list([3, 4]) with self.assertRaisesRegex(NotImplementedError, "new_args"): x[0] = 4 @@ -2611,12 +2842,13 @@ def forward(self, x, y): return 2 * x else: return x + mod = Foo() - mod_true = symbolic_trace(mod, concrete_args={'y': True}) - mod_false = symbolic_trace(mod, concrete_args={'y': False}) + mod_true = symbolic_trace(mod, concrete_args={"y": True}) + mod_false = symbolic_trace(mod, concrete_args={"y": False}) self.assertEqual(mod_true(3, True), 6) print(mod_true.code) - assert(any([i.target == torch._assert for i in mod_true.graph.nodes])) + assert any([i.target == torch._assert for i in mod_true.graph.nodes]) with self.assertRaises(AssertionError): mod_true(3, False) self.assertEqual(mod_false(3, False), 3) @@ -2626,7 +2858,7 @@ def forward(self, x, y): def f_higher(a, f): return f(a) - nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2}) + nf = symbolic_trace(f_higher, concrete_args={"f": lambda x: x * 2}) self.assertEqual(nf(3, lambda x: x * 2), 6) def test_custom_traceback_raised_when_exception_source_is_graphmodule(self): @@ -2642,8 +2874,9 @@ def forward(self, x): out = [n for n in traced.graph.nodes if n.op == "output"][-1] with traced.graph.inserting_before(out): - relu_out = traced.graph.call_method(method_name='relu', - args=(out.args[0],)) + relu_out = traced.graph.call_method( + method_name="relu", args=(out.args[0],) + ) out.args = (relu_out,) traced.recompile() @@ -2652,11 +2885,15 @@ def forward(self, x): with self.assertRaises(TypeError): traced(5) - self.assertRegex(captured[0], - r"Call using an FX-traced Module, line .* of the " - r"traced Module's generated forward function:") + self.assertRegex( + captured[0], + r"Call using an FX-traced Module, line .* of the " + r"traced Module's generated forward function:", + ) - def test_custom_traceback_not_raised_when_exception_source_is_submodule(self): + def test_custom_traceback_not_raised_when_exception_source_is_submodule( + self, + ): class M(torch.nn.Module): def __init__(self): super().__init__() @@ -2674,9 +2911,11 @@ def forward(self, x): except RuntimeError: captured = traceback.format_exc() - self.assertNotRegex(captured, - r"Call using an FX-traced Module, line .* of the " - r"traced Module's generated forward function:") + self.assertNotRegex( + captured, + r"Call using an FX-traced Module, line .* of the " + r"traced Module's generated forward function:", + ) def test_graph_module_replicate_for_dp(self): class Foo(torch.nn.Module): @@ -2727,7 +2966,9 @@ class MyTracer(pippy.fx.Tracer): check_mutable_operations = True tracer = MyTracer() - with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'): + with self.assertRaisesRegex( + RuntimeError, "mutable operation aten::sigmoid.out" + ): traced_graph = tracer.trace(foo) def test_ast_rewriter_reassigns_submodules(self): @@ -2783,27 +3024,35 @@ def to_trace(y): def test_profiler_ranges_side_effect(self): g = pippy.fx.Graph() - handle = g.call_function(torch.ops.profiler._record_function_enter, ('test_range',)) + handle = g.call_function( + torch.ops.profiler._record_function_enter, ("test_range",) + ) g.call_function(torch.ops.profiler._record_function_exit, (handle,)) g.output(None) found_targets = {} for node in g.nodes: - if node.op == 'call_function': + if node.op == "call_function": found_targets.setdefault(node.target) self.assertEqual( list(found_targets.keys()), - [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit] + [ + torch.ops.profiler._record_function_enter, + torch.ops.profiler._record_function_exit, + ], ) g.eliminate_dead_code() found_targets = {} for node in g.nodes: - if node.op == 'call_function': + if node.op == "call_function": found_targets.setdefault(node.target) self.assertEqual( list(found_targets.keys()), - [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit] + [ + torch.ops.profiler._record_function_enter, + torch.ops.profiler._record_function_exit, + ], ) def test_ast_rewriter_wrapped_via_decorator(self): @@ -2896,8 +3145,9 @@ def forward(self, x): conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1] with a.graph.inserting_before(conv): with warnings.catch_warnings(record=True) as w: - dropout = a.graph.call_module(module_name="net_b.net_c.dropout", - args=conv.args) + dropout = a.graph.call_module( + module_name="net_b.net_c.dropout", args=conv.args + ) self.assertEqual(len(w), 0) conv.replace_all_uses_with(dropout) @@ -2908,12 +3158,14 @@ def module_exists(gm: GraphModule, path: str) -> bool: return any(path == name for name, _ in gm.named_modules()) def parameter_exists(gm: GraphModule, path: str) -> bool: - return (any(path == name for name, _ in gm.named_parameters()) - and any(path == name for name in gm.state_dict().keys())) + return any( + path == name for name, _ in gm.named_parameters() + ) and any(path == name for name in gm.state_dict().keys()) def buffer_exists(gm: GraphModule, path: str) -> bool: - return (any(path == name for name, _ in gm.named_buffers()) - and any(path == name for name in gm.state_dict().keys())) + return any(path == name for name, _ in gm.named_buffers()) and any( + path == name for name in gm.state_dict().keys() + ) # Test that we added the "dropout" submodule self.assertTrue(module_exists(a, "net_b.net_c.dropout")) @@ -2937,23 +3189,26 @@ def buffer_exists(gm: GraphModule, path: str) -> bool: self.assertFalse(module_exists(a, "net_b.net_c.conv")) # Test `get_submodule` with a deleted submodule - with self.assertRaisesRegex(AttributeError, "has no attribute " - "`conv`"): + with self.assertRaisesRegex( + AttributeError, "has no attribute " "`conv`" + ): self.assertIsNone(a.get_submodule("net_b.net_c.conv")) # Test `get_attr` warnings cat = [n for n in a.graph.nodes if n.target == torch.cat][-1] with a.graph.inserting_before(cat): - with warnings.catch_warnings(record=True) as w: param = a.graph.get_attr(qualified_name="net_b.net_c.param") self.assertEqual(len(w), 0) - with self.assertWarnsRegex(UserWarning, "Attempted to " - "insert a get_attr Node with no " - "underlying reference in the " - "owning GraphModule"): + with self.assertWarnsRegex( + UserWarning, + "Attempted to " + "insert a get_attr Node with no " + "underlying reference in the " + "owning GraphModule", + ): bad_param = a.graph.get_attr(qualified_name="net_b.param") a.graph.erase_node(bad_param) @@ -2965,20 +3220,22 @@ def buffer_exists(gm: GraphModule, path: str) -> bool: # Test `get_parameter` a.get_parameter("net_b.net_c.param") - with self.assertRaisesRegex(AttributeError, "is not an " - "nn.Parameter"): + with self.assertRaisesRegex( + AttributeError, "is not an " "nn.Parameter" + ): a.get_parameter("net_b.buf") - with self.assertRaisesRegex(AttributeError, "has no attribute " - "`param`"): + with self.assertRaisesRegex( + AttributeError, "has no attribute " "`param`" + ): a.get_parameter("net_b.param") # Test `get_buffer` a.get_buffer("net_b.buf") - with self.assertRaisesRegex(AttributeError, "is not a " - "buffer"): + with self.assertRaisesRegex(AttributeError, "is not a " "buffer"): a.get_buffer("net_b.net_c.param") - with self.assertRaisesRegex(AttributeError, "has no attribute " - "`buf`"): + with self.assertRaisesRegex( + AttributeError, "has no attribute " "`buf`" + ): a.get_buffer("net_b.net_c.buf") # Test non-nested attributes @@ -3030,7 +3287,9 @@ def forward(self, x): model = Model() class MyCustomTracer(pippy.fx.Tracer): - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + def is_leaf_module( + self, m: torch.nn.Module, module_qualified_name: str + ) -> bool: return module_qualified_name == "submod" inputs = torch.randn(1, 10) @@ -3044,7 +3303,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) - self.register_buffer('buffer', torch.ones(1)) + self.register_buffer("buffer", torch.ones(1)) def forward(self, x): return self.l1(x) + self.buffer @@ -3054,9 +3313,7 @@ def forward(self, x): weight = torch.tensor([[1.0]], requires_grad=True) bias = torch.tensor([0.0], requires_grad=True) buffer = torch.tensor([0.0]) - parameters = {'l1.weight': weight, - 'l1.bias': bias, - 'buffer': buffer} + parameters = {"l1.weight": weight, "l1.bias": bias, "buffer": buffer} fx_module = pippy.fx.symbolic_trace(module) res = _stateless.functional_call(fx_module, parameters, x) res.backward() @@ -3125,7 +3382,11 @@ def is_leaf_module(self, module, name): # Test graphmodule/submodule a is not inlined. self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] + match = [ + n + for n in gm.graph.nodes + if n.op == "call_module" and n.target == "a" + ] self.assertTrue(len(match) == 1) # Test submodule b is not treated as leaf. @@ -3154,12 +3415,20 @@ def is_leaf_module(self, module, name): # Test graphmodule/submodule a is not inlined. self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] + match = [ + n + for n in gm.graph.nodes + if n.op == "call_module" and n.target == "a" + ] self.assertTrue(len(match) == 1) # Test submodule b is leaf: self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) - match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"] + match = [ + n + for n in gm.graph.nodes + if n.op == "call_module" and n.target == "b" + ] self.assertTrue(len(match) == 1) # Test b.__call__ was run @@ -3180,11 +3449,19 @@ def is_leaf_module(self, module, name): gm.recompile() self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"] + match = [ + n + for n in gm.graph.nodes + if n.op == "call_module" and n.target == "a" + ] self.assertTrue(len(match) == 1) self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) - match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"] + match = [ + n + for n in gm.graph.nodes + if n.op == "call_module" and n.target == "b" + ] self.assertTrue(len(match) == 1) def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool): @@ -3206,7 +3483,9 @@ def forward(self, x): orig_buff = mod_traced.get_buffer("my_buff") orig_param = mod_traced.get_parameter("my_param") mod_traced_new = GraphModule( - {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod, + {"my_buff": orig_buff, "my_param": orig_param} + if use_dict_init + else mod, mod_traced.graph, ) @@ -3252,12 +3531,14 @@ def __call__(self, x: torch.Tensor): return torch.add(x, x) class M(torch.nn.Module): - def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor': + def forward(self, x: "torch.Tensor", a: "A") -> "torch.Tensor": return a(x) self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self): + def test_annotations_with_non_torch_reference_and_no_internal_forward_references( + self, + ): class A: def __call__(self, x: torch.Tensor): return torch.add(x, x) @@ -3268,22 +3549,26 @@ def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor: self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - def test_annotations_with_non_torch_reference_and_internal_forward_references(self): + def test_annotations_with_non_torch_reference_and_internal_forward_references( + self, + ): class A: def __call__(self, x: torch.Tensor): return torch.add(x, x) class M(torch.nn.Module): - def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor': + def forward(self, x: List["torch.Tensor"], a: A) -> "torch.Tensor": return a(x)[0] self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature " - "`annotations` is not defined in Python <3.7") + @unittest.skipIf( + sys.version_info < (3, 7), + "`__future__` feature " "`annotations` is not defined in Python <3.7", + ) def test_annotation_with_future(self): try: - import fx.test_future # noqa: F401 + import fx.test_future # noqa: F401 finally: del sys.modules["__future__"] @@ -3299,24 +3584,29 @@ def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]): traced(x, y) - FileCheck().check("_Tuple[()]") \ - .check("typing_Tuple[str,typing_Tuple[()]]") \ - .run(traced.code) + FileCheck().check("_Tuple[()]").check( + "typing_Tuple[str,typing_Tuple[()]]" + ).run(traced.code) scripted = torch.jit.script(traced) scripted(x, y) - FileCheck().check("Tuple[()]") \ - .check("Tuple[str, Tuple[()]]") \ - .run(scripted.code) + FileCheck().check("Tuple[()]").check("Tuple[str, Tuple[()]]").run( + scripted.code + ) - @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108") - @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10") + @unittest.skipIf( + IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108" + ) + @unittest.skipIf( + sys.version_info >= (3, 10), "Does not work on Python-3.10" + ) def test_assert(self): def f(x): assert x > 1 return x + 1 + try: pippy.fx.proxy.TracerBase.trace_asserts = True traced = symbolic_trace(f) @@ -3344,7 +3634,7 @@ def f_dict_list_map(x): return new_dict def f_dict_add(x): - return x['a'] + sum(x['z']) + return x["a"] + sum(x["z"]) def f_namedtuple_add(x): return x.x + x.y @@ -3368,42 +3658,55 @@ def f_return_custom(x): tests = [ (f_sum, [PH, PH, PH]), (f_sum, []), - (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}), - (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}), + (f_sum_dict, {"a": PH, "b": PH, "c": PH}), + (f_dict_list_map, {"a": (PH, PH), "b": [PH], "c": []}), (f_dict_list_map, {5: (PH, PH, PH)}), - (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}), - (f_dict_add, {'a': PH, 'z': []}), + (f_dict_add, {"a": PH, "z": (PH, PH, PH)}), + (f_dict_add, {"a": PH, "z": []}), (f_custom, Foo(PH, PH)), (f_custom, Foo(PH, 3)), - (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)), + (f_custom_dict, Foo({"a": PH, "b": PH}, PH)), # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees (f_namedtuple_add, Point(PH, PH)), ] def verify_pytree(f, inp): - val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp) + val = pytree.tree_map( + lambda x: torch.randn(3) if x == PH else x, inp + ) num_flat_args = len([i == PH for i in pytree.tree_flatten(inp)[0]]) orig_out = f(val) - nf = symbolic_trace(f, concrete_args={'x': inp}) + nf = symbolic_trace(f, concrete_args={"x": inp}) self.assertEqual(nf(val), orig_out) bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) bare_fx.graph.set_codegen(CodeGen()) bare_fx.recompile() - self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out) + self.assertEqual( + nf.graph.process_outputs( + bare_fx(*nf.graph.process_inputs(val)) + ), + orig_out, + ) assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args) + assert ( + sum([i.op == "placeholder" for i in nf.graph.nodes]) + == num_flat_args + ) nf = symbolic_trace(nf) self.assertEqual(nf(val), orig_out) assert "tree_flatten_spec" not in nf.code - assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1) + assert sum([i.op == "placeholder" for i in nf.graph.nodes]) == 1 - nf = symbolic_trace(nf, concrete_args={'x': inp}) + nf = symbolic_trace(nf, concrete_args={"x": inp}) self.assertEqual(nf(val), orig_out) assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args) + assert ( + sum([i.op == "placeholder" for i in nf.graph.nodes]) + == num_flat_args + ) pickled = pickle.dumps(nf) nf = pickle.loads(pickled) @@ -3415,11 +3718,11 @@ def verify_pytree(f, inp): def test_pytree_concrete(self): def f(b, a): if b: - return a['a'] + return a["a"] else: - return a['z'] + return a["z"] - inp = {'a': {'a': PH, 'z': PH}, 'b': True} + inp = {"a": {"a": PH, "z": PH}, "b": True} nf = symbolic_trace(f, concrete_args=inp) val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp) self.assertEqual(nf(**val), f(**val)) @@ -3436,10 +3739,10 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return lst_unpack def additional_globals(self): - return [('List', typing.List)] + return [("List", typing.List)] def process_inputs(self, *inputs): - assert(len(inputs) == 1) + assert len(inputs) == 1 return inputs[0] def f(a, b): @@ -3457,7 +3760,10 @@ def f(a, b): bare_fx.recompile() self.assertEqual(nf(vals), f(*vals)) - self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals)) + self.assertEqual( + nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), + f(*vals), + ) ts_f = torch.jit.script(nf) self.assertEqual(nf(vals), ts_f(vals)) @@ -3471,10 +3777,10 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return lst_unpack def additional_globals(self): - return [('List', typing.List)] + return [("List", typing.List)] def process_inputs(self, *inputs): - assert(len(inputs) == 1) + assert len(inputs) == 1 return inputs[0] def f(a, b): @@ -3500,14 +3806,14 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: return lst_unpack def additional_globals(self): - return [('List', typing.List)] + return [("List", typing.List)] def process_inputs(self, *inputs): - assert(len(inputs) == 1) + assert len(inputs) == 1 return inputs[0] def generate_output(self, output_args): - return f'return list({repr(output_args)})' + return f"return list({repr(output_args)})" def process_outputs(self, outputs): return list(outputs) @@ -3544,10 +3850,15 @@ def fn(x, y): tracer_after = copy.deepcopy(tracer) self.assertEqual(str(tracer.graph), str(tracer_after.graph)) - self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph)) + self.assertTrue( + not hasattr(tracer_before, "graph") + or str(tracer.graph) != str(tracer_before.graph) + ) + def run_getitem_target(): from pippy.fx._symbolic_trace import _wrapped_methods_to_patch + _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) try: TestFX().getitem_inner() @@ -3559,34 +3870,46 @@ class TestOperatorSignatures(JitTestCase): def setUp(self): # Checking for mutable operations whil tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = pippy.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + pippy.fx.proxy.TracerBase.check_mutable_operations + ) pippy.fx.proxy.TracerBase.check_mutable_operations = True def tearDown(self): - pippy.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + pippy.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) @onlyCPU @ops(op_db, allowed_dtypes=(torch.float,)) def test_get_torch_func_signature_exhaustive(self, device, dtype, op): if not isinstance(op.op, types.BuiltinFunctionType): - raise unittest.SkipTest("This path doesn't work on Python functions") + raise unittest.SkipTest( + "This path doesn't work on Python functions" + ) sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) schemas = get_signature_for_torch_op(op.op) if not schemas: - raise RuntimeError('No Schemas Returned') + raise RuntimeError("No Schemas Returned") for sample_input in sample_inputs_itr: # Iterate through overloads until we hit a match. If we exit this # loop via `else`, we haven't found a match for schema in schemas: try: - bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs) + bound_args = schema.bind( + sample_input.input, + *sample_input.args, + **sample_input.kwargs, + ) bound_args.apply_defaults() op(*bound_args.args, **bound_args.kwargs) break except TypeError as e: pass else: - raise RuntimeError(f'Did not match any schemas for op {op.name}!') + raise RuntimeError( + f"Did not match any schemas for op {op.name}!" + ) class TestFXAPIBackwardCompatibility(JitTestCase): @@ -3596,13 +3919,16 @@ def setUp(self): # Checking for mutable operations whil tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = pippy.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + pippy.fx.proxy.TracerBase.check_mutable_operations + ) pippy.fx.proxy.TracerBase.check_mutable_operations = True def tearDown(self): super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag - + pippy.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) def _fn_to_stable_annotation_str(self, obj): """ @@ -3614,88 +3940,102 @@ def _fn_to_stable_annotation_str(self, obj): signature = inspect.signature(obj) - sig_str = f'{fn_name}{signature}' + sig_str = f"{fn_name}{signature}" arg_strs = [] for k, v in signature.parameters.items(): - maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\ - if v.annotation is not inspect.Signature.empty else '' + maybe_type_annotation = ( + f": {self._annotation_type_to_stable_str(v.annotation, sig_str)}" + if v.annotation is not inspect.Signature.empty + else "" + ) def default_val_str(val): if isinstance(val, (tuple, list)): - str_pieces = ['(' if isinstance(val, tuple) else '['] - str_pieces.append(', '.join(default_val_str(v) for v in val)) + str_pieces = ["(" if isinstance(val, tuple) else "["] + str_pieces.append( + ", ".join(default_val_str(v) for v in val) + ) if isinstance(val, tuple) and len(str_pieces) == 2: - str_pieces.append(',') - str_pieces.append(')' if isinstance(val, tuple) else ']') - return ''.join(str_pieces) + str_pieces.append(",") + str_pieces.append(")" if isinstance(val, tuple) else "]") + return "".join(str_pieces) # Need to fix up some default value strings. # First case: modules. Default module `repr` contains the FS path of the module. # Don't leak that if isinstance(val, types.ModuleType): - return f'' + return f"" # Second case: callables. Callables (such as lambdas) encode their address in # their string repr. Don't do that if callable(val): - return f'' + return f"" return str(val) if v.default is not inspect.Signature.empty: - default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'" - maybe_default = f' = {default_val_str}' + default_val_str = ( + default_val_str(v.default) + if not isinstance(v.default, str) + else f"'{v.default}'" + ) + maybe_default = f" = {default_val_str}" else: - maybe_default = '' - maybe_stars = '' + maybe_default = "" + maybe_stars = "" if v.kind == inspect.Parameter.VAR_POSITIONAL: - maybe_stars = '*' + maybe_stars = "*" elif v.kind == inspect.Parameter.VAR_KEYWORD: - maybe_stars = '**' - arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}') + maybe_stars = "**" + arg_strs.append( + f"{maybe_stars}{k}{maybe_type_annotation}{maybe_default}" + ) - return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\ - if signature.return_annotation is not inspect.Signature.empty else '' + return_annot = ( + f" -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}" + if signature.return_annotation is not inspect.Signature.empty + else "" + ) return f'{fn_name}({", ".join(arg_strs)}){return_annot}' def _annotation_type_to_stable_str(self, t, sig_str): if t is inspect.Signature.empty: - return '' + return "" # Forward ref if isinstance(t, str): return f"'{t}'" - if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef): + if hasattr(typing, "ForwardRef") and isinstance(t, typing.ForwardRef): return t.__forward_arg__ - if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef): + if hasattr(typing, "_ForwardRef") and isinstance(t, typing._ForwardRef): return t.__forward_arg__ trivial_mappings = { - str : 'str', - int : 'int', - float: 'float', - bool: 'bool', - torch.dtype: 'torch.dtype', - torch.Tensor: 'torch.Tensor', - torch.device: 'torch.device', - torch.memory_format: 'torch.memory_format', - slice: 'slice', - torch.nn.Module: 'torch.nn.modules.module.Module', - pippy.fx.Graph : 'pippy.fx.graph.Graph', - pippy.fx.Node : 'pippy.fx.node.Node', - pippy.fx.Proxy : 'pippy.fx.proxy.Proxy', - pippy.fx.node.Target : 'pippy.fx.node.Target', - pippy.fx.node.Argument : 'pippy.fx.node.Argument', - pippy.fx.graph.PythonCode : 'pippy.fx.graph.PythonCode', - pippy.fx.graph_module.GraphModule: 'pippy.fx.graph_module.GraphModule', - pippy.fx.subgraph_rewriter.Match: 'pippy.fx.subgraph_rewriter.Match', - Ellipsis : '...', - typing.Any: 'Any', - type(None): 'NoneType', - None: 'None', - typing.Iterator: 'Iterator', + str: "str", + int: "int", + float: "float", + bool: "bool", + torch.dtype: "torch.dtype", + torch.Tensor: "torch.Tensor", + torch.device: "torch.device", + torch.memory_format: "torch.memory_format", + slice: "slice", + torch.nn.Module: "torch.nn.modules.module.Module", + pippy.fx.Graph: "pippy.fx.graph.Graph", + pippy.fx.Node: "pippy.fx.node.Node", + pippy.fx.Proxy: "pippy.fx.proxy.Proxy", + pippy.fx.node.Target: "pippy.fx.node.Target", + pippy.fx.node.Argument: "pippy.fx.node.Argument", + pippy.fx.graph.PythonCode: "pippy.fx.graph.PythonCode", + pippy.fx.graph_module.GraphModule: "pippy.fx.graph_module.GraphModule", + pippy.fx.subgraph_rewriter.Match: "pippy.fx.subgraph_rewriter.Match", + Ellipsis: "...", + typing.Any: "Any", + type(None): "NoneType", + None: "None", + typing.Iterator: "Iterator", } mapping = trivial_mappings.get(t, None) @@ -3703,7 +4043,7 @@ def _annotation_type_to_stable_str(self, t, sig_str): return mapping # Handle types with contained types - contained = getattr(t, '__args__', None) or [] + contained = getattr(t, "__args__", None) or [] # Callables contain a bare List for arguments contained = t if isinstance(t, list) else contained @@ -3712,39 +4052,63 @@ def _annotation_type_to_stable_str(self, t, sig_str): if all(isinstance(ct, typing.TypeVar) for ct in contained): contained = [] - contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained] - contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else '' - + contained_type_annots = [ + self._annotation_type_to_stable_str(ct, sig_str) for ct in contained + ] + contained_type_str = ( + f'[{", ".join(contained_type_annots)}]' + if len(contained_type_annots) > 0 + else "" + ) - origin = getattr(t, '__origin__', None) + origin = getattr(t, "__origin__", None) if origin is None: # Unbound types don't have `__origin__` in some Python versions, so fix that up here. - origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin + origin = ( + t + if t + in { + typing.Tuple, + typing.Union, + typing.Dict, + typing.List, + typing.Type, + typing.Callable, + } + else origin + ) if origin in {tuple, typing.Tuple}: - return f'Tuple{contained_type_str}' + return f"Tuple{contained_type_str}" if origin in {typing.Union}: # Annoying hack to detect Optional - if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)): - not_none_param = contained[0] if contained[0] is not type(None) else contained[1] - return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]' - return f'Union{contained_type_str}' + if len(contained) == 2 and (contained[0] is type(None)) ^ ( + contained[1] is type(None) + ): + not_none_param = ( + contained[0] + if contained[0] is not type(None) + else contained[1] + ) + return f"Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]" + return f"Union{contained_type_str}" if origin in {dict, typing.Dict}: - return f'Dict{contained_type_str}' + return f"Dict{contained_type_str}" if origin in {list, typing.List}: - return f'List{contained_type_str}' + return f"List{contained_type_str}" if origin in {type, typing.Type}: - return f'Type{contained_type_str}' + return f"Type{contained_type_str}" if isinstance(t, typing.Callable): if len(contained) > 0 and contained[0] is not Ellipsis: return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]' else: - return f'Callable{contained_type_str}' - - raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.' - f'Please add support for this type and confirm with the ' - f'FX team that your signature change is valid.') + return f"Callable{contained_type_str}" + raise RuntimeError( + f"Unrecognized type {t} used in BC-compatible type signature {sig_str}." + f"Please add support for this type and confirm with the " + f"FX team that your signature change is valid." + ) def test_function_back_compat(self): """ @@ -3764,14 +4128,19 @@ def test_function_back_compat(self): signature_strs.sort() try: - self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures') + self.assertExpected( + "\n".join(signature_strs) + "\n", + "fx_backcompat_function_signatures", + ) except AssertionError as e: - msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \ - f"as backwards-compatible has experienced a signature change. See the " \ - f"above exception context for more information. If this change was " \ - f"unintended, please revert it. If it was intended, check with the FX " \ - f"team to ensure that the proper deprecation protocols have been followed " \ - f"and subsequently --accept the change." + msg = ( + f"{e}\n****** ERROR ******\nAn FX function that has been marked " + f"as backwards-compatible has experienced a signature change. See the " + f"above exception context for more information. If this change was " + f"unintended, please revert it. If it was intended, check with the FX " + f"team to ensure that the proper deprecation protocols have been followed " + f"and subsequently --accept the change." + ) raise AssertionError(msg) def test_class_member_back_compat(self): @@ -3784,34 +4153,42 @@ def test_class_member_back_compat(self): for obj in _BACK_COMPAT_OBJECTS: if isinstance(obj, type): - public_members = [name for name in obj.__dict__ if not name.startswith('_')] - class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}') + public_members = [ + name for name in obj.__dict__ if not name.startswith("_") + ] + class_method_strs.append( + f"{torch.typename(obj)} {sorted(public_members)}" + ) class_method_strs.sort() try: - self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members') + self.assertExpected( + "\n".join(class_method_strs), "fx_backcompat_class_members" + ) except AssertionError as e: - msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \ - f"as backwards-compatible has experienced change in its public members. See the " \ - f"above exception context for more information. If this change was " \ - f"unintended, please revert it. If it was intended, check with the FX " \ - f"team to ensure that the proper deprecation protocols have been followed " \ - f"and subsequently --accept the change." + msg = ( + f"{e}\n****** ERROR ******\nAn FX class that has been marked " + f"as backwards-compatible has experienced change in its public members. See the " + f"above exception context for more information. If this change was " + f"unintended, please revert it. If it was intended, check with the FX " + f"team to ensure that the proper deprecation protocols have been followed " + f"and subsequently --accept the change." + ) raise AssertionError(msg) def test_public_api_surface(self): non_back_compat_objects = {} def check_symbols_have_bc_designation(m, prefix): - if not m.__name__.startswith('pippy.fx'): + if not m.__name__.startswith("pippy.fx"): return - if m.__name__.startswith('pippy.fx.experimental'): + if m.__name__.startswith("pippy.fx.experimental"): return for k, v in m.__dict__.items(): if v is m: continue - if k.startswith('_'): + if k.startswith("_"): continue if isinstance(v, types.ModuleType): check_symbols_have_bc_designation(v, prefix + [k]) @@ -3819,50 +4196,83 @@ def check_symbols_have_bc_designation(m, prefix): if v not in _MARKED_WITH_COMATIBLITY: non_back_compat_objects.setdefault(v) - check_symbols_have_bc_designation(pippy.fx, ['torch', 'fx']) - check_symbols_have_bc_designation(pippy.fx.passes, ['torch', 'fx', 'passes']) + check_symbols_have_bc_designation(pippy.fx, ["torch", "fx"]) + check_symbols_have_bc_designation( + pippy.fx.passes, ["torch", "fx", "passes"] + ) - non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()] + non_back_compat_strs = [ + torch.typename(obj) for obj in non_back_compat_objects.keys() + ] # Only want objects in pippy.fx non_back_compat_strs = [ - s for s in non_back_compat_strs if s.startswith('pippy.fx') and not s.startswith('pippy.fx.experimental')] + s + for s in non_back_compat_strs + if s.startswith("pippy.fx") + and not s.startswith("pippy.fx.experimental") + ] # Only want objects in public namespaces non_back_compat_strs = [ - s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))] + s + for s in non_back_compat_strs + if all(not atom.startswith("_") for atom in s.split(".")) + ] non_back_compat_strs.sort() if len(non_back_compat_strs) != 0: - raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a " - f"backwards-compatibility classification! Please decorate these " - f"API(s) with `@pippy.fx._compatibility.compatibility` to specify " - f"BC guarantees.") + raise AssertionError( + f"Public FX API(s) {non_back_compat_strs} introduced but not given a " + f"backwards-compatibility classification! Please decorate these " + f"API(s) with `@pippy.fx._compatibility.compatibility` to specify " + f"BC guarantees." + ) + class TestFunctionalTracing(JitTestCase): def setUp(self): super().setUp() # Checking for mutable operations whil tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = pippy.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + pippy.fx.proxy.TracerBase.check_mutable_operations + ) pippy.fx.proxy.TracerBase.check_mutable_operations = True def tearDown(self): super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + pippy.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) - IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary", - "has_torch_function_variadic", "handle_torch_function", - "boolean_dispatch") - TO_PATCH = {"has_torch_function": None, - "has_torch_function_unary": None, - "has_torch_function_variadic": None} + IGNORE_FUNCS = ( + "has_torch_function", + "has_torch_function_unary", + "has_torch_function_variadic", + "handle_torch_function", + "boolean_dispatch", + ) + TO_PATCH = { + "has_torch_function": None, + "has_torch_function_unary": None, + "has_torch_function_variadic": None, + } BUILT_IN_FUNC = (AssertionError, "") PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable") PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") - LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default") + LEN_ERROR = ( + RuntimeError, + r"'len' is not supported in symbolic tracing by default", + ) ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$") - CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow") - INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined") + CONTROL_FLOW = ( + TraceError, + r"symbolically traced variables cannot be used as inputs to control flow", + ) + INTERPOLATE_ARGS_CONFLICT = ( + ValueError, + r"only one of size or scale_factor should be defined", + ) MUTABLE = (RuntimeError, r"Tried to trace mutable operation") UNTRACEABLE_FUNCTIONALS = { @@ -3902,13 +4312,11 @@ def tearDown(self): "softplus": BUILT_IN_FUNC, "softshrink": BUILT_IN_FUNC, "threshold_": BUILT_IN_FUNC, - "adaptive_avg_pool2d": LEN_ERROR, "adaptive_avg_pool3d": LEN_ERROR, "adaptive_max_pool2d_with_indices": LEN_ERROR, "adaptive_max_pool3d_with_indices": LEN_ERROR, "instance_norm": CONTROL_FLOW, - "adaptive_max_pool1d": PROXY_ITERABLE, "adaptive_max_pool2d": PROXY_ITERABLE, "adaptive_max_pool3d": PROXY_ITERABLE, @@ -3917,19 +4325,16 @@ def tearDown(self): "max_pool1d": PROXY_ITERABLE, "max_pool2d": PROXY_ITERABLE, "max_pool3d": PROXY_ITERABLE, - "group_norm": PROXY_ITERATED, "lp_pool2d": PROXY_ITERATED, "max_unpool1d": PROXY_ITERATED, "max_unpool2d": PROXY_ITERATED, "max_unpool3d": PROXY_ITERATED, - "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, "layer_norm": ARG_TYPE_MISMATCH, "lp_pool1d": ARG_TYPE_MISMATCH, - "affine_grid": CONTROL_FLOW, "alpha_dropout": CONTROL_FLOW, "batch_norm": CONTROL_FLOW, @@ -3986,7 +4391,6 @@ def tearDown(self): "triplet_margin_with_distance_loss": CONTROL_FLOW, "unfold": CONTROL_FLOW, "upsample": CONTROL_FLOW, - "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, } @@ -4021,8 +4425,7 @@ def tearDown(self): "max_pool1d": PROXY_ITERATED, "max_pool2d": PROXY_ITERATED, "max_pool3d": PROXY_ITERATED, - - "group_norm": LEN_ERROR + "group_norm": LEN_ERROR, } @classmethod @@ -4032,7 +4435,7 @@ def _get_functional(cls): if not f.islower(): continue # Ignore internal functions - if f.startswith('_'): + if f.startswith("_"): continue # Ignore supporting functions if f in cls.IGNORE_FUNCS: @@ -4046,7 +4449,9 @@ def _get_functional(cls): sig = inspect.signature(fn) has_tensor_arg = False for arg, param in sig.parameters.items(): - if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor): + if isinstance(param.annotation, type) and issubclass( + param.annotation, torch.Tensor + ): has_tensor_arg = True if not has_tensor_arg: continue @@ -4058,10 +4463,12 @@ def _get_functional(cls): @classmethod def generate_test_func(cls, func_name, fn): - def functional_test(self): - if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \ - sys.version_info >= (3, 8) and sys.version_info < (3, 11): + if ( + func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 + and sys.version_info >= (3, 8) + and sys.version_info < (3, 11) + ): exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name] with self.assertRaisesRegex(exc, err): symbolic_trace(fn) @@ -4071,6 +4478,7 @@ def functional_test(self): symbolic_trace(fn) else: symbolic_trace(fn) + return functional_test @classmethod @@ -4083,7 +4491,6 @@ def generate_tests(cls): @classmethod def setUpClass(cls): - def no(*args, **kwargs): return False @@ -4096,27 +4503,33 @@ def tearDownClass(cls): for name in cls.TO_PATCH.keys(): setattr(torch.nn.functional, name, cls.TO_PATCH[name]) + TestFunctionalTracing.generate_tests() instantiate_device_type_tests(TestOperatorSignatures, globals()) + @skipIfNoTorchVision @skipIfSlowGradcheckEnv class TestVisionTracing(JitTestCase): def setUp(self): # Checking for mutable operations while tracing is feature flagged # Enable it in testing but not by default - self.orig_tracer_mutable_flag = pippy.fx.proxy.TracerBase.check_mutable_operations + self.orig_tracer_mutable_flag = ( + pippy.fx.proxy.TracerBase.check_mutable_operations + ) pippy.fx.proxy.TracerBase.check_mutable_operations = True def tearDown(self): - pippy.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + pippy.fx.proxy.TracerBase.check_mutable_operations = ( + self.orig_tracer_mutable_flag + ) PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") INCONSISTENT_TYPE = ( RuntimeError, - r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor" + r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor", ) UNTRACEABLE_MODELS = { @@ -4164,7 +4577,7 @@ def run_test(self): graph = symbolic_trace(model) else: out_transform = self.output_transform.get(name, lambda x: x) - graph : pippy.fx.GraphModule = symbolic_trace(model) + graph: pippy.fx.GraphModule = symbolic_trace(model) a = out_transform(model(x)) b = out_transform(graph(x)) self.assertEqual(a, b) @@ -4183,16 +4596,22 @@ def run_test(self): @classmethod def generate_classification_tests(cls): for k in torchvision_models.list_models(module=torchvision_models): - test_name = 'test_torchvision_models_' + k - x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224) + test_name = "test_torchvision_models_" + k + x = ( + torch.rand(1, 3, 299, 299) + if k in ["inception_v3"] + else torch.rand(1, 3, 224, 224) + ) kwargs = dict(num_classes=50) model_test = cls.generate_test_fn(k, x, kwargs) setattr(cls, test_name, model_test) @classmethod def generate_segmentation_tests(cls): - for k in torchvision_models.list_models(module=torchvision_models.segmentation): - test_name = 'test_torchvision_models_segmentation_' + k + for k in torchvision_models.list_models( + module=torchvision_models.segmentation + ): + test_name = "test_torchvision_models_segmentation_" + k x = torch.rand(1, 3, 32, 32) kwargs = dict(num_classes=10, pretrained_backbone=False) model_test = cls.generate_test_fn(k, x, kwargs) @@ -4200,8 +4619,10 @@ def generate_segmentation_tests(cls): @classmethod def generate_detection_tests(cls): - for k in torchvision_models.list_models(module=torchvision_models.detection): - test_name = 'test_torchvision_models_detection_' + k + for k in torchvision_models.list_models( + module=torchvision_models.detection + ): + test_name = "test_torchvision_models_detection_" + k x = [torch.rand(3, 300, 300)] kwargs = dict(num_classes=10, pretrained_backbone=False) model_test = cls.generate_test_fn(k, x, kwargs) @@ -4209,8 +4630,10 @@ def generate_detection_tests(cls): @classmethod def generate_video_tests(cls): - for k in torchvision_models.list_models(module=torchvision_models.video): - test_name = 'test_torchvision_models_video_' + k + for k in torchvision_models.list_models( + module=torchvision_models.video + ): + test_name = "test_torchvision_models_video_" + k x = ( torch.rand(1, 3, 4, 112, 112) if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"} @@ -4227,8 +4650,9 @@ def generate_tests(cls): cls.generate_segmentation_tests() cls.generate_video_tests() + if HAS_TORCHVISION: TestVisionTracing.generate_tests() -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 46de1f220..0ba3f2b8e 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -8,43 +8,44 @@ import sys import tempfile import unittest -from typing import Callable, Dict, Union, List, Optional from types import BuiltinFunctionType +from typing import Callable, Dict, List, Optional, Union -import torch +import pippy.fx.experimental.meta_tracer import pippy.fx.experimental.optimization as optimization + +import torch from pippy.fx._symbolic_trace import symbolic_trace from pippy.fx.experimental import merge_matmul from pippy.fx.experimental.accelerator_partitioner import Partitioner -from pippy.fx.experimental.normalize import NormalizeOperators, NormalizeArgs -from pippy.fx.passes import graph_manipulation -from pippy.fx.passes.param_fetch import lift_lowering_attrs_to_nodes +from pippy.fx.experimental.normalize import NormalizeArgs, NormalizeOperators from pippy.fx.experimental.partitioner_utils import ( - NodeLatency, - get_partition_to_latency_mapping, - get_latency_of_partitioned_graph, Device, + get_latency_of_partitioned_graph, + get_partition_to_latency_mapping, + NodeLatency, PartitionerConfig, PartitionMode, ) from pippy.fx.experimental.rewriter import RewritingTracer from pippy.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema -import pippy.fx.experimental.meta_tracer from pippy.fx.graph_module import GraphModule from pippy.fx.node import Node from pippy.fx.operator_schemas import ( _torchscript_type_to_python_type, + create_type_hint, normalize_function, normalize_module, type_matches, - create_type_hint, ) +from pippy.fx.passes import graph_manipulation +from pippy.fx.passes.param_fetch import lift_lowering_attrs_to_nodes from pippy.fx.passes.shape_prop import ShapeProp from pippy.fx.passes.split_module import split_module from torch.testing._internal.common_device_type import ( - ops, - onlyCPU, instantiate_device_type_tests, + onlyCPU, + ops, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_nn import module_tests, new_module_tests @@ -60,12 +61,16 @@ HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") skipIfNoMkldnn = unittest.skipIf( - not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()), + not ( + torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available() + ), "no MKLDNN", ) -def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule: +def symbolic_trace_with_rewrite( + root: Union[torch.nn.Module, Callable] +) -> GraphModule: return GraphModule( root if isinstance(root, torch.nn.Module) else torch.nn.Module(), RewritingTracer().trace(root), @@ -108,7 +113,9 @@ def forward(self, a, b): graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] - partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + partitioner_config = PartitionerConfig( + devices, PartitionMode.size_based + ) catch_runtime_error = False try: ret = partitioner.partition_graph(traced, m, partitioner_config) @@ -139,7 +146,9 @@ def forward(self, a): Device("dev_3", 40, 0), Device("dev_4", 40, 0), ] - partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + partitioner_config = PartitionerConfig( + devices, PartitionMode.size_based + ) catch_runtime_error = False try: ret = partitioner.partition_graph(traced, m, partitioner_config) @@ -197,7 +206,9 @@ def forward(self, a, b): Device("dev_1", 125, 1), Device("dev_2", 125, 2), ] - partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + partitioner_config = PartitionerConfig( + devices, PartitionMode.size_based + ) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag @@ -225,7 +236,9 @@ def forward(self, a): graph_manipulation.get_size_of_all_nodes(traced, [a]) partitioner = Partitioner() devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] - partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) + partitioner_config = PartitionerConfig( + devices, PartitionMode.size_based + ) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag @@ -238,7 +251,9 @@ def forward(self, a): def test_sparse_nn_partition(self): class MyRecommendationModule(torch.nn.Module): - def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): + def create_mlp( + self, num_of_layers: int, input_size: int, output_size: int + ): layers = torch.nn.ModuleList() for _ in range(num_of_layers): ll = torch.nn.Linear(input_size, output_size) @@ -256,7 +271,9 @@ def __init__(self): el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) for i in range(3): - el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True) + el = torch.nn.EmbeddingBag( + 1000000, 4, mode="sum", sparse=True + ) self.embedding_layers.append(el) el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) @@ -273,7 +290,9 @@ def forward(self, a, b, offset): y.append(self.embedding_layers[i](c[i], offset)) else: y.append( - self.embedding_layers[i](torch.randint(10, (8,)), offset) + self.embedding_layers[i]( + torch.randint(10, (8,)), offset + ) ) z = torch.cat([x] + y, dim=1) p = self.top_layers(z) @@ -295,7 +314,9 @@ def forward(self, a, b, offset): ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag - self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) + self.assertEqual( + traced(a, b, offset), module_with_submodules(a, b, offset) + ) assert len(module_with_submodules.graph.nodes) == 24 def test_partition_latency(self): @@ -319,13 +340,18 @@ def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, NodeLatency] = {} for node in fx_module.graph.nodes: if node.op not in {"output", "placeholder", "get_attr"}: - if node.size_bytes.total_size == node.size_bytes.output_size: + if ( + node.size_bytes.total_size + == node.size_bytes.output_size + ): node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, 2.0 * node.size_bytes.total_size + node.size_bytes.total_size, + 2.0 * node.size_bytes.total_size, ) else: node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, node.size_bytes.output_size + node.size_bytes.total_size, + node.size_bytes.output_size, ) return node_to_latency_mapping @@ -351,7 +377,9 @@ def get_node_to_latency_mapping(fx_module: GraphModule): assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) transfer_rate_bytes_per_sec = 2 critical_path_latency_sec = get_latency_of_partitioned_graph( - partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec + partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, ) assert critical_path_latency_sec == 208.0 @@ -374,13 +402,17 @@ def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, NodeLatency] = {} for node in fx_module.graph.nodes: if node.op not in {"output", "placeholder", "get_attr"}: - if node.size_bytes.total_size == node.size_bytes.output_size: + if ( + node.size_bytes.total_size + == node.size_bytes.output_size + ): node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 1 ) else: node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, node.size_bytes.output_size + node.size_bytes.total_size, + node.size_bytes.output_size, ) return node_to_latency_mapping @@ -524,7 +556,9 @@ def test_conv_bn_fusion(self): fused = optimization.fuse(traced) self.assertTrue( - all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) + all( + not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules() + ) ) N, C, H, W = 20, 3, 224, 224 @@ -537,7 +571,13 @@ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) - self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) + self.bn = torch.nn.BatchNorm2d( + 64, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=False, + ) def forward(self, x): x = self.conv(x) @@ -588,33 +628,40 @@ def test_meta_tracer(self): class MetaTracerTestModule(torch.nn.Module): def __init__(self): super().__init__() - self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16) + self.emb = torch.nn.Embedding( + num_embeddings=42, embedding_dim=16 + ) self.layernorm = torch.nn.LayerNorm(16) def forward(self, x): emb = self.emb(x) - emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device) + emb = emb + torch.arange( + emb.shape[-1], dtype=torch.float, device=emb.device + ) lol = self.layernorm(emb) - return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) + return ( + torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) + ) mttm = MetaTracerTestModule() for BS in [15, 35]: x = torch.zeros(BS, dtype=torch.long).random_(42) - meta_args = {'x' : x.to(device='meta')} - gm = pippy.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args) + meta_args = {"x": x.to(device="meta")} + gm = pippy.fx.experimental.meta_tracer.symbolic_trace( + mttm, meta_args=meta_args + ) torch.testing.assert_close(gm(x), mttm(x)) # Test serialization/deserialization with tempfile.TemporaryDirectory() as tmp_dir: - with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f: + with open(f"{tmp_dir}/meta_module.pkl", "wb") as f: pickle.dump(gm, f) - with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f: + with open(f"{tmp_dir}/meta_module.pkl", "rb") as f: loaded = pickle.load(f) torch.testing.assert_close(loaded(x), mttm(x)) - def test_call_to_assert_with_msg(self): class M(torch.nn.Module): def forward(self, a, b): @@ -765,7 +812,7 @@ def mod_partition(node: Node): def test_split_module_kwargs_expansion(self): class ModuleWithKwargsExpansion(torch.nn.Module): def forward(self, x, **kwargs): - return x + kwargs['foo'] + return x + kwargs["foo"] mod = ModuleWithKwargsExpansion() traced = pippy.fx.symbolic_trace(mod) @@ -810,7 +857,7 @@ def forward(self, x, targets=None): return x mtt = ModelToTrace() - traced = pippy.fx.symbolic_trace(mtt, concrete_args={'targets': None}) + traced = pippy.fx.symbolic_trace(mtt, concrete_args={"targets": None}) split = split_module(traced, mtt, lambda node: 0) @@ -971,7 +1018,11 @@ def forward(self, {params}): # These Modules have an RNG in their forward, so testing # correctness by comparing outputs is not correct. Skip that # check for these - stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"} + stochastic_modules = { + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "RReLU", + } if mod.__class__.__name__ not in stochastic_modules: self.assertEqual(traced(*inputs), mod(*inputs)) @@ -1033,7 +1084,9 @@ def test_annotate_returns_with_schema(self): m = resnet18() traced_modules = symbolic_trace(m) - traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform() + traced_modules_annotated = AnnotateTypesWithSchema( + traced_modules + ).transform() for node in traced_modules_annotated.graph.nodes: if node.type is None: check = (node.op, node.target) @@ -1045,7 +1098,7 @@ def test_annotate_returns_with_schema(self): ("call_function", operator.add), ("call_function", torch.flatten), ("output", "output"), - } + }, ) # Smoke test torchscript compilation since now we're emitting type annotations @@ -1060,7 +1113,9 @@ def is_leaf_module( leaves = set([torch.nn.BatchNorm2d]) return type(m) in leaves - traced_functionals = pippy.fx.GraphModule(m, FunctionalTracer().trace(m)) + traced_functionals = pippy.fx.GraphModule( + m, FunctionalTracer().trace(m) + ) traced_functionals_annotated = AnnotateTypesWithSchema( traced_functionals @@ -1134,20 +1189,18 @@ def forward(self, x): part_idx = 0 - def split_callback(n : pippy.fx.Node): + def split_callback(n: pippy.fx.Node): nonlocal part_idx - if (n.op, n.target) == ('call_module', 'lin'): + if (n.op, n.target) == ("call_module", "lin"): part_idx += 1 return part_idx # split module in module with submodules - qualname_map : Dict[str, str] = {} + qualname_map: Dict[str, str] = {} module_with_submodules = split_module( my_module_traced, my_module, split_callback, qualname_map ) - expected_qualname_map = { - 'submod_1.lin': 'lin', 'submod_2.lin': 'lin' - } + expected_qualname_map = {"submod_1.lin": "lin", "submod_2.lin": "lin"} self.assertEqual(qualname_map, expected_qualname_map) def test_traceable_function_with_nonstandard_name(self): @@ -1168,7 +1221,9 @@ def __init__(self): self.register_buffer("attr3", torch.ones(2, dtype=torch.int32)) def forward(self, x): - return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x)) + return self.linear( + self.seq(self.W + self.attr + self.attr2 + self.attr3 + x) + ) mod = symbolic_trace(Test()) module_name = "Foo" @@ -1242,6 +1297,7 @@ def test_merge_matmuls(self): A collection of test cases for pippy.fx.experimental.merge_matmul, a graph transformation that merges matrix multiplication operations. """ + # Utility function for counting matmuls for test assertions. def _count_matmuls(mod): gm = pippy.fx.symbolic_trace(mod) @@ -1371,22 +1427,40 @@ def test_type_matches(self): (List[int], type(5)), (List[int], create_type_hint([int, int])), (List[int], create_type_hint((int, int))), - (List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), + ( + List[torch.Tensor], + create_type_hint([torch.Tensor, torch.Tensor]), + ), ( List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), ), (torch.Tensor, torch.nn.Parameter), - (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), - (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), - (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), + ( + List[torch.Tensor], + create_type_hint([torch.nn.Parameter, torch.Tensor]), + ), + ( + List[torch.Tensor], + create_type_hint([torch.Tensor, torch.nn.Parameter]), + ), + ( + List[torch.Tensor], + create_type_hint((torch.Tensor, torch.Tensor)), + ), ( List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), ), (torch.Tensor, torch.nn.Parameter), - (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), - (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), + ( + List[torch.Tensor], + create_type_hint((torch.nn.Parameter, torch.Tensor)), + ), + ( + List[torch.Tensor], + create_type_hint((torch.Tensor, torch.nn.Parameter)), + ), (Optional[List[torch.Tensor]], List[torch.Tensor]), (Optional[List[int]], List[int]), ] @@ -1425,7 +1499,7 @@ def __init__(self): def forward(self, x): return self.model(x) + self.model2(x) - N, C, H, W, = ( + (N, C, H, W) = ( 1, 3, 224, @@ -1458,7 +1532,7 @@ def test_optimize_for_inference_cpu_torchvision(self): with torch.no_grad(): for model_type in models: model = model_type() - C, H, W, = ( + (C, H, W) = ( 3, 224, 224, @@ -1467,7 +1541,9 @@ def test_optimize_for_inference_cpu_torchvision(self): model(inp) model.eval() inp = torch.randn(1, C, H, W) - heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0) + heuristic = optimization.gen_mkl_autotuner( + inp, iters=0, warmup=0 + ) optimized_model = optimization.optimize_for_inference(model) orig_out = model(inp) @@ -1480,7 +1556,14 @@ class TestNormalizeOperators(JitTestCase): @ops(op_db, allowed_dtypes=(torch.float,)) def test_normalize_operator_exhaustive(self, device, dtype, op): # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors) - fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot"} + fx_fail = { + "cat", + "stack", + "hstack", + "vstack", + "dstack", + "linalg.multi_dot", + } sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) if isinstance(op.op, torch._ops.OpOverload): self.skipTest("normalize operator doesn't work on torch.ops") @@ -1530,7 +1613,9 @@ def jit_infer_type(v): - otherwise, file an issue """ ) - test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs) + test_out = op.op( + *norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs + ) self.assertEqual(test_out, ref_out) # Test normalized_arguments as part of FX @@ -1610,13 +1695,22 @@ def test_normalize_quantized_eb(self): self.assertEqual(norm_args_and_kwargs.args, tuple()) def test_normalize_args_op_overload(self): - for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]: + for target in [ + torch.ops.aten.resize_as_.default, + torch.ops.aten.resize_as_, + ]: inp1 = torch.rand([1]) inp2 = torch.rand([4]) - args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True) + args, kwargs = normalize_function( + target, + (inp1,), + {"the_template": inp2}, + normalize_to_only_use_kwargs=True, + ) self.assertIs(kwargs["input"], inp1) self.assertIs(kwargs["the_template"], inp2) + instantiate_device_type_tests(TestNormalizeOperators, globals()) if __name__ == "__main__": diff --git a/test/test_ir.py b/test/test_ir.py index baebbc7d2..768dec27f 100644 --- a/test/test_ir.py +++ b/test/test_ir.py @@ -2,29 +2,29 @@ import copy import pickle import tempfile -import torch import unittest from typing import NamedTuple +import pippy.fx +import torch + from pippy.IR import ( + _null_coalesce_accumulate, + annotate_split_points, + MultiUseParameterConfig, Pipe, - PipeSequential, - TrivialLossWrapper, pipe_split, - MultiUseParameterConfig, - annotate_split_points, + PipeSequential, PipeSplitWrapper, - _null_coalesce_accumulate, + TrivialLossWrapper, ) from pippy.microbatch import ( - TensorChunkSpec, + merge_chunks, Replicate, split_args_kwargs_into_chunks, - merge_chunks, + TensorChunkSpec, ) -import pippy.fx - @pippy.fx.wrap def arange_wrapper(*args, **kwargs): From 1690ee0a5faf042f46d5670b4b7affcb86234f26 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Wed, 28 Jun 2023 16:42:13 -0700 Subject: [PATCH 18/96] combine index-file-saving and params-saving into one API (#823) ## Description bring together `_save_index` and `_save_params` into one function `save_checkpoint`. The entire module is used to write the index file, only in rank 0, while submodule parameters are saved to a file in each rank. The associated ckpt test is updated to reflect this change. ## Type of change Please delete options that are not relevant. - [x] New feature (non-breaking change which adds functionality) - [x] This change requires a documentation update ## Feature/Issue validation/testing - [x] Test LocalIndexMetadataTest ## Checklist: - [x] Have you added tests that prove your fix is effective or that this feature works? - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy --- .gitignore | 1 - pippy/hf/_SaveModule.py | 21 +++++++++++++++++++-- test/local_test_ckpt_index_file.py | 16 ++++++++-------- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 7dca738f4..1dededf6a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -docs/build __pycache__ build pippy.egg-info diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index 594a1a3f5..36f2b8bc4 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -119,7 +119,7 @@ def _get_binary_filename(cur_idx: int) -> str: # type: ignore[valid-type] return f"pytorch_model-{idx}-of-{world_size}.bin" -def _save_checkpoint(submod: Pipe, checkpoint_dir: str) -> None: +def _save_params(submod: torch.nn.Module, checkpoint_dir: str) -> None: """ writes `module`'s parameters and buffers to disk. @@ -134,8 +134,25 @@ def _save_checkpoint(submod: Pipe, checkpoint_dir: str) -> None: ) torch.save( { - submod.remap_qualname(param_name): param + submod.remap_qualname(param_name): param # type: ignore for param_name, param in submod.state_dict().items() }, filepath, ) + + +def save_checkpoint(stage: Pipe, checkpoint_dir: str) -> None: + """ + Save the entire model's(`stage`) metadata in an index file and the `submod` + parameters in `checkpoint_dir` + + Args: + stage(`Pipe`): model pipeline graph + submod(`torch.nn.Module`): submod whose params are to be saved + checkpoin_dir(`str`): directory where to save the index file and params binaries + """ + # write index file in rank 0 + if dist.get_rank() == 0: + _save_index(stage, checkpoint_dir=checkpoint_dir) + + _save_params(stage.submod, checkpoint_dir) # type: ignore diff --git a/test/local_test_ckpt_index_file.py b/test/local_test_ckpt_index_file.py index 95539f4c0..b8f26a2e9 100644 --- a/test/local_test_ckpt_index_file.py +++ b/test/local_test_ckpt_index_file.py @@ -7,14 +7,16 @@ from typing import List import torch + import torch.distributed as dist import torch.optim as optim from pippy.compile import compile_stage -from pippy.hf._SaveModule import _save_checkpoint, _save_index +from pippy.hf._SaveModule import save_checkpoint from pippy.IR import pipe_split, TrivialLossWrapper from pippy.LoadModule import load_checkpoint + DEFAULT_FILENAME = "pytorch_model.bin.index.json" CKPT_DIR = "test_ckpts" WEIGHT_MAP = set( @@ -92,10 +94,13 @@ def run_worker(args: List[str | int]) -> None: else: stage() + # Take an optimization step + optimizer.step() + ref = deepcopy(stage.submod.state_dict()) + save_checkpoint(stage, CKPT_DIR) + # save index file in rank 0 if args.rank == 0: - _save_index(stage, checkpoint_dir=CKPT_DIR) - filepath = os.path.join(CKPT_DIR, DEFAULT_FILENAME) with open(filepath) as f: content = f.read() @@ -113,11 +118,6 @@ def run_worker(args: List[str | int]) -> None: for param in WEIGHT_MAP: assert param in data["weight_map"] - # Take an optimization step - optimizer.step() - ref = deepcopy(stage.submod.state_dict()) - _save_checkpoint(stage.submod, CKPT_DIR) - # second run # Zero gradients optimizer.zero_grad() From 62f84bb44257be6276ff7c639ddbc149210604fc Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Wed, 28 Jun 2023 17:23:01 -0700 Subject: [PATCH 19/96] rename ckpt test (#825) ## Description Rename the previous ckpt index file test to ckpt test to reflect its covering of both index file and params saving. - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? Co-authored-by: Eddy --- .../{local_test_ckpt_index_file.py => local_test_checkpoint.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename test/{local_test_ckpt_index_file.py => local_test_checkpoint.py} (99%) diff --git a/test/local_test_ckpt_index_file.py b/test/local_test_checkpoint.py similarity index 99% rename from test/local_test_ckpt_index_file.py rename to test/local_test_checkpoint.py index b8f26a2e9..8df259314 100644 --- a/test/local_test_ckpt_index_file.py +++ b/test/local_test_checkpoint.py @@ -193,7 +193,7 @@ def main(args: List[str | int] = None) -> None: main() -class LocalIndexMetadataTest(unittest.TestCase): +class LocalCheckpointTest(unittest.TestCase): def test_index_file(self): import random From c0186a26cbc7267fe8239219aadf8bf19b57af4a Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Thu, 29 Jun 2023 00:44:11 -0700 Subject: [PATCH 20/96] save optim state dict (#826) ## Description Save optimizer's state into a diff binary file such that for each rank, there are two state bin files: submod params and optim state. Fixes #(issue) ## Type of change Please delete options that are not relevant. - [x] New feature (non-breaking change which adds functionality) - [x] This change requires a documentation update ## Feature/Issue validation/testing Will update checkpoint test to reflect change in a different PR. ## Checklist: - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy --- pippy/hf/_SaveModule.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index 36f2b8bc4..94fb07082 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -103,7 +103,7 @@ def _save_index( logging.info(f"Saved index file to {filepath}") -def _get_binary_filename(cur_idx: int) -> str: # type: ignore[valid-type] +def _get_binary_filename(cur_idx: int, is_optim: bool = False) -> str: # type: ignore[valid-type] """ Gets filename for pytorch checkpoint binary based on current index and world size. @@ -116,7 +116,9 @@ def _get_binary_filename(cur_idx: int) -> str: # type: ignore[valid-type] idx = str(cur_idx + 1).zfill(5) world_size = str(dist.get_world_size()).zfill(5) - return f"pytorch_model-{idx}-of-{world_size}.bin" + state_type = "optim" if is_optim else "model" + + return f"pytorch_{state_type}-{idx}-of-{world_size}.bin" def _save_params(submod: torch.nn.Module, checkpoint_dir: str) -> None: @@ -141,7 +143,28 @@ def _save_params(submod: torch.nn.Module, checkpoint_dir: str) -> None: ) -def save_checkpoint(stage: Pipe, checkpoint_dir: str) -> None: +def _save_optim_state( + optimizer: torch.optim.Optimizer, checkpoint_dir: str +) -> None: + """ + saves `optimizer`'s state_dict to disk. + + Args: + optimizer(`torch.optim.Optimizer`): pytorch optimizer + checkpoint_dir(`str`): where to keep the checkpoint binaries + """ + if not os.path.exists(checkpoint_dir): + os.mkdir(checkpoint_dir) + filepath = os.path.join( + checkpoint_dir, _get_binary_filename(dist.get_rank(), is_optim=True) + ) + # save optimizer state directly + torch.save(optimizer.state_dict(), filepath) + + +def save_checkpoint( + stage: Pipe, checkpoint_dir: str, optimizer: torch.optim.Optimizer = None +) -> None: """ Save the entire model's(`stage`) metadata in an index file and the `submod` parameters in `checkpoint_dir` @@ -156,3 +179,6 @@ def save_checkpoint(stage: Pipe, checkpoint_dir: str) -> None: _save_index(stage, checkpoint_dir=checkpoint_dir) _save_params(stage.submod, checkpoint_dir) # type: ignore + # save optimizer state, if passed + if optimizer: + _save_optim_state(optimizer, checkpoint_dir) # type: ignore From c05496d058698de818a41d538127076df372da00 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Thu, 29 Jun 2023 23:08:49 -0700 Subject: [PATCH 21/96] test optimizer using torch.load (#827) ## Description Since we're also saving the optimizer's state dict (in a different binary file for each rank), we can't use load_checkpoint to check that it works. Here, I use torch.load to check that a certain reference optimizer state dict is similar/close to the saved optimizer's state dict, after a `step()` run. ## Type of change - [x] New feature (non-breaking change which adds functionality) - [x] This change requires a documentation update ## Checklist: - [x] Has code been commented, particularly in hard-to-understand areas? --------- Co-authored-by: Eddy --- test/local_test_checkpoint.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/test/local_test_checkpoint.py b/test/local_test_checkpoint.py index 8df259314..8b0dce6ac 100644 --- a/test/local_test_checkpoint.py +++ b/test/local_test_checkpoint.py @@ -12,7 +12,7 @@ import torch.optim as optim from pippy.compile import compile_stage -from pippy.hf._SaveModule import save_checkpoint +from pippy.hf._SaveModule import _get_binary_filename, save_checkpoint from pippy.IR import pipe_split, TrivialLossWrapper from pippy.LoadModule import load_checkpoint @@ -96,8 +96,9 @@ def run_worker(args: List[str | int]) -> None: # Take an optimization step optimizer.step() - ref = deepcopy(stage.submod.state_dict()) - save_checkpoint(stage, CKPT_DIR) + submod_ref = deepcopy(stage.submod.state_dict()) + optim_ref = deepcopy(optimizer.state_dict()) + save_checkpoint(stage, CKPT_DIR, optimizer) # save index file in rank 0 if args.rank == 0: @@ -139,8 +140,17 @@ def run_worker(args: List[str | int]) -> None: os.path.join(CKPT_DIR, "pytorch_model.bin.index.json"), args.device, ) + # load optim + optimizer.load_state_dict( + torch.load( + os.path.join( + CKPT_DIR, _get_binary_filename(dist.get_rank(), is_optim=True) + ) + ) + ) - torch.testing.assert_close(mod.state_dict(), ref) + torch.testing.assert_close(mod.state_dict(), submod_ref) + torch.testing.assert_close(optimizer.state_dict(), optim_ref) dist.barrier() print(f"Rank {args.rank} completes") From d6b24111cd30adc02bd08e0f38c600970edad80c Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Fri, 30 Jun 2023 11:50:12 -0700 Subject: [PATCH 22/96] update docstrings (#829) ## Description update _SaveModule.py funcs docstrings to reflect changes to function signatures. ## Checklist: - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? Co-authored-by: Eddy --- pippy/hf/_SaveModule.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index 94fb07082..7c440378f 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -21,12 +21,19 @@ def _get_param_size(param: torch.Tensor) -> int: + """ + Returns a tensor's size in bytes + + Args: + param(`torch.Tensor`): torch tensor + """ return param.numel() * DTYPE_SIZES[param.dtype] def _atomic_write(file_contents: str, target_file_path: str, mode="w") -> None: """ Atomically writes `file_contents` into `target_file_path`. + Args: file_contents (str): contents to write to file target_file_path (str): path to write to @@ -109,6 +116,8 @@ def _get_binary_filename(cur_idx: int, is_optim: bool = False) -> str: # type: Args: cur_idx (int): current device index + is_optim (bool): True if generating binary filename for optimizer, + False otherwise Returns: str: checkpoint filename @@ -163,7 +172,9 @@ def _save_optim_state( def save_checkpoint( - stage: Pipe, checkpoint_dir: str, optimizer: torch.optim.Optimizer = None + stage: Pipe, + checkpoint_dir: str = "checkpoints", + optimizer: torch.optim.Optimizer = None, ) -> None: """ Save the entire model's(`stage`) metadata in an index file and the `submod` @@ -171,8 +182,9 @@ def save_checkpoint( Args: stage(`Pipe`): model pipeline graph - submod(`torch.nn.Module`): submod whose params are to be saved - checkpoin_dir(`str`): directory where to save the index file and params binaries + checkpoint_dir(`str`): directory where to save the index file and params binaries + defaults to `checkpoints` + optimizer(`torch.optim.Optimizer`): optimizer whose state dict is to be saved """ # write index file in rank 0 if dist.get_rank() == 0: From d8670242f12b1c66626fad55dbed3e5ee17734ca Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Thu, 6 Jul 2023 15:31:57 -0700 Subject: [PATCH 23/96] Fix PiPPy README typos for inference (#834) ## Description Please read our [CONTRIBUTING.md](https://github.com/pytorch/PiPPy/blob/main/CONTRIBUTING.md) prior to creating your first pull request. Please include a summary of the feature or issue being fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. Fixes #(issue) ## Type of change Please delete options that are not relevant. - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] New feature (non-breaking change which adds functionality) - [ ] This change requires a documentation update ## Feature/Issue validation/testing Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced. Please also list any relevant details for your test configuration. - [ ] Test A Logs for Test A - [ ] Test B Logs for Test B ## Checklist: - [ ] Have you added tests that prove your fix is effective or that this feature works? - [ ] Has code been commented, particularly in hard-to-understand areas? - [ ] Have you made corresponding changes to the documentation? --- examples/inference/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/inference/README.md b/examples/inference/README.md index 9b1af668c..f4fbeea38 100644 --- a/examples/inference/README.md +++ b/examples/inference/README.md @@ -1,12 +1,12 @@ # PiPPy (Pipline Parallelism for PyTorch) Distributed Inference for Large Models PiPPy helps to run very large models for inference by splitting the model into mutliple stages running on multiple GPUs. -PiPPy make this easier by providing a auto split API that automates this process for user. +PiPPy make this easier by providing an auto split API that automates this process for user. ## How It Works -PiPPy splits your model into multiple stages, each stage loaded on one gpu then the input batch will be furhter divided into micro-batches and run through the splits from -rank0..rankN. Results are being returned to rank0 as its runing the PipelineDriver. Please read more on pipleines [here](https://github.com/pytorch/tau/blob/main/README.md) +PiPPy splits your model into multiple stages, each stage loaded on one gpu then the input batch will be further divided into micro-batches and run through the splits from +rank0..rankN. Results are returned to rank0 as rank 0 is running the PipelineDriver. Please read more on pipleines [here](https://github.com/pytorch/tau/blob/main/README.md) The flowchart below helps to visualize the process in high level as well. @@ -14,14 +14,14 @@ The flowchart below helps to visualize the process in high level as well. ## PiPPy Supports Arbitary Model Partitioning -Unlike most of the available solutions that they need to know the model architecture beforehand, PiPPy supports arbitary PyTorch models. +Unlike most of the available solutions that need to know the model architecture beforehand, PiPPy supports arbitary PyTorch models. * PiPPy supports both manual splitting and auto split. * Auto split uses `split_policy` and support both `equal_size` and `threshod` policies, the name are self-explanatory. * PiPPy use FX to trace and split the model. ## Settings To Care About -* **world_size** specifies your availble number of gpus for paritioning your model +* **world_size** specifies your availble number of gpus for partitioning your model * **split_policy** it can be either `equal_size`, `split_into_equal_size(number_of_workers)` or `threshod`, `split_on_size_threshold(#some number)` @@ -151,4 +151,4 @@ git clone https://huggingface.co/bigscience/bloom-7b1 torchrun --nproc_per_node 4 hf_generate.py --world_size 4 --model_name ./bloom-7b1 --index_filename bloom-7b1/pytorch_model.bin.index.json ``` -In this case, each rank will only load a part of the model. \ No newline at end of file +In this case, each rank will only load a part of the model. From d5502af31f2bbd15452506d7fddad75396605004 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 6 Jul 2023 20:14:05 -0400 Subject: [PATCH 24/96] Internal lints (#833) Mapping internal lint fixes to GitHub --- .pyre_configuration | 0 format.sh | 36 ++++++++++++++++++------------------ requirements.txt | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) delete mode 100644 .pyre_configuration diff --git a/.pyre_configuration b/.pyre_configuration deleted file mode 100644 index e69de29bb..000000000 diff --git a/format.sh b/format.sh index a4c7b0bd7..2755578ba 100755 --- a/format.sh +++ b/format.sh @@ -8,29 +8,29 @@ DEFAULT_TARGETS=() for f in $(git ls-files | grep '\.py$'); do case "$f" in 'pippy/fx/'*) - # ignore - ;; + # ignore + ;; 'pippy/'*) - DEFAULT_TARGETS+=( "$f" ) - ;; + DEFAULT_TARGETS+=( "$f" ) + ;; 'examples/'*) - # ignore - ;; + # ignore + ;; 'docs/'*) - # ignore - ;; + # ignore + ;; 'test/'*fx*) - # ignore - ;; + # ignore + ;; *) - # include - DEFAULT_TARGETS+=( "$f" ) - ;; + # include + DEFAULT_TARGETS+=( "$f" ) + ;; esac done @@ -100,10 +100,10 @@ function main() { for x in "$@"; do case "$x" in '--show-targets') - for f in ${DEFAULT_TARGETS[@]}; do - echo $f; - done - exit 0; + for f in "${DEFAULT_TARGETS[@]}"; do + echo $f; + done + exit 0; ;; '--check') @@ -129,7 +129,7 @@ function main() { case "$x" in *.py) PY_TARGETS+=( "$x" ); - ;; + ;; esac fi done diff --git a/requirements.txt b/requirements.txt index 61f9dc0d4..ac954cf5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ torch >= 1.13.0 -packaging >= 21.3 \ No newline at end of file +packaging >= 21.3 From edc0452daab754673d8cd768a568ba726fc4dbb0 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 10 Jul 2023 13:12:18 -0400 Subject: [PATCH 25/96] Disable bert CI (#836) ## Description Disable BERT test in CI. Seems to be related to PyTorch nightly change. First seen Jun 30, 2023: https://github.com/pytorch/PiPPy/actions/runs/5426074544/jobs/9867645401 ``` -- Process 0 terminated with the following error: Traceback (most recent call last): File "/usr/local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, *args) File "/usr/local/lib/python3.8/site-packages/torchpippy-0.1.1-py3.8.egg/pippy/utils.py", line 268, in run_worker run_func(my_pp_ranks, args, *extra_args) File "/__w/PiPPy/PiPPy/test/local_test_forward_hf_bert.py", line 74, in run_master bert_pipe = Pipe.from_tracing( File "/usr/local/lib/python3.8/site-packages/torchpippy-0.1.1-py3.8.egg/pippy/IR.py", line 10[54](https://github.com/pytorch/PiPPy/actions/runs/5509919435/jobs/10043352906?pr=835#step:12:55), in from_tracing graph = _pipeline_tracer.trace(mod, **kwargs) File "/usr/local/lib/python3.8/site-packages/torchpippy-0.1.1-py3.8.egg/pippy/hf/utils.py", line 266, in trace graph = super().trace(*args, **kwargs) File "/usr/local/lib/python3.8/site-packages/transformers/utils/fx.py", line 1088, in trace self.graph = super().trace(root, concrete_args=concrete_args) File "/usr/local/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 778, in trace (self.create_arg(fn(*args)),), File "/usr/local/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 970, in forward self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) File "/usr/local/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3488, in warn_if_padding_and_no_attention_mask if self.config.pad_token_id in input_ids[:, [-1, 0]]: File "/usr/local/lib/python3.8/site-packages/transformers/utils/fx.py", line [64](https://github.com/pytorch/PiPPy/actions/runs/5509919435/jobs/10043352906?pr=835#step:12:65)6, in __contains__ return key in self._metadata File "/usr/local/lib/python3.8/site-packages/torch/_tensor.py", line 997, in __contains__ return (element == self).any().item() # type: ignore[union-attr] NotImplementedError: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher]. ``` --- .github/workflows/pippy_tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pippy_tests.yaml b/.github/workflows/pippy_tests.yaml index 26c555fe3..a0f9978c2 100644 --- a/.github/workflows/pippy_tests.yaml +++ b/.github/workflows/pippy_tests.yaml @@ -106,8 +106,8 @@ jobs: run: python test/local_test_null_coalesce_accumulate.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} - name: Run PP + DDP test run: python test/local_test_ddp.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run HF BERT forward-only integration test - run: python test/local_test_forward_hf_bert.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + #- name: Run HF BERT forward-only integration test + # run: python test/local_test_forward_hf_bert.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - name: Run HF GPT2 forward-only integration test run: python test/local_test_forward_hf_gpt2.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - name: Run visualizer test From 11b2c4910523cb195923834a714ab0f59efb221b Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 10 Jul 2023 15:53:02 -0400 Subject: [PATCH 26/96] Fix generation of loss spec from output spec (#835) ## Description In PiPPy we automatically generate loss spec from output chunk spec. Terminology: "loss spec": a one-hot map indicating which output value is a loss. It consists of True, False values. "output chunk spec": a data structure corresponding to output format describing which (chunked) value should be merged and which should be reduced (such as loss). ## Issue In previous code, we only considered the case where the output chunk spec is a dictionary. But in cases such as `return logits, loss`, the output chunk spec is a tuple, i.e. `(TensorChunkSpec(0), sum_reducer)`. ## Fix Use `fx.node.map_aggregate` to generalize the auto generation. It takes a lambda function that outputs True/False. ## Testing torchrun --nproc-per-node 4 local_test_chunkspec.py --- pippy/compile.py | 16 ++--- test/local_test_chunkspec.py | 136 +++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 8 deletions(-) create mode 100644 test/local_test_chunkspec.py diff --git a/pippy/compile.py b/pippy/compile.py index 7b76513cb..3c3564037 100644 --- a/pippy/compile.py +++ b/pippy/compile.py @@ -81,10 +81,10 @@ def _compile( # Figure out which output is loss from output_chunk_spec output_loss_value_spec: Any = None - if isinstance(output_chunk_spec, dict): - output_loss_value_spec = { - k: isinstance(v, LossReducer) for k, v in output_chunk_spec.items() - } + if output_chunk_spec is not None: + output_loss_value_spec = fx.node.map_aggregate( + output_chunk_spec, lambda v: isinstance(v, LossReducer) + ) logging.info("[PiPPy] Tracing model ...") pipe_model = Pipe.from_tracing( @@ -239,10 +239,10 @@ def compile_stage( # Figure out which output is loss from output_chunk_spec output_loss_value_spec: Any = None - if isinstance(output_chunk_spec, dict): - output_loss_value_spec = { - k: isinstance(v, LossReducer) for k, v in output_chunk_spec.items() - } + if output_chunk_spec is not None: + output_loss_value_spec = fx.node.map_aggregate( + output_chunk_spec, lambda v: isinstance(v, LossReducer) + ) logging.info("[PiPPy] Tracing model ...") pipe = Pipe.from_tracing( diff --git a/test/local_test_chunkspec.py b/test/local_test_chunkspec.py new file mode 100644 index 000000000..fce4f97c8 --- /dev/null +++ b/test/local_test_chunkspec.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os +import unittest + +import torch +import torch.distributed as dist + +from pippy.compile import compile_stage +from pippy.IR import pipe_split +from pippy.microbatch import sum_reducer, TensorChunkSpec + +d_hid = 512 +chunk_size = 256 + +torch.manual_seed(0) + + +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin = torch.nn.Linear(d_hid, d_hid) + self.mse_loss = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, target): + x = torch.mm(x, self.mm_param) + skip_connection = x + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param) + x = self.lin(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin(x) + x = torch.relu(x) + loss = self.mse_loss(x, target) + return loss, x + + +def run_worker(args): + ec = ExampleCode() + ec.to(args.device) + ec.train() + + ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + + stage = compile_stage( + ec, + args.rank, + args.world_size, + args.chunks, + args.device, + None, + [ec_x, target], + output_chunk_spec=(sum_reducer, TensorChunkSpec(0)), + ) + + # Run + if args.rank == 0: + out = stage(ec_x) + elif args.rank == args.world_size - 1: + out = stage(target) + else: + stage() + + dist.barrier() + print(f"Rank {args.rank} completes") + + # Last rank checks result + if args.rank == args.world_size - 1: + ref_out = ec(ec_x, target) + torch.testing.assert_close(out, ref_out) + print( + f"equivalence test passed, loss = {out[0]}, ref loss = {ref_out[0]}" + ) + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) + + +if __name__ == "__main__": + main() + + +class LocalTestC10DBwdTest(unittest.TestCase): + def test_c10d_bwd(self): + import random + + port = random.randint(29500, 30000) + args = [ + "--master_port", + str(port), + ] + main(args) From bf63a64d22b5265aa359ca0736677cbdf07c1254 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:40:09 -0700 Subject: [PATCH 27/96] rename reference state dicts (#838) ## Description Use clearer variable names for reference state dicts in the checkpoint test. ## Checklist: - [x] Have you added tests that prove your fix is effective or that this feature works? - [x] Has code been commented, particularly in hard-to-understand areas? Co-authored-by: Eddy --- test/local_test_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/local_test_checkpoint.py b/test/local_test_checkpoint.py index 8b0dce6ac..9d11d1f42 100644 --- a/test/local_test_checkpoint.py +++ b/test/local_test_checkpoint.py @@ -96,8 +96,8 @@ def run_worker(args: List[str | int]) -> None: # Take an optimization step optimizer.step() - submod_ref = deepcopy(stage.submod.state_dict()) - optim_ref = deepcopy(optimizer.state_dict()) + ref_state_dict = deepcopy(stage.submod.state_dict()) + ref_optim_state_dict = deepcopy(optimizer.state_dict()) save_checkpoint(stage, CKPT_DIR, optimizer) # save index file in rank 0 From a2ac4a2beb38f7786c58c37c132d4f41b751f2c6 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Thu, 13 Jul 2023 22:35:34 -0700 Subject: [PATCH 28/96] place optim state_dict loading inside `load_checkpoint` (#839) ## Description To bring all the checkpoint loading operations into one single API, I place the optim load_state_dict code inside load_checkpoint. ## Type of change Please delete options that are not relevant. - [x] New feature (non-breaking change which adds functionality) - [x] This change requires a documentation update ## Checklist: - [x] Have you added tests that prove your fix is effective or that this feature works? - [x] Has code been commented, particularly in hard-to-understand areas? - [x] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy --- pippy/LoadModule.py | 25 ++++++++++++++++++++++--- pippy/hf/_SaveModule.py | 21 +-------------------- pippy/utils.py | 22 ++++++++++++++++++++++ test/local_test_checkpoint.py | 15 ++++----------- 4 files changed, 49 insertions(+), 34 deletions(-) diff --git a/pippy/LoadModule.py b/pippy/LoadModule.py index e3f04f986..95b4d3c84 100644 --- a/pippy/LoadModule.py +++ b/pippy/LoadModule.py @@ -5,8 +5,11 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from torch import nn +from pippy.utils import _get_binary_filename + TYPICAL_PREFIXES = [ "model", # facebook/opt-6.7b @@ -17,26 +20,30 @@ def load_checkpoint( model: nn.Module, index_filename: Union[str, os.PathLike], + optim: torch.optim.Optimizer = None, device: torch.device = None, dtype: torch.dtype = None, checkpoint_prefix: str = None, -) -> nn.Module: +): """ - Load a checkpoint from a model file. + Load a checkpoint from a model (and optimizer) file. Args: model (`torch.nn.Module`): the model to load the checkpoint into index_filename (`Union[str, os.PathLike]`): path to the checkpoint's index (metadata file) + optim (`torch.optim.Optimizer`): optimizer object to load ckpt state dict into device (`torch.device`): the device on which to load the checkpoint dtype (`torch.dtype`): the dtype on which to load the checkpoint checkpoint_prefix (`str`): the prefix of the checkpoint to load Returns: - The loaded checkpoint model + The loaded checkpoint model, or, if an optimizer is passed as an argument, + both the loaded checkpoint model and a optimizer Example: ``` checkpoint = load_checkpoint(model, index_filename, device, dtype) ``` """ checkpoint_folder = os.path.split(index_filename)[0] + with open(index_filename, "r") as f: index = json.loads(f.read()) if "weight_map" in index: @@ -78,6 +85,18 @@ def load_checkpoint( del checkpoint gc.collect() + if optim: + optim.load_state_dict( + torch.load( + os.path.join( + checkpoint_folder, + _get_binary_filename(dist.get_rank(), is_optim=True), + ) + ) + ) + + if optim: + return model, optim return model diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index 7c440378f..e046534a6 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -10,6 +10,7 @@ import torch.distributed as dist from pippy.IR import Pipe +from pippy.utils import _get_binary_filename CKPT_INDEX_JSON_FILENAME = "pytorch_model.bin.index.json" @@ -110,26 +111,6 @@ def _save_index( logging.info(f"Saved index file to {filepath}") -def _get_binary_filename(cur_idx: int, is_optim: bool = False) -> str: # type: ignore[valid-type] - """ - Gets filename for pytorch checkpoint binary based on current index and world size. - - Args: - cur_idx (int): current device index - is_optim (bool): True if generating binary filename for optimizer, - False otherwise - - Returns: - str: checkpoint filename - """ - idx = str(cur_idx + 1).zfill(5) - world_size = str(dist.get_world_size()).zfill(5) - - state_type = "optim" if is_optim else "model" - - return f"pytorch_{state_type}-{idx}-of-{world_size}.bin" - - def _save_params(submod: torch.nn.Module, checkpoint_dir: str) -> None: """ writes `module`'s parameters and buffers to disk. diff --git a/pippy/utils.py b/pippy/utils.py index ac76372f1..7e3d4d7d4 100644 --- a/pippy/utils.py +++ b/pippy/utils.py @@ -4,6 +4,8 @@ import socket from typing import List +import torch.distributed as dist + # Pinning process to a separate GPU if not yet done by launch script # Notes: @@ -309,3 +311,23 @@ def dont_traverse_size(a): pippy.fx.node.map_aggregate(args, extract_tensor_args, dont_traverse_size) return flat_args + + +def _get_binary_filename(cur_idx: int, is_optim: bool = False) -> str: # type: ignore[valid-type] + """ + Gets filename for pytorch checkpoint binary based on current index and world size. + + Args: + cur_idx (int): current device index + is_optim (bool): True if generating binary filename for optimizer, + False otherwise + + Returns: + str: checkpoint filename + """ + idx = str(cur_idx + 1).zfill(5) + world_size = str(dist.get_world_size()).zfill(5) + + state_type = "optim" if is_optim else "model" + + return f"pytorch_{state_type}-{idx}-of-{world_size}.bin" diff --git a/test/local_test_checkpoint.py b/test/local_test_checkpoint.py index 9d11d1f42..dbfb7a98a 100644 --- a/test/local_test_checkpoint.py +++ b/test/local_test_checkpoint.py @@ -12,7 +12,7 @@ import torch.optim as optim from pippy.compile import compile_stage -from pippy.hf._SaveModule import _get_binary_filename, save_checkpoint +from pippy.hf._SaveModule import save_checkpoint from pippy.IR import pipe_split, TrivialLossWrapper from pippy.LoadModule import load_checkpoint @@ -134,19 +134,12 @@ def run_worker(args: List[str | int]) -> None: # Take an optimization step optimizer.step() - # load ckpt - mod = load_checkpoint( + # new api + mod, optimizer = load_checkpoint( stage.submod, os.path.join(CKPT_DIR, "pytorch_model.bin.index.json"), args.device, - ) - # load optim - optimizer.load_state_dict( - torch.load( - os.path.join( - CKPT_DIR, _get_binary_filename(dist.get_rank(), is_optim=True) - ) - ) + optim=optimizer, ) torch.testing.assert_close(mod.state_dict(), submod_ref) From c95d2e6f602e68611e004bbf6e2a00f954b73e9d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 14 Jul 2023 12:41:13 -0400 Subject: [PATCH 29/96] Enable pipeline + DDP (c10d version) (#828) This is a mirror of #108, implemented in the c10d context. User can use DDP to wrap the pipeline stage submodule compiled by PiPPy, achieving composability compared to the previous `init_data_parallel` API offered by PiPPy. For each pipeline stage, we omit DDP synchronization (i.e. gradient all-reduce) on all except the last backward micro-batch. A `local_test_c10d_ddp.py` has been added too. **Update 07/14:** Added gradient equivalence test. Test command: ``` torchrun --nproc-per-node 8 local_test_c10d_ddp.py ``` Test output: ``` PP group size = 4, DP group size = 2 ... Output equivalence test passed Output equivalence test passed Checking gradient of lin1.weight Checking gradient of lin1.bias Gradient equivalence test passed Gradient equivalence test passed Checking gradient of mm_param2 Gradient equivalence test passed Gradient equivalence test passed Checking gradient of mm_param1 Checking gradient of lin0.weight Checking gradient of mm_param0 Checking gradient of lin0.bias Gradient equivalence test passed Gradient equivalence test passed Gradient equivalence test passed Gradient equivalence test passed ``` --- pippy/PipelineStage.py | 40 ++++++- test/local_test_c10d_ddp.py | 228 ++++++++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+), 2 deletions(-) create mode 100644 test/local_test_c10d_ddp.py diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 96fa2166a..53514ae5f 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel import pippy import pippy.fx @@ -485,6 +486,36 @@ def _send_grads( return grad_send_reqs def forward(self, *args, **kwargs): + # If submod is wrapped with DDP, we use the `no_sync` context manager to + # avoid gradient all-reduce per microbatch + def forward_maybe_with_nosync(*args, **kwargs): + if isinstance(self.submod, DistributedDataParallel): + with self.submod.no_sync(): # type: ignore[operator] + out_val = self.submod(*args, **kwargs) + else: + out_val = self.submod(*args, **kwargs) + return out_val + + def backward_maybe_with_nosync(bwd_kwargs: Dict, is_last_chunk: bool): + if isinstance(self.submod, DistributedDataParallel): + if is_last_chunk: + # HACK: reaching into DDP implementation details here. Is there a better way? + self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] + list( + torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] + bwd_kwargs["stage_output"] + ) + ) + ) + grads_input, _ = stage_backward(**bwd_kwargs) + else: + with self.submod.no_sync(): # type: ignore[operator] + grads_input, _ = stage_backward(**bwd_kwargs) + else: + # Non-DDP submodule, regular backward + grads_input, _ = stage_backward(**bwd_kwargs) + return grads_input + # map microbatch ID to list of forward tensor args fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} @@ -513,7 +544,9 @@ def forward(self, *args, **kwargs): # Compute forward try: - output = self.submod(*composite_args, **composite_kwargs) + output = forward_maybe_with_nosync( + *composite_args, **composite_kwargs + ) except Exception as e: exc_msg = f""" @@ -567,7 +600,10 @@ def forward(self, *args, **kwargs): ) # `stage_backward` node does not have `args`, only `kwargs` - grads_input, _ = stage_backward(**bwd_kwargs) + grads_input = backward_maybe_with_nosync( + bwd_kwargs, + bwd_chunk == self.chunks - 1, + ) grad_send_reqs = self._send_grads(grads_input) all_grad_send_reqs += grad_send_reqs diff --git a/test/local_test_c10d_ddp.py b/test/local_test_c10d_ddp.py new file mode 100644 index 000000000..1fe163ff3 --- /dev/null +++ b/test/local_test_c10d_ddp.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os +import unittest + +import pippy + +import torch +import torch.distributed as dist +from pippy.IR import pipe_split +from torch.nn.parallel import DistributedDataParallel + + +d_hid = 512 +chunk_size = 256 + + +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin0 = torch.nn.Linear(d_hid, d_hid) + self.lin1 = torch.nn.Linear(d_hid, d_hid) + self.loss_fn = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, target): + x = torch.mm(x, self.mm_param0) + skip_connection = x + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param1) + x = self.lin0(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin1(x) + x = torch.relu(x) + loss = self.loss_fn(x, target) + return {"loss": loss} + + +def create_model() -> torch.nn.Module: + # Fix a seed such that models are created the same + torch.manual_seed(42) + ec = ExampleCode() + return ec + + +# Get process group for ranks in a pipeline +def get_pp_subgroup(args): + my_pp_rank = args.rank // args.dp_group_size + my_dp_rank = args.rank % args.dp_group_size + for dp_rank in range(0, args.dp_group_size): + pp_group_ranks = list( + range(dp_rank, args.world_size, args.dp_group_size) + ) + pp_group = dist.new_group(ranks=pp_group_ranks) + if dp_rank == my_dp_rank: + my_pp_group = pp_group + print(f"Rank {args.rank} done getting pp group") + return my_pp_group, my_pp_rank + + +# Get DP process group for ranks with the same stage +def get_dp_subgroup(args): + my_pp_rank = args.rank // args.dp_group_size + my_dp_rank = args.rank % args.dp_group_size + for pp_rank in range(0, args.pp_group_size): + dp_group_ranks = list( + range( + pp_rank * args.dp_group_size, (pp_rank + 1) * args.dp_group_size + ) + ) + dp_group = dist.new_group(ranks=dp_group_ranks) + if pp_rank == my_pp_rank: + my_dp_group = dp_group + print(f"Rank {args.rank} done getting dp group") + return my_dp_group, my_dp_rank + + +# Main program +def run_worker(args): + ec_with_loss = create_model() + ec_with_loss.to(args.device) + + input = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + + # Get DP and PP sub process groups + dp_group, dp_rank = get_dp_subgroup(args) + pp_group, pp_rank = get_pp_subgroup(args) + + stage = pippy.compile_stage( + ec_with_loss, + pp_rank, + args.pp_group_size, + args.chunks, + args.device, + pp_group, + [input, target], + ) + + # Form a map from original qualname to param for equivalence check later + pipe_params = {} + for qualname, param in stage.submod.named_parameters(): + origin_name = stage.submod.remap_qualname(qualname) + pipe_params[origin_name] = param + + # Wrap stage module with DDP + stage.submod = DistributedDataParallel( + stage.submod, + process_group=dp_group, + ) + + # Run + if pp_rank == 0: + stage(input) + elif pp_rank == args.pp_group_size - 1: + pipe_out = stage(target) + else: + stage() + + # Form a map from original qualname to gradient for equivalence check later + pipe_grads = {} + for origin_name, pipe_param in pipe_params.items(): + pipe_grads[origin_name] = pipe_param.grad + + # DDP reference model + ref_mod = create_model() + ref_mod.to(args.device) + ddp_ref_mod = DistributedDataParallel( + ref_mod, + process_group=dp_group, + ) + + # DDP forward and backward + ddp_out = ddp_ref_mod(input, target) + ddp_out["loss"].backward() + + # Compare pipeline output and DDP output + if pp_rank == args.pp_group_size - 1: + torch.testing.assert_close(pipe_out, ddp_out) + print("Output equivalence test passed") + + # Compare pipeline gradient and DDP gradient + for origin_name, pipe_grad in pipe_grads.items(): + ddp_param = ddp_ref_mod.module.get_parameter(origin_name) + if dp_rank == 0: + print(f"Checking gradient of {origin_name}") + # Since we use synthetic input and output, the gradients generated are + # large. Hence we need to manually set relative tolerance + torch.testing.assert_close( + pipe_grad, + ddp_param.grad, + rtol=7e-2, + atol=1e-5, + ) + + print("Gradient equivalence test passed") + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 8)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + # pp group size must match with pipe_split's in model + args.pp_group_size = 4 + # world size must be multiple of pp group size + assert args.world_size % args.pp_group_size == 0 + args.dp_group_size = args.world_size // args.pp_group_size + if args.rank == 0: + print( + f"PP group size = {args.pp_group_size}, DP group size = {args.dp_group_size}" + ) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) + + +if __name__ == "__main__": + main() + + +class LocalTestC10dDDPTest(unittest.TestCase): + def test_c10d_ddp(self): + import random + + port = random.randint(29500, 30000) + args = [ + "--master_port", + str(port), + ] + main(args) From a6f99973347be65c892b9099d0f14c5da505743c Mon Sep 17 00:00:00 2001 From: Yeonju Ro Date: Fri, 14 Jul 2023 10:14:51 -0700 Subject: [PATCH 30/96] selective 2d api/example added for fine-grained tp/pp demo (#830) ## Description added 2d parallelism (tp+pp) API and example for fine-grained tp/pp ## Checklist: - [v] Has code been commented, particularly in hard-to-understand areas? - [v] Have you made corresponding changes to the documentation? --- examples/selective2d/2d_train.py | 478 +++++++++++++++++++++++++++ examples/selective2d/model.py | 540 +++++++++++++++++++++++++++++++ 2 files changed, 1018 insertions(+) create mode 100644 examples/selective2d/2d_train.py create mode 100644 examples/selective2d/model.py diff --git a/examples/selective2d/2d_train.py b/examples/selective2d/2d_train.py new file mode 100644 index 000000000..175a54c85 --- /dev/null +++ b/examples/selective2d/2d_train.py @@ -0,0 +1,478 @@ +""" +This training script updates NanoGPT to run with either TP, PP, or TP+PP (2D). +Usage: +gpurun4 torchrun --nproc-per-node 4 2d_train.py +""" + +import argparse +import os +import time + +import torch +import torch.distributed as dist + +from model import GPT, GPTConfig +from pippy.compile import compile_stage + +from pippy.IR import annotate_split_points, PipeSplitWrapper +from pippy.microbatch import sum_reducer, TensorChunkSpec + +from torch.distributed._tensor import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PairwiseParallel, + parallelize_module, + RowwiseParallel, +) + + +def get_args(): + # default config values designed to train a gpt2 (124M) on OpenWebText + + def str_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("true", "t", "1"): + return True + elif v.lower() in ("false", "f", "0"): + return False + else: + raise ArgumentTypeError("Boolean expected.") + + # I/O + parser = argparse.ArgumentParser() + parser.add_argument("--out_dir", type=str, default="out") + parser.add_argument("--eval_interval", type=int, default=2000) + parser.add_argument("--log_interval", type=int, default=2) + parser.add_argument("--eval_iters", type=int, default=200) + parser.add_argument( + "--eval_only", type=str_to_bool, default=False + ) # if True, script exits right after the first eval + parser.add_argument( + "--always_save_checkpoint", type=str_to_bool, default=True + ) # if True, always save a checkpoint after each eval + parser.add_argument( + "--init_from", type=str, default="scratch" + ) # 'scratch', 'resume', 'gpt2*' + parser.add_argument("--train_iters", type=int, default=200000) + parser.add_argument("--seed", type=int, default=1337) + + # data + parser.add_argument( + "--dataset", type=str, default="shakespeare_char" + ) # "openwebtext" + parser.add_argument( + "--gradient_accumulation_steps", type=int, default=1 + ) # used to simulate larger batch sizes + parser.add_argument( + "--batch_size", type=int, default=12 + ) # if gradient_accumulation_steps > 1, this is the micro-batch size + parser.add_argument("--block_size", type=int, default=1024) + parser.add_argument("--vocab_size", type=int, default=50304) + + # model + parser.add_argument("--n_layer", type=int, default=12) + parser.add_argument("--n_head", type=int, default=12) + parser.add_argument("--n_embd", type=int, default=768) + parser.add_argument( + "--dropout", type=float, default=0.0 + ) # for pretraining 0 is good, for finetuning try 0.1+ + parser.add_argument("--bias", type=str_to_bool, default=False) + + # adamw optimizer + parser.add_argument( + "--learning_rate", type=float, default=4e-4 + ) # max learning rate + parser.add_argument( + "--max_iters", type=int, default=600000 + ) # total number of training iterations + parser.add_argument("--weight_decay", type=float, default=1e-2) + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.95) + parser.add_argument( + "--grad_clip", type=float, default=1.0 + ) # clip gradients at this value, or disable if == 0.0 + parser.add_argument( + "--decay_lr", type=str_to_bool, default=True + ) # whether to decay the learning rate + parser.add_argument("--warmup_iters", type=int, default=2000) + parser.add_argument("--lr_decay_iters", type=int, default=600000) + parser.add_argument( + "--min_lr", type=float, default=6e-5 + ) # minimum learning rate + + # distributed + parser.add_argument( + "--backend", type=str, default="nccl" + ) # 'nccl', 'gloo', etc. + parser.add_argument( + "--compile", type=str_to_bool, default=False + ) # use PyTorch 2.0 to compile the model to be faster + parser.add_argument("--rank", type=int, default=int(os.environ["RANK"])) + parser.add_argument( + "--local_rank", type=int, default=int(os.environ["LOCAL_RANK"]) + ) + parser.add_argument( + "--world_size", type=int, default=int(os.environ["WORLD_SIZE"]) + ) + parser.add_argument( + "--device", type=str, default=f"cuda:{os.environ['LOCAL_RANK']}" + ) + parser.add_argument( + "--master_process", + type=str_to_bool, + default=bool(os.environ["RANK"] == 0), + ) + parser.add_argument("--tp_size", type=int, default=2) + parser.add_argument("--pp_size", type=int, default=2) + + parser.add_argument("--debug", dest="debug", action="store_true") + + args = parser.parse_args() + + return args + + +def rank_print(x): + _rank = os.getenv("RANK") + if _rank == "0": + print(x) + + +def get_rand(args): + x = torch.randint( + 0, + args.vocab_size, + (args.batch_size, args.block_size), + device=args.device, + ) + y = torch.randint( + 0, + args.vocab_size, + (args.batch_size, args.block_size), + device=args.device, + ) + return x, y + + +def tp_attention(model, name, mesh, tp_dim=0, q="q", k="k", v="v", o="c_proj"): + layer = model.get_submodule(name) + parallelize_module( + layer, + mesh, + { + q: ColwiseParallel(), + k: ColwiseParallel(), + v: ColwiseParallel(), + o: RowwiseParallel(), + }, + tp_mesh_dim=tp_dim, + ) + + return model + + +def tp_mlp(model, name, mesh, tp_dim=0, mlp="mlp"): + layer = model.get_submodule(name) + parallelize_module( + layer, mesh, {mlp: PairwiseParallel()}, tp_mesh_dim=tp_dim + ) + + return model + + +def tp(model, n_layer, mesh, offset=0, tp_dim=0): + for i in range(n_layer): + block = model.get_submodule(f"transformer.h.{i + offset}") + parallelize_module( + block, + mesh, + { + "attn.q": ColwiseParallel(), + "attn.k": ColwiseParallel(), + "attn.v": ColwiseParallel(), + "attn.c_proj": RowwiseParallel(), + "mlp": PairwiseParallel(), + }, + tp_mesh_dim=tp_dim, + ) + + return model + + +def pp(model, pp_device_mesh, args): + pp_chunks = args.world_size + pp_groups = pp_device_mesh.get_dim_groups()[0] + + output_chunk_spec = (TensorChunkSpec(0), sum_reducer) + stage = compile_stage( + model, + args.rank, + args.world_size, + pp_chunks, + pp_device_mesh, + pp_groups, + example_inputs=[X, Y], + output_chunk_spec=output_chunk_spec, + ) + + print(f"[Rank{_rank}] {stage.submod.print_readable()}") + return model, stage + + +def pp_and_tp(model, mesh, args): + """ + Apply TP and PP to all layers in a model + This function assumes the model is already cut manually + """ + pp_dim, tp_dim = 0, 1 + pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size + pp_groups = mesh.get_dim_groups()[pp_dim] + + # TP + tp(model, args.n_layer, mesh, 0, tp_dim) + + X, Y = get_rand(args) + + # PP + stage = compile_stage( + model, + pp_rank, + args.world_size, + args.pp_size, + args.device, + pp_groups, + example_inputs=[X, Y], + ) + + return model, stage + + +def even_cut(model, args, pp_size): + """ + Evenly cut a model into pp_size stages + """ + cut = {} + cutpoint = args.n_layer // pp_size + for i in range(args.n_layer): + name = f"transformer.h.{i}" + if i > 0 and i % cutpoint == 0: + cut[name] = PipeSplitWrapper.SplitPoint.BEGINNING # or END + + annotate_split_points(model, cut) + + +def after_ar_cut(model, args, pp_size): + """ + Cut a model right after AllReduce happens + """ + cut = {} + cutpoint = args.n_layer // pp_size + for i in range(args.n_layer): + name = f"transformer.h.{i}" + if i != args.n_layer - 1 and i % cutpoint == cutpoint - 1: + cut[f"{name}.mlp.dropout"] = PipeSplitWrapper.SplitPoint.BEGINNING + + annotate_split_points(model, cut) + + +def pp_and_tp_selective( + model, mesh, args, tp_attn_layers=None, tp_mlp_layers=None, cut_fn=even_cut +): + """ + Apply pipeline parallelism and tensor parallelism to a model. + """ + + pp_dim, tp_dim = 0, 1 + pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size + pp_groups = mesh.get_dim_groups()[pp_dim] + + # TP + # Apply TP to layers if layer_id is in tp_attn / tp_mlp + tp_attn_layers = ( + list(range(args.n_layer)) if tp_attn_layers is None else tp_attn_layers + ) + tp_mlp_layers = ( + list(range(args.n_layer)) if tp_mlp_layers is None else tp_mlp_layers + ) + for i in range(args.n_layer): + name = f"transformer.h.{i}" + att = tp_attention(model, f"{name}.attn", mesh, tp_dim) + mlp = tp_mlp(model, f"{name}", mesh, tp_dim) + + X, Y = get_rand(args) + + # PP + cut_fn(model, args, args.pp_size) + stage = compile_stage( + model, + pp_rank, + args.world_size, + args.pp_size, + args.device, + pp_groups, + example_inputs=[X, Y], + ) + + return model, stage + + +def pp_tp_train(stage, mesh, args): + pp_dim, tp_dim = 0, 1 + pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size + pp_groups = mesh.get_dim_groups()[pp_dim] + + train_iters = 10 if args.debug else args.train_iters + optimizer = torch.optim.AdamW( + stage.submod.parameters(), lr=args.learning_rate + ) + local_iter_num = 0 + iter_time = 0.0 + while local_iter_num < train_iters: + optimizer.zero_grad() + t0 = time.perf_counter() + X, Y = get_rand(args) + if pp_rank == 0: + out = stage(X) + elif pp_rank == args.pp_size - 1: + out = stage(Y) + else: + out = stage() + optimizer.step() + t1 = time.perf_counter() + dt = t1 - t0 + local_iter_num += 1 + iter_time += dt + + return local_iter_num, iter_time + + +def pp_train(stage, args): + train_iters = 10 if args.debug else args.train_iters + optimizer = torch.optim.AdamW( + stage.submod.parameters(), lr=args.learning_rate + ) + local_iter_num = 0 + iter_time = 0.0 + while local_iter_num < train_iters: + optimizer.zero_grad() + t0 = time.perf_counter() + X, Y = get_rand(args) + if args.rank == 0: + out = stage(X) + elif args.rank == args.world_size - 1: + out = stage(Y) + else: + out = stage() + optimizer.step() + t1 = time.perf_counter() + dt = t1 - t0 + local_iter_num += 1 + iter_time += dt + + return local_iter_num, iter_time + + +def tp_train(): + local_iter_num = 0 + iter_time = 0.0 + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + while local_iter_num < train_iters: + optimizer.zero_grad(set_to_none=True) + t0 = time.perf_counter() + X, Y = get_rand(args) + logits, loss = model(X, Y) + loss.backward() + optimizer.step() + torch.distributed.barrier() + t1 = time.perf_counter() + dt = t1 - t0 + lossf = loss.item() + rank_print( + f"iter {local_iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms" + ) + local_iter_num += 1 + iter_time += dt + + return local_iter_num, iter_time + + +if __name__ == "__main__": + _multi_gpu = int(os.environ.get("RANK", -1)) != -1 # verify distributed run + assert ( + _multi_gpu + ), "this config assumes distributed setup - multi-gpu not ready here." + + args = get_args() + + device_type = ( + "cuda" if "cuda" in args.device else "cpu" + ) # for later use in torch.autocast + torch.cuda.set_device(args.device) + + dist.init_process_group( + backend=args.backend, rank=args.rank, world_size=args.world_size + ) + + if args.master_process: + os.makedirs(args.out_dir, exist_ok=True) + + torch.manual_seed(args.seed) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + torch.backends.cuda.enable_mem_efficient_sdp(enabled=False) + + # init these up here, can override if init_from='resume' (i.e. from a checkpoint) + iter_num = 0 + best_val_loss = 1e9 + + # model init + model_args = dict( + n_layer=args.n_layer, + n_head=args.n_head, + n_embd=args.n_embd, + block_size=args.block_size, + bias=args.bias, + vocab_size=None, + dropout=args.dropout, + ) # start with model_args from command line + + # init a new model from scratch + rank_print("Initializing a new model from scratch") + + oned_mesh = DeviceMesh(device_type, list(range(args.world_size))) + twod_mesh = DeviceMesh( + device_type=device_type, + mesh=torch.arange(0, args.world_size).view(-1, args.tp_size), + ) + + model_args["vocab_size"] = args.vocab_size + + gptconf = GPTConfig(**model_args) + model = GPT(twod_mesh, gptconf, args.device, args.pp_size) + model.to(args.device) + + _current_model_params = model.get_num_params() / 1e6 + + # model = tp(model, args.n_layer, oned_mesh) + # model, stage = pp(model, oned_mesh, args) + # model, stage = pp_and_tp(model, twod_mesh, args) + model, stage = pp_and_tp_selective(model, twod_mesh, args) + + # iter_count, iter_time = pp_train(stage, args) + iter_count, iter_time = pp_tp_train(stage, twod_mesh, args) + + # display run stats + rank_print(f"\nTraining completed.\n") + + gpu_type = torch.cuda.get_device_name(0) + gpu_count = dist.get_world_size() + rank_print(f"\n----- Performance Stats --------\n") + rank_print(f"\nModel Size: {_current_model_params:.2f}M") + rank_print(f"Run completed with {gpu_count} gpus, of type {gpu_type}") + iter_avg = round(iter_time / iter_count, 4) + rank_print( + f"Avg iter speed (in seconds): {iter_avg}, with {iter_count} iterations averaged.\n" + ) + + dist.destroy_process_group() diff --git a/examples/selective2d/model.py b/examples/selective2d/model.py new file mode 100644 index 000000000..3bfd07d25 --- /dev/null +++ b/examples/selective2d/model.py @@ -0,0 +1,540 @@ +# Original code from https://github.com/karpathy/nanoGPT +""" +Full definition of a GPT Language Model, all of it in this single file. +References: +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +""" +import inspect + +import math +import os +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) +def new_gelu(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, mesh, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + self.mesh = mesh + + def forward(self, input): + return F.layer_norm( + input, self.weight.shape, self.weight, self.bias, 1e-5 + ) + + +class CausalSelfAttention(nn.Module): + def __init__(self, mesh, config): + super().__init__() + tp_size = mesh.mesh.size(0) + assert config.n_head % tp_size == 0 + assert config.n_embd % config.n_head == 0 + self.mesh = mesh + self.tp_size = tp_size + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear( + config.n_embd, 3 * config.n_embd, bias=config.bias + ) + self.q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.k = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.v = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = ( + hasattr(torch.nn.functional, "scaled_dot_product_attention") + and self.dropout == 0.0 + ) + + if not self.flash: + print( + "WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0" + ) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.block_size = config.block_size + self.register_buffer( + "bias", + torch.tril( + torch.ones(config.block_size, config.block_size) + ).view(1, 1, config.block_size, config.block_size), + ) + + def forward(self, x): + ( + B, + T, + C, + ) = ( + x.size() + ) # batch size, sequence length, embedding dimensionality (n_embd) + + def print0(msg): + if os.getenv("RANK") == "0": + print(msg) + + channel_head_size = C // self.n_head + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = ( + self.q(x) + .split(self.n_embd // self.tp_size, dim=2)[0] + .view(B, T, self.n_head // self.tp_size, C // self.n_head) + .transpose(1, 2) + ) # (B, nh, T, hs) + k = ( + self.k(x) + .split(self.n_embd // self.tp_size, dim=2)[0] + .view(B, T, self.n_head // self.tp_size, C // self.n_head) + .transpose(1, 2) + ) # (B, nh, T, hs) + v = ( + self.v(x) + .split(self.n_embd // self.tp_size, dim=2)[0] + .view(B, T, self.n_head // self.tp_size, C // self.n_head) + .transpose(1, 2) + ) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True + ) + else: + # manual implementation of attention + from torch.distributed._tensor import ( + DeviceMesh, + distribute_tensor, + Replicate, + Shard, + ) + + mesh = DeviceMesh("cuda", list(range(2))) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C // self.tp_size) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear( + config.n_embd, 4 * config.n_embd, bias=config.bias + ) + self.gelu = nn.GELU() + self.c_proj = nn.Linear( + 4 * config.n_embd, config.n_embd, bias=config.bias + ) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, mesh, config): + super().__init__() + self.ln_1 = LayerNorm(mesh, config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(mesh, config) + self.ln_2 = LayerNorm(mesh, config.n_embd, bias=config.bias) + self.mlp = MLP(config) + self.mesh = mesh + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +class GPT(nn.Module): + def __init__(self, mesh, config, device, pp_size=2): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.config = config + self.mesh = mesh + self.pp_size = pp_size + self.device = device + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList( + [Block(mesh, config) for _ in range(config.n_layer)] + ), + ln_f=LayerNorm(mesh, config.n_embd, bias=config.bias), + ) + ) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # with weight tying when using torch.compile() some warnings get generated: + # "UserWarning: functional_call was passed multiple values for tied weights. + # This behavior is deprecated and will be an error in future versions" + # not 100% sure what this is, so far seems to be harmless. TODO investigate + self.transformer.wte.weight = ( + self.lm_head.weight + ) # https://paperswithcode.com/method/weight-tying + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + ) + + # report number of parameters + print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer.wpe.weight.numel() + return n_params + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, idx, targets=None): + # device = idx.device + # b, t = idx.size() + # assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # WARNING: t needs to actual sequence length, shape should be (1,t) + pos = torch.arange( + 0, self.config.block_size, dtype=torch.long, device=self.device + ).unsqueeze(0) + + # forward the GPT model itself + tok_emb = self.transformer.wte( + idx + ) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, n_embd) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-1, + ) + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head( + x[:, [-1], :] + ) # note: using list [-1] to preserve the time dim + loss = None + + return logits, loss + + def crop_block_size(self, block_size): + # model surgery to decrease the block size if necessary + # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) + # but want to use a smaller block size for some smaller, simpler model + assert block_size <= self.config.block_size + self.config.block_size = block_size + self.transformer.wpe.weight = nn.Parameter( + self.transformer.wpe.weight[:block_size] + ) + for block in self.transformer.h: + block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] + + @classmethod + def from_pretrained(cls, model_type, override_args=None): + assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} + override_args = override_args or {} # default to empty dict + # only dropout can be overridden see more notes below + assert all(k == "dropout" for k in override_args) + from transformers import GPT2LMHeadModel + + print("loading weights from pretrained gpt: %s" % model_type) + + # n_layer, n_head and n_embd are determined from model_type + config_args = { + "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params + "gpt2-medium": dict( + n_layer=24, n_head=16, n_embd=1024 + ), # 350M params + "gpt2-large": dict( + n_layer=36, n_head=20, n_embd=1280 + ), # 774M params + "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + }[model_type] + print("forcing vocab_size=50257, block_size=1024, bias=True") + config_args[ + "vocab_size" + ] = 50257 # always 50257 for GPT model checkpoints + config_args[ + "block_size" + ] = 1024 # always 1024 for GPT model checkpoints + config_args["bias"] = True # always True for GPT model checkpoints + # we can override the dropout rate, if desired + if "dropout" in override_args: + print(f"overriding dropout rate to {override_args['dropout']}") + config_args["dropout"] = override_args["dropout"] + # create a from-scratch initialized minGPT model + config = GPTConfig(**config_args) + model = GPT(config) + sd = model.state_dict() + sd_keys = sd.keys() + sd_keys = [ + k for k in sd_keys if not k.endswith(".attn.bias") + ] # discard this mask / buffer, not a param + + # init a huggingface/transformers model + model_hf = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = model_hf.state_dict() + + # copy while ensuring all of the parameters are aligned and match in names and shapes + sd_keys_hf = sd_hf.keys() + sd_keys_hf = [ + k for k in sd_keys_hf if not k.endswith(".attn.masked_bias") + ] # ignore these, just a buffer + sd_keys_hf = [ + k for k in sd_keys_hf if not k.endswith(".attn.bias") + ] # same, just the mask (buffer) + transposed = [ + "attn.c_attn.weight", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_proj.weight", + ] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear + # this means that we have to transpose these weights when we import them + assert len(sd_keys_hf) == len( + sd_keys + ), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" + for k in sd_keys_hf: + if any(k.endswith(w) for w in transposed): + # special treatment for the Conv1D weights we need to transpose + assert sd_hf[k].shape[::-1] == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + # vanilla copy over the other parameters + assert sd_hf[k].shape == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k]) + + return model + + def configure_optimizers( + self, weight_decay, learning_rate, betas, device_type + ): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear,) + blacklist_weight_modules = ( + torch.nn.LayerNorm, + LayerNorm, + torch.nn.Embedding, + ) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + # random note: because named_modules and named_parameters are recursive + # we will see the same tensors p many many times. but doing it this way + # allows us to know which parent module any tensor p belongs to... + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance( + m, whitelist_weight_modules + ): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance( + m, blacklist_weight_modules + ): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they + # will appear in the no_decay and decay sets respectively after the above. + # In addition, because named_parameters() doesn't return duplicates, it + # will only return the first occurence, key'd by 'transformer.wte.weight', below. + # so let's manually remove 'lm_head.weight' from decay set. This will include + # this tensor into optimization via transformer.wte.weight only, and not decayed. + decay.remove("lm_head.weight") + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), "parameters %s made it into both decay/no_decay sets!" % ( + str(inter_params), + ) + assert len(param_dict.keys() - union_params) == 0, ( + "parameters %s were not separated into either decay/no_decay set!" + % (str(param_dict.keys() - union_params),) + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, + ] + # new PyTorch nightly has a new 'fused' option for AdamW that is much faster + use_fused = (device_type == "cuda") and ( + "fused" in inspect.signature(torch.optim.AdamW).parameters + ) + use_fused = False # YEONJU + print(f"using fused AdamW: {use_fused}") + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW( + optim_groups, lr=learning_rate, betas=betas, **extra_args + ) + + return optimizer + + def estimate_mfu(self, num_params, fwdbwd_per_iter, dt): + """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" + # first estimate the number of flops we do per iteration. + # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 + # N = self.get_num_params() + N = num_params + cfg = self.config + tp_size = 2 + actual_head = cfg.n_head // tp_size + # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size + L, H, Q, T = ( + cfg.n_layer, + actual_head, + cfg.n_embd // actual_head, + cfg.block_size, + ) + flops_per_token = 6 * N + 12 * L * H * Q * T + flops_per_fwdbwd = flops_per_token * T + flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter + # express our flops throughput as ratio of A100 bfloat16 peak flops + flops_achieved = flops_per_iter * (1.0 / dt) # per second + # flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS + # mfu = flops_achieved / flops_promised + flops_promised = 125e12 # A10 TFlops .... 312e12 A100 GPU bfloat16 peak flops is 312 TFLOPS + mfu = (flops_achieved / flops_promised) / tp_size + return mfu + + @torch.no_grad() + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = ( + idx + if idx.size(1) <= self.config.block_size + else idx[:, -self.config.block_size :] + ) + # forward the model to get the logits for the index in the sequence + logits, _ = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx From e60ebeaca13d0a52361f92cccfbcb594e1886ca2 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 14 Jul 2023 18:19:02 -0400 Subject: [PATCH 31/96] Set output_grads correctly (#840) ## Issue In case there are multiple output values and one of them is loss, some user reported the following error: ``` output_grads[i] for i in outputs_with_grads_idxs IndexError: tuple index out of range ... RuntimeError: Failed to run backward stage stage_backward for stage %submod_7 : [ = call_module[target=submod_7](args = (%submod_6, %_inputs), kwargs = {}) Stage output: ('Tensor(torch.Size([100, 20, 4096]), grad=False)', 'Tensor(torch.Size([100, 4096]), grad=False)', 'Tensor(torch.Size([100, 4096]), grad=False)', 'Tensor(torch.Size([]), grad=True)', 'Tensor(torch.Size([100]), grad=False)', 'Tensor(torch.Size([100]), grad=False)') Output gradient: ('None',) Input: ['Tensor(torch.Size([100, 20, 4096]), grad=True)', 'Tensor(torch.Size([100, 20, 4096]), grad=False)', 'Tensor(torch.Size([100]), grad=False)', 'Tensor(torch.Size([100]), grad=False)'] ``` Note this part: `Output gradient: ('None',)` I can repro the issue in local_test_c10d_bwd.py, if I change the output to: ``` - return {"loss": loss} + return {"logits": x, "loss": loss} ``` ## Cause The above issue is caused by the fixed setting in the else case: ``` # (None,) is for `stage_backward` signature bwd_kwargs["output_grads"] = ( grads if len(grads) > 0 else (None,) ) ``` This tuple `(None,)` should have variable length depending on the output. ## Fix Only update `bwd_kwargs["output_grads"]` when we have actually received gradients; otherwise, use the tuple prepared during IR phase, i.e. `bwd_node.kwargs["output_grads"]`, which may look like `(None, None)` if there are two outputs. --- pippy/PipelineStage.py | 11 +++++++---- test/local_test_c10d_bwd.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 53514ae5f..cb3a311a1 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -594,10 +594,13 @@ def backward_maybe_with_nosync(bwd_kwargs: Dict, is_last_chunk: bool): bwd_kwargs["stage_output"], bwd_kwargs["input_values"], ) = fwd_cache.pop(bwd_chunk) - # (None,) is for `stage_backward` signature - bwd_kwargs["output_grads"] = ( - grads if len(grads) > 0 else (None,) - ) + # Fill actual gradients received for outputs + # If nothing received, as in the case of last stage, then we + # would use the default `output_grads` prepared in the IR phase, + # i.e. from `bwd_node.kwargs`. For example, it may look like + # this if there are two outputs: ('None', 'None') + if len(grads) > 0: + bwd_kwargs["output_grads"] = grads # `stage_backward` node does not have `args`, only `kwargs` grads_input = backward_maybe_with_nosync( diff --git a/test/local_test_c10d_bwd.py b/test/local_test_c10d_bwd.py index 5cdd897e9..8bfd2a05c 100644 --- a/test/local_test_c10d_bwd.py +++ b/test/local_test_c10d_bwd.py @@ -39,7 +39,7 @@ def forward(self, x, target): x = self.lin(x) x = torch.relu(x) loss = self.mse_loss(x, target) - return {"loss": loss} + return {"logits": x, "loss": loss} def run_worker(args): @@ -74,7 +74,7 @@ def run_worker(args): # Last rank checks result if args.rank == args.world_size - 1: ref_out = ec(ec_x, target) - torch.testing.assert_close(out["loss"], ref_out["loss"]) + torch.testing.assert_close(out, ref_out) print( f"equivalence test passed, loss = {out['loss']}, ref loss = {ref_out['loss']}" ) From 9a7e1e92ded7fc070a108536253530bcd39135a2 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:59:04 -0700 Subject: [PATCH 32/96] mkdir -p when creating ckpt dir (#841) ## Description Create full path, if not initially present, when saving checkpoints. ## Type of change - [x] New feature (non-breaking change which adds functionality) ## Checklist: - [x] Have you added tests that prove your fix is effective or that this feature works? - [x] Has code been commented, particularly in hard-to-understand areas? --------- Co-authored-by: Eddy --- examples/checkpoint/toy_model.py | 182 +++++++++++++++++++++++++++++++ pippy/hf/_SaveModule.py | 10 +- 2 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 examples/checkpoint/toy_model.py diff --git a/examples/checkpoint/toy_model.py b/examples/checkpoint/toy_model.py new file mode 100644 index 000000000..b71fae11d --- /dev/null +++ b/examples/checkpoint/toy_model.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os + +import torch +from torch.utils.data import Dataset, random_split +import torch.distributed as dist +import torch.optim as optim + +from pippy.hf._SaveModule import save_checkpoint +from pippy.compile import compile_stage +from pippy.IR import pipe_split + + +d_hid = 512 +chunk_size = 256 + +torch.manual_seed(0) + + +class RandomCustomDataset(Dataset): + def __init__(self, chunks=1, size=100): # TODO: reset size to 10000 + self.samples = [torch.randn(chunks * chunk_size, d_hid) for _ in range(size)] + self.targets = [torch.randn(chunks * chunk_size, d_hid) for _ in range(size)] + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx], self.targets[idx] + + +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin = torch.nn.Linear(d_hid, d_hid) + self.mse_loss = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, target): + x = torch.mm(x, self.mm_param) + skip_connection = x + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param) + x = self.lin(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin(x) + x = torch.relu(x) + loss = self.mse_loss(x, target) + return {"loss": loss} + + +def run_worker(args): + ec = ExampleCode() + ec.to(args.device) + ec.train() + + ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + + ds = RandomCustomDataset(chunks=args.chunks) + train_size = int(0.7*len(ds)) + test_size = len(ds) - train_size + train_ds, test_ds = random_split(ds, [train_size, test_size]) + datasets = { + "train": train_ds, + "test": test_ds, + } + + stage = compile_stage( + ec, + args.rank, + args.world_size, + args.chunks, + args.device, + None, + [ec_x, target], + ) + + # Create an optimizer for stage submodule's parameters + optimizer = optim.SGD(stage.submod.parameters(), lr=1e-3, momentum=0.9) + + for epoch in range(args.epochs): # change to no. epochs + print(f"Epoch: {epoch + 1}") + + # save checkpoint + if (epoch + 1) % args.checkpoint_epochs == 0: # save ckpt after every `args.checkpoint_epochs` epochs + print("RUuuuuuuuuuuunnnnnnnnnninngggggggggg") + save_checkpoint(stage, checkpoint_dir=os.path.join("checkpoints", str(epoch + 1)), optimizer=optimizer) + print("Doooooooooooonnnnnnnnnnnnnneeeeeeeeeee") + + for k, dataset in datasets.items(): + epoch_correct = 0 + epoch_all = 0 + for i, (x, y) in enumerate(dataset): + x = x.to(args.device) + y = y.to(args.device) + if k == "train": + # Zero gradients + optimizer.zero_grad() + # Run + if args.rank == 0: + out = stage(x) + elif args.rank == args.world_size - 1: + out = stage(target) + out_tensor = out['loss'] + preds = out_tensor.argmax(-1) + correct = (preds == y).sum() + epoch_all += len(y) + epoch_correct += correct.item() + # Take an optimization step + optimizer.step() + else: + stage() + else: + stage.eval() + with torch.no_grad(): + if args.rank == 0: + out = stage(x) + elif args.rank == args.world_size - 1: + out = stage(x)['loss'] + preds = out.argmax(-1) + correct = (preds == y).sum() + epoch_all += len(y) + epoch_correct += correct.item() + else: + stage(x) + # print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}") + + dist.barrier() + print(f"Rank {args.rank} completes") + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument("--epochs", type=int, default=2) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + parser.add_argument("--checkpoint_epochs", type=int, default=1) + args = parser.parse_args(args) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) + + +if __name__ == "__main__": + main() diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index e046534a6..9110f3298 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -1,6 +1,7 @@ import json import logging import os +import pathlib import tempfile from itertools import chain @@ -103,7 +104,8 @@ def _save_index( # create checkpoint directory if it doesn't exist if not os.path.exists(checkpoint_dir): - os.mkdir(checkpoint_dir) + pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) + # os.mkdir(checkpoint_dir) # write index file atomically to avoid partial/corrupted writes _atomic_write(json_str, filepath) @@ -120,7 +122,8 @@ def _save_params(submod: torch.nn.Module, checkpoint_dir: str) -> None: checkpoint_dir(`str`): where to keep the checkpoint binaries """ if not os.path.exists(checkpoint_dir): - os.mkdir(checkpoint_dir) + pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) + # os.mkdir(checkpoint_dir) filepath = os.path.join( checkpoint_dir, _get_binary_filename(dist.get_rank()) ) @@ -144,7 +147,8 @@ def _save_optim_state( checkpoint_dir(`str`): where to keep the checkpoint binaries """ if not os.path.exists(checkpoint_dir): - os.mkdir(checkpoint_dir) + pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) + # os.mkdir(checkpoint_dir) filepath = os.path.join( checkpoint_dir, _get_binary_filename(dist.get_rank(), is_optim=True) ) From b9fc2f804291e2955f3c407592a934be79552d4d Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Mon, 17 Jul 2023 16:48:56 -0700 Subject: [PATCH 33/96] setup model trainer (#837) ## Description Please read our [CONTRIBUTING.md](https://github.com/pytorch/PiPPy/blob/main/CONTRIBUTING.md) prior to creating your first pull request. Please include a summary of the feature or issue being fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. Fixes #(issue) ## Type of change Please delete options that are not relevant. - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] New feature (non-breaking change which adds functionality) - [ ] This change requires a documentation update ## Feature/Issue validation/testing Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced. Please also list any relevant details for your test configuration. - [ ] Test A Logs for Test A - [ ] Test B Logs for Test B ## Checklist: - [ ] Have you added tests that prove your fix is effective or that this feature works? - [ ] Has code been commented, particularly in hard-to-understand areas? - [ ] Have you made corresponding changes to the documentation? --------- Co-authored-by: Eddy Co-authored-by: Ke Wen Co-authored-by: Yeonju Ro --- examples/checkpoint/toy_model.py | 66 +++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/examples/checkpoint/toy_model.py b/examples/checkpoint/toy_model.py index b71fae11d..10a848ed7 100644 --- a/examples/checkpoint/toy_model.py +++ b/examples/checkpoint/toy_model.py @@ -3,13 +3,13 @@ import os import torch -from torch.utils.data import Dataset, random_split import torch.distributed as dist import torch.optim as optim +from pippy.compile import compile_stage from pippy.hf._SaveModule import save_checkpoint -from pippy.compile import compile_stage from pippy.IR import pipe_split +from torch.utils.data import Dataset, random_split d_hid = 512 @@ -19,9 +19,17 @@ class RandomCustomDataset(Dataset): + """ + Setup random inputs and outputs for a desired dataset size. + """ + def __init__(self, chunks=1, size=100): # TODO: reset size to 10000 - self.samples = [torch.randn(chunks * chunk_size, d_hid) for _ in range(size)] - self.targets = [torch.randn(chunks * chunk_size, d_hid) for _ in range(size)] + self.samples = [ + torch.randn(chunks * chunk_size, d_hid) for _ in range(size) + ] + self.targets = [ + torch.randn(chunks * chunk_size, d_hid) for _ in range(size) + ] def __len__(self): return len(self.samples) @@ -31,6 +39,13 @@ def __getitem__(self, idx): class ExampleCode(torch.nn.Module): + """ + A normal pytorch model(nn.Module) with a loss. + The loss is calculated in the `forward` function which lets pippy + automatically run a .backward(). Pippy handles this backward call + because of the nontrivial structure(FillDrain schedule) of the model pipeline. + """ + def __init__(self): super().__init__() self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) @@ -52,20 +67,24 @@ def forward(self, x, target): pipe_split() x = self.lin(x) x = torch.relu(x) - loss = self.mse_loss(x, target) + loss = self.mse_loss( + x, target + ) # loss called here in forward, triggers backward call return {"loss": loss} def run_worker(args): ec = ExampleCode() ec.to(args.device) - ec.train() + # ec.train() + # sample input and output for compile_stage func call ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + # setup data ds = RandomCustomDataset(chunks=args.chunks) - train_size = int(0.7*len(ds)) + train_size = int(0.7 * len(ds)) test_size = len(ds) - train_size train_ds, test_ds = random_split(ds, [train_size, test_size]) datasets = { @@ -89,31 +108,37 @@ def run_worker(args): for epoch in range(args.epochs): # change to no. epochs print(f"Epoch: {epoch + 1}") - # save checkpoint - if (epoch + 1) % args.checkpoint_epochs == 0: # save ckpt after every `args.checkpoint_epochs` epochs - print("RUuuuuuuuuuuunnnnnnnnnninngggggggggg") - save_checkpoint(stage, checkpoint_dir=os.path.join("checkpoints", str(epoch + 1)), optimizer=optimizer) - print("Doooooooooooonnnnnnnnnnnnnneeeeeeeeeee") + # save checkpoints + if (epoch + 1) % args.checkpoint_epochs == 0: + save_checkpoint( + stage, + checkpoint_dir=os.path.join("checkpoints", f"{epoch + 1}"), + optimizer=optimizer, + ) for k, dataset in datasets.items(): epoch_correct = 0 epoch_all = 0 + for i, (x, y) in enumerate(dataset): x = x.to(args.device) y = y.to(args.device) + if k == "train": # Zero gradients optimizer.zero_grad() + # Run if args.rank == 0: out = stage(x) elif args.rank == args.world_size - 1: out = stage(target) - out_tensor = out['loss'] + out_tensor = out["loss"] preds = out_tensor.argmax(-1) correct = (preds == y).sum() - epoch_all += len(y) epoch_correct += correct.item() + epoch_all += len(y) + # Take an optimization step optimizer.step() else: @@ -124,11 +149,12 @@ def run_worker(args): if args.rank == 0: out = stage(x) elif args.rank == args.world_size - 1: - out = stage(x)['loss'] - preds = out.argmax(-1) + out = stage(x) + out_tensor = out["loss"] + preds = out_tensor.argmax(-1) correct = (preds == y).sum() - epoch_all += len(y) epoch_correct += correct.item() + epoch_all += len(y) else: stage(x) # print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}") @@ -152,12 +178,16 @@ def main(args=None): parser.add_argument( "--cuda", type=int, default=int(torch.cuda.is_available()) ) - parser.add_argument("--epochs", type=int, default=2) parser.add_argument( "--chunks", type=int, default=4, ) + parser.add_argument( + "--epochs", + type=int, + default=2, + ) parser.add_argument("--checkpoint_epochs", type=int, default=1) args = parser.parse_args(args) From 98dc7704f047e10cff6b65789d4678900293fede Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sat, 22 Jul 2023 23:13:01 -0400 Subject: [PATCH 34/96] Add profiling example based on MLPs (#843) Adding an example for us to evaluate PiPPy's trace. The model consists of 4 MLP layers. And is divided into 4 stages. To run: $ torchrun --nproc-per-node 4 mlp_profiling.py --- examples/profiling/mlp_profiling.py | 144 ++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 examples/profiling/mlp_profiling.py diff --git a/examples/profiling/mlp_profiling.py b/examples/profiling/mlp_profiling.py new file mode 100644 index 000000000..bd1362add --- /dev/null +++ b/examples/profiling/mlp_profiling.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Run command: +# torchrun --nproc-per-node 4 mlp_profiling.py + +import argparse +import os + +import torch +import torch.distributed as dist +from torch.profiler import profile, ProfilerActivity + +from pippy.compile import compile_stage +from pippy.IR import pipe_split + + +d_hid = 1024 +chunk_size = 1024 + +torch.manual_seed(0) + + +class MLPModule(torch.nn.Module): + def __init__(self, d_hid): + super(MLPModule, self).__init__() + self.net1 = torch.nn.Linear(d_hid, d_hid) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + x = self.net1(x) + x = self.relu(x) + x = self.net2(x) + return x + + +class ExampleCode(torch.nn.Module): + def __init__(self, d_hid): + super().__init__() + self.mlp0 = MLPModule(d_hid) + self.mlp1 = MLPModule(d_hid) + self.mlp2 = MLPModule(d_hid) + self.mlp3 = MLPModule(d_hid) + self.mse_loss = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, target): + x = self.mlp0(x) + pipe_split() + x = self.mlp1(x) + pipe_split() + x = self.mlp2(x) + pipe_split() + x = self.mlp3(x) + loss = self.mse_loss(x, target) + return {"logits": x, "loss": loss} + + +def run_worker(args): + ec = ExampleCode(d_hid) + ec.to(args.device) + ec.train() + + ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + + stage = compile_stage( + ec, + args.rank, + args.world_size, + args.chunks, + args.device, + None, + [ec_x, target], + ) + + # Run + for _ in range(10): + if args.rank == 0: + out = stage(ec_x) + elif args.rank == args.world_size - 1: + out = stage(target) + else: + stage() + + dist.barrier() + print(f"Rank {args.rank} warmup completes") + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + for _ in range(20): + if args.rank == 0: + out = stage(ec_x) + elif args.rank == args.world_size - 1: + out = stage(target) + else: + stage() + + print(f"Rank {args.rank} profiling run completed") + prof.export_chrome_trace( + f"{os.path.splitext(os.path.basename(__file__))[0]}_{args.rank}.json" + ) + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) + + +if __name__ == "__main__": + main() From 751b19df888d4b1f1adff256e6ddf158516b5e1c Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Mon, 24 Jul 2023 10:51:53 -0700 Subject: [PATCH 35/96] mnist example update to new compile stage API (#831) ## Description Update the mnist example to use the new compile_stage api ## Checklist: - [x] Has code been commented, particularly in hard-to-understand areas? --------- Co-authored-by: Eddy Co-authored-by: Eddy Ogola Onyango --- examples/mnist/pippy_mnist.py | 289 ++++++++++++++++++++-------------- 1 file changed, 172 insertions(+), 117 deletions(-) diff --git a/examples/mnist/pippy_mnist.py b/examples/mnist/pippy_mnist.py index f1729190e..4a816f02f 100644 --- a/examples/mnist/pippy_mnist.py +++ b/examples/mnist/pippy_mnist.py @@ -1,70 +1,93 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import argparse import os -from functools import reduce + +import pippy +import pippy.fx import torch +import torch.distributed as dist + +from pippy.hf._SaveModule import save_checkpoint +from pippy.IR import LossWrapper, PipeSplitWrapper +from pippy.microbatch import sum_reducer, TensorChunkSpec + from torch import nn, optim from torch.nn.functional import cross_entropy from torch.utils.data import DistributedSampler from torchvision import datasets, transforms # type: ignore from tqdm import tqdm # type: ignore -import pippy.fx -from pippy import run_pippy -from pippy.IR import MultiUseParameterConfig, Pipe, PipeSplitWrapper, LossWrapper -from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ - PipelineDriverBase -from pippy.events import EventsContext -from pippy.microbatch import sum_reducer, TensorChunkSpec -from pippy.visualizer import events_to_json - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - 'FillDrain': PipelineDriverFillDrain, - '1F1B': PipelineDriver1F1B, - 'Interleaved1F1B': PipelineDriverInterleaved1F1B, -} - -LR_VERBOSE = 0 pippy.fx.Tracer.proxy_buffer_attributes = True -USE_TQDM = bool(int(os.getenv('USE_TQDM', '1'))) - - -def resolve_pg_per_stage(pp_rank): - assert pippy.utils.dp_pg_per_pp_rank - return pippy.utils.dp_pg_per_pp_rank[pp_rank + 1] # exclude master - - -def run_master(pp_ranks, args): +USE_TQDM = bool(int(os.getenv("USE_TQDM", "1"))) + + +# Get process group for ranks in a pipeline +def get_pp_subgroup(args): + my_pp_rank = args.rank // args.dp_group_size + my_dp_rank = args.rank % args.dp_group_size + for dp_rank in range(0, args.dp_group_size): + pp_group_ranks = list( + range(dp_rank, args.world_size, args.dp_group_size) + ) + pp_group = dist.new_group(ranks=pp_group_ranks) + if dp_rank == my_dp_rank: + my_pp_group = pp_group + print(f"Rank {args.rank} done getting pp group") + return my_pp_group, my_pp_rank + + +# Get DP process group for ranks with the same stage +def get_dp_subgroup(args): + my_pp_rank = args.rank // args.dp_group_size + my_dp_rank = args.rank % args.dp_group_size + for pp_rank in range(0, args.pp_group_size): + dp_group_ranks = list( + range( + pp_rank * args.dp_group_size, (pp_rank + 1) * args.dp_group_size + ) + ) + dp_group = dist.new_group(ranks=dp_group_ranks) + if pp_rank == my_pp_rank: + my_dp_group = dp_group + print(f"Rank {args.rank} done getting dp group") + return my_dp_group, my_dp_rank + + +def run_worker(args): torch.manual_seed(42) - MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if args.replicate else MultiUseParameterConfig.TRANSMIT - print(f'REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}') - print("Using schedule:", args.schedule) - print("Using device:", args.device) - number_of_workers = 3 - all_worker_ranks = pp_ranks[1:1 + number_of_workers] # exclude master - chunks = len(all_worker_ranks) - batch_size = args.batch_size * chunks + # Get DP and PP sub process groups + dp_group, dp_rank = get_dp_subgroup(args) + pp_group, pp_rank = get_pp_subgroup(args) - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) + batch_size = args.batch_size * args.chunks - train_data = datasets.MNIST('./data', train=True, download=True, transform=transform) - valid_data = datasets.MNIST('./data', train=False, transform=transform) + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) - train_sampler = DistributedSampler(train_data, num_replicas=args.dp_group_size, rank=args.rank, shuffle=False, - drop_last=False) + train_data = datasets.MNIST( + "./data", train=True, download=True, transform=transform + ) + valid_data = datasets.MNIST("./data", train=False, transform=transform) + + train_sampler = DistributedSampler( + train_data, + num_replicas=args.dp_group_size, + rank=dp_rank, + shuffle=False, + drop_last=False, + ) - train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler) - valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_data, batch_size=batch_size, sampler=train_sampler + ) + valid_dataloader = torch.utils.data.DataLoader( + valid_data, batch_size=batch_size + ) class OutputLossWrapper(LossWrapper): def __init__(self, module, loss_fn): @@ -78,110 +101,142 @@ def forward(self, input, target): nn.Flatten(), nn.Linear(28 * 28, 128), nn.ReLU(), - PipeSplitWrapper(nn.Sequential( - nn.Linear(128, 64), - nn.ReLU(), - )), - PipeSplitWrapper(nn.Linear(64, 10)) + PipeSplitWrapper( + nn.Sequential( + nn.Linear(128, 64), + nn.ReLU(), + ) + ), + PipeSplitWrapper(nn.Linear(64, 10)), ) wrapper = OutputLossWrapper(model, cross_entropy) - pipe = Pipe.from_tracing(wrapper, MULTI_USE_PARAM_CONFIG, output_loss_value_spec=(False, True)) - pipe.to(args.device) + wrapper.to(args.device) output_chunk_spec = (TensorChunkSpec(0), sum_reducer) - pipe_driver: PipelineDriverBase = schedules[args.schedule](pipe, chunks, - world_size=len(all_worker_ranks), - all_ranks=all_worker_ranks, - output_chunk_spec=output_chunk_spec, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint)) - pipe_driver.init_data_parallel(dp_group_size=args.dp_group_size, dp_pg_cb=resolve_pg_per_stage) - - optimizer = pipe_driver.instantiate_optimizer(optim.Adam, lr=1e-3, betas=(0.9, 0.999), eps=1e-8) - lr_sched = pipe_driver.instantiate_lr_scheduler(optim.lr_scheduler.LinearLR, verbose=LR_VERBOSE) + # sample input + x = torch.randint(0, 5, (batch_size, 28, 28), device=args.device) + target = torch.randint(0, 9, (batch_size,), device=args.device) + + stage = pippy.compile_stage( + wrapper, + pp_rank, + args.pp_group_size, + args.chunks, + args.device, + pp_group, + [x, target], + output_chunk_spec=output_chunk_spec, + ) - loaders = { - "train": train_dataloader, - "valid": valid_dataloader - } + optimizer = optim.Adam( + stage.submod.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8 + ) + # TODO: add back LR scheduler + # lr_sched = pipe_driver.instantiate_lr_scheduler(optim.lr_scheduler.LinearLR, verbose=LR_VERBOSE) - this_file_name = os.path.splitext(os.path.basename(__file__))[0] - pipe_visualized_filename = f"{this_file_name}_visualized_{args.rank}.json" - batches_events_contexts = [] + loaders = {"train": train_dataloader, "valid": valid_dataloader} for epoch in range(args.max_epochs): print(f"Epoch: {epoch + 1}") for k, dataloader in loaders.items(): epoch_correct = 0 epoch_all = 0 - for i, (x_batch, y_batch) in enumerate(tqdm(dataloader) if USE_TQDM else dataloader): + for i, (x_batch, y_batch) in enumerate( + tqdm(dataloader) if USE_TQDM else dataloader + ): x_batch = x_batch.to(args.device) y_batch = y_batch.to(args.device) if k == "train": - pipe_driver.train() + outp = None optimizer.zero_grad() - outp, _ = pipe_driver(x_batch, y_batch) - preds = outp.argmax(-1) - correct = (preds == y_batch).sum() - all = len(y_batch) - epoch_correct += correct.item() - epoch_all += all + if pp_rank == 0: + stage(x_batch) + elif pp_rank == args.pp_group_size - 1: + outp, _ = stage(y_batch) + else: + stage() optimizer.step() - else: - pipe_driver.eval() - with torch.no_grad(): - outp, _ = pipe_driver(x_batch, y_batch) + + if outp is not None: preds = outp.argmax(-1) correct = (preds == y_batch).sum() all = len(y_batch) epoch_correct += correct.item() epoch_all += all - if args.visualize: - batches_events_contexts.append(pipe_driver.retrieve_events()) - print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}") + # save checkpoint - after training epoch + if (epoch + 1) % args.checkpoint_epochs == 0: + save_checkpoint( + stage, + checkpoint_dir=os.path.join( + "checkpoints", str(epoch + 1) + ), + optimizer=optimizer, + ) + else: + # TODO: add evaluation support in PiPPy + pass + + if pp_rank == args.pp_group_size - 1 and epoch_all > 0: + print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}") - if k == "train": - lr_sched.step() - if LR_VERBOSE: - print(f"Pipe {pp_ranks} last_lr: {lr_sched.get_last_lr()}") - print(f"Pipe {pp_ranks} state_dict: {lr_sched.state_dict()}") + # if k == "train": + # lr_sched.step() - if args.visualize: - all_events_contexts: EventsContext = reduce(lambda c1, c2: EventsContext().update(c1).update(c2), - batches_events_contexts, EventsContext()) - with open(pipe_visualized_filename, "w") as f: - f.write(events_to_json(all_events_contexts)) - print(f"Saved {pipe_visualized_filename}") - print('Finished') + print("Finished") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) - parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) - parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) - - parser.add_argument('--max_epochs', type=int, default=10) - parser.add_argument('--batch_size', type=int, default=10) - - parser.add_argument('-s', '--schedule', type=str, default=list(schedules.keys())[0], choices=schedules.keys()) - parser.add_argument('--replicate', type=int, default=int(os.getenv("REPLICATE", '0'))) - parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) - parser.add_argument('--visualize', type=int, default=0, choices=[0, 1]) - parser.add_argument('--record_mem_dumps', type=int, default=0, choices=[0, 1]) - parser.add_argument('--checkpoint', type=int, default=0, choices=[0, 1]) - parser.add_argument('--exclude_master', type=int, default=0, choices=[0, 1]) - args = parser.parse_args() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 3)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) - args.pp_group_size = 4 + parser.add_argument("--max_epochs", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=10) - assert args.world_size % args.pp_group_size == 0 + parser.add_argument( + "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument("--visualize", type=int, default=0, choices=[0, 1]) + parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) + parser.add_argument("--checkpoint_epochs", type=int, default=5) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args() + args.pp_group_size = 3 + assert args.world_size % args.pp_group_size == 0 args.dp_group_size = args.world_size // args.pp_group_size - run_pippy(run_master, args) + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) From cecc4fc4b015843076b688560c354e14eac2e7c1 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Thu, 27 Jul 2023 11:14:34 -0700 Subject: [PATCH 36/96] Create ckpt dir only once in _SaveModule (#845) ## Description I was initially checking whether the ckpt dir existed, and writing it, when writing the mod state_dict, index file, and optim state dicts. Here, I do that only once. ## Feature/Issue validation/testing The test/local_test_checkpoint.py test ensures things are running smoothly. ## Checklist: - [x] Has code been commented, particularly in hard-to-understand areas? --------- Co-authored-by: Eddy Ogola Onyango --- pippy/hf/_SaveModule.py | 15 ++++----------- test/local_test_checkpoint.py | 21 ++++++++++++--------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/pippy/hf/_SaveModule.py b/pippy/hf/_SaveModule.py index 9110f3298..966cc0b4a 100644 --- a/pippy/hf/_SaveModule.py +++ b/pippy/hf/_SaveModule.py @@ -102,11 +102,6 @@ def _save_index( filepath = os.path.join(checkpoint_dir, ckpt_index_filename) - # create checkpoint directory if it doesn't exist - if not os.path.exists(checkpoint_dir): - pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) - # os.mkdir(checkpoint_dir) - # write index file atomically to avoid partial/corrupted writes _atomic_write(json_str, filepath) @@ -121,9 +116,6 @@ def _save_params(submod: torch.nn.Module, checkpoint_dir: str) -> None: submod(`Pipe`): a submodule of the model's graph checkpoint_dir(`str`): where to keep the checkpoint binaries """ - if not os.path.exists(checkpoint_dir): - pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) - # os.mkdir(checkpoint_dir) filepath = os.path.join( checkpoint_dir, _get_binary_filename(dist.get_rank()) ) @@ -146,9 +138,6 @@ def _save_optim_state( optimizer(`torch.optim.Optimizer`): pytorch optimizer checkpoint_dir(`str`): where to keep the checkpoint binaries """ - if not os.path.exists(checkpoint_dir): - pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) - # os.mkdir(checkpoint_dir) filepath = os.path.join( checkpoint_dir, _get_binary_filename(dist.get_rank(), is_optim=True) ) @@ -171,6 +160,10 @@ def save_checkpoint( defaults to `checkpoints` optimizer(`torch.optim.Optimizer`): optimizer whose state dict is to be saved """ + # create checkpoint directory if it doesn't exist + if not os.path.exists(checkpoint_dir): + pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) + # write index file in rank 0 if dist.get_rank() == 0: _save_index(stage, checkpoint_dir=checkpoint_dir) diff --git a/test/local_test_checkpoint.py b/test/local_test_checkpoint.py index dbfb7a98a..ca173c2ae 100644 --- a/test/local_test_checkpoint.py +++ b/test/local_test_checkpoint.py @@ -98,6 +98,7 @@ def run_worker(args: List[str | int]) -> None: optimizer.step() ref_state_dict = deepcopy(stage.submod.state_dict()) ref_optim_state_dict = deepcopy(optimizer.state_dict()) + save_checkpoint(stage, CKPT_DIR, optimizer) # save index file in rank 0 @@ -135,15 +136,17 @@ def run_worker(args: List[str | int]) -> None: optimizer.step() # new api - mod, optimizer = load_checkpoint( - stage.submod, - os.path.join(CKPT_DIR, "pytorch_model.bin.index.json"), - args.device, - optim=optimizer, - ) - - torch.testing.assert_close(mod.state_dict(), submod_ref) - torch.testing.assert_close(optimizer.state_dict(), optim_ref) + # after index file has been written, load_checkpoint will read it + if os.path.exists(os.path.join(CKPT_DIR, DEFAULT_FILENAME)): + mod, optimizer = load_checkpoint( + stage.submod, + os.path.join(CKPT_DIR, DEFAULT_FILENAME), + optim=optimizer, + device=args.device, + ) + + torch.testing.assert_close(mod.state_dict(), ref_state_dict) + torch.testing.assert_close(optimizer.state_dict(), ref_optim_state_dict) dist.barrier() print(f"Rank {args.rank} completes") From 8489afd5608ba703592f3bb763aa8c0a1dcfba93 Mon Sep 17 00:00:00 2001 From: Yeonju Ro Date: Sun, 30 Jul 2023 18:03:54 -1000 Subject: [PATCH 37/96] profiler for example selective-tp (#847) ## Type of change profiler added ## Checklist: - [v] Has code been commented, particularly in hard-to-understand areas? --- examples/selective2d/2d_train.py | 52 ++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/examples/selective2d/2d_train.py b/examples/selective2d/2d_train.py index 175a54c85..4a416c7d6 100644 --- a/examples/selective2d/2d_train.py +++ b/examples/selective2d/2d_train.py @@ -25,6 +25,8 @@ RowwiseParallel, ) +from torch.profiler import profile, ProfilerActivity + def get_args(): # default config values designed to train a gpt2 (124M) on OpenWebText @@ -125,6 +127,7 @@ def str_to_bool(v): ) parser.add_argument("--tp_size", type=int, default=2) parser.add_argument("--pp_size", type=int, default=2) + parser.add_argument("--n_chunks", type=int, default=2) parser.add_argument("--debug", dest="debug", action="store_true") @@ -223,7 +226,7 @@ def pp(model, pp_device_mesh, args): def pp_and_tp(model, mesh, args): """ Apply TP and PP to all layers in a model - This function assumes the model is already cut manually + This function assumes the model is already cut manually """ pp_dim, tp_dim = 0, 1 pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size @@ -239,7 +242,7 @@ def pp_and_tp(model, mesh, args): model, pp_rank, args.world_size, - args.pp_size, + args.n_chunks, args.device, pp_groups, example_inputs=[X, Y], @@ -308,7 +311,7 @@ def pp_and_tp_selective( model, pp_rank, args.world_size, - args.pp_size, + args.n_chunks, args.device, pp_groups, example_inputs=[X, Y], @@ -328,21 +331,30 @@ def pp_tp_train(stage, mesh, args): ) local_iter_num = 0 iter_time = 0.0 - while local_iter_num < train_iters: - optimizer.zero_grad() - t0 = time.perf_counter() - X, Y = get_rand(args) - if pp_rank == 0: - out = stage(X) - elif pp_rank == args.pp_size - 1: - out = stage(Y) - else: - out = stage() - optimizer.step() - t1 = time.perf_counter() - dt = t1 - t0 - local_iter_num += 1 - iter_time += dt + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + skip_first=5, wait=0, warmup=4, active=1, repeat=1 + ), + ) as prof: + while local_iter_num < train_iters: + optimizer.zero_grad() + t0 = time.perf_counter() + X, Y = get_rand(args) + if pp_rank == 0: + out = stage(X) + elif pp_rank == args.pp_size - 1: + out = stage(Y) + else: + out = stage() + optimizer.step() + t1 = time.perf_counter() + dt = t1 - t0 + local_iter_num += 1 + iter_time += dt + prof.step() + + prof.export_chrome_trace(f"trace_rank{args.rank}.json") return local_iter_num, iter_time @@ -457,7 +469,9 @@ def tp_train(): # model = tp(model, args.n_layer, oned_mesh) # model, stage = pp(model, oned_mesh, args) # model, stage = pp_and_tp(model, twod_mesh, args) - model, stage = pp_and_tp_selective(model, twod_mesh, args) + model, stage = pp_and_tp_selective( + model, twod_mesh, args, cut_fn=after_ar_cut + ) # iter_count, iter_time = pp_train(stage, args) iter_count, iter_time = pp_tp_train(stage, twod_mesh, args) From a1ee78d2521c1679e439cc2e8388765c87915182 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 31 Jul 2023 11:19:25 -0400 Subject: [PATCH 38/96] Add 1F1B schedule (#844) ## Description c10d version 1F1B schedule. We use `nstages` chunks to warm up the pipeline, then we enter a 1F1B stable phase, and lastly use `nstages` to cool down the pipeline. Ref: https://arxiv.org/pdf/2104.04473.pdf. ## API ``` stage = compile_stage( model, ... , schedule="1F1B", ) ``` ## Implementation Details To avoid writing duplicated code, we modularize the original code in FillDrain implementation into `forward_one_chunk` and `backward_one_chunk`. Then the two different schedules can share these modular code and just focus on the schedule itself. ## Test Added "schedule" option to test: ``` $ torchrun --nproc-per-node 4 local_test_c10d_bwd.py --schedule=1F1B --chunks=16 ``` --- pippy/PipelineStage.py | 291 +++++++++++++++++++++++++----------- pippy/compile.py | 40 +++-- test/local_test_c10d_bwd.py | 12 ++ 3 files changed, 243 insertions(+), 100 deletions(-) diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index cb3a311a1..f7c525d56 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -485,37 +485,110 @@ def _send_grads( return grad_send_reqs - def forward(self, *args, **kwargs): + def forward_maybe_with_nosync(self, *args, **kwargs): # If submod is wrapped with DDP, we use the `no_sync` context manager to # avoid gradient all-reduce per microbatch - def forward_maybe_with_nosync(*args, **kwargs): - if isinstance(self.submod, DistributedDataParallel): - with self.submod.no_sync(): # type: ignore[operator] - out_val = self.submod(*args, **kwargs) - else: + if isinstance(self.submod, DistributedDataParallel): + with self.submod.no_sync(): # type: ignore[operator] out_val = self.submod(*args, **kwargs) - return out_val - - def backward_maybe_with_nosync(bwd_kwargs: Dict, is_last_chunk: bool): - if isinstance(self.submod, DistributedDataParallel): - if is_last_chunk: - # HACK: reaching into DDP implementation details here. Is there a better way? - self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] - list( - torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] - bwd_kwargs["stage_output"] - ) + else: + out_val = self.submod(*args, **kwargs) + return out_val + + def backward_maybe_with_nosync(self, bwd_kwargs: Dict, is_last_chunk: bool): + if isinstance(self.submod, DistributedDataParallel): + if is_last_chunk: + # HACK: reaching into DDP implementation details here. Is there a better way? + self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] + list( + torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] + bwd_kwargs["stage_output"] ) ) - grads_input, _ = stage_backward(**bwd_kwargs) - else: - with self.submod.no_sync(): # type: ignore[operator] - grads_input, _ = stage_backward(**bwd_kwargs) - else: - # Non-DDP submodule, regular backward + ) grads_input, _ = stage_backward(**bwd_kwargs) - return grads_input + else: + with self.submod.no_sync(): # type: ignore[operator] + grads_input, _ = stage_backward(**bwd_kwargs) + else: + # Non-DDP submodule, regular backward + grads_input, _ = stage_backward(**bwd_kwargs) + return grads_input + def forward_one_chunk( + self, + chunk: int, + args_split, + kwargs_split, + fwd_cache: Dict[int, Any], + ): + composite_args, composite_kwargs = self._recv_and_fill_inputs( + chunk, + args_split, + kwargs_split, + ) + + # Compute forward + try: + output = self.forward_maybe_with_nosync( + *composite_args, **composite_kwargs + ) + + except Exception as e: + exc_msg = f""" + Rank {self.rank} failed to run forward stage: {self.name} + args: {map_debug_info(composite_args)} + kwargs: {map_debug_info(composite_kwargs)} + """ + raise RuntimeError(exc_msg) from e + + # Unify output form to tuple for easy correspondance with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + send_reqs = self._send_activations(output_tuple) + + # Save activations and inputs for backward + flat_args = flatten_args(composite_args) + flat_kwargs = flatten_args(composite_kwargs) + flatten_input_tensors = flat_args + flat_kwargs + fwd_cache[chunk] = ( + output_tuple, # stage_output + flatten_input_tensors, # input_values + ) + + return output, send_reqs + + def backward_one_chunk( + self, + bwd_chunk: int, + fwd_cache: Dict[int, Any], + ): + grads = self._recv_grads(bwd_chunk) + + # Pack args for `stage_backward`` + bwd_kwargs = dict(self.bwd_node.kwargs) + ( + bwd_kwargs["stage_output"], + bwd_kwargs["input_values"], + ) = fwd_cache.pop(bwd_chunk) + # Fill actual gradients received for outputs + # If nothing received, as in the case of last stage, then we + # would use the default `output_grads` prepared in the IR phase, + # i.e. from `bwd_node.kwargs`. For example, it may look like + # this if there are two outputs: ('None', 'None') + if len(grads) > 0: + bwd_kwargs["output_grads"] = grads + + # `stage_backward` node does not have `args`, only `kwargs` + grads_input = self.backward_maybe_with_nosync( + bwd_kwargs, + bwd_chunk == self.chunks - 1, + ) + + grad_send_reqs = self._send_grads(grads_input) + return grad_send_reqs + + def forward(self, *args, **kwargs): # map microbatch ID to list of forward tensor args fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} @@ -535,86 +608,130 @@ def backward_maybe_with_nosync(bwd_kwargs: Dict, is_last_chunk: bool): output_chunks = [] + # Forward pass of all chunks for chunk in range(self.chunks): - composite_args, composite_kwargs = self._recv_and_fill_inputs( - chunk, - args_split, - kwargs_split, + output, send_reqs = self.forward_one_chunk( + chunk, args_split, kwargs_split, fwd_cache ) + all_send_reqs += send_reqs + # Prepare for final output merge or reduction + output_chunks.append(output) - # Compute forward - try: - output = forward_maybe_with_nosync( - *composite_args, **composite_kwargs - ) + # Wait for all sends to finish + # TODO: okay to delay the sync till completion of all chunks? + for work in all_send_reqs: + work.wait() + + # Backward starts here + # Grad send requests of all chunk + all_grad_send_reqs: List[dist.Work] = [] + + for bwd_chunk in range(self.chunks): + if self.pipe.has_loss_and_backwards: + grad_send_reqs = self.backward_one_chunk(bwd_chunk, fwd_cache) + all_grad_send_reqs += grad_send_reqs + + # Wait for all sends to finish + # TODO: okay to delay the sync till completion of all chunks? + for work in all_grad_send_reqs: + work.wait() + + # Last rank return merged results per original format + if self.rank == self.nstages - 1: + return merge_chunks( + output_chunks, + self.output_chunk_spec, + ) + else: + return None - except Exception as e: - exc_msg = f""" - Rank {self.rank} failed to run forward stage: {self.name} - args: {map_debug_info(composite_args)} - kwargs: {map_debug_info(composite_kwargs)} - """ - raise RuntimeError(exc_msg) from e - # Unify output form to tuple for easy correspondance with - # `act_send_info` - output_tuple = output if type(output) is tuple else (output,) +class PipelineStage1F1B(PipelineStage): + def __init__( + self, + pipe: Pipe, + rank: int, + nstages: int, + chunks: int, + device: torch.device, + group: dist.ProcessGroup = None, + args_chunk_spec=None, + kwargs_chunk_spec=None, + output_chunk_spec=None, + ): + super().__init__( + pipe, + rank, + nstages, + chunks, + device, + group=group, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_chunk_spec=output_chunk_spec, + ) + + def forward(self, *args, **kwargs): + # map microbatch ID to list of forward tensor args + fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} - send_reqs = self._send_activations(output_tuple) + args_split = None + kwargs_split = None + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self.chunks, + self.args_chunk_spec, + self.kwargs_chunk_spec, + ) + + # Activation send requests of all chunk + all_send_reqs: List[dist.Work] = [] + # Grad send requests of all chunk + all_grad_send_reqs: List[dist.Work] = [] + # Caching chunk outputs for final output merge or reduction + output_chunks = [] + + warmup_chunks = cooldown_chunks = self.nstages + + # Warm-up phase: forward number of chunks equal to pipeline depth. + for chunk in range(warmup_chunks): + output, send_reqs = self.forward_one_chunk( + chunk, args_split, kwargs_split, fwd_cache + ) all_send_reqs += send_reqs + output_chunks.append(output) + + # 1F1B phase + for bwd_chunk in range(0, self.chunks - cooldown_chunks): + # Schedule backward for one warmed up chunk + if self.pipe.has_loss_and_backwards: + grad_send_reqs = self.backward_one_chunk(bwd_chunk, fwd_cache) + all_grad_send_reqs += grad_send_reqs + # Schedule forward for one new chunk + fwd_chunk = bwd_chunk + warmup_chunks + output, send_reqs = self.forward_one_chunk( + fwd_chunk, args_split, kwargs_split, fwd_cache + ) + all_send_reqs += send_reqs # Prepare for final output merge or reduction output_chunks.append(output) - # Save activations and inputs for backward - flat_args = flatten_args(composite_args) - flat_kwargs = flatten_args(composite_kwargs) - flatten_input_tensors = flat_args + flat_kwargs - fwd_cache[chunk] = ( - output_tuple, # stage_output - flatten_input_tensors, # input_values - ) + # Cool-down phase: backward for the rest of the chunks + for bwd_chunk in range(self.chunks - cooldown_chunks, self.chunks): + if self.pipe.has_loss_and_backwards: + grad_send_reqs = self.backward_one_chunk(bwd_chunk, fwd_cache) + all_grad_send_reqs += grad_send_reqs # Wait for all sends to finish # TODO: okay to delay the sync till completion of all chunks? for work in all_send_reqs: work.wait() - if self.pipe.has_loss_and_backwards: - # Backward starts here - # Grad send requests of all chunk - all_grad_send_reqs: List[dist.Work] = [] - - for bwd_chunk in range(self.chunks): - grads = self._recv_grads(bwd_chunk) - - # Pack args for `stage_backward`` - bwd_kwargs = dict(self.bwd_node.kwargs) - ( - bwd_kwargs["stage_output"], - bwd_kwargs["input_values"], - ) = fwd_cache.pop(bwd_chunk) - # Fill actual gradients received for outputs - # If nothing received, as in the case of last stage, then we - # would use the default `output_grads` prepared in the IR phase, - # i.e. from `bwd_node.kwargs`. For example, it may look like - # this if there are two outputs: ('None', 'None') - if len(grads) > 0: - bwd_kwargs["output_grads"] = grads - - # `stage_backward` node does not have `args`, only `kwargs` - grads_input = backward_maybe_with_nosync( - bwd_kwargs, - bwd_chunk == self.chunks - 1, - ) - - grad_send_reqs = self._send_grads(grads_input) - all_grad_send_reqs += grad_send_reqs - - # Wait for all sends to finish - # TODO: okay to delay the sync till completion of all chunks? - for work in all_grad_send_reqs: - work.wait() + for work in all_grad_send_reqs: + work.wait() # Last rank return merged results per original format if self.rank == self.nstages - 1: diff --git a/pippy/compile.py b/pippy/compile.py index 3c3564037..df3bd0d3f 100644 --- a/pippy/compile.py +++ b/pippy/compile.py @@ -21,7 +21,7 @@ PipelineDriverFillDrain, PipelineDriverInterleaved1F1B, ) -from pippy.PipelineStage import PipelineStage +from pippy.PipelineStage import PipelineStage, PipelineStage1F1B from pippy.utils import get_device, get_pp_rank, get_rank @@ -231,6 +231,7 @@ def compile_stage( args_chunk_spec=None, kwargs_chunk_spec=None, output_chunk_spec=None, + schedule="FillDrain", **kwargs, ) -> PipelineStage: # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across @@ -293,15 +294,28 @@ def compile_stage( else gen_output_chunk_spec(pipe.loss_spec, loss_reducer) ) - # Create pipeline stage - return PipelineStage( - pipe, - rank, - num_ranks, - num_chunks, - device, - group, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, - ) + # Create pipeline stage based on schedule + if schedule == "1F1B": + return PipelineStage1F1B( + pipe, + rank, + num_ranks, + num_chunks, + device, + group, + args_chunk_spec, + kwargs_chunk_spec, + output_chunk_spec, + ) + else: + return PipelineStage( + pipe, + rank, + num_ranks, + num_chunks, + device, + group, + args_chunk_spec, + kwargs_chunk_spec, + output_chunk_spec, + ) diff --git a/test/local_test_c10d_bwd.py b/test/local_test_c10d_bwd.py index 8bfd2a05c..21db17240 100644 --- a/test/local_test_c10d_bwd.py +++ b/test/local_test_c10d_bwd.py @@ -10,6 +10,11 @@ from pippy.IR import pipe_split +schedules = [ + "FillDrain", + "1F1B", +] + d_hid = 512 chunk_size = 256 @@ -58,6 +63,7 @@ def run_worker(args): args.device, None, [ec_x, target], + schedule=args.schedule, ) # Run @@ -100,6 +106,12 @@ def main(args=None): type=int, default=4, ) + parser.add_argument( + "--schedule", + type=str, + default="FillDrain", + choices=schedules, + ) args = parser.parse_args(args) if args.cuda: From 95c14198761513bde583ce583c3e97f2c96c3e9f Mon Sep 17 00:00:00 2001 From: Yeonju Ro Date: Mon, 31 Jul 2023 09:33:55 -1000 Subject: [PATCH 39/96] adding dimension solver submodule (#848) ## Description add dimension solver submodule for selective2d example and training script that uses dim solver ## Checklist: - [v] Have you added tests that prove your fix is effective or that this feature works? - [v] Has code been commented, particularly in hard-to-understand areas? - [v] Have you made corresponding changes to the documentation? --- .gitmodules | 3 +++ examples/selective2d/2d_train.py | 1 + examples/selective2d/dim_solver | 1 + examples/selective2d/run.sh | 20 ++++++++++++++++++++ 4 files changed, 25 insertions(+) create mode 160000 examples/selective2d/dim_solver create mode 100644 examples/selective2d/run.sh diff --git a/.gitmodules b/.gitmodules index 9778ca341..3161f1f29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "docs/src/pytorch-sphinx-theme"] path = docs/src/pytorch-sphinx-theme url = https://github.com/pytorch/pytorch_sphinx_theme.git +[submodule "examples/selective2d/dim_solver"] + path = examples/selective2d/dim_solver + url = https://github.com/moonbucks/dim_solver.git diff --git a/examples/selective2d/2d_train.py b/examples/selective2d/2d_train.py index 4a416c7d6..2540e1228 100644 --- a/examples/selective2d/2d_train.py +++ b/examples/selective2d/2d_train.py @@ -127,6 +127,7 @@ def str_to_bool(v): ) parser.add_argument("--tp_size", type=int, default=2) parser.add_argument("--pp_size", type=int, default=2) + parser.add_argument("--i_stage", type=int, default=1) parser.add_argument("--n_chunks", type=int, default=2) parser.add_argument("--debug", dest="debug", action="store_true") diff --git a/examples/selective2d/dim_solver b/examples/selective2d/dim_solver new file mode 160000 index 000000000..342e002e5 --- /dev/null +++ b/examples/selective2d/dim_solver @@ -0,0 +1 @@ +Subproject commit 342e002e5e9c6d0e593ed0738e7dd7bb8da51a6e diff --git a/examples/selective2d/run.sh b/examples/selective2d/run.sh new file mode 100644 index 000000000..0f713af2f --- /dev/null +++ b/examples/selective2d/run.sh @@ -0,0 +1,20 @@ +#!/bin/bash +NUM_GPUS=8 + +# 1. run dimension solver +python3 ./dim_solver/main.py --num-gpu $NUM_GPUS --opt-ar1 + +# 2. process output file +# output file name: solver.out +# output data: pp_group_size tp_group_size size_microbatch i_stage n_chunks + +pp_group_size=$(cat ./solver.out | awk '{print $1}') +tp_group_size=$(cat ./solver.out | awk '{print $2}') +size_microbatch=$(cat ./solver.out | awk '{print $3}') +i_stage=$(cat ./solver.out | awk '{print $4}') +n_chunks=$(cat ./solver.out | awk '{print $5}') + +batch_size=$((size_microbatch * n_chunks)) + +# 3. run training with optimal configuration +torchrun --nproc-per-node=$NUM_GPUS 2d_train.py --batch_size $size_microbatch --n_chunks $n_chunks --i_stage $i_stage From b7f60b831c74d53149350ddaf53337ea2e485831 Mon Sep 17 00:00:00 2001 From: eddogola <64967909+eddogola@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:33:01 -0700 Subject: [PATCH 40/96] Make SaveModule.py public (#850) ## Description Make SaveModule.py module public and place it at the same level as LoadModule.py ## Checklist: - [x] Have you added tests that prove your fix is effective or that this feature works? - [x] Has code been commented, particularly in hard-to-understand areas? --- examples/checkpoint/toy_model.py | 2 +- examples/mnist/pippy_mnist.py | 2 +- pippy/{hf/_SaveModule.py => SaveModule.py} | 0 test/local_test_checkpoint.py | 3 ++- 4 files changed, 4 insertions(+), 3 deletions(-) rename pippy/{hf/_SaveModule.py => SaveModule.py} (100%) diff --git a/examples/checkpoint/toy_model.py b/examples/checkpoint/toy_model.py index 10a848ed7..b08ca31fc 100644 --- a/examples/checkpoint/toy_model.py +++ b/examples/checkpoint/toy_model.py @@ -7,7 +7,7 @@ import torch.optim as optim from pippy.compile import compile_stage -from pippy.hf._SaveModule import save_checkpoint +from pippy.SaveModule import save_checkpoint from pippy.IR import pipe_split from torch.utils.data import Dataset, random_split diff --git a/examples/mnist/pippy_mnist.py b/examples/mnist/pippy_mnist.py index 4a816f02f..852c30c78 100644 --- a/examples/mnist/pippy_mnist.py +++ b/examples/mnist/pippy_mnist.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist -from pippy.hf._SaveModule import save_checkpoint +from pippy.SaveModule import save_checkpoint from pippy.IR import LossWrapper, PipeSplitWrapper from pippy.microbatch import sum_reducer, TensorChunkSpec diff --git a/pippy/hf/_SaveModule.py b/pippy/SaveModule.py similarity index 100% rename from pippy/hf/_SaveModule.py rename to pippy/SaveModule.py diff --git a/test/local_test_checkpoint.py b/test/local_test_checkpoint.py index ca173c2ae..cdeb81c46 100644 --- a/test/local_test_checkpoint.py +++ b/test/local_test_checkpoint.py @@ -12,10 +12,11 @@ import torch.optim as optim from pippy.compile import compile_stage -from pippy.hf._SaveModule import save_checkpoint from pippy.IR import pipe_split, TrivialLossWrapper from pippy.LoadModule import load_checkpoint +from pippy.SaveModule import save_checkpoint + DEFAULT_FILENAME = "pytorch_model.bin.index.json" CKPT_DIR = "test_ckpts" From 491e495e710c3a9fc6e6310edefe6a32e4dc8cae Mon Sep 17 00:00:00 2001 From: Yeonju Ro Date: Wed, 2 Aug 2023 13:35:30 -0700 Subject: [PATCH 41/96] 2d training bug fixes (#853) ## Description Fixes (1) tp size set wrong (2) specified output chunk spec explicitly ## Type of change - [v] Bug fix (non-breaking change which fixes an issue) ## Feature/Issue validation/testing - [v] Local test with command torchrun --nproc-per-node 8 2d_train.py --pp_size 4 --tp_size 2 --debug passed ## Checklist: - [v] Have you added tests that prove your fix is effective or that this feature works? - [v] Has code been commented, particularly in hard-to-understand areas? - [v] Have you made corresponding changes to the documentation? --- examples/selective2d/2d_train.py | 6 +++++- examples/selective2d/model.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/selective2d/2d_train.py b/examples/selective2d/2d_train.py index 2540e1228..c20af5938 100644 --- a/examples/selective2d/2d_train.py +++ b/examples/selective2d/2d_train.py @@ -239,6 +239,7 @@ def pp_and_tp(model, mesh, args): X, Y = get_rand(args) # PP + output_chunk_spec = (TensorChunkSpec(0), sum_reducer) stage = compile_stage( model, pp_rank, @@ -247,6 +248,7 @@ def pp_and_tp(model, mesh, args): args.device, pp_groups, example_inputs=[X, Y], + output_chunk_spec=output_chunk_spec, ) return model, stage @@ -308,14 +310,16 @@ def pp_and_tp_selective( # PP cut_fn(model, args, args.pp_size) + output_chunk_spec = (TensorChunkSpec(0), sum_reducer) stage = compile_stage( model, pp_rank, - args.world_size, + args.pp_size, args.n_chunks, args.device, pp_groups, example_inputs=[X, Y], + output_chunk_spec=output_chunk_spec, ) return model, stage diff --git a/examples/selective2d/model.py b/examples/selective2d/model.py index 3bfd07d25..e202c45b5 100644 --- a/examples/selective2d/model.py +++ b/examples/selective2d/model.py @@ -54,7 +54,7 @@ def forward(self, input): class CausalSelfAttention(nn.Module): def __init__(self, mesh, config): super().__init__() - tp_size = mesh.mesh.size(0) + tp_size = mesh.mesh.size(1) assert config.n_head % tp_size == 0 assert config.n_embd % config.n_head == 0 self.mesh = mesh From 15dfcd8ea1f445a30627693334ac0b9be160d01b Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 15 Aug 2023 09:02:35 -0400 Subject: [PATCH 42/96] Modularize microbatch fwd bwd (#861) ## Description 1. Modularize code into `forward_one_chunk` and `backward_one_chunk`, for the purpose of easy schedule implementation. 2. Rename `rank` in `PipelineStage` to `stage_index`, as stage is a virtual concept. For example, in Interleaved 1F1B, one rank can hold multiple stages. --- pippy/PipelineStage.py | 274 +++++++++++++++++++++-------------------- pippy/compile.py | 30 ++--- setup.py | 1 - 3 files changed, 153 insertions(+), 152 deletions(-) diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index f7c525d56..5f5293af2 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -50,7 +50,7 @@ class PipelineStage(torch.nn.Module): def __init__( self, pipe: Pipe, - rank: int, + stage_index: int, nstages: int, chunks: int, device: torch.device, @@ -61,7 +61,7 @@ def __init__( ): super().__init__() self.pipe = pipe - self.rank = rank + self.stage_index = stage_index self.nstages = nstages self.chunks = chunks self.device = device @@ -70,12 +70,28 @@ def __init__( self.kwargs_chunk_spec = kwargs_chunk_spec self.output_chunk_spec = output_chunk_spec + # `group_rank` is rank in process group `group`. + self.group_rank = dist.get_rank(group) + + # Run time states + # map microbatch ID to list of forward tensor args + self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} + # Split input chunks + self.args_split = None + self.kwargs_split = None + # Activation send requests of all chunk + self.all_act_send_reqs: List[dist.Work] = [] + # Grad send requests of all chunk + self.all_grad_send_reqs: List[dist.Work] = [] + # Caching chunk outputs for final output merge or reduction + self.output_chunks: List[Any] = [] + # Find my submodule self.split_gm = self.pipe.split_gm named_children = list(self.split_gm.named_children()) - self.name, self.submod = named_children[rank] + self.name, self.submod = named_children[stage_index] logging.info( - f"[{self.rank}][{self.name}] " + f"[{self.group_rank}][{self.name}] " f"Creating PipelineStage:\n" f"{self.submod}" ) @@ -97,7 +113,7 @@ def __init__( for node in reversed(self.split_gm.graph.nodes): if (node.op, node.target) == ("call_function", stage_backward): seen_bwd += 1 - if seen_bwd == self.rank: + if seen_bwd == self.stage_index: found_bwd = True self.bwd_node = node break @@ -107,13 +123,28 @@ def __init__( ) # Create submod to rank mapping - self.submod_to_rank: Dict[str, int] = {} + self.submod_to_stage_index: Dict[str, int] = {} for i, (name, _) in enumerate(self.split_gm.named_children()): - self.submod_to_rank.setdefault(name, i) + self.submod_to_stage_index.setdefault(name, i) + + # Create stage id to group rank mapping + # In interleaved case, `group_rank` is stage index % group size. + self.stage_index_to_group_rank: Dict[int, int] = {} + pg_world_size = dist.get_world_size(group) + for i in range(nstages): + # We only support wrapped-around interleaving + peer_rank = i % pg_world_size + self.stage_index_to_group_rank.setdefault(i, peer_rank) # Prepare send/recv infrastructure self._prepare_send_recv_infra() + def is_first(self): + return self.stage_index == 0 + + def is_last(self): + return self.stage_index == self.nstages - 1 + def _prepare_send_recv_infra(self): """ Create send/recv infrastructures for activations (during forward) and @@ -150,14 +181,14 @@ def _prepare_send_recv_infra(self): self.kwargs_recv_info[0], ) - def get_rank_of_submod( + def get_stage_index_of_submod( self, submod_name: str, ): - if submod_name not in self.submod_to_rank: - raise AssertionError(f"Rank of {submod_name} not found") + if submod_name not in self.submod_to_stage_index: + raise AssertionError(f"Stage id of {submod_name} not found") - return self.submod_to_rank[submod_name] + return self.submod_to_stage_index[submod_name] def _create_act_recv_buffers( self, @@ -196,12 +227,12 @@ def create_recv_tensor( tensor_meta = input_node.meta["tensor_meta"] logging.info( - f"[{self.rank}][{self.name}] " + f"[{self.group_rank}][{self.name}] " f"Creating recv buffer for input '{input_node.name}' " f"value index {output_idx}: {tensor_meta.shape}" ) - src_rank = self.get_rank_of_submod(input_node.name) + src_rank = self.get_stage_index_of_submod(input_node.name) buffer = _make_tensor_from_meta(tensor_meta, self.device) # Enable gradient in training mode if self.pipe.has_loss_and_backwards: @@ -225,7 +256,7 @@ def create_recv_tensor( ) logging.info( - f"[{self.rank}][{self.name}] " + f"[{self.group_rank}][{self.name}] " f"Activation recv info: {args_recv_info}" ) return args_recv_info, kwargs_recv_info @@ -240,7 +271,7 @@ def find_dst_rank( """ if user.op == "call_module": # User is a stage (`call_module`) - return self.get_rank_of_submod(user.name) + return self.get_stage_index_of_submod(user.name) elif user.target is sync_barrier: # Send result back to pp rank 0 return 0 @@ -275,7 +306,7 @@ def _create_act_send_info(self): dsts.append(dst_rank) logging.info( - f"[{self.rank}][{self.name}] " f"Send info: {act_send_info}" + f"[{self.group_rank}][{self.name}] " f"Send info: {act_send_info}" ) return act_send_info @@ -308,7 +339,8 @@ def _create_grad_recv_info( ) logging.info( - f"[{self.rank}][{self.name}] " f"Grad recv info: {grad_recv_info}" + f"[{self.group_rank}][{self.name}] " + f"Grad recv info: {grad_recv_info}" ) return grad_recv_info @@ -332,22 +364,24 @@ def map_recv_to_send(a): pippy.fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send) logging.info( - f"[{self.rank}][{self.name}] " f"Grad send info: {grad_send_info}" + f"[{self.group_rank}][{self.name}] " + f"Grad send info: {grad_send_info}" ) return grad_send_info def _recv_tensor(self, info, recv_reqs): logging.debug( - f"[{self.rank}][{self.name}] " + f"[{self.group_rank}][{self.name}] " f"Receiving tensor '{info.input_name}' from Rank {info.source}: " f"{info.buffer.size()}" ) # Use async to parallelize recv of tensors + peer_rank = self.stage_index_to_group_rank[info.source] work = dist.irecv( info.buffer, - info.source + peer_rank if self.group is None - else dist.get_global_rank(self.group, info.source), + else dist.get_global_rank(self.group, peer_rank), group=self.group, ) recv_reqs.append(work) @@ -359,41 +393,51 @@ def recv_tensor_fn( ): return lambda info: self._recv_tensor(info, reqs) + def split_inputs(self, args, kwargs): + self.args_split = None + self.kwargs_split = None + if args or kwargs: + self.args_split, self.kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self.chunks, + self.args_chunk_spec, + self.kwargs_chunk_spec, + ) + def _recv_and_fill_inputs( self, chunk: int, - args_split, - kwargs_split, ): # Receive requests of a chunk recv_reqs: List[dist.Work] = [] act_recv = self.recv_tensor_fn(recv_reqs) - if args_split: - chunk_args = args_split[chunk] + if self.args_split: + chunk_args = self.args_split[chunk] chunk_args_list = list(chunk_args) def recv_args(info): if isinstance(info, RecvInfo): return act_recv(info) else: - return chunk_args_list.pop(0) + return chunk_args_list.pop(0) # type: ignore[has-type] composite_args = pippy.fx.node.map_aggregate( self.args_recv_info[chunk], recv_args, ) - if kwargs_split: - chunk_kwargs = kwargs_split[chunk] + if self.kwargs_split: + chunk_kwargs = self.kwargs_split[chunk] def recv_kwargs(info): if isinstance(info, RecvInfo): return act_recv(info) else: - k = next(iter(chunk_kwargs)) - return chunk_kwargs.pop(k) + k = next(iter(chunk_kwargs)) # type: ignore[has-type] + return chunk_kwargs.pop(k) # type: ignore[has-type] composite_kwargs = pippy.fx.node.map_aggregate( self.kwargs_recv_info[chunk], @@ -414,19 +458,20 @@ def _send_activations( send_reqs: List[dist.Work] = [] for idx, out in enumerate(output_tuple): - dst_ranks = self.act_send_info[idx] - for dst in dst_ranks: + dst_stages = self.act_send_info[idx] + for dst in dst_stages: if dst is None: continue logging.debug( - f"[{self.rank}][{self.name}] " + f"[{self.group_rank}][{self.name}] " f"Sending tensor to Rank {dst}: {out.size()}" ) + peer_rank = self.stage_index_to_group_rank[dst] work = dist.isend( out, - dst + peer_rank if self.group is None - else dist.get_global_rank(self.group, dst), # TODO + else dist.get_global_rank(self.group, peer_rank), # TODO group=self.group, ) send_reqs.append(work) @@ -452,7 +497,7 @@ def _recv_grads( work.wait() logging.debug( - f"[{self.rank}][{self.name}] " + f"[{self.group_rank}][{self.name}] " f"Received output grads of chunk {bwd_chunk}: {map_debug_info(grads)}" ) return grads @@ -464,24 +509,23 @@ def _send_grads( # Send requests of a chunk grad_send_reqs: List[dist.Work] = [] - for grad, grad_receiver in zip(grads_input, self.grad_send_info): - if isinstance(grad, torch.Tensor) and grad_receiver is not None: + for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): + if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: logging.debug( - f"[{self.rank}][{self.name}] " - f"Sending gradient to Rank {grad_receiver}: {grad.size()}" + f"[{self.group_rank}][{self.name}] " + f"Sending gradient to Rank {grad_recv_stage}: {grad.size()}" ) + peer_rank = self.stage_index_to_group_rank[grad_recv_stage] work = dist.isend( grad, - grad_receiver + peer_rank if self.group is None - else dist.get_global_rank( - self.group, grad_receiver - ), # TODO + else dist.get_global_rank(self.group, peer_rank), # TODO group=self.group, ) grad_send_reqs.append(work) else: - assert grad is None and grad_receiver is None + assert grad is None and grad_recv_stage is None return grad_send_reqs @@ -518,15 +562,8 @@ def backward_maybe_with_nosync(self, bwd_kwargs: Dict, is_last_chunk: bool): def forward_one_chunk( self, chunk: int, - args_split, - kwargs_split, - fwd_cache: Dict[int, Any], ): - composite_args, composite_kwargs = self._recv_and_fill_inputs( - chunk, - args_split, - kwargs_split, - ) + composite_args, composite_kwargs = self._recv_and_fill_inputs(chunk) # Compute forward try: @@ -536,7 +573,7 @@ def forward_one_chunk( except Exception as e: exc_msg = f""" - Rank {self.rank} failed to run forward stage: {self.name} + Rank {self.group_rank} failed to run forward stage: {self.name} args: {map_debug_info(composite_args)} kwargs: {map_debug_info(composite_kwargs)} """ @@ -545,24 +582,29 @@ def forward_one_chunk( # Unify output form to tuple for easy correspondance with # `act_send_info` output_tuple = output if type(output) is tuple else (output,) + # Prepare for final output merge or reduction + self.output_chunks.append(output) + + # Send activations send_reqs = self._send_activations(output_tuple) + self.all_act_send_reqs += send_reqs # Save activations and inputs for backward flat_args = flatten_args(composite_args) flat_kwargs = flatten_args(composite_kwargs) flatten_input_tensors = flat_args + flat_kwargs - fwd_cache[chunk] = ( + self.fwd_cache[chunk] = ( output_tuple, # stage_output flatten_input_tensors, # input_values ) - return output, send_reqs - def backward_one_chunk( self, bwd_chunk: int, - fwd_cache: Dict[int, Any], ): + if not self.pipe.has_loss_and_backwards: + return None + grads = self._recv_grads(bwd_chunk) # Pack args for `stage_backward`` @@ -570,7 +612,7 @@ def backward_one_chunk( ( bwd_kwargs["stage_output"], bwd_kwargs["input_values"], - ) = fwd_cache.pop(bwd_chunk) + ) = self.fwd_cache.pop(bwd_chunk) # Fill actual gradients received for outputs # If nothing received, as in the case of last stage, then we # would use the default `output_grads` prepared in the IR phase, @@ -586,62 +628,53 @@ def backward_one_chunk( ) grad_send_reqs = self._send_grads(grads_input) - return grad_send_reqs + self.all_grad_send_reqs += grad_send_reqs - def forward(self, *args, **kwargs): + def clear_runtime_states(self): # map microbatch ID to list of forward tensor args - fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} + self.fwd_cache.clear() + # Activation send requests of all chunk + self.all_act_send_reqs.clear() + # Grad send requests of all chunk + self.all_grad_send_reqs.clear() + # Caching chunk outputs for final output merge or reduction + self.output_chunks.clear() - args_split = None - kwargs_split = None - if args or kwargs: - args_split, kwargs_split = split_args_kwargs_into_chunks( - args, - kwargs, - self.chunks, - self.args_chunk_spec, - self.kwargs_chunk_spec, - ) + def merge_output_chunks(self): + return merge_chunks( + self.output_chunks, + self.output_chunk_spec, + ) - # Activation send requests of all chunk - all_send_reqs: List[dist.Work] = [] + def forward(self, *args, **kwargs): + # Clean per iteration + self.clear_runtime_states() - output_chunks = [] + # Split inputs into chunks + self.split_inputs(args, kwargs) # Forward pass of all chunks for chunk in range(self.chunks): - output, send_reqs = self.forward_one_chunk( - chunk, args_split, kwargs_split, fwd_cache - ) - all_send_reqs += send_reqs - # Prepare for final output merge or reduction - output_chunks.append(output) + self.forward_one_chunk(chunk) # Wait for all sends to finish # TODO: okay to delay the sync till completion of all chunks? - for work in all_send_reqs: + for work in self.all_act_send_reqs: work.wait() # Backward starts here - # Grad send requests of all chunk - all_grad_send_reqs: List[dist.Work] = [] for bwd_chunk in range(self.chunks): - if self.pipe.has_loss_and_backwards: - grad_send_reqs = self.backward_one_chunk(bwd_chunk, fwd_cache) - all_grad_send_reqs += grad_send_reqs + self.backward_one_chunk(bwd_chunk) # Wait for all sends to finish # TODO: okay to delay the sync till completion of all chunks? - for work in all_grad_send_reqs: + for work in self.all_grad_send_reqs: work.wait() # Last rank return merged results per original format - if self.rank == self.nstages - 1: - return merge_chunks( - output_chunks, - self.output_chunk_spec, - ) + if self.is_last(): + return self.merge_output_chunks() else: return None @@ -672,72 +705,41 @@ def __init__( ) def forward(self, *args, **kwargs): - # map microbatch ID to list of forward tensor args - fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} + # Clean per iteration + self.clear_runtime_states() - args_split = None - kwargs_split = None - if args or kwargs: - args_split, kwargs_split = split_args_kwargs_into_chunks( - args, - kwargs, - self.chunks, - self.args_chunk_spec, - self.kwargs_chunk_spec, - ) - - # Activation send requests of all chunk - all_send_reqs: List[dist.Work] = [] - # Grad send requests of all chunk - all_grad_send_reqs: List[dist.Work] = [] - # Caching chunk outputs for final output merge or reduction - output_chunks = [] + # Split inputs into chunks + self.split_inputs(args, kwargs) warmup_chunks = cooldown_chunks = self.nstages # Warm-up phase: forward number of chunks equal to pipeline depth. for chunk in range(warmup_chunks): - output, send_reqs = self.forward_one_chunk( - chunk, args_split, kwargs_split, fwd_cache - ) - all_send_reqs += send_reqs - output_chunks.append(output) + self.forward_one_chunk(chunk) # 1F1B phase for bwd_chunk in range(0, self.chunks - cooldown_chunks): # Schedule backward for one warmed up chunk - if self.pipe.has_loss_and_backwards: - grad_send_reqs = self.backward_one_chunk(bwd_chunk, fwd_cache) - all_grad_send_reqs += grad_send_reqs + self.backward_one_chunk(bwd_chunk) # Schedule forward for one new chunk fwd_chunk = bwd_chunk + warmup_chunks - output, send_reqs = self.forward_one_chunk( - fwd_chunk, args_split, kwargs_split, fwd_cache - ) - all_send_reqs += send_reqs - # Prepare for final output merge or reduction - output_chunks.append(output) + self.forward_one_chunk(fwd_chunk) # Cool-down phase: backward for the rest of the chunks for bwd_chunk in range(self.chunks - cooldown_chunks, self.chunks): - if self.pipe.has_loss_and_backwards: - grad_send_reqs = self.backward_one_chunk(bwd_chunk, fwd_cache) - all_grad_send_reqs += grad_send_reqs + self.backward_one_chunk(bwd_chunk) # Wait for all sends to finish # TODO: okay to delay the sync till completion of all chunks? - for work in all_send_reqs: + for work in self.all_act_send_reqs: work.wait() - for work in all_grad_send_reqs: + for work in self.all_grad_send_reqs: work.wait() # Last rank return merged results per original format - if self.rank == self.nstages - 1: - return merge_chunks( - output_chunks, - self.output_chunk_spec, - ) + if self.is_last(): + return self.merge_output_chunks() else: return None diff --git a/pippy/compile.py b/pippy/compile.py index df3bd0d3f..5ddf074c6 100644 --- a/pippy/compile.py +++ b/pippy/compile.py @@ -218,8 +218,8 @@ def all_compile( def compile_stage( mod: torch.nn.Module, - rank: int, - num_ranks: int, + stage_index: int, + num_stages: int, num_chunks: int, device: torch.device, group: dist.ProcessGroup, @@ -257,7 +257,7 @@ def compile_stage( ) gm = pipe.split_gm - if rank == 0: + if stage_index == 0: logging.info(gm) if PIPPY_VERBOSITY == "INFO": gm.graph.print_tabular() @@ -298,24 +298,24 @@ def compile_stage( if schedule == "1F1B": return PipelineStage1F1B( pipe, - rank, - num_ranks, + stage_index, + num_stages, num_chunks, device, - group, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, + group=group, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_chunk_spec=output_chunk_spec, ) else: return PipelineStage( pipe, - rank, - num_ranks, + stage_index, + num_stages, num_chunks, device, - group, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, + group=group, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_chunk_spec=output_chunk_spec, ) diff --git a/setup.py b/setup.py index f445bdfe2..f5229da0c 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,6 @@ def write_version_file(): class clean(distutils.command.clean.clean): # type: ignore def run(self): - with open(".gitignore", "r") as f: ignores = f.read() for wildcard in filter(None, ignores.split("\n")): From 83a2308f4a53ae36eba2f0c1b2b262d5d697d37b Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sat, 11 Nov 2023 14:47:15 +0900 Subject: [PATCH 43/96] Update README.md (#870) ## Description adition -> addition --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8a89b9d1a..e212dbe97 100644 --- a/README.md +++ b/README.md @@ -512,7 +512,7 @@ Note: in cases where a model's parameters do not fit into the memory of a single distributed initialization which only materializes a pipeline stage on its corresponding GPU worker. For details, please see PiPPy's `Pipe.defer_stage_init` API. -In adition, some backend options need to be passed to RPC initialization. RPC by default uses the TensorPipe backend +In addition, some backend options need to be passed to RPC initialization. RPC by default uses the TensorPipe backend that supports point-to-point communication in an asynchronous manner. Configurations for TensorPipe can be specified with a `TensorPipeRpcBackendOptions` object. Here is an example: From 619d535d1106d5b5d70262119ceb4c509189c83c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 29 Nov 2023 16:44:13 -0500 Subject: [PATCH 44/96] Migrate from FX tracer to _export_to_torch_ir (#873) ## Changes - Use `torch._export._export_to_torch_ir` as new tracer (to replace FX symbolic tracing) - Remove pippy.fx and use torch.fx - Remove `PipelineDriver` (rpc backend) Old APIs such as `PipelineDriver` are no longer supported. And signature of the `from_tracing` API is changed to be in line with that of export. ## Major API usage ``` pipe = Pipe.from_tracing( mod, args.chunks, example_args=(x, y), ) stage = PipelineStage( pipe, args.rank, device=args.device, ) # Run if args.rank == 0: stage(x) elif args.rank == args.world_size - 1: out = stage(y) else: stage() ``` --- .github/workflows/code-quality.yml | 8 +- .github/workflows/pippy_gpu_tests.sh | 44 - .github/workflows/pippy_tests.yaml | 170 +- check.sh | 7 +- examples/hf/gpt2/pippy_gpt2.py | 12 +- pippy/IR.py | 337 +- pippy/ModelSplit.py | 28 +- pippy/PipelineDriver.py | 2281 -------- pippy/PipelineStage.py | 154 +- pippy/__init__.py | 15 - pippy/auto_parallelization.py | 4 +- pippy/compile.py | 321 -- pippy/debug.py | 12 +- pippy/fx/OVERVIEW.md | 134 - pippy/fx/__init__.py | 90 - pippy/fx/__init__.pyi | 8 - pippy/fx/_compatibility.py | 35 - pippy/fx/_pytree.py | 43 - pippy/fx/_symbolic_trace.py | 1080 ---- pippy/fx/annotate.py | 22 - pippy/fx/experimental/__init__.py | 1 - .../experimental/accelerator_partitioner.py | 1083 ---- pippy/fx/experimental/const_fold.py | 294 -- pippy/fx/experimental/debug.py | 32 - .../experimental/graph_gradual_typechecker.py | 927 ---- pippy/fx/experimental/merge_matmul.py | 172 - pippy/fx/experimental/meta_tracer.py | 269 - .../migrate_gradual_types/__init__.py | 1 - .../migrate_gradual_types/constraint.py | 559 -- .../constraint_generator.py | 1282 ----- .../constraint_transformation.py | 1041 ---- .../migrate_gradual_types/operation.py | 16 - .../migrate_gradual_types/transform_to_z3.py | 349 -- .../migrate_gradual_types/util.py | 53 - .../migrate_gradual_types/z3_types.py | 30 - pippy/fx/experimental/normalize.py | 163 - pippy/fx/experimental/optimization.py | 406 -- pippy/fx/experimental/partitioner_utils.py | 320 -- pippy/fx/experimental/proxy_tensor.py | 683 --- pippy/fx/experimental/refinement_types.py | 17 - pippy/fx/experimental/rewriter.py | 122 - .../fx/experimental/schema_type_annotation.py | 112 - pippy/fx/experimental/symbolic_shapes.py | 472 -- pippy/fx/experimental/unification/LICENSE.txt | 28 - pippy/fx/experimental/unification/__init__.py | 5 - pippy/fx/experimental/unification/core.py | 119 - pippy/fx/experimental/unification/dispatch.py | 7 - pippy/fx/experimental/unification/match.py | 123 - pippy/fx/experimental/unification/more.py | 117 - .../unification/multipledispatch/__init__.py | 4 - .../unification/multipledispatch/conflict.py | 120 - .../unification/multipledispatch/core.py | 82 - .../multipledispatch/dispatcher.py | 433 -- .../unification/multipledispatch/utils.py | 126 - .../unification/multipledispatch/variadic.py | 93 - .../unification/unification_tools.py | 393 -- pippy/fx/experimental/unification/utils.py | 106 - pippy/fx/experimental/unification/variable.py | 81 - pippy/fx/experimental/unify_refinements.py | 121 - pippy/fx/graph.py | 1507 ------ pippy/fx/graph_module.py | 759 --- pippy/fx/immutable_collections.py | 53 - pippy/fx/interpreter.py | 481 -- pippy/fx/node.py | 627 --- pippy/fx/operator_schemas.py | 409 -- pippy/fx/passes/README.md | 20 - pippy/fx/passes/__init__.py | 12 - pippy/fx/passes/backends/__init__.py | 1 - pippy/fx/passes/backends/cudagraphs.py | 56 - pippy/fx/passes/backends/nvfuser.py | 287 - pippy/fx/passes/dialect/__init__.py | 0 pippy/fx/passes/dialect/common/__init__.py | 0 pippy/fx/passes/dialect/common/cse_pass.py | 114 - pippy/fx/passes/fake_tensor_prop.py | 30 - pippy/fx/passes/graph_drawer.py | 328 -- pippy/fx/passes/graph_manipulation.py | 111 - pippy/fx/passes/infra/__init__.py | 2 - pippy/fx/passes/infra/partitioner.py | 228 - pippy/fx/passes/infra/pass_base.py | 79 - pippy/fx/passes/infra/pass_manager.py | 308 -- pippy/fx/passes/net_min_base.py | 619 --- pippy/fx/passes/operator_support.py | 207 - pippy/fx/passes/param_fetch.py | 67 - pippy/fx/passes/pass_manager.py | 242 - pippy/fx/passes/reinplace.py | 663 --- pippy/fx/passes/shape_prop.py | 153 - pippy/fx/passes/split_module.py | 327 -- pippy/fx/passes/split_utils.py | 278 - pippy/fx/passes/splitter_base.py | 854 --- pippy/fx/passes/tests/__init__.py | 1 - pippy/fx/passes/tests/test_pass_manager.py | 37 - pippy/fx/passes/tools_common.py | 254 - pippy/fx/passes/utils/__init__.py | 2 - pippy/fx/passes/utils/common.py | 84 - pippy/fx/passes/utils/fuser_utils.py | 214 - pippy/fx/passes/utils/matcher_utils.py | 309 -- pippy/fx/proxy.py | 419 -- pippy/fx/subgraph_rewriter.py | 255 - pippy/fx/tensor_type.py | 105 - pippy/fx/traceback.py | 62 - pippy/microbatch.py | 30 +- pippy/utils.py | 288 +- requirements.txt | 2 +- setup.py | 2 +- ..._compat-fx_backcompat_class_members.expect | 19 - ...t-fx_backcompat_function_signatures.expect | 74 - test/fx/named_tup.py | 8 - test/fx/quantization.py | 325 -- test/fx/test_common_passes.py | 116 - test/fx/test_cse_pass.py | 234 - test/fx/test_dce_pass.py | 185 - test/fx/test_future.py | 51 - test/fx/test_fx_const_fold.py | 712 --- test/fx/test_fx_param_shape_control_flow.py | 156 - test/fx/test_gradual_type.py | 1017 ---- test/fx/test_pass_infra.py | 176 - test/fx/test_subgraph_rewriter.py | 777 --- test/fx/test_z3_gradual_types.py | 2481 --------- test/local_test_compile.py | 110 - test/local_test_forward.py | 168 - test/local_test_visualizer.py | 365 -- test/{local_test_c10d_bwd.py => test_bwd.py} | 65 +- test/{local_test_c10d.py => test_fwd.py} | 41 +- test/test_fx.py | 4658 ----------------- test/test_fx_experimental.py | 1717 ------ test/test_pipe.py | 113 + test/test_pipe_bwd.py | 123 + 127 files changed, 628 insertions(+), 39116 deletions(-) delete mode 100755 .github/workflows/pippy_gpu_tests.sh delete mode 100644 pippy/PipelineDriver.py delete mode 100644 pippy/compile.py delete mode 100644 pippy/fx/OVERVIEW.md delete mode 100644 pippy/fx/__init__.py delete mode 100644 pippy/fx/__init__.pyi delete mode 100644 pippy/fx/_compatibility.py delete mode 100644 pippy/fx/_pytree.py delete mode 100644 pippy/fx/_symbolic_trace.py delete mode 100644 pippy/fx/annotate.py delete mode 100644 pippy/fx/experimental/__init__.py delete mode 100644 pippy/fx/experimental/accelerator_partitioner.py delete mode 100644 pippy/fx/experimental/const_fold.py delete mode 100644 pippy/fx/experimental/debug.py delete mode 100644 pippy/fx/experimental/graph_gradual_typechecker.py delete mode 100644 pippy/fx/experimental/merge_matmul.py delete mode 100644 pippy/fx/experimental/meta_tracer.py delete mode 100644 pippy/fx/experimental/migrate_gradual_types/__init__.py delete mode 100644 pippy/fx/experimental/migrate_gradual_types/constraint.py delete mode 100644 pippy/fx/experimental/migrate_gradual_types/constraint_generator.py delete mode 100644 pippy/fx/experimental/migrate_gradual_types/constraint_transformation.py delete mode 100644 pippy/fx/experimental/migrate_gradual_types/operation.py delete mode 100644 pippy/fx/experimental/migrate_gradual_types/transform_to_z3.py delete mode 100644 pippy/fx/experimental/migrate_gradual_types/util.py delete mode 100644 pippy/fx/experimental/migrate_gradual_types/z3_types.py delete mode 100644 pippy/fx/experimental/normalize.py delete mode 100644 pippy/fx/experimental/optimization.py delete mode 100644 pippy/fx/experimental/partitioner_utils.py delete mode 100644 pippy/fx/experimental/proxy_tensor.py delete mode 100644 pippy/fx/experimental/refinement_types.py delete mode 100644 pippy/fx/experimental/rewriter.py delete mode 100644 pippy/fx/experimental/schema_type_annotation.py delete mode 100644 pippy/fx/experimental/symbolic_shapes.py delete mode 100644 pippy/fx/experimental/unification/LICENSE.txt delete mode 100644 pippy/fx/experimental/unification/__init__.py delete mode 100644 pippy/fx/experimental/unification/core.py delete mode 100644 pippy/fx/experimental/unification/dispatch.py delete mode 100644 pippy/fx/experimental/unification/match.py delete mode 100644 pippy/fx/experimental/unification/more.py delete mode 100644 pippy/fx/experimental/unification/multipledispatch/__init__.py delete mode 100644 pippy/fx/experimental/unification/multipledispatch/conflict.py delete mode 100644 pippy/fx/experimental/unification/multipledispatch/core.py delete mode 100644 pippy/fx/experimental/unification/multipledispatch/dispatcher.py delete mode 100644 pippy/fx/experimental/unification/multipledispatch/utils.py delete mode 100644 pippy/fx/experimental/unification/multipledispatch/variadic.py delete mode 100644 pippy/fx/experimental/unification/unification_tools.py delete mode 100644 pippy/fx/experimental/unification/utils.py delete mode 100644 pippy/fx/experimental/unification/variable.py delete mode 100644 pippy/fx/experimental/unify_refinements.py delete mode 100644 pippy/fx/graph.py delete mode 100644 pippy/fx/graph_module.py delete mode 100644 pippy/fx/immutable_collections.py delete mode 100644 pippy/fx/interpreter.py delete mode 100644 pippy/fx/node.py delete mode 100644 pippy/fx/operator_schemas.py delete mode 100644 pippy/fx/passes/README.md delete mode 100644 pippy/fx/passes/__init__.py delete mode 100644 pippy/fx/passes/backends/__init__.py delete mode 100644 pippy/fx/passes/backends/cudagraphs.py delete mode 100644 pippy/fx/passes/backends/nvfuser.py delete mode 100644 pippy/fx/passes/dialect/__init__.py delete mode 100644 pippy/fx/passes/dialect/common/__init__.py delete mode 100644 pippy/fx/passes/dialect/common/cse_pass.py delete mode 100644 pippy/fx/passes/fake_tensor_prop.py delete mode 100644 pippy/fx/passes/graph_drawer.py delete mode 100644 pippy/fx/passes/graph_manipulation.py delete mode 100644 pippy/fx/passes/infra/__init__.py delete mode 100644 pippy/fx/passes/infra/partitioner.py delete mode 100644 pippy/fx/passes/infra/pass_base.py delete mode 100644 pippy/fx/passes/infra/pass_manager.py delete mode 100644 pippy/fx/passes/net_min_base.py delete mode 100644 pippy/fx/passes/operator_support.py delete mode 100644 pippy/fx/passes/param_fetch.py delete mode 100644 pippy/fx/passes/pass_manager.py delete mode 100644 pippy/fx/passes/reinplace.py delete mode 100644 pippy/fx/passes/shape_prop.py delete mode 100644 pippy/fx/passes/split_module.py delete mode 100644 pippy/fx/passes/split_utils.py delete mode 100644 pippy/fx/passes/splitter_base.py delete mode 100644 pippy/fx/passes/tests/__init__.py delete mode 100644 pippy/fx/passes/tests/test_pass_manager.py delete mode 100644 pippy/fx/passes/tools_common.py delete mode 100644 pippy/fx/passes/utils/__init__.py delete mode 100644 pippy/fx/passes/utils/common.py delete mode 100644 pippy/fx/passes/utils/fuser_utils.py delete mode 100644 pippy/fx/passes/utils/matcher_utils.py delete mode 100644 pippy/fx/proxy.py delete mode 100644 pippy/fx/subgraph_rewriter.py delete mode 100644 pippy/fx/tensor_type.py delete mode 100644 pippy/fx/traceback.py delete mode 100644 test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect delete mode 100644 test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect delete mode 100644 test/fx/named_tup.py delete mode 100644 test/fx/quantization.py delete mode 100644 test/fx/test_common_passes.py delete mode 100644 test/fx/test_cse_pass.py delete mode 100644 test/fx/test_dce_pass.py delete mode 100644 test/fx/test_future.py delete mode 100644 test/fx/test_fx_const_fold.py delete mode 100644 test/fx/test_fx_param_shape_control_flow.py delete mode 100644 test/fx/test_gradual_type.py delete mode 100644 test/fx/test_pass_infra.py delete mode 100644 test/fx/test_subgraph_rewriter.py delete mode 100644 test/fx/test_z3_gradual_types.py delete mode 100644 test/local_test_compile.py delete mode 100644 test/local_test_forward.py delete mode 100644 test/local_test_visualizer.py rename test/{local_test_c10d_bwd.py => test_bwd.py} (72%) rename test/{local_test_c10d.py => test_fwd.py} (82%) delete mode 100644 test/test_fx.py delete mode 100644 test/test_fx_experimental.py create mode 100644 test/test_pipe.py create mode 100644 test/test_pipe_bwd.py diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 3008e2e18..1f52b1cb6 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -23,10 +23,8 @@ jobs: pip install --upgrade pip pip install -r docs/requirements.txt pip install types-docutils types-setuptools tqdm types-tabulate - if [ -f requirements.txt ]; then pip install -r requirements.txt --index-url https://download.pytorch.org/whl/cpu; fi - pip install torchvision --index-url https://download.pytorch.org/whl/cpu - pip install git+https://github.com/pbelevich/transformers.git@compatible_with_pt_master - pip install "black<23" pylint==v3.0.0a5 mypy==v0.960 flake8==3.8.2 pyre-check==0.9.15 ufmt==2.1.0 + if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi + pip install "black<23" pylint==v3.0.0a5 mypy==v0.981 flake8==3.8.2 pyre-check==0.9.15 ufmt==2.1.0 - name: Static Analysis Checks if: always() - run: ./check.sh --keep-going + run: ./check.sh diff --git a/.github/workflows/pippy_gpu_tests.sh b/.github/workflows/pippy_gpu_tests.sh deleted file mode 100755 index 33f9d487f..000000000 --- a/.github/workflows/pippy_gpu_tests.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash - -set -x - -# Print test options -echo "REPLICATE: ${REPLICATE}" -echo "SCHEDULE: ${SCHEDULE}" - -nvidia-smi -nvcc --version -which python3 -python3 --version -which pip3 -pip3 --version - -# Install git -apt-get update -apt-get install git -y - -# Install dependencies -# Turn off progress bar to save logs -pip3 config set global.progress_bar off -pip3 install flake8 pytest pytest-cov numpy -if [ -f requirements.txt ]; then pip3 install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html; fi - -# Install pavel's huggingface fork -pip3 install git+https://github.com/huggingface/transformers.git@main sentencepiece - -# Install pippy -python3 setup.py install - -set -ex - -# Run all integration tests -python3 test/local_test_forward.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 test/local_test_forward_backward.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 test/local_test_compile.py -s ${SCHEDULE} -python3 examples/hf/gpt2/pippy_gpt2.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 examples/gspmd/pippy_gspmd.py -s ${SCHEDULE} - -# Run flaky integration tests -python3 test/local_test_ddp.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 test/local_test_forward_hf_gpt2.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 test/local_test_forward_hf_bert.py --replicate ${REPLICATE} -s ${SCHEDULE} diff --git a/.github/workflows/pippy_tests.yaml b/.github/workflows/pippy_tests.yaml index a0f9978c2..b79122078 100644 --- a/.github/workflows/pippy_tests.yaml +++ b/.github/workflows/pippy_tests.yaml @@ -21,28 +21,26 @@ concurrency: jobs: - pytest_tests: - runs-on: linux.4xlarge - strategy: - matrix: - python-version: ["3.8", "3.9"] - container: - image: python:${{ matrix.python-version }} + # pytest_tests: + # runs-on: linux.4xlarge + # strategy: + # matrix: + # python-version: ["3.8", "3.9"] + # container: + # image: python:${{ matrix.python-version }} - steps: - - uses: actions/checkout@v2 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest pytest-cov pytest-xdist numpy - if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pavel's huggingface fork - run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses - - name: Install pippy - run: "python setup.py install" - - name: Test with pytest - run: | - pytest --cov=pippy --ignore=test/hf_test.py --ignore=test/test_fx.py --ignore=test/test_fx_experimental.py --ignore=test/fx test/ + # steps: + # - uses: actions/checkout@v2 + # - name: Install dependencies + # run: | + # python -m pip install --upgrade pip + # pip install flake8 pytest pytest-cov pytest-xdist numpy + # if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi + # - name: Install pippy + # run: "python setup.py install" + # - name: Test with pytest + # run: | + # pytest --cov=pippy test/ # hf_model_tests: # runs-on: linux.12xlarge @@ -76,10 +74,8 @@ jobs: runs-on: linux.4xlarge strategy: matrix: - python-version: ["3.8", "3.9"] - replicate: ["0", "1"] - schedule: ["FillDrain", "1F1B"] - checkpoint: [ "0", "1" ] + python-version: ["3.9"] + schedule: ["FillDrain"] env: OMP_NUM_THREADS: "1" container: @@ -92,30 +88,26 @@ jobs: python -m pip install --upgrade pip pip install flake8 pytest pytest-cov numpy datasets evaluate scikit-learn sacrebleu if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pavel's huggingface fork - run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses - name: Install pippy run: "python setup.py install" + - name: Test forward pipe generation + run: python test/test_pipe.py + - name: Test backward pipe generation + run: python test/test_pipe_bwd.py - name: Run forward-only integration test - run: python test/local_test_forward.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run forward-only-auto-parallel integration test - run: python test/local_test_forward_auto_parallel.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + run: torchrun --nproc-per-node 4 test/test_fwd.py - name: Run forward-loss-backward integration test - run: python test/local_test_forward_backward.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run null_coalesce_accumulate integration test - run: python test/local_test_null_coalesce_accumulate.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} - - name: Run PP + DDP test - run: python test/local_test_ddp.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + run: torchrun --nproc-per-node 4 test/test_bwd.py --schedule ${{ matrix.schedule }} + # - name: Run null_coalesce_accumulate integration test + # run: python test/local_test_null_coalesce_accumulate.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} + # - name: Run PP + DDP test + # run: python test/local_test_ddp.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} #- name: Run HF BERT forward-only integration test - # run: python test/local_test_forward_hf_bert.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run HF GPT2 forward-only integration test - run: python test/local_test_forward_hf_gpt2.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run visualizer test - run: python test/local_test_visualizer.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} - - name: Run auto-split test - run: python test/local_test_autosplit.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} - - name: Run compile test - run: python test/local_test_compile.py -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + # run: python test/local_test_forward_hf_bert.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + # - name: Run HF GPT2 forward-only integration test + # run: python test/local_test_forward_hf_gpt2.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + # - name: Run auto-split test + # run: python test/local_test_autosplit.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} # hf_examples_set1: # runs-on: linux.12xlarge @@ -145,11 +137,11 @@ jobs: # git submodule update --init test/minGPT # python test/min_gpt_tracing.py # - name: Run GPT2 example - # run: python examples/hf/gpt2/pippy_gpt2.py -s ${{ matrix.schedule }} + # run: python examples/hf/gpt2/pippy_gpt2.py --schedule ${{ matrix.schedule }} # - name: Run BERT example - # run: python examples/hf/bert/pippy_bert.py -s ${{ matrix.schedule }} + # run: python examples/hf/bert/pippy_bert.py --schedule ${{ matrix.schedule }} # - name: Run T5 example - # run: python examples/hf/t5/pippy_t5.py -s ${{ matrix.schedule }} + # run: python examples/hf/t5/pippy_t5.py --schedule ${{ matrix.schedule }} # - name: "HF Translation: fine-tune T5 model translation English to Romanian" # run: > # python examples/hf/translation/run_translation.py --model_name_or_path t5-small --do_train --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --dataset_name wmt16 --dataset_config_name ro-en --output_dir /tmp/tst-translation --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --overwrite_output_dir --predict_with_generate --max_steps=10 --dp_group_size=1 --pp_group_size=8 @@ -186,84 +178,6 @@ jobs: # - name: "HF Text classification: fine-tune BERT on the GLUE benchmark" # run: python examples/hf/text-classification/run_glue.py --dp_group_size=2 --pp_group_size=8 --model_name_or_path bert-base-cased --task_name mrpc --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --output_dir /tmp/mrpc/ --max_steps=3 --overwrite_output_dir - integration_test_gpu: - runs-on: linux.16xlarge.nvidia.gpu - strategy: - matrix: - python-version: ["3.8"] - replicate: ["0", "1"] - schedule: ["FillDrain", "1F1B"] - env: - DOCKER_IMAGE: qts8n/cuda-python:devel - PIPPY_ROOT: /PiPPy - OMP_NUM_THREADS: "1" - REPLICATE: ${{ matrix.replicate }} - SCHEDULE: ${{ matrix.schedule }} - - steps: - - name: Clean working directory - shell: bash - run: | - sudo rm -rf /home/ec2-user/actions-runner/_work/PiPPy/PiPPy/* || true - - uses: actions/checkout@v2 - - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: pytorch/test-infra/.github/actions/setup-nvidia@main - - name: Pull Docker image - run: | - retry () { - "$@" || (sleep 1 && "$@") || (sleep 2 && "$@") - } - retry docker pull "${DOCKER_IMAGE}" - - name: Test docker run - run: | - set -x - # shellcheck disable=SC2086,SC2090 - container_name=$(docker run \ - --gpus all \ - --shm-size=1g --ulimit memlock=-1 \ - -e OMP_NUM_THREADS \ - -e REPLICATE \ - -e SCHEDULE \ - --tty \ - --detach \ - -v "$(pwd):${PIPPY_ROOT}" \ - -w "${PIPPY_ROOT}" \ - "${DOCKER_IMAGE}" - ) - # Run GPU tests and return error signal from docker - docker exec -t -w "${PIPPY_ROOT}" "${container_name}" bash -c "bash .github/workflows/pippy_gpu_tests.sh; exit \$?" - - name: Chown workspace - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd):${PIPPY_ROOT}" -w "${PIPPY_ROOT}" "${DOCKER_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Kill containers, clean up images - if: always() - run: | - # ignore expansion of "docker ps -q" since it could be empty - # shellcheck disable=SC2046 - docker stop $(docker ps -q) || true - # Prune all of the docker images - docker system prune -af - - programming_model_tests: - runs-on: linux.4xlarge - strategy: - matrix: - python-version: ["3.9"] - container: - image: python:${{ matrix.python-version }} - - steps: - - uses: actions/checkout@v2 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install numpy datasets evaluate scikit-learn sacrebleu - if [ -f requirements.txt ]; then pip install --pre -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pippy - run: "python setup.py install" - - name: Test PiPPy + Dynamo example - run: python examples/TorchDynamo/pippy_dynamo.py - - name: Run PiPPy in GSPMD style - run: python examples/gspmd/pippy_gspmd.py + # TODO: + # Update GPU test to use template in: + # https://github.com/pytorch/test-infra/wiki/Writing-generic-CI-jobs diff --git a/check.sh b/check.sh index 6f74be023..c20c968bf 100755 --- a/check.sh +++ b/check.sh @@ -4,7 +4,7 @@ function usage() { echo 2>&1 < torch.fx.GraphModule: + logger.info("[PiPPy] Tracing model ...") + try: + torch._dynamo.allow_in_graph(pipe_split) + traced: torch.fx.GraphModule = torch._export._export_to_torch_ir( + mod, + example_args, + example_kwargs, + constraints, + ) + if split_policy is not None: + traced = split_policy(traced) + finally: + torch._dynamo.disallow_in_graph(pipe_split) + return traced + @staticmethod def from_tracing( mod: torch.nn.Module, - multi_use_param_spec: Optional[MultiUseParamSpec] = None, - tracer=None, - output_loss_value_spec=None, - deep_copy_module=False, + num_chunks: int, + example_args: Tuple[Any, ...], + example_kwargs: Optional[Dict[str, Any]] = None, split_policy: Optional[ - Callable[[pippy.fx.GraphModule], pippy.fx.GraphModule] + Callable[[fx.GraphModule], fx.GraphModule] ] = None, - return_to_0: bool = True, - **kwargs, + args_chunk_spec: Optional[Tuple[Any, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, Any]] = None, + output_chunk_spec=None, + constraints: Optional[List[Constraint]] = None, ): - # TODO: abstract partitioning policy + # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across + # stages instead of TRANSMIT'ting it + multi_use_param_spec = MultiUseParameterConfig.REPLICATE + + # Figure out which output is loss from output_chunk_spec + output_loss_value_spec: Any = None + if output_chunk_spec is not None: + output_loss_value_spec = fx.node.map_aggregate( + output_chunk_spec, lambda v: isinstance(v, LossReducer) + ) - global _pipeline_tracer - old__pipeline_tracer = _pipeline_tracer - _pipeline_tracer = tracer or pippy.fx.Tracer() - try: - # TODO: tracing policy - if deep_copy_module: - mod = copy.deepcopy( - mod - ) # because further pipe building activities can modify mod - graph = _pipeline_tracer.trace(mod, **kwargs) - if isinstance(graph, torch_fx.Graph): - # HACK to convert torch.fx.Graph to pippy.fx.Graph - g_new = pippy.fx.Graph() - val_map: Dict[pippy.fx.Node, pippy.fx.Node] = {} - out = g_new.graph_copy(graph, val_map, False) - g_new.output(out) - - # `pippy.fx.map_arg` doesn't work on torch.fx.Node instances; - # do it here - def remap_vals(n): - return val_map[n] - - for node in g_new.nodes: - node.args = torch_fx.map_arg(node.args, remap_vals) - node.kwargs = torch_fx.map_arg(node.kwargs, remap_vals) - graph = g_new - - traced = pippy.fx.GraphModule(mod, graph) - finally: - _pipeline_tracer = old__pipeline_tracer + # Get split example inputs + if example_kwargs is None: + # Needed by `split_args_kwargs_into_chunks` + example_kwargs = {} + + args_split, kwargs_split = split_args_kwargs_into_chunks( + example_args, + example_kwargs, + num_chunks, + args_chunk_spec, + kwargs_chunk_spec, # TODO: merge into args_chunk_spec + ) - if split_policy is not None: - traced = split_policy(traced) + # Trace with export + traced = Pipe._trace_with_export( + mod, + example_args=args_split[0], + example_kwargs=kwargs_split[0], + constraints=constraints, + split_policy=split_policy, + ) - return Pipe._from_traced( + pipe = Pipe._from_traced( mod, traced, multi_use_param_spec, output_loss_value_spec=output_loss_value_spec, - return_to_0=return_to_0, ) + logger.info(pipe.split_gm) + if PIPPY_VERBOSITY == "DEBUG": + pipe.split_gm.graph.print_tabular() + + pipe.num_chunks = num_chunks + pipe.args_chunk_spec = args_chunk_spec + pipe.kwargs_chunk_spec = kwargs_chunk_spec + pipe.output_chunk_spec = output_chunk_spec + + # Shape propagation to get shapes of all tensors + PipeFakeTensorProp(pipe.split_gm).run() + for node in pipe.split_gm.graph.nodes: + logger.debug( + f"{node.name}, " + f"{node.meta['example_value'] if 'example_value' in node.meta else 'None'}", + ) + + return pipe + def __str__(self): return self.split_gm.__str__() def __repr__(self): return self.split_gm.__repr__() - # Conditoinal variable to ensure `defer_stage_init` is called before other callers call `materialize_stage` - # TODO: cleaner approach - _stage_init_lock = threading.Lock() - stage_init_cv = threading.Condition(_stage_init_lock) - - def defer_stage_init( - self, - device: torch.device, - index_filename: Union[str, os.PathLike] = None, - dtype: torch.dtype = None, - checkpoint_prefix: str = None, - ): - def materialize_stage(target: str) -> torch.nn.Module: - logging.info(f"Materializing {target} on {device}") - submodule = self.split_gm.get_submodule(target) - if index_filename is not None: - submodule = load_checkpoint( - model=submodule, - index_filename=index_filename, - device=device, - dtype=dtype, - checkpoint_prefix=checkpoint_prefix, - ) - try: - submodule.to(device) - except Exception: - # Usually `to(device)` fails because there is still some meta - # tensor in submodule, potentially because the checkpoint load - # did not cover that parameter. And the reason is often that - # that parameter shares weight with another parameter. - for name, param in submodule.named_parameters(): - if param.device == torch.device("meta"): - logging.warning(f"{name} is a meta tensor") - # Re-throw the original exception - raise - return submodule - - with Pipe.stage_init_cv: - setattr(Pipe, "materialize_stage", materialize_stage) - Pipe.stage_init_cv.notify() - - @staticmethod - def is_stage_init_deferred(): - return hasattr(Pipe, "materialize_stage") - - def export(self, stage_id: int) -> torch.nn.Module: - split_gm_children = list(self.split_gm.children()) - submod = split_gm_children[stage_id] - - # HACK: reusing defer init path in PipelineDriver - def exported_stage(target: str) -> torch.nn.Module: - logging.info(f"Retrieving exported {target}") - assert self.split_gm.get_submodule(target) is submod - return submod - - with Pipe.stage_init_cv: - if not hasattr(Pipe, "materialize_stage"): - setattr(Pipe, "materialize_stage", exported_stage) - Pipe.stage_init_cv.notify() - - return submod - class PipeSplitWrapper(torch.nn.Module): class SplitPoint(Enum): @@ -1198,18 +1157,30 @@ def annotate_split_points( setattr(predecessor_module, atoms[-1], wrapped_mod) -class PiPPyShapeProp(shape_prop.ShapeProp): +class PipeFakeTensorProp(Interpreter): def __init__( - self, module: pippy.fx.GraphModule, garbage_collect_values: bool = True + self, module: fx.GraphModule, garbage_collect_values: bool = True ): super().__init__(module, garbage_collect_values) self.stop_prop = False - def run_node(self, n: pippy.fx.Node) -> Any: - if (n.op, n.target) == ("call_function", stage_backward): + def run(self): + inp = tuple( + node.meta["val"] + for node in self.module.graph.nodes + if node.op == "placeholder" + ) + super().run(*inp) + + def run_node(self, node): + # Do not propagate through the stage backward call because it won't work + if (node.op, node.target) == ("call_function", stage_backward): self.stop_prop = True if self.stop_prop: return None - return super().run_node(n) + res = super().run_node(node) + node.meta["example_value"] = res + node.meta["val"] = res + return res diff --git a/pippy/ModelSplit.py b/pippy/ModelSplit.py index 74490b86c..1fd5e32fe 100644 --- a/pippy/ModelSplit.py +++ b/pippy/ModelSplit.py @@ -3,8 +3,8 @@ from typing import Callable, Dict, List, Tuple import torch +import torch.fx as fx -import pippy.fx from pippy.IR import pipe_split """ @@ -16,13 +16,13 @@ def _analyze_node_size( - gm: pippy.fx.GraphModule, -) -> Dict[pippy.fx.Node, Dict[str, int]]: + gm: fx.GraphModule, +) -> Dict[fx.Node, Dict[str, int]]: # state_dict helps us to get parameter sizes state_dict = gm.state_dict() # Function Parameter Usage - node_param_sizes: Dict[pippy.fx.Node, Dict[str, int]] = {} + node_param_sizes: Dict[fx.Node, Dict[str, int]] = {} for node in gm.graph.nodes: if node.op == "get_attr": # a parameter node param_name = node.target @@ -53,7 +53,7 @@ def _analyze_node_size( """ Split a model based on a maximum number of parameter and buffer elements a pipeline stage can have Input: - gm: `pippy.fx.GraphModule` to split + gm: `fx.GraphModule` to split threshold: maximum number of parameter and buffer elements a stage can have max_stages: maximum number of stages; default = -1, no limit Output: @@ -64,15 +64,15 @@ def _analyze_node_size( def _split_on_size_threshold_with_max_stages( - gm: pippy.fx.GraphModule, + gm: fx.GraphModule, threshold: int, max_stages: int = -1, -) -> Tuple[pippy.fx.GraphModule, int]: +) -> Tuple[fx.GraphModule, int]: # Analyze size of parameters/buffers used by each node in the graph node_param_sizes = _analyze_node_size(gm) # Record split positions - insert_before_nodes: List[pippy.fx.Node] = [] + insert_before_nodes: List[fx.Node] = [] def new_stage_before(node): insert_before_nodes.append(node) @@ -150,10 +150,10 @@ def new_stage_before(node): def split_on_size_threshold( threshold: int, -) -> Callable[[pippy.fx.GraphModule], pippy.fx.GraphModule]: +) -> Callable[[fx.GraphModule], fx.GraphModule]: def _split_on_size_threshold( - gm: pippy.fx.GraphModule, - ) -> pippy.fx.GraphModule: + gm: fx.GraphModule, + ) -> fx.GraphModule: gm, _ = _split_on_size_threshold_with_max_stages(gm, threshold) return gm @@ -172,10 +172,10 @@ def _split_on_size_threshold( def split_into_equal_size( nstages: int = 1, -) -> Callable[[pippy.fx.GraphModule], pippy.fx.GraphModule]: +) -> Callable[[fx.GraphModule], fx.GraphModule]: def _split_into_nstages_equal_size( - gm: pippy.fx.GraphModule, - ) -> pippy.fx.GraphModule: + gm: fx.GraphModule, + ) -> fx.GraphModule: param_size = 0 for param in gm.parameters(): param_size += param.numel() diff --git a/pippy/PipelineDriver.py b/pippy/PipelineDriver.py deleted file mode 100644 index 26dcda518..000000000 --- a/pippy/PipelineDriver.py +++ /dev/null @@ -1,2281 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import logging -import operator -import threading -import time -import warnings -from enum import Enum -from inspect import Parameter, Signature -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -import torch.distributed.rpc as rpc - -import pippy.fx -from pippy.backward import ( - _null_coalesce_accumulate, - stage_backward, - sync_barrier, -) -from pippy.events import Allocator, Event, EventRecorder, EventsContext -from pippy.fx.passes import shape_prop - -from pippy.IR import Pipe -from pippy.microbatch import ( - gen_output_chunk_spec, - LossReducer, - merge_chunks, - split_args_kwargs_into_chunks, - sum_reducer, -) -from pippy.utils import flatten_args_detach - -# TODO: Define the strategy for replicating the computation. In particular, we will likely make the assumption -# that the operations in the program are batch-wise commutative (my term), i.e. we can guarantee equivalence -# with splitting up the operation along the batch dimension, applying the computation to those sub-batches, -# then merging them back together via concatenation. We should provide a crisp contract surrounding this - -# ===== Questions to Answer ===== -# 1. When does each stage happen? -# micro-batch splitting: per-invocation or with one fixed chunk size? -# physical compilation: this depends on micro-batch splitting (for e.g. scheduling -# so it would have to be ordered after micro-batch splitting -# runtime: obviously needs to happen at runtime -# -# Conceptually: -# -# replicated_programs : List[IR] = replicate(chunks) -# schedule : List[IR] = schedule(replicated_programs) -# for device_schedule in schedule: -# for instruction in device_schedule: -# invoke(rank, instruction) -# -# `chunks` is the only external dependency that could potentially be used per-invocation. -# Do we want to: -# a) Take it as a per-invocation parameter and re-do compilation each time? (-overhead) -# b) Take it as a one-time initialization parameter and consistently split each -# batch into a single `chunks` value (-flexibility) -# c) Allow it to be dynamic but cache compiled policies? -# -# Decision: We can easily convert (a) to (c), so let's go with (a). - -DEBUG = False - - -class Phase(Enum): - FORWARD = 0 - BACKWARD = 1 - ACCUMULATE_GRAD = 2 - SYNC_BARRIER = 3 - - -# TODO: do we need this? -class SchedState(Enum): - WAITING = 0 - READY = 1 - RUNNING = 2 - DONE = 3 - - -def event_name(ph, stage_id, mbid): - phase_to_short_str = { - Phase.FORWARD: "F", - Phase.BACKWARD: "B", - Phase.ACCUMULATE_GRAD: "A", - Phase.SYNC_BARRIER: "S", - } - return f"{phase_to_short_str[ph]}_{stage_id},{mbid}" - - -def event_id(ph, stage_id, mbid, bid): - return f"{event_name(ph, stage_id, mbid)},{bid}" - - -def prev_event_name(ph: Any, all_stages: List[int], stage_id: int, mbid: Any): - i = all_stages.index(stage_id) - if ph == Phase.FORWARD and i > 0: - prev_stage = all_stages[i - 1] - return event_name(ph, prev_stage, mbid) - elif ph == Phase.BACKWARD and i < len(all_stages) - 1: - next_stage = all_stages[i + 1] - return event_name(ph, next_stage, mbid) - else: - return None - - -def next_event_name(ph: Any, all_stages: List[int], stage_id: int, mbid: Any): - i = all_stages.index(stage_id) - if ph == Phase.FORWARD and i < len(all_stages) - 1: - next_stage = all_stages[i + 1] - return event_name(ph, next_stage, mbid) - elif ph == Phase.BACKWARD and i > 0: - prev_stage = all_stages[i - 1] - return event_name(ph, prev_stage, mbid) if stage_id > 0 else None - else: - return None - - -class WorkItem: - def __init__( - self, - stage_id, - phase, - args, - kwargs, - future, - microbatch_id, - blocked_args_count, - ready_args, - batch_id, - num_microbatches, - state=SchedState.WAITING, - debug_str="", - ): - args_to_fwd = [ - "stage_id", - "phase", - "args", - "kwargs", - "future", - "microbatch_id", - "blocked_args_count", - "ready_args", - "batch_id", - "num_microbatches", - "state", - "debug_str", - ] - - for arg in args_to_fwd: - setattr(self, arg, locals()[arg]) - - stage_id: int - phase: Phase - args: Tuple[Any] - kwargs: Dict[str, Any] - future: torch.futures.Future - microbatch_id: int - - blocked_args_count: int - ready_args: Dict[int, Any] - state: SchedState - debug_str: str - - batch_id: int - num_microbatches: int - - def __str__(self): - return f"WorkItem({self.debug_str})" - - -class ValueReference: - def __init__(self, stage_id, unique_key): - self.stage_id = stage_id - self.unique_key = unique_key - self.meta: Dict[str, Any] = {} - - stage_id: int - unique_key: str - - def __repr__(self): - return f"ValueReference({self.stage_id}, {self.unique_key})" - - -class RefcountedFuture: - future: torch.futures.Future - refcount: int - - def __init__(self, future, refcount): - self.future, self.refcount = future, refcount - - def release(self): - """ - Decrement refcount by 1. Return True if this instance should be freed - """ - assert ( - self.refcount != 0 - ), "Detected reference counting inconsistency. Please report a bug to PiPPy" - self.refcount -= 1 - return self.refcount == 0 - - -class RankWorker(EventRecorder): - """ - RankWorker is the underlying WorkItem processing engine for pipeline stages - resident on this rank. WorkItems of multiple stages would share the same - queue in the RankWorker. RankWorker will also maintain states like the - number of outstanding WorkItems. - - * TODO: in-order execution - * Queueing of jobs and execution schedule, e.g. - * Static Schedules - * Fill-drain (GPipe) pipeline by serializing jobs - * TODO: 1F1B scheduling by serializing jobs and stalling for a specific - phase to come through - * TODO: Interleaved 1F1B (TODO: how to set up these data dependencies) - * Dynamic Schedules - * TODO: Varuna dynamic schedule - * TODO: dynamic scheduling via registers and back-pressure (TODO: how to - specify resource limits and how to implement backpressure?) - """ - - def __init__( - self, - rank, - all_stages, - max_outstanding=None, - pp_rank=None, - _record_mem_dumps=False, - checkpoint=False, - ): - logging.info(f"[{rank}] Instantiating RankWorker") - self.rank = rank - self.all_stages = all_stages - self.rank = rank - self.pp_rank = pp_rank - self._record_mem_dumps = _record_mem_dumps - self.checkpoint = checkpoint - - # Maximum outstanding micro-batches of the pipeline schedule - self.max_outstanding = max_outstanding - # Keeps track of the outstanding micro-batches in current rank executor - self.outstanding = 0 - self.stage_executors: Dict[int, PipeStageExecutor] = {} - self.events: List[Event] = [] - - self.waiting_runlist_lock = threading.Lock() - # self.waiting_runlist (*and the contained WorkItems*) are guarded by - # self.waiting_runlist_lock - self.waiting_runlist: Dict[str, WorkItem] = {} - - self.ready_runlist_lock = threading.Lock() - self.ready_runlist_cv = threading.Condition(self.ready_runlist_lock) - self.ready_runlist: Dict[str, WorkItem] = {} - - self.worker_thread = threading.Thread( - target=self.worker_loop, name=f"worker_{self.rank}", daemon=True - ) - self.worker_thread.start() - - def create_stage_executor(self, stage_id, mod, mod_name): - if stage_id in self.stage_executors: - raise AssertionError( - f"Rank {self.rank} already has stage {stage_id}" - ) - - assert ( - mod is not None or mod_name is not None - ), "PipeStageExecutor requires mod or mod_name" - - if mod is None: - with Pipe.stage_init_cv: - defer_called = Pipe.stage_init_cv.wait_for( - Pipe.is_stage_init_deferred, - timeout=100, # stop waiting after 100s - ) - if not defer_called: - raise AssertionError( - f"Rank {self.rank} did not defer stage {stage_id} initialization " - f"though pipeline driver expect it to do so." - ) - - self.stage_executors[stage_id] = PipeStageExecutor( - stage_id=stage_id, - mod=mod or Pipe.materialize_stage(mod_name), # type: ignore[attr-defined] - rank_worker=self, - _record_mem_dumps=self._record_mem_dumps, - ) - return self.stage_executors[stage_id] - - def enqueue_ready_runlist(self, unique_key, work_item): - with self.ready_runlist_cv: - logging.debug( - f"[{self.rank}] Current ready runlist keys: {self.ready_runlist.keys()}" - ) - self.ready_runlist[unique_key] = work_item - self.ready_runlist_cv.notify() - - def enqueue_waiting_runlist(self, unique_key, work_item): - with self.waiting_runlist_lock: - logging.debug( - f"[{self.rank}] Current waiting runlist keys: {self.waiting_runlist.keys()}" - ) - assert ( - unique_key not in self.waiting_runlist - ), f"key {unique_key} already in waiting runlist {self.waiting_runlist}" - self.waiting_runlist[unique_key] = work_item - - def worker_loop(self): - batch_id_to_remaining_backward_microbatches: Dict[int, int] = {} - while True: - work_item = None - with self.ready_runlist_cv: - while len(self.ready_runlist) == 0: - self.ready_runlist_cv.wait() - - logging.debug( - f"[{self.rank}] Dequeueing workitem from set of {len(self.ready_runlist)}" - ) - # TODO: extra priorities - for key in iter(self.ready_runlist.keys()): - # Skip forward work items if we hit the max outstanding limit - # If there are no other READY WorkItems, the runloop wraps around to the beginning and blocks again, - # waiting for another scheduled WorkItem to wake it back up. This works because the only condition - # that can schedule a WAITING Workitem is if another backward WorkItem executes and reduces the number - # of outstanding mciro-batches; - # If there are other READY WorkItems, the runloop executes as normally processing those - if ( - self.ready_runlist[key].phase == Phase.FORWARD - and self.max_outstanding is not None - and self.outstanding >= self.max_outstanding - ): - continue - work_item = self.ready_runlist.pop(key) - break - - # We may not fetch any actionable work item in the above loop, go - # back to the loop in this case - if work_item is None: - continue - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Got WorkItem {work_item}" - ) - - work_item.state = SchedState.RUNNING - args_value_refs = work_item.args - kwargs_value_refs = work_item.kwargs - future = work_item.future - microbatch_id = work_item.microbatch_id - ready_args = work_item.ready_args - phase = work_item.phase - try: - stage_executor = self.stage_executors[work_item.stage_id] - except KeyError: - raise RuntimeError( - f"Rank {self.rank} does not have stage {work_item.stage_id}" - f"Current keys {self.stage_executors.keys()}" - ) - - batch_id = work_item.batch_id - num_microbatches = work_item.num_microbatches - - if batch_id not in batch_id_to_remaining_backward_microbatches: - batch_id_to_remaining_backward_microbatches[ - batch_id - ] = num_microbatches - - start_ts = time.time() - name = event_name( - work_item.phase, work_item.stage_id, work_item.microbatch_id - ) - id = event_id( - work_item.phase, - work_item.stage_id, - work_item.microbatch_id, - work_item.batch_id, - ) - if self._record_mem_dumps: - stage_executor._record_dumps_on_all_peer_executors( - f"M{id}_start", start_ts - ) - - value_ref_arg_idx = 0 - - def retrieve_value_ref_args_by_idx(a): - if isinstance(a, ValueReference) and a.unique_key != "noop": - nonlocal value_ref_arg_idx - val = ready_args[value_ref_arg_idx] - value_ref_arg_idx += 1 - return val - else: - return a - - args = pippy.fx.node.map_aggregate( - args_value_refs, retrieve_value_ref_args_by_idx - ) - kwargs = pippy.fx.node.map_aggregate( - kwargs_value_refs, retrieve_value_ref_args_by_idx - ) - - def forward(args, kwargs, no_grad): - args, flat_args = flatten_args_detach(args) - kwargs, flat_kwargs = flatten_args_detach(kwargs) - # Contains all tensors from args and kwargs, in flattened form - flat_args += flat_kwargs - - logging.info( - f"[{self.rank}] Running forward module for microbatch {work_item.microbatch_id}" # type: ignore[union-attr] - ) - - def forward_maybe_with_ddp(args, kwargs): - if isinstance( - stage_executor.mod, - torch.nn.parallel.distributed.DistributedDataParallel, - ): - with stage_executor.mod.no_sync(): # type: ignore[operator] - out_val = stage_executor.mod(*args, **kwargs) - else: - out_val = stage_executor.mod(*args, **kwargs) - return out_val - - def set_requires_grad(a): - if isinstance(a, torch.Tensor) and a.is_floating_point(): - a.requires_grad_(True) - return a - - def dont_traverse_size(a): - return type(a) != torch.Size - - if no_grad: - with torch.no_grad(): - out_val = forward_maybe_with_ddp(args, kwargs) - out_val = pippy.fx.node.map_aggregate( - out_val, set_requires_grad, dont_traverse_size - ) - else: - with torch.enable_grad(): - out_val = forward_maybe_with_ddp(args, kwargs) - - return out_val, flat_args - - if phase == Phase.BACKWARD: - if self.checkpoint: - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Running backward phase. " - f"Rerunning forward because of checkpointing" - ) - f_args, f_kwargs = stage_executor.fwd_cache.pop( - microbatch_id - ) - out_val, flat_tensor_args = forward( - f_args, f_kwargs, no_grad=False - ) - kwargs = dict(kwargs) - kwargs["stage_output"], kwargs["input_values"] = ( - out_val if isinstance(out_val, tuple) else (out_val,), - flat_tensor_args, - ) - else: - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Running backward phase. " - f"Retrieving stashed values" - ) - # HACK: here we are directly accessing the saved tensor outputs - # for closed-over outputs so that they still have the grad_fn - # from local autograd. Can we solve this more elegantly? - kwargs = dict(kwargs) - ( - kwargs["stage_output"], - kwargs["input_values"], - ) = stage_executor.fwd_cache.pop(microbatch_id) - - if work_item.phase == Phase.FORWARD: - self.outstanding += 1 - out_val, flat_tensor_args = forward( - args, kwargs, no_grad=self.checkpoint - ) - if self.checkpoint: - stage_executor.fwd_cache[microbatch_id] = args, kwargs - else: - stage_executor.fwd_cache[microbatch_id] = ( - out_val if isinstance(out_val, tuple) else (out_val,), - flat_tensor_args, - ) - - elif work_item.phase == Phase.BACKWARD: - logging.info( - f"[{self.rank}] Running backward for microbatch {work_item.microbatch_id}" - ) - - batch_id_to_remaining_backward_microbatches[batch_id] -= 1 - - if ( - isinstance( - stage_executor.mod, - torch.nn.parallel.distributed.DistributedDataParallel, - ) - and batch_id_to_remaining_backward_microbatches[batch_id] - == 0 - ): - # HACK: reaching into DDP implementation details here. Is there a better way? - stage_executor.mod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] - list( - torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] - kwargs["stage_output"] - ) - ) - ) - - out_val = stage_backward(*args, **kwargs) - - # Schedule forward stage of a new micro-batch - self.outstanding -= 1 - elif work_item.phase == Phase.ACCUMULATE_GRAD: - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Running accumulate grad" - ) - out_val = _null_coalesce_accumulate(*args, **kwargs) - elif work_item.phase == Phase.SYNC_BARRIER: - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Running sync_barrier" - ) - out_val = sync_barrier(*args, **kwargs) - else: - assert ( - False - ), f"Unrecognized phase {work_item.phase} encountered in execution" - - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Populating result of type {type(out_val)} " - f"for {key}" - ) - future.set_result(out_val) - work_item.state = SchedState.DONE - - prev_name = prev_event_name( - work_item.phase, - self.all_stages, - work_item.stage_id, - work_item.microbatch_id, - ) - next_name = next_event_name( - work_item.phase, - self.all_stages, - work_item.stage_id, - work_item.microbatch_id, - ) - finish_ts = time.time() - self.record_event( - rank=self.rank, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type=work_item.phase, - mbid=work_item.microbatch_id, - ) - self.record_event_dependency( - from_id=prev_name, to_id=name, type="transfer" - ) - self.record_event_dependency( - from_id=name, to_id=next_name, type="transfer" - ) - - if self._record_mem_dumps: - stage_executor._record_dumps_on_all_peer_executors( - f"M{id}_finish", finish_ts - ) - - # For work item marked with runlist_key, update its operand list with value - def update_run_list(self, runlist_key, arg_idx, value): - with self.waiting_runlist_lock: - work_item = self.waiting_runlist[runlist_key] - work_item.ready_args[arg_idx] = value - work_item.blocked_args_count -= 1 - if work_item.blocked_args_count == 0: - with self.ready_runlist_cv: - work_item.state = SchedState.READY - self.ready_runlist[runlist_key] = self.waiting_runlist.pop( - runlist_key - ) - self.ready_runlist_cv.notify() - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] all operands ready: {runlist_key}" - ) - - -class PipeStageExecutor(EventRecorder): - """ - PipeStageExecutor encapsulates the execution semantics of a fragment of - code on a pipeline stage. PipeStageExecutor handles: - - * Ownership of the stage's module and its recursive submodules/parameters - * Serving as an entrypoint for the driver to push jobs into RankWorker's queue - * TODO: gradient checkpointing - """ - - def __init__(self, stage_id, mod, rank_worker, _record_mem_dumps=False): - logging.info(f"Instantiating PipeStageExecutor for stage {stage_id}") - self.stage_id = stage_id - self.mod = mod - self.rank_worker = rank_worker - # map microbatch ID to list of forward tensor args - self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} - - self.value_store_lock = threading.Lock() - self.value_store_cv = threading.Condition(self.value_store_lock) - self.value_store: Dict[str, RefcountedFuture] = {} - - self.peer_executors: Dict[int, torch._C._distributed_rpc.PyRRef] = None # type: ignore[assignment] - self._record_mem_dumps = _record_mem_dumps - - self.optimizer = None - # Used to ensure optimizer is created before we create learning rate scheduler - self.optim_init_lock = threading.Lock() - self.optim_init_cv = threading.Condition(self.optim_init_lock) - - self.lr_scheduler = None - self.device = self._find_mod_device() - - # Send/recv order normalization - self.callee_send_tag: Dict[int, int] = {} # callee stage: tag seq num - self.caller_recv_tag: Dict[int, int] = {} # caller stage: tag seq num - self.callee_send_tag_lock = threading.Lock() - self.caller_recv_tag_lock = threading.Lock() - self.caller_recv_tag_cv = threading.Condition(self.caller_recv_tag_lock) - - def _find_mod_device(self): - # We assume that all parameters in the module are on the same device - # HACK: we assume the module has at least one parameter - param = next(self.mod.parameters(), None) - buffer = next(self.mod.buffers(), None) - if param is not None: - device = param.device - elif buffer is not None: - device = buffer.device - else: - logging.warning( - f"Module of stage {self.stage_id} has no parameter or buffer, " - f"cannot figure out device. Setting it to cpu" - ) - device = torch.device("cpu") - return device - - def __getstate__(self): - # Adding an empty __getstate__ function here to work around the DDP pickling issue (#153) that occurs when the - # PipelineDiver asks PipeStageExecutors to install_peer_executor(a list of RRefs) - # More elegant solution is needed in CUDAFuture or RPC to avoid pickling when users do not need to transfer - # tensors - pass - - def install_peer_executors(self, peer_executors): - assert self.peer_executors is None - self.peer_executors = peer_executors - return None - - def init_data_parallel(self, n_stages, dp_group_size, dp_pg_cb=None): - worker_rank = self.rank_worker.rank - if dp_pg_cb is not None: - logging.info( - f"Rank[{worker_rank}] stage[{self.stage_id}] Initializing data parallel: " - f"using DP process groups provided by user" - ) - self.mod = torch.nn.parallel.DistributedDataParallel( - self.mod, process_group=dp_pg_cb(self.stage_id) - ) - return - - logging.debug( - f"Rank[{worker_rank}] stage[{self.stage_id}] Initializing data parallel: " - f"creating DP process groups internally" - ) - # Discover DP peers via Store - # HACK: using the Store coming with the default process group - _store = torch.distributed.distributed_c10d._get_default_store() - # Wrap default store by adding a prefix to each key inserted so as not to step into default store's space - store = torch.distributed.PrefixStore("PiPPy", _store) - # TODO: figure out the unique global "stage rank" for Interleaved 1F1B - my_rank = str(worker_rank) - my_stage = str(self.stage_id) - # Each stage rank checks in with their stage id in respective pipe - store.set(my_rank, my_stage) - - # Create a mapping from stage id to DP ranks - stage_to_dp_ranks: Dict[int, List[int]] = {} - for stage in range(n_stages): - stage_to_dp_ranks.setdefault(stage, []) - - # Wait for all stages to check in - world_size = n_stages * dp_group_size - all_ranks = [str(i) for i in range(world_size)] - store.wait(all_ranks) - logging.debug( - f"Rank[{worker_rank}] stage[{self.stage_id}] Initializing data parallel: all stages have checked in" - ) - - # Fill the mapping - for rank in all_ranks: - stage = store.get(rank) # type: ignore[assignment] - stage_to_dp_ranks[int(stage)].append(int(rank)) - - # Create DP process group for each stage - # Note: even if a rank is not in the DP group of another stage, it must still participate in the new_group call of - # that stage; this is required by c10d - for stage in range(n_stages): - dp_group_ranks = stage_to_dp_ranks[stage] - dp_pg_for_stage = torch.distributed.new_group(dp_group_ranks) - if stage == self.stage_id: - logging.info( - f"Rank[{worker_rank}] stage[{self.stage_id}] " - f"DP group {dp_group_ranks} -- init complete" - ) - - # Wrap stage module with DDP using the DP group corresponding to own stage - if self.stage_id == stage: - self.mod = torch.nn.parallel.DistributedDataParallel( - self.mod, process_group=dp_pg_for_stage - ) - - def create_future(self): - # Future constructor does not accept CPU device, must set to None - return torch.futures.Future( - devices=None if self.device.type == "cpu" else [self.device] - ) - - def invoke( - self, - output_unique_key: str, - phase: Phase, - args, - kwargs, - cur_microbatch: int, - debug_str: str, - output_refcount: int, - batch_id: int, - num_microbatches: int, - ): - start_ts = time.time() - target_name = event_name(phase, self.stage_id, cur_microbatch) - target_id = event_id(phase, self.stage_id, cur_microbatch, batch_id) - name = f"R{target_name}" - id = f"R{target_id}" - if self._record_mem_dumps: - self._record_dumps_on_all_peer_executors(f"M{id}_invoke", start_ts) - # TODO: do we need to serialize calls to invoke() to preserve the order in which WorkItems appear for - # static schedules? - - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Received invoke call for {debug_str}" - ) - # Extract all ValueRef arguments so we can spawn asynchronous data transfers - # for each of them - value_ref_args: List[ValueReference] = [] - - def extract_value_ref_args(arg): - if isinstance(arg, ValueReference) and arg.unique_key != "noop": - value_ref_args.append(arg) - - pippy.fx.node.map_aggregate(args, extract_value_ref_args) - pippy.fx.node.map_aggregate(kwargs, extract_value_ref_args) - - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Invoke call found {len(value_ref_args)} ValueReference arguments" - ) - - # Construct WorkItem for this microbatch+phase and record it in the - # waiting runlist - - # We provide device to the Future constructor so that between - # future.set_result() and future.wait() correct dependencies can be - # captured - # We assume the output value is on the same device as the stage's parameters - - # Future constructor does not accept CPU device, must set to None - future: torch.futures.Future = self.create_future() - - # TODO: increase blocked_args_count for extra things like scheduling - work_item = WorkItem( - stage_id=self.stage_id, - phase=phase, - args=args, - kwargs=kwargs, - future=future, - microbatch_id=cur_microbatch, - blocked_args_count=len(value_ref_args), - ready_args={}, - batch_id=batch_id, - num_microbatches=num_microbatches, - debug_str=debug_str, - ) - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Invoke instantiated WorkItem {work_item} with key {output_unique_key}" - ) - if len(value_ref_args) == 0: - # TODO: convert initial input into ValueRef? - # We always put this work item into the ready queue, though we mark - # it with different state flags depending on whether the schedule - # would hold it based on max outstanding allowed - work_item.state = SchedState.READY - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] No RRef arguments. " - f"Scheduling directly as READY workitem" - ) - self.rank_worker.enqueue_ready_runlist(output_unique_key, work_item) - else: - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Scheduling WorkItem as WAITING workitem" - ) - work_item.state = SchedState.WAITING - self.rank_worker.enqueue_waiting_runlist( - output_unique_key, work_item - ) - - # Group Value Ref Args based on source stage - # `callee_stage_dict` has the following structure: - # Dict[callee_stage, Dict[my_arg_idx, value_ref]] - callee_stage_dict: Dict[int, Dict[int, ValueReference]] = {} - for arg_idx, value_ref_arg in enumerate(value_ref_args): - # Check if the ValRef corresponds to a tensor - if "tensor_meta" in value_ref_arg.meta: - callee_stage = value_ref_arg.stage_id - batch_refs = callee_stage_dict.setdefault(callee_stage, {}) - batch_refs[arg_idx] = value_ref_arg - else: - # For non-tensor (e.g. a value or a size vector), we use RPC to spawn asynchronous data transfer - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Launching RPC data transfer for " - f"ValueReference {arg_idx} {value_ref_arg}" - ) - self.async_transfer( - cur_microbatch, value_ref_arg, arg_idx, output_unique_key - ) - - # For tensors, we use c10d two-sided send/recv - # Batch call per source stage to reduce number of RPC threads - with self.callee_send_tag_lock: - for callee_stage, batch_refs in callee_stage_dict.items(): - value_ref_executor_rref = self.peer_executors[callee_stage] - tag = self.callee_send_tag.setdefault(callee_stage, 0) - self.callee_send_tag[callee_stage] += 1 - value_ref_executor_rref.rpc_async().batch_send( - self.stage_id, - output_unique_key, - cur_microbatch, - batch_refs, - tag, - ) - self.batch_recv( - cur_microbatch, - output_unique_key, - callee_stage, - batch_refs, - tag, - ) - - with self.value_store_cv: - assert output_unique_key not in self.value_store, ( - f"[{self.stage_id}] Output key {output_unique_key} " - f"already exists or is not consumed from previous batch" - ) - self.value_store[output_unique_key] = RefcountedFuture( - future, output_refcount - ) - self.value_store_cv.notify_all() - - finish_ts = time.time() - self.record_event( - rank=self.rank_worker.rank, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type="received", - mbid=cur_microbatch, - ) - self.record_event_dependency( - from_id=name, to_id=target_name, type="waiting" - ) - - return ValueReference(self.stage_id, output_unique_key) - - def coalesced_index_value( - self, indices: List[Tuple[str, int, ValueReference, int]] - ): - for index_tuple in indices: - # `output_unique_key` is the key for the indexed output value (single) - # `value_ref.unique_key` is the key for the overall output of current stage (can have multiple values) - (output_unique_key, output_refcount, value_ref, idx) = index_tuple - logging.debug( - f"[{self.stage_id}] Received getitem call: {(output_unique_key, output_refcount, value_ref, idx)}" - ) - with self.value_store_cv: - # TODO: investigate why value reference in the last batch has not been fully consumed - if output_unique_key in self.value_store: - logging.debug( - f"[{self.stage_id}] Indexed value already in store: {(output_unique_key, output_refcount, value_ref, idx)}" - ) - # raise RuntimeError(f'Repeated index value call detected, potentially due to getitem calls not consumed in previous batch') - - # Wait for the future representing the stage output to be created - while value_ref.unique_key not in self.value_store: - self.value_store_cv.wait() - # Now the stage output future is created - refcounted_future = self.value_store[value_ref.unique_key] - - # For the purposes of refcounting, decrement this use - if refcounted_future.release(): - self.value_store.pop(value_ref.unique_key) - - # Create an indexed future that represents a specific output arg - # Here we use an attach functon so that the index passed to the lambda changes in every loop - def attach_index(fut, index): - indexed_fut = fut.then(lambda f: f.value()[index]) - return indexed_fut - - indexed = attach_index(refcounted_future.future, idx) - - # Enqueue the indexed future - # And notify places that may be waiting for it to be created, such as get_value - self.value_store[output_unique_key] = RefcountedFuture( - indexed, output_refcount - ) - self.value_store_cv.notify_all() - - def get_value( - self, - caller_stage, - runlist_key, - microbatch, - value_ref_arg, - ): - callee_stage = value_ref_arg.stage_id - logging.debug( - f"[{callee_stage}][{microbatch}] Executing transfer of value " - f"{value_ref_arg} initiated by stage {caller_stage} for {runlist_key}" - ) - assert ( - callee_stage == self.stage_id - ), "Mismatch between ValueRef and stage executor" - - with self.value_store_cv: - # Waiting for the indexed future for this arg to be created - while value_ref_arg.unique_key not in self.value_store: - self.value_store_cv.wait() - # Now the indexed future is created - refcounted_future = self.value_store[value_ref_arg.unique_key] - - value = refcounted_future.future.wait() - - with self.value_store_lock: - if refcounted_future.release(): - self.value_store.pop(value_ref_arg.unique_key) - - return value - - def async_transfer(self, microbatch, value_ref_arg, arg_idx, runlist_key): - logging.debug( - f"[{self.stage_id}][{microbatch}] Requesting transfer of value {value_ref_arg} " - f"for runlist item {runlist_key} arg_idx {arg_idx}" - ) - callee_stage = value_ref_arg.stage_id - value_ref_executor_rref = self.peer_executors[callee_stage] - - fut = value_ref_executor_rref.rpc_async().get_value( - self.stage_id, - runlist_key, - microbatch, - value_ref_arg, - ) - - def bottom_half(fut): - logging.debug( - f"[{self.stage_id}][{microbatch}] Completing transfer of value {value_ref_arg} " - f"for runlist item {runlist_key} arg_idx {arg_idx}" - ) - value = fut.value() - self.rank_worker.update_run_list(runlist_key, arg_idx, value) - - return fut.then(bottom_half) - - def batch_send( - self, - caller_stage, - runlist_key, - microbatch, - batch_refs, - tag, - ): - # Wait till this batch's turn to send - with self.caller_recv_tag_cv: - self.caller_recv_tag.setdefault(caller_stage, 0) - while self.caller_recv_tag[caller_stage] < tag: - self.caller_recv_tag_cv.wait() - - logging.debug( - f"[{self.stage_id}][{microbatch}] Sending batch {tag} of " - f"{len(batch_refs)} values initiated by stage {caller_stage} for {runlist_key}" - ) - - for _, value_ref_arg in batch_refs.items(): - with self.value_store_cv: - # Waiting for the indexed future for this arg to be created - while value_ref_arg.unique_key not in self.value_store: - self.value_store_cv.wait() - # Now the indexed future is created - refcounted_future = self.value_store[value_ref_arg.unique_key] - - value = refcounted_future.future.wait() - - with self.value_store_lock: - if refcounted_future.release(): - self.value_store.pop(value_ref_arg.unique_key) - - # Instead of return value let's do a send call - if torch.distributed.get_backend() == "gloo": - # Gloo P2P does not support work.get_future, so we use send instead - torch.distributed.send(value, caller_stage, tag=tag) - else: - torch.distributed.isend(value, caller_stage, tag=tag) - - # Notify next send that's potentially waiting - with self.caller_recv_tag_cv: - self.caller_recv_tag[caller_stage] += 1 - self.caller_recv_tag_cv.notify_all() - - def batch_recv( - self, microbatch, runlist_key, callee_stage, batch_refs, tag - ): - logging.debug( - f"[{self.stage_id}][{microbatch}] Receiving batch {tag} of {len(batch_refs)} values " - f"for runlist item {runlist_key} from stage {callee_stage}" - ) - futures = [] - - for arg_idx, value_ref_arg in batch_refs.items(): - tm = value_ref_arg.meta["tensor_meta"] - recv_buff = torch.empty( - tm.shape, dtype=tm.dtype, device=self.device - ) - - if torch.distributed.get_backend() == "gloo": - # Gloo P2P does not support work.get_future, so we need to: - # - manually create the Future, - # - use recv instead, and - # - manually set_result to the Future - fut: torch.futures.Future = self.create_future() - torch.distributed.recv(recv_buff, callee_stage, tag=tag) - fut.set_result(recv_buff) - else: - work = torch.distributed.irecv(recv_buff, callee_stage, tag=tag) - fut = work.get_future() # type: ignore[attr-defined] - - def bottom_half(fut): - logging.debug( - f"[{self.stage_id}][{microbatch}] Completing transfer of value {value_ref_arg} " - f"for runlist item {runlist_key} arg_idx {arg_idx}" - ) - value = fut.value() - # It is awkward that the Work class in PyTorch fixes the result return to a List: - # def result(self) -> List[Tensor]: ... - # See torch/_C/_distributed_c10d.pyi - # We don't expect P2P operations to actually result in a List, hence unpacking and getting the first and - # only tensor out - if isinstance(value, List): - value = value[0] - self.rank_worker.update_run_list(runlist_key, arg_idx, value) - - futures.append(fut.then(bottom_half)) - - return futures - - def get_grad(self, qualname): - mod = self.mod - if isinstance(mod, torch.nn.parallel.DistributedDataParallel): - mod = mod.module - return mod.get_parameter(qualname).grad - - def set_grad(self, qualname, value): - mod = self.mod - if isinstance(mod, torch.nn.parallel.DistributedDataParallel): - mod = mod.module - param = mod.get_parameter(qualname) - param.grad = value - - def train(self, mode=True): - self.mod.train(mode=mode) - - def _should_instantiate_optim(self): - return len(list(self.mod.parameters())) > 0 - - def instantiate_optimizer(self, optim_class, *args, **kwargs): - assert self._should_instantiate_optim() - with self.optim_init_cv: - self.optimizer = optim_class(self.mod.parameters(), *args, **kwargs) - self.optim_init_cv.notify() - return self.optimizer - - def instantiate_lr_scheduler(self, lr_sched_class, *args, **kwargs): - # Make sure optimizer has been created - with self.optim_init_cv: - while self.optimizer is None: - self.optim_init_cv.wait() - - logging.info(f"[{self.stage_id}] Creating learning rate scheduler") - self.lr_scheduler = lr_sched_class(self.optimizer, *args, **kwargs) - return self.lr_scheduler - - def step_lr_scheduler(self, *args, **kwargs): - self.lr_scheduler.step(*args, **kwargs) # type: ignore[union-attr] - - def _check_cleanup(self) -> bool: - if len(self.value_store): - logging.warning( - f"[{self.stage_id}] Unclean value store: {self.value_store}" - ) - return False - return True - - def _record_dump(self, dump_id, ts): - first_param = next(self.mod.parameters(), None) - device: torch.device = ( - first_param.device - if first_param is not None - else torch.device("cpu") - ) - if device.type == "cuda": - alloc = torch.cuda.memory_allocated() - max_alloc = torch.cuda.max_memory_allocated() - rsrvd = torch.cuda.memory_reserved() - max_rsrvd = torch.cuda.max_memory_reserved() - assert ( - alloc <= max_alloc - ), f"alloc = {alloc} max_alloc = {max_alloc}" - assert ( - rsrvd <= max_rsrvd - ), f"rsrvd = {rsrvd} max_rsrvd = {max_rsrvd}" - assert ( - max_alloc <= max_rsrvd - ), f"max_alloc = {max_alloc} max_rsrvd = {max_rsrvd}" - self.record_dump( - rank=self.rank_worker.rank, - ts=ts, - id=dump_id, - name=dump_id, - type="dump", - allocators={ - "cuda.4.alloc": Allocator( - f"alloc_{self.rank_worker.rank}", - { - "size": alloc, - }, - ), - "cuda.3.max_alloc-alloc": Allocator( - f"max_alloc-alloc_{self.rank_worker.rank}", - { - "size": max_alloc - alloc, - }, - ), - "cuda.2.rsrvd-max_alloc": Allocator( - f"rsrvd-max_alloc_{self.rank_worker.rank}", - { - "size": max(rsrvd - max_alloc, 0), - }, - ), - "cuda.1.max_rsrvd-max_alloc_or_rsrvd": Allocator( - f"max_rsrvd-max_alloc_or_rsrvd_{self.rank_worker.rank}", - { - "size": max_rsrvd - - (max_alloc if max_alloc > rsrvd else rsrvd), - }, - ), - }, - ) - - def _record_dumps_on_all_peer_executors(self, id, ts): - for peer_executor_rref in self.peer_executors.values(): - peer_executor_rref.rpc_sync()._record_dump(f"{id}", ts) - - -def _wait_for_all(rpc_futs): - # Stolen from DistributedOptimizer implementation - # TODO: improve error propagation - exception = None - results = [] - for fut in rpc_futs: - try: - results.append(fut.wait()) - except Exception as e: - results.append(e) - exception = e - if exception is not None: - raise exception - return results - - -class PipelineOptimizer(torch.optim.Optimizer): - def __init__(self, remote_optims): - self.remote_optims = remote_optims - - # TODO: enable this - # self._hook_for_profile() - - # TODO: enable this - # self.state = defaultdict(dict) - - self.param_groups = [] - - # Collect RRefs to remote parameters - param_group = {"params": []} # type: ignore[var-annotated] - - for optim in self.remote_optims: - remote_state = optim.rpc_sync().__getstate__() - assert isinstance(remote_state, dict) - for group in remote_state["param_groups"]: - param_group["params"].extend(group["params"]) - for k in group: - if k != "params": - param_group.setdefault(k, group[k]) - - self.param_groups = [param_group] - - def __getstate__(self): - raise NotImplementedError() - - def __setstate__(self, state): - raise NotImplementedError() - - def _hook_for_profile(self): - raise NotImplementedError() - - def state_dict(self): - raise NotImplementedError() - - def load_state_dict(self, state_dict): - raise NotImplementedError() - - # PyTorch type annotation for this function is wrong. See - # https://github.com/pytorch/pytorch/pull/76998 for proposed fix - def zero_grad(self, set_to_none: bool = False): # type: ignore - futs = [] - for optim in self.remote_optims: - futs.append(optim.rpc_async().zero_grad(set_to_none)) - _wait_for_all(futs) - - def step(self, closure=None): - futs = [] - for optim in self.remote_optims: - futs.append(optim.rpc_async().step(closure)) - _wait_for_all(futs) - - def add_param_group(self, param_group): - raise NotImplementedError() - - -class PipelineLRScheduler(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, stage_to_scheds, stage_to_executor): - # A dict from stage id to LR schedulers - self.stage_to_scheds = stage_to_scheds - self.stage_to_executor = stage_to_executor - self.new_step_called = False - self.last_lr = [] - - def step(self, *args, **kwargs): - futs = [] - # Step all remote LR schedulers - - # We use the executor block below because calling scheduler.step() - # remotely might cause pickling nested functions, where these nested - # functions are usually defined inside user's lr scheduler constructor - # as lambda functions to be used by the lr scheduler - # See https://github.com/pytorch/PiPPy/issues/404 - """ - for scheduler in self.stage_to_scheds.values(): - futs.append(scheduler.rpc_async().step(*args, **kwargs)) - """ - for executor in self.stage_to_executor.values(): - futs.append(executor.rpc_async().step_lr_scheduler(*args, **kwargs)) - - _wait_for_all(futs) - # Mark new step (invalidates last_lr) - self.new_step_called = True - - def get_last_lr(self): - """Return last computed learning rate by remote schedulers.""" - # No need to involve remote schedulers if no new step calls - if not self.new_step_called: - return self.last_lr - - # Ask LR scheduler of stage 0 to return new learning rate as representation of all stages, because: - # (i) we do not support multiple parameter groups yet (neither PipelineOptimizer nor PipelineLRScheduler does), - # so there are not param group specific LR's; and - # (ii) current LRS implementations do not relies on state within the optimizer, so the LR's of different stages - # will not diverge - assert self.stage_to_scheds, "No learning rate scheduler" - self.last_lr = self.stage_to_scheds[0].remote().get_last_lr().to_here() - self.new_step_called = False - return self.last_lr - - def state_dict(self): - """Returns the state of the remote schedulers as a :class:`dict`""" - # Ask LR scheduler of stage 0 to return state_dict as representation of all stages, for the same reason as - # stated in get_last_lr() - rv: Dict = {} - assert self.stage_to_scheds, "No learning rate scheduler" - rv = self.stage_to_scheds[0].remote().state_dict().to_here() - return rv - - def load_state_dict(self, state_dict): - """Loads the scheduler state. - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - futs = [] - for scheduler in self.stage_to_scheds.values(): - futs.append(scheduler.rpc_async().load_state_dict(state_dict)) - - _wait_for_all(futs) - - def get_lr(self): - # Even in single scheduler setting, get_lr is more of an internal method to be called by step() - # See: pytorch/torch/optim/lr_scheduler.py - warnings.warn( - "To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`." - ) - raise NotImplementedError - - def print_lr(self, is_verbose, group, lr, epoch=None): - """Display the current learning rate.""" - # This is more of an internal method of native scheduler - # See: pytorch/torch/optim/lr_scheduler.py - raise NotImplementedError - - -class PipelineDriverBase(torch.nn.Module): - def __init__( - self, - pipe: Pipe, - chunks: int, - world_size: int, - all_ranks: List[int] = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - _debug_mask_minibatches: bool = False, - max_outstanding=None, - interleave_stages=False, - _record_mem_dumps=False, - checkpoint=False, - use_c10d=False, - loss_reducer: LossReducer = sum_reducer, - ): - super().__init__() - self.pipe = pipe - self.chunks = chunks - self.world_size = world_size - self.all_ranks = all_ranks - self.args_chunk_spec = args_chunk_spec - self.kwargs_chunk_spec = kwargs_chunk_spec - self.loss_reducer = loss_reducer - self.output_chunk_spec = ( - output_chunk_spec - if output_chunk_spec - else gen_output_chunk_spec(pipe.loss_spec, loss_reducer) - ) - - # Maximum outstanding micro-batches allowed by the pipeline schedule - # None means no limit - self.max_outstanding: Optional[int] = max_outstanding - self._debug_mask_minibatches = _debug_mask_minibatches - self.interleave_stages = interleave_stages - - self.microbatch_interpreters: List[RemoteInterpreter] = [] - self.batch_id = 0 - self._record_mem_dumps = _record_mem_dumps - self.optimizer_inited = False - self.checkpoint = checkpoint - self.use_c10d = use_c10d - - def _init_remote_executors(self): - self.rank_worker_rrefs: Dict[int, torch.distributed.rpc.RRef] = {} - self.remote_stage_executor_rrefs: Dict[ # type: ignore[syntax] - str, (int, torch.distributed.rpc.RRef) - ] = {} - - if self.all_ranks is not None: - assert ( - len(self.all_ranks) == self.world_size - ), "Explicitly specified ranks must match world_size" - else: - self.all_ranks = list(range(self.world_size)) - logging.info( - f"[root] Creating pipeline driver with {self.world_size} workers: {self.all_ranks}" - ) - - class ExecutorDescriptor: - name: str - mod: Optional[torch.nn.Module] - has_backward: bool = False - - split_gm = self.pipe.split_gm - - executor_descriptors = [] - bw_idx = -1 - for node in split_gm.graph.nodes: - if node.op == "call_module": - descr = ExecutorDescriptor() - descr.name = node.target - if Pipe.is_stage_init_deferred(): - descr.mod = None - else: - descr.mod = split_gm.get_submodule(node.target) - executor_descriptors.append(descr) - elif (node.op, node.target) == ("call_function", stage_backward): - executor_descriptors[bw_idx].has_backward = True - node.meta["fw_stage"] = executor_descriptors[bw_idx].name - bw_idx -= 1 - elif (node.op, node.target) == ( - "call_function", - _null_coalesce_accumulate, - ): - node.meta["fw_stage"] = executor_descriptors[bw_idx].name - - assert all(d.has_backward for d in executor_descriptors) or all( - not d.has_backward for d in executor_descriptors - ) - - if len(executor_descriptors) > self.world_size: - if not self.interleave_stages: - raise RuntimeError( - f"Tried to run pipeline with {len(executor_descriptors)} stages with a world size of " - f"{self.world_size}. Please ensure world_size is large enough to accommodate your pipeline." - ) - - ranks_to_launch = self.world_size - n_stages = len(executor_descriptors) - if n_stages < self.world_size: - ranks_to_launch = n_stages - warnings.warn( - f"Running pipeline with {n_stages} stages on world_size of {self.world_size}. " - f"Remaining ranks will be idle." - ) - - if self.interleave_stages and n_stages <= ranks_to_launch: - self.interleave_stages = False - warnings.warn( - "Falling back from Interleaved 1F1B to 1F1B " - "since there are enough ranks to support one stage per rank" - ) - - # Fire up rank workers - all_stages = list(range(n_stages)) - pp_rank = 0 - for rank in self.all_ranks[:ranks_to_launch]: - kwargs = { - "rank": rank, - "all_stages": all_stages, - "max_outstanding": self.max_outstanding, - "pp_rank": pp_rank, - "_record_mem_dumps": self._record_mem_dumps, - "checkpoint": self.checkpoint, - } - self.rank_worker_rrefs[rank] = rpc.remote( - rank, RankWorker, args=(), kwargs=kwargs - ) - pp_rank += 1 - - self.stage_to_executor: Dict = {} - - # Ask each RankWorker to create stage thereon - # This can involve checkpoint loading in deferred init case - for stage_id, descr in enumerate(executor_descriptors): - # Assign stages to rank workers in a round-robin fashion - rank = self.all_ranks[stage_id % self.world_size] - logging.debug(f"[root] Sending stage_id = {stage_id} mod to worker") - self.remote_stage_executor_rrefs[descr.name] = ( - stage_id, - self.rank_worker_rrefs[rank] - .remote() - .create_stage_executor( - stage_id=stage_id, - mod=descr.mod, - mod_name=descr.name, - ), - ) - - # Check that each RankWorker has completed stage init - for stage_id, descr in enumerate(executor_descriptors): - logging.debug( - f"[root] Waiting stage_id = {stage_id} mod to be confirmed by worker" - ) - while not self.remote_stage_executor_rrefs[descr.name][ - 1 - ].confirmed_by_owner(): - pass - - self.stage_to_executor[stage_id] = self.remote_stage_executor_rrefs[ - descr.name - ][1] - - # Inform executors of their peers - for stage_id, executor in self.stage_to_executor.items(): - executor.rpc_sync().install_peer_executors(self.stage_to_executor) - - """ - Method for creating a data parallel clique for each stage, across multiple pipelines - dp_group_size: size of each data parallel group, equals to the number of pipelines - dp_pg_cb: optional Callable taking pipeline stage as argument and returning corresponding data parallel group; - user can use this Callable to pass in prepared data parallel groups - """ - - def init_data_parallel(self, dp_group_size, dp_pg_cb=None): - if dp_group_size <= 1: - logging.info( - "[root] Data parallel group size <= 1, skipping data parallel initialization" - ) - return - - n_stages = len(self.stage_to_executor) - logging.info( - f"[root] Initializing {n_stages} data parallel groups, each of size {dp_group_size}" - ) - futs = [] - # Asks all stage executors to participate in DP process group init - # These must be async calls because otherwise there will be deadlocks - for executor in self.stage_to_executor.values(): - futs.append( - executor.rpc_async().init_data_parallel( - n_stages, dp_group_size, dp_pg_cb - ) - ) - - # Here we wait for all DP process groups to be initialized before the user can ask the PipeDriver to run - _wait_for_all(futs) - - def forward(self, *args, **kwargs): - raise NotImplementedError( - "PipelineDriverBase is an abstract base class, please use a concrete " - "implementation class." - ) - - def train(self, mode=True): - for executor in self.stage_to_executor.values(): - executor.rpc_sync().train(mode=mode) - - def eval(self): - self.train(mode=False) - - def instantiate_optimizer(self, optim_class, *args, **kwargs): - remote_optims = [] - # Keeps track of stage to optimizer mapping - self.stage_to_optim: Dict = {} - for stage, executor in self.stage_to_executor.items(): - if executor.rpc_sync()._should_instantiate_optim(): - remote_optim = executor.remote().instantiate_optimizer( - optim_class, *args, **kwargs - ) - remote_optims.append(remote_optim) - self.stage_to_optim[stage] = remote_optim - - self.optimizer_inited = True - return PipelineOptimizer( - [optim for optim in remote_optims if optim is not None] - ) - - """ - Create learning rate scheduler for the optimizer of the pipeline. - Note: this API cannot be called before instantiate_optimizer is called. - """ - - def instantiate_lr_scheduler(self, lr_sched_class, *args, **kwargs): - if not self.optimizer_inited: - raise RuntimeError( - "[root] instantiate_optimizer must be called before instantiate_lr_scheduler" - ) - - stage_to_scheds: Dict = {} - for stage, optim in self.stage_to_optim.items(): - if optim is not None: - executor = self.stage_to_executor[stage] - remote_lr_sched = executor.remote().instantiate_lr_scheduler( - lr_sched_class, *args, **kwargs - ) - stage_to_scheds[stage] = remote_lr_sched - - return PipelineLRScheduler(stage_to_scheds, self.stage_to_executor) - - def _sync_replicated_params(self): - logging.debug( - f"[root] Synchronizing gradients for {len(self.pipe.replicated_params)} sets of replicated parameters" - ) - for param_set in self.pipe.replicated_params: - grad_values = [] - for module_name, param_qualname in param_set.items(): - assert module_name in self.remote_stage_executor_rrefs - stage_id, module_rref = self.remote_stage_executor_rrefs[ - module_name - ] - grad_value = module_rref.rpc_sync().get_grad(param_qualname) - grad_values.append(grad_value) - - synced_value = torch.sum(torch.stack(grad_values), dim=0) - - for module_name, param_qualname in param_set.items(): - assert module_name in self.remote_stage_executor_rrefs - stage_id, module_rref = self.remote_stage_executor_rrefs[ - module_name - ] - module_rref.rpc_sync().set_grad(param_qualname, synced_value) - - def _retrieve_output_values(self, microbatch_interpreters, last_nodes): - logging.debug( - f"[root] Retrieving output values from {len(microbatch_interpreters)} chunks" - ) - output_vals = [] - for interp, last_node in zip(microbatch_interpreters, last_nodes): - interp.run_until(lambda n: False) - output_vals.append(interp.env[last_node]) - - # First kick of async transfers to retrieve ValueReference values - def initiate_async_transfer(a): - if isinstance(a, ValueReference): - value_ref_executor_rref = self.stage_to_executor[a.stage_id] - return value_ref_executor_rref.rpc_async().get_value( - "root", "collect", -1, a - ) - else: - return a - - output_vals = pippy.fx.node.map_aggregate( - output_vals, initiate_async_transfer - ) - - # Then wait for futures to be ready - return pippy.fx.node.map_aggregate( - output_vals, - lambda a: a.wait() if isinstance(a, torch._C.Future) else a, - ) - - def retrieve_events(self) -> EventsContext: - events_context = EventsContext() - for rank, worker_rref in self.rank_worker_rrefs.items(): - events_context.update(worker_rref.rpc_sync().retrieve_events()) - for interp in self.microbatch_interpreters: - events_context.update(interp.retrieve_events()) - for _, executor_rref in self.remote_stage_executor_rrefs.values(): - events_context.update(executor_rref.rpc_sync().retrieve_events()) - events_context.events.sort(key=lambda e: e.start_ts) - return events_context - - def _check_stages_cleanup(self) -> bool: - clean = True - for executor in self.stage_to_executor.values(): - clean &= executor.rpc_sync()._check_cleanup() - return clean - - -class RemoteInterpreter(pippy.fx.Interpreter, EventRecorder): - def __init__( - self, - remote_stage_executor_rrefs, - stage_to_executor, - module, - cur_microbatch: int, - args, - kwargs, - batch_id: int, - num_microbatches: int, - garbage_collect_values=True, - ): - super().__init__(module, garbage_collect_values) - self.remote_stage_executor_rrefs = remote_stage_executor_rrefs - self.stage_to_executor = stage_to_executor - self.cur_microbatch = cur_microbatch - self.pc = 0 - self.node_list = list(self.module.graph.nodes) - logging.debug( - f"[root] RemoteInterpreter created with {len(self.node_list)} nodes" - ) - - # Process args/kwargs - - # TODO: replace this with GraphModule.signature() when it lands - parameters = [] - for node in self.module.graph.nodes: - if node.op != "placeholder": - continue - default = next(iter(node.args)) if node.args else Parameter.empty - parameters.append( - Parameter( - node.name, Parameter.POSITIONAL_OR_KEYWORD, default=default - ) - ) - - # We are building a safety net here in case user passes in extra arguments than those defined as variable - # arguments (i.e. non-concrete args) at the tracing phase - # TODO: Remove this safety net - traced_args = [p.name for p in parameters] - filtered_kwargs = {k: v for k, v in kwargs.items() if k in traced_args} - if len(filtered_kwargs) != len(kwargs): - extra_args = kwargs.keys() - filtered_kwargs.keys() - warnings.warn( - f"Received extra arguments: {extra_args}. " - f"They might have already been given a concrete value during pipeline compilation via `concrete_args`. " - f"We will ignore the current inputs and use the values given during compilation." - ) - - sig = Signature(parameters) - bound_args = sig.bind(*args, **filtered_kwargs) - bound_args.apply_defaults() - self.args = bound_args.args - self.args_iter = iter(self.args) - self.batch_id = batch_id - self.num_microbatches = num_microbatches - # Dict from stage id to a list holding the coalesced getitem indices - self.stage_output_indices: Dict[ - int, List[Tuple[str, int, ValueReference, int]] - ] = {} - - def call_module(self, target, args, kwargs): - assert isinstance(target, str) - node = self.node_list[self.pc] - # if PipelineDriver is running inside `torch.no_grad()` context manager then `stage_backward*` nodes - # are excluded from execution, so we need exclude `stage_backward*` from reference count, otherwise - # it will cause memory leak. - users = ( - list( - filter( - lambda user: not user.name.startswith("stage_backward"), - node.users.keys(), - ) - ) - if not torch.is_grad_enabled() - else node.users.keys() - ) - if target in self.remote_stage_executor_rrefs: - stage_id, stage_executor = self.remote_stage_executor_rrefs[target] - logging.debug( - f"[root][{self.cur_microbatch}] Issuing {Phase.FORWARD} " - f"invocation for target {target} on stage {stage_id}" - ) - invocation_key = f"{self.cur_microbatch}_{node.name}" - start_ts = time.time() - forward_name = event_name( - Phase.FORWARD, stage_id, self.cur_microbatch - ) - forward_id = event_id( - Phase.FORWARD, stage_id, self.cur_microbatch, self.batch_id - ) - name = f"I{forward_name}" - id = f"I{forward_id}" - stage_executor.rpc_async().invoke( - invocation_key, - Phase.FORWARD, - args, - kwargs, - self.cur_microbatch, - debug_str=node.format_node(), - output_refcount=len(users), - batch_id=self.batch_id, - num_microbatches=self.num_microbatches, - ) - finish_ts = time.time() - self.record_event( - rank=0, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type="invoke", - mbid=self.cur_microbatch, - ) - self.record_event_dependency( - from_id=name, to_id=f"R{forward_name}", type="invoke" - ) - return ValueReference(stage_id, invocation_key) - else: - logging.debug( - f"[root][{self.cur_microbatch}] Running local operation {target} from driver" - ) - return super().call_module(target, args, kwargs) - - def call_function(self, target, args, kwargs): - node = self.node_list[self.pc] - invocation_key = f"{self.cur_microbatch}_{node.name}" - # if PipelineDriver is running inside `torch.no_grad()` context manager then `stage_backward*` nodes - # are excluded from execution, so we need exclude `stage_backward*` from reference count, otherwise - # it will cause memory leak. - users = ( - list( - filter( - lambda user: not user.name.startswith("stage_backward"), - node.users.keys(), - ) - ) - if not torch.is_grad_enabled() - else node.users.keys() - ) - if target is operator.getitem and isinstance(args[0], ValueReference): - val_ref = args[0] - stage_id = val_ref.stage_id - num_users = len(users) - if not torch.is_grad_enabled() and val_ref.unique_key == "noop": - return ValueReference(stage_id, "noop") - elif num_users == 0: - # TODO: investigate why there are getitem calls with 0 users - return ValueReference(stage_id, "noop") - else: - indices = self.stage_output_indices.setdefault(stage_id, []) - arg_idx = args[1] - index_tuple = (invocation_key, num_users, val_ref, arg_idx) - logging.debug( - f"[root][{self.cur_microbatch}] Appending getitem tuple to stage {stage_id}: {index_tuple}" - ) - indices.append(index_tuple) - return ValueReference(stage_id, invocation_key) - elif target is stage_backward: - assert "fw_stage" in node.meta - stage_id, stage_executor = self.remote_stage_executor_rrefs[ - node.meta["fw_stage"] - ] - if torch.is_grad_enabled(): - logging.debug( - f"[root][{self.cur_microbatch}] Issuing BW invocation " - f'for target {node.meta["fw_stage"]} on stage {stage_id}' - ) - start_ts = time.time() - backward_name = event_name( - Phase.BACKWARD, stage_id, self.cur_microbatch - ) - backward_id = event_id( - Phase.BACKWARD, stage_id, self.cur_microbatch, self.batch_id - ) - name = f"I{backward_name}" - id = f"I{backward_id}" - stage_executor.rpc_async().invoke( - invocation_key, - Phase.BACKWARD, - args, - kwargs, - self.cur_microbatch, - debug_str=node.format_node(), - output_refcount=len(users), - batch_id=self.batch_id, - num_microbatches=self.num_microbatches, - ) - finish_ts = time.time() - self.record_event( - rank=0, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type="invoke", - mbid=self.cur_microbatch, - ) - self.record_event_dependency( - from_id=name, to_id=backward_name, type="invoke" - ) - return ValueReference(stage_id, invocation_key) - else: - return ValueReference(stage_id, "noop") - elif target is sync_barrier: - executor_keys = list(self.remote_stage_executor_rrefs.keys()) - stage_id, stage_executor = self.remote_stage_executor_rrefs[ - executor_keys[0] - ] - logging.debug( - f"[root][{self.cur_microbatch}] Issuing sync invocation " - f"on stage {stage_id}" - ) - stage_executor.rpc_async().invoke( - invocation_key, - Phase.SYNC_BARRIER, - args, - kwargs, - self.cur_microbatch, - debug_str=node.format_node(), - output_refcount=len(users), - batch_id=self.batch_id, - num_microbatches=self.num_microbatches, - ) - return ValueReference(stage_id, invocation_key) - elif target is _null_coalesce_accumulate: - assert "fw_stage" in node.meta - stage_id, stage_executor = self.remote_stage_executor_rrefs[ - node.meta["fw_stage"] - ] - if torch.is_grad_enabled(): - logging.debug( - f"[root][{self.cur_microbatch}] Issuing accumulate grad invocation " - f'for target {node.meta["fw_stage"]} on stage {stage_id}' - ) - stage_executor.rpc_async().invoke( - invocation_key, - Phase.ACCUMULATE_GRAD, - args, - kwargs, - self.cur_microbatch, - debug_str=node.format_node(), - output_refcount=len(users), - batch_id=self.batch_id, - num_microbatches=self.num_microbatches, - ) - return ValueReference(stage_id, invocation_key) - else: - return ValueReference(stage_id, "noop") - else: - raise AssertionError(f"Unknown operator {torch.typename(target)}") - - def issue_coalesced_getitem_calls(self): - if len(self.stage_output_indices) == 0: - return - logging.debug( - f"[root][{self.cur_microbatch}] Issuing getitem calls to stage: {self.stage_output_indices.keys()}" - ) - - for stage_id in self.stage_output_indices: - stage_executor = self.stage_to_executor[stage_id] - name = f"G{stage_id},{self.cur_microbatch}" - id = f"{name},{self.batch_id}" - start_ts = time.time() - stage_executor.rpc_async().coalesced_index_value( - self.stage_output_indices[stage_id] - ) - finish_ts = time.time() - self.record_event( - rank=0, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type="invoke", - mbid=self.cur_microbatch, - ) - - self.stage_output_indices.clear() - - def run_until( - self, predicate: Callable[[pippy.fx.Node], bool] - ) -> Optional[pippy.fx.Node]: - while self.pc < len(self.node_list): - node = self.node_list[self.pc] - - if predicate(node): - # Issue coalesced getitem calls as we pause issuing stage calls - self.issue_coalesced_getitem_calls() - return node - - self.run_one(node) - - # Have run through the entire node_list, using None to mean no node left to run - return None - - def run_one(self, node): - # TODO: hoist run() implementation - logging.debug( - f"[{self.cur_microbatch}] Issue command to run {node.format_node()}" - ) - self.env[node] = super().run_node(node) - - # TODO: we could potentially move this waiting to the use sites for an RRef - # (i.e. during Interpreter.map_nodes_to_values or when we pass args/kwargs - # to the callees) as an optimization - # TODO: is it possible for there to be a blocking version of this API? - def wait_for_confirmation(n): - # The following if will not be true as we are using our own ValueRef - # instead of RPC's RRef - if isinstance(n, torch._C._distributed_rpc.PyRRef): - while not n.confirmed_by_owner(): - pass - - pippy.fx.node.map_aggregate(self.env[node], wait_for_confirmation) - - if DEBUG and isinstance( - self.env[node], torch._C._distributed_rpc.PyRRef - ): - print(node, self.env[node]) - self.env[node].to_here() - - # Insert tensor meta to ValueReference returned by node call - # TODO: there is some problem with "call_function", disabling for now - if node.op == "call_module": - if "tensor_meta" in node.meta and isinstance( - node.meta["tensor_meta"], - shape_prop.TensorMetadata, - ): - val_ref: ValueReference = self.env[node] - val_ref.meta.setdefault("tensor_meta", node.meta["tensor_meta"]) - - self.pc += 1 - return node - - def propagate_shape(self, args, kwargs): - logging.info("Propagating shape across split GraphModule") - sp = shape_prop.ShapeProp(self.module) - # Not sure why FX's propagate API takes only args. Hence we unpack kwargs.values() without keys here - sp.propagate(*args, *kwargs.values()) - for node in self.node_list: - logging.debug(f"Node: {node.name}, outputs: ") - if "tensor_meta" in node.meta: - if isinstance( - node.meta["tensor_meta"], shape_prop.TensorMetadata - ): - logging.debug(f"- {node.meta['tensor_meta']}") - else: - # Multiple output tensors - for t_meta in node.meta["tensor_meta"]: - logging.debug(f"- {t_meta}") - - -class _run_until_criteria: - def __init__(self): - self.seen_stages = 0 - - # Run the node we start with including all nodes that are tuple - # indexing, then stop - def hitting_next_stage(self, node): - if node.op == "output": - return True - - if ( - node.target != operator.getitem - and node.target != _null_coalesce_accumulate - ): - self.seen_stages += 1 - - if self.seen_stages > 1: - return True - elif self.seen_stages == 1 and node.target == _null_coalesce_accumulate: - # We are hitting the accumulate call of the next (backward) stage, stop - return True - else: - return False - - -class PipelineDriverFillDrain(PipelineDriverBase): - def __init__( - self, - pipe: Pipe, - chunks: int, - world_size: int, - all_ranks: List[int] = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - single_loss: bool = False, - _debug_mask_minibatches: bool = False, - max_outstanding=None, - interleave_stages=False, - _record_mem_dumps=False, - checkpoint=False, - use_c10d=False, - loss_reducer: LossReducer = sum_reducer, - ): - super().__init__( - pipe, - chunks, - world_size, - all_ranks, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, - _debug_mask_minibatches, - max_outstanding=max_outstanding, - interleave_stages=interleave_stages, - _record_mem_dumps=_record_mem_dumps, - checkpoint=checkpoint, - use_c10d=use_c10d, - loss_reducer=loss_reducer, - ) - self.single_loss = single_loss - - self.last_grads = None - - self._init_remote_executors() - - def forward(self, *args, **kwargs): - if self.single_loss: - raise NotImplementedError("Single minibatch loss not implemented") - - # Roadmap: - # 1) Micro-batch splitting - divide input arguments out into concrete chunk values - # 2) Interpreter tiling - one interpreter per micro-batch - # 3) Scheduling - Use control logic to advance interpreters to issue round-robin - # forward work items, then round-robin losses, then round-robin backwards - - args_split, kwargs_split = split_args_kwargs_into_chunks( - args, - kwargs, - self.chunks, - self.args_chunk_spec, - self.kwargs_chunk_spec, - self._debug_mask_minibatches, - ) - - real_num_chunks = self.chunks - if len(args_split) < self.chunks: - real_num_chunks = len(args_split) - warnings.warn( - f"Reducing micro-batch numbers from {self.chunks} to " - f"{real_num_chunks}." - ) - - logging.info( - f"[root] Running pipeline with {real_num_chunks} micro-batches" - ) - - self.microbatch_interpreters = [] - - batch_id = self.batch_id - self.batch_id += 1 - - for chunk in range(real_num_chunks): - logging.debug( - f"[root] Instantiating microbatch interpreter for chunk {chunk}" - ) - interp = RemoteInterpreter( - remote_stage_executor_rrefs=self.remote_stage_executor_rrefs, - stage_to_executor=self.stage_to_executor, - module=self.pipe.split_gm, - cur_microbatch=chunk, - args=args_split[chunk], - kwargs=kwargs_split[chunk], - batch_id=batch_id, - num_microbatches=real_num_chunks, - ) - # If user wants to use c10d for P2P, we would perform the shape propagation here. The shape prop is - # performed per batch, thus supporting dynamic shape in batch dimension. Dynamic shape in microbatch - # dimension is not yet supported, because all RemoteInterpreters share the same shape info (since they share - # the same split_gm) - if self.use_c10d and chunk == 0: - interp.propagate_shape(args_split[chunk], kwargs_split[chunk]) - - self.microbatch_interpreters.append(interp) - - logging.debug( - f"[root] {len(self.microbatch_interpreters)} instantiated" - ) - - # Deterministic clock cycle - see torchgpipe paper section 3.2.1 for details - - # Advance past placeholders - for interp in self.microbatch_interpreters: - interp.run_until(lambda n: n.op != "placeholder") - - # Ramp-up, admit diagonal wavefront until we get to a full diagonal - # location in the matrix - - for ramp_up_idx in range(len(self.microbatch_interpreters)): - for i in range(ramp_up_idx + 1): - interp = self.microbatch_interpreters[i] - criteria = _run_until_criteria() - interp.run_until(criteria.hitting_next_stage) - - # Steady-state. We have a full diagonal in the matrix; keep dispatching - # across the diagonal - - any_valid = True - while any_valid: - any_valid = False - for interp in self.microbatch_interpreters: - start_node = interp.node_list[ - min(interp.pc, len(interp.node_list) - 1) - ] - criteria = _run_until_criteria() - interp.run_until(criteria.hitting_next_stage) - - any_valid |= interp.node_list[interp.pc] != start_node - - last_nodes = [ - interp.node_list[interp.pc] - for interp in self.microbatch_interpreters - ] - assert all(node.op == "output" for node in last_nodes) - - local_results_and_last_grads = self._retrieve_output_values( - self.microbatch_interpreters, last_nodes - ) - - if self.pipe.has_loss_and_backwards: - # Shared parameter sync - # At this point, all of the gradient jobs should have been run - # (by way of the synchronization dependency earlier) - self._sync_replicated_params() - - if DEBUG: - self._check_stages_cleanup() - - if self.pipe.has_loss_and_backwards: - local_results = [] - last_grads = [] - for local_result in local_results_and_last_grads: - local_results.append(local_result[0]) - last_grads.append(local_result[1]) - - self.last_grads = last_grads # type: ignore[assignment] - else: - local_results = local_results_and_last_grads - - return merge_chunks( - local_results, self.output_chunk_spec, self._debug_mask_minibatches - ) - - -class PipelineDriver1F1B(PipelineDriverFillDrain): - def __init__( - self, - pipe: Pipe, - chunks: int, - world_size: int, - all_ranks: List[int] = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - single_loss: bool = False, - _debug_mask_minibatches: bool = False, - interleave_stages=False, - _record_mem_dumps=False, - checkpoint=False, - use_c10d=False, - loss_reducer: LossReducer = sum_reducer, - ): - # In 1F1B with backward stages, the maximum number of outstanding - # micro-batches equals the number of pipeline stages - max_outstanding = ( - pipe.num_stages if pipe.has_loss_and_backwards else None - ) - - super().__init__( - pipe, - chunks, - world_size, - all_ranks, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, - single_loss, - _debug_mask_minibatches, - max_outstanding=max_outstanding, - interleave_stages=interleave_stages, - _record_mem_dumps=_record_mem_dumps, - checkpoint=checkpoint, - use_c10d=use_c10d, - loss_reducer=loss_reducer, - ) - - -class PipelineDriverInterleaved1F1B(PipelineDriver1F1B): - def __init__( - self, - pipe: Pipe, - chunks: int, - world_size: int, - all_ranks: List[int] = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - single_loss: bool = False, - _debug_mask_minibatches: bool = False, - _record_mem_dumps=False, - checkpoint=False, - use_c10d=False, - loss_reducer: LossReducer = sum_reducer, - ): - super().__init__( - pipe, - chunks, - world_size, - all_ranks, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, - single_loss, - _debug_mask_minibatches, - interleave_stages=True, - _record_mem_dumps=_record_mem_dumps, - checkpoint=checkpoint, - use_c10d=use_c10d, - loss_reducer=loss_reducer, - ) diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 5f5293af2..a3b1903e9 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -5,25 +5,26 @@ import torch import torch.distributed as dist +import torch.fx as fx +from torch._subclasses.fake_tensor import FakeTensor from torch.nn.parallel import DistributedDataParallel -import pippy -import pippy.fx -from pippy.backward import stage_backward, sync_barrier +from pippy.backward import stage_backward from pippy.debug import map_debug_info - -from pippy.fx.passes import shape_prop from pippy.IR import Pipe from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks from pippy.utils import flatten_args +logger = logging.getLogger(__name__) + + def _make_tensor_from_meta( - tensor_meta: shape_prop.TensorMetadata, + example_value: FakeTensor, device: torch.device, ) -> torch.Tensor: return torch.empty( - tensor_meta.shape, dtype=tensor_meta.dtype, device=device + example_value.size(), dtype=example_value.dtype, device=device ) @@ -39,7 +40,7 @@ def __init__( self.buffer = buffer def __repr__(self): - return f"RecvInfo(input={self.input_name}, source={self.source}, buffer={self.buffer.size()})" + return f"RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" class StageArgPlaceholder: @@ -51,24 +52,20 @@ def __init__( self, pipe: Pipe, stage_index: int, - nstages: int, - chunks: int, device: torch.device, group: dist.ProcessGroup = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, ): super().__init__() self.pipe = pipe self.stage_index = stage_index - self.nstages = nstages - self.chunks = chunks + self.nstages = pipe.num_stages + self.chunks = pipe.num_chunks self.device = device self.group = group - self.args_chunk_spec = args_chunk_spec - self.kwargs_chunk_spec = kwargs_chunk_spec - self.output_chunk_spec = output_chunk_spec + if dist.get_world_size(self.group) > self.nstages: + raise RuntimeError( + "Number of ranks is larger than number of stages, some ranks are unused" + ) # `group_rank` is rank in process group `group`. self.group_rank = dist.get_rank(group) @@ -90,8 +87,8 @@ def __init__( self.split_gm = self.pipe.split_gm named_children = list(self.split_gm.named_children()) self.name, self.submod = named_children[stage_index] - logging.info( - f"[{self.group_rank}][{self.name}] " + logger.info( + f"[{self.group_rank}] " f"Creating PipelineStage:\n" f"{self.submod}" ) @@ -131,7 +128,7 @@ def __init__( # In interleaved case, `group_rank` is stage index % group size. self.stage_index_to_group_rank: Dict[int, int] = {} pg_world_size = dist.get_world_size(group) - for i in range(nstages): + for i in range(self.nstages): # We only support wrapped-around interleaving peer_rank = i % pg_world_size self.stage_index_to_group_rank.setdefault(i, peer_rank) @@ -209,7 +206,7 @@ def create_recv_tensor( # real source e.g. getitem1 = submod0[1] # Here `submod0` is args[0], 1 is args[1] if input_node.target is operator.getitem: - if "tensor_meta" in input_node.meta: + if "example_value" in input_node.meta: real_input_node = input_node.args[0] out_idx = input_node.args[1] return create_recv_tensor(real_input_node, out_idx) @@ -220,20 +217,20 @@ def create_recv_tensor( ) if output_idx is not None: - # If a node has multiple output values, "tensor_meta" is a list + # If a node has multiple output values, "example_value" is a list # of tensor meta - tensor_meta = input_node.meta["tensor_meta"][output_idx] + example_value = input_node.meta["example_value"][output_idx] else: - tensor_meta = input_node.meta["tensor_meta"] + example_value = input_node.meta["example_value"] - logging.info( - f"[{self.group_rank}][{self.name}] " + logger.info( + f"[{self.group_rank}] " f"Creating recv buffer for input '{input_node.name}' " - f"value index {output_idx}: {tensor_meta.shape}" + f"value index {output_idx}: {example_value.size()}" ) src_rank = self.get_stage_index_of_submod(input_node.name) - buffer = _make_tensor_from_meta(tensor_meta, self.device) + buffer = _make_tensor_from_meta(example_value, self.device) # Enable gradient in training mode if self.pipe.has_loss_and_backwards: buffer.requires_grad_(True) @@ -245,25 +242,20 @@ def create_recv_tensor( # `args` is a Tuple, hence we will have: # Tuple[RecvInfo] - args_recv_info = pippy.fx.node.map_arg( - self.node.args, create_recv_tensor - ) + args_recv_info = fx.node.map_arg(self.node.args, create_recv_tensor) # `kwargs` is a Dict, hence we will have: # Dict[keyword, RecvInfo] - kwargs_recv_info = pippy.fx.node.map_arg( - self.node.kwargs, create_recv_tensor - ) + kwargs_recv_info = fx.node.map_arg(self.node.kwargs, create_recv_tensor) - logging.info( - f"[{self.group_rank}][{self.name}] " - f"Activation recv info: {args_recv_info}" + logger.info( + f"[{self.group_rank}] " f"Activation recv info: {args_recv_info}" ) return args_recv_info, kwargs_recv_info def find_dst_rank( self, - user: pippy.fx.Node, + user: fx.Node, ) -> Optional[int]: """ Find the destination rank of a `user` node. @@ -272,9 +264,6 @@ def find_dst_rank( if user.op == "call_module": # User is a stage (`call_module`) return self.get_stage_index_of_submod(user.name) - elif user.target is sync_barrier: - # Send result back to pp rank 0 - return 0 else: # - If user.op == "output": # No need to send back to rank 0 @@ -305,9 +294,7 @@ def _create_act_send_info(self): if dst_rank is not None: dsts.append(dst_rank) - logging.info( - f"[{self.group_rank}][{self.name}] " f"Send info: {act_send_info}" - ) + logger.info(f"[{self.group_rank}] " f"Send info: {act_send_info}") return act_send_info def _create_grad_recv_info( @@ -316,7 +303,7 @@ def _create_grad_recv_info( ) -> Dict[int, RecvInfo]: # Dict[output_index, RecvInfo] grad_recv_info: Dict = {} - my_tensor_meta = self.node.meta["tensor_meta"] + my_example_value = self.node.meta["example_value"] for out_idx, dst_list in act_send_info.items(): if not dst_list: @@ -325,9 +312,9 @@ def _create_grad_recv_info( # TODO: clean way if len(act_send_info) > 1: - tensor_meta = my_tensor_meta[out_idx] + example_value = my_example_value[out_idx] else: - tensor_meta = my_tensor_meta + example_value = my_example_value # TODO: otherwise needs grad accumulation assert len(dst_list) == 1 @@ -335,13 +322,10 @@ def _create_grad_recv_info( grad_recv_info[out_idx] = RecvInfo( f"{grad_src}", grad_src, - _make_tensor_from_meta(tensor_meta, self.device), + _make_tensor_from_meta(example_value, self.device), ) - logging.info( - f"[{self.group_rank}][{self.name}] " - f"Grad recv info: {grad_recv_info}" - ) + logger.info(f"[{self.group_rank}] " f"Grad recv info: {grad_recv_info}") return grad_recv_info def _create_grad_send_info( @@ -359,19 +343,16 @@ def map_recv_to_send(a): grad_send_info.append(None) return None - pippy.fx.node.map_aggregate(args_recv_info, map_recv_to_send) + fx.node.map_aggregate(args_recv_info, map_recv_to_send) - pippy.fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send) + fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send) - logging.info( - f"[{self.group_rank}][{self.name}] " - f"Grad send info: {grad_send_info}" - ) + logger.info(f"[{self.group_rank}] " f"Grad send info: {grad_send_info}") return grad_send_info def _recv_tensor(self, info, recv_reqs): - logging.debug( - f"[{self.group_rank}][{self.name}] " + logger.debug( + f"[{self.group_rank}] " f"Receiving tensor '{info.input_name}' from Rank {info.source}: " f"{info.buffer.size()}" ) @@ -401,8 +382,8 @@ def split_inputs(self, args, kwargs): args, kwargs, self.chunks, - self.args_chunk_spec, - self.kwargs_chunk_spec, + self.pipe.args_chunk_spec, + self.pipe.kwargs_chunk_spec, ) def _recv_and_fill_inputs( @@ -424,7 +405,7 @@ def recv_args(info): else: return chunk_args_list.pop(0) # type: ignore[has-type] - composite_args = pippy.fx.node.map_aggregate( + composite_args = fx.node.map_aggregate( self.args_recv_info[chunk], recv_args, ) @@ -439,7 +420,7 @@ def recv_kwargs(info): k = next(iter(chunk_kwargs)) # type: ignore[has-type] return chunk_kwargs.pop(k) # type: ignore[has-type] - composite_kwargs = pippy.fx.node.map_aggregate( + composite_kwargs = fx.node.map_aggregate( self.kwargs_recv_info[chunk], recv_kwargs, ) @@ -462,8 +443,8 @@ def _send_activations( for dst in dst_stages: if dst is None: continue - logging.debug( - f"[{self.group_rank}][{self.name}] " + logger.debug( + f"[{self.group_rank}] " f"Sending tensor to Rank {dst}: {out.size()}" ) peer_rank = self.stage_index_to_group_rank[dst] @@ -488,7 +469,7 @@ def _recv_grads( recv_grad = self.recv_tensor_fn(grad_recv_reqs) # Receive gradients - grads = pippy.fx.node.map_aggregate( + grads = fx.node.map_aggregate( self.grad_recv_info[bwd_chunk], recv_grad, ) @@ -496,8 +477,8 @@ def _recv_grads( for work in grad_recv_reqs: work.wait() - logging.debug( - f"[{self.group_rank}][{self.name}] " + logger.debug( + f"[{self.group_rank}] " f"Received output grads of chunk {bwd_chunk}: {map_debug_info(grads)}" ) return grads @@ -511,8 +492,8 @@ def _send_grads( for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: - logging.debug( - f"[{self.group_rank}][{self.name}] " + logger.debug( + f"[{self.group_rank}] " f"Sending gradient to Rank {grad_recv_stage}: {grad.size()}" ) peer_rank = self.stage_index_to_group_rank[grad_recv_stage] @@ -579,6 +560,11 @@ def forward_one_chunk( """ raise RuntimeError(exc_msg) from e + if type(output) is list: + # HACK: this is a hacky workaround for the fact that export creates + # output in list format + output = tuple(output) + logger.debug(map_debug_info(output)) # Unify output form to tuple for easy correspondance with # `act_send_info` output_tuple = output if type(output) is tuple else (output,) @@ -643,7 +629,7 @@ def clear_runtime_states(self): def merge_output_chunks(self): return merge_chunks( self.output_chunks, - self.output_chunk_spec, + self.pipe.output_chunk_spec, ) def forward(self, *args, **kwargs): @@ -656,16 +642,18 @@ def forward(self, *args, **kwargs): # Forward pass of all chunks for chunk in range(self.chunks): self.forward_one_chunk(chunk) - - # Wait for all sends to finish - # TODO: okay to delay the sync till completion of all chunks? - for work in self.all_act_send_reqs: - work.wait() + logger.debug(f"[{self.group_rank}] Forwarded chunk {chunk}") # Backward starts here for bwd_chunk in range(self.chunks): self.backward_one_chunk(bwd_chunk) + logger.debug(f"[{self.group_rank}] Backwarded chunk {bwd_chunk}") + + # Wait for all sends to finish + # TODO: okay to delay the sync till completion of all chunks? + for work in self.all_act_send_reqs: + work.wait() # Wait for all sends to finish # TODO: okay to delay the sync till completion of all chunks? @@ -684,24 +672,14 @@ def __init__( self, pipe: Pipe, rank: int, - nstages: int, - chunks: int, device: torch.device, group: dist.ProcessGroup = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, ): super().__init__( pipe, rank, - nstages, - chunks, device, group=group, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, ) def forward(self, *args, **kwargs): diff --git a/pippy/__init__.py b/pippy/__init__.py index 49dc0493a..75de5f7bd 100644 --- a/pippy/__init__.py +++ b/pippy/__init__.py @@ -1,10 +1,4 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.compile import ( - all_compile, - compile, - compile_stage, - create_default_args, -) from pippy.IR import ( annotate_split_points, LossWrapper, @@ -15,8 +9,6 @@ TrivialLossWrapper, ) from pippy.ModelSplit import split_into_equal_size, split_on_size_threshold -from pippy.PipelineDriver import PipelineDriver1F1B, PipelineDriverFillDrain -from pippy.utils import run_pippy __all__ = [ @@ -25,15 +17,8 @@ "TrivialLossWrapper", "Pipe", "pipe_split", - "run_pippy", "PipeSplitWrapper", "annotate_split_points", - "PipelineDriverFillDrain", - "PipelineDriver1F1B", "split_into_equal_size", "split_on_size_threshold", - "compile", - "all_compile", - "create_default_args", - "compile_stage", ] diff --git a/pippy/auto_parallelization.py b/pippy/auto_parallelization.py index 94e026807..a22cc81f1 100644 --- a/pippy/auto_parallelization.py +++ b/pippy/auto_parallelization.py @@ -18,7 +18,7 @@ import numpy as np -import pippy.fx +from torch import fx from pippy import pipe_split @@ -272,7 +272,7 @@ class AutoParallelConfig: def dp_auto_parallel(config: AutoParallelConfig): - def _dp_auto_parallel(fx_mod: pippy.fx.GraphModule): + def _dp_auto_parallel(fx_mod: fx.GraphModule): n_graph_nodes = len(fx_mod.graph.nodes) submesh_shapes = get_possible_submesh_shapes( n_compute_nodes=config.n_compute_nodes, diff --git a/pippy/compile.py b/pippy/compile.py deleted file mode 100644 index 5ddf074c6..000000000 --- a/pippy/compile.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -import logging -from typing import Any, Callable, List, Optional - -import torch -import torch.distributed as dist -from torch._subclasses.fake_tensor import FakeTensorMode - -import pippy.fx as fx -from pippy.debug import PIPPY_VERBOSITY -from pippy.IR import MultiUseParameterConfig, Pipe, PiPPyShapeProp -from pippy.microbatch import ( - gen_output_chunk_spec, - LossReducer, - split_args_kwargs_into_chunks, - sum_reducer, -) -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) -from pippy.PipelineStage import PipelineStage, PipelineStage1F1B -from pippy.utils import get_device, get_pp_rank, get_rank - - -PIPELINE_SCHEDULE_DRIVERS = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - - -def create_default_args( - mod: torch.nn.Module, - except_keys: List = None, -): - if except_keys is None: - except_keys = [] - sig = inspect.signature(mod.forward) - default_kwargs = { - p.name: p.default - for p in sig.parameters.values() - if p.name not in except_keys and p.default is not inspect._empty - } - return default_kwargs - - -def _compile( - all_compile: bool, - mod: torch.nn.Module, - num_ranks: int, - num_chunks: int, - schedule: Optional[str] = "FillDrain", - split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, - ranks: List[int] = None, - tracer=None, - loss_reducer: LossReducer = sum_reducer, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - checkpoint=False, - _debug_mask_minibatches: bool = False, - index_filename=None, - checkpoint_prefix: str = None, - **kwargs, -): - if ranks is None: - ranks = list(range(num_ranks)) - - if all_compile: - rank = get_rank() - pp_rank = get_pp_rank(rank, ranks) - else: - pp_rank = 0 - - # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across - # stages instead of TRANSMIT'ting it - multi_use_param_spec = MultiUseParameterConfig.REPLICATE - - # Figure out which output is loss from output_chunk_spec - output_loss_value_spec: Any = None - if output_chunk_spec is not None: - output_loss_value_spec = fx.node.map_aggregate( - output_chunk_spec, lambda v: isinstance(v, LossReducer) - ) - - logging.info("[PiPPy] Tracing model ...") - pipe_model = Pipe.from_tracing( - mod, - multi_use_param_spec=multi_use_param_spec, - tracer=tracer, - output_loss_value_spec=output_loss_value_spec, - split_policy=split_policy, - **kwargs, - ) - - # In all_compile mode, each rank calls pippy.all_compile, hence they will all have the pipe. - # We can hence ask each rank to get its own stage from the pipe, and materialize it locally. - if all_compile: - device = get_device() - - # `None` means self.dtype, i.e. no change - dtype = None - # TODO: generalize this - if hasattr(mod, "config") and hasattr(mod.config, "torch_dtype"): - dtype = mod.config.torch_dtype # type: ignore[union-attr] - - pipe_model.defer_stage_init( - device, - index_filename, - dtype, - checkpoint_prefix, - ) - stage_mod = pipe_model.export(pp_rank) - - if pp_rank == 0: - logging.info(pipe_model.split_gm) - - logging.info("[PiPPy] Creating pipeline driver ...") - if schedule not in PIPELINE_SCHEDULE_DRIVERS: - raise ValueError( - f"Unknown pipeline schedule: {schedule}. " - f"Please select from {PIPELINE_SCHEDULE_DRIVERS.keys()}" - ) - pipeline_driver = PIPELINE_SCHEDULE_DRIVERS[schedule]( - pipe_model, - num_chunks, - num_ranks, - all_ranks=ranks, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - checkpoint=checkpoint, - loss_reducer=loss_reducer, - _debug_mask_minibatches=_debug_mask_minibatches, - ) - - if not all_compile: - return pipeline_driver - - if pp_rank == 0: - return pipeline_driver, stage_mod - else: - return None, stage_mod - - -def compile( - mod: torch.nn.Module, - num_ranks: int, - num_chunks: int, - schedule: Optional[str] = "FillDrain", - split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, - ranks: List[int] = None, - tracer=None, - loss_reducer: LossReducer = sum_reducer, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - checkpoint=False, - _debug_mask_minibatches: bool = False, - **kwargs, -): - return _compile( - False, - mod, - num_ranks, - num_chunks, - schedule=schedule, - split_policy=split_policy, - ranks=ranks, - tracer=tracer, - loss_reducer=loss_reducer, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - checkpoint=checkpoint, - _debug_mask_minibatches=_debug_mask_minibatches, - **kwargs, - ) - - -def all_compile( - mod: torch.nn.Module, - num_ranks: int, - num_chunks: int, - schedule: Optional[str] = "FillDrain", - split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, - ranks: List[int] = None, - tracer=None, - loss_reducer: LossReducer = sum_reducer, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - checkpoint=False, - _debug_mask_minibatches: bool = False, - **kwargs, -): - return _compile( - True, - mod, - num_ranks, - num_chunks, - schedule=schedule, - split_policy=split_policy, - ranks=ranks, - tracer=tracer, - loss_reducer=loss_reducer, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - checkpoint=checkpoint, - _debug_mask_minibatches=_debug_mask_minibatches, - **kwargs, - ) - - -def compile_stage( - mod: torch.nn.Module, - stage_index: int, - num_stages: int, - num_chunks: int, - device: torch.device, - group: dist.ProcessGroup, - example_inputs: List[torch.Tensor], - split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, - return_to_0: bool = False, - tracer=None, - loss_reducer: LossReducer = sum_reducer, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - schedule="FillDrain", - **kwargs, -) -> PipelineStage: - # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across - # stages instead of TRANSMIT'ting it - multi_use_param_spec = MultiUseParameterConfig.REPLICATE - - # Figure out which output is loss from output_chunk_spec - output_loss_value_spec: Any = None - if output_chunk_spec is not None: - output_loss_value_spec = fx.node.map_aggregate( - output_chunk_spec, lambda v: isinstance(v, LossReducer) - ) - - logging.info("[PiPPy] Tracing model ...") - pipe = Pipe.from_tracing( - mod, - multi_use_param_spec=multi_use_param_spec, - tracer=tracer, - output_loss_value_spec=output_loss_value_spec, - split_policy=split_policy, - return_to_0=return_to_0, - **kwargs, - ) - - gm = pipe.split_gm - if stage_index == 0: - logging.info(gm) - if PIPPY_VERBOSITY == "INFO": - gm.graph.print_tabular() - - # Get shape of chunked arguments - args_split, _ = split_args_kwargs_into_chunks( - example_inputs, - {}, # kwargs included in `example_inputs` - num_chunks, - args_chunk_spec, - kwargs_chunk_spec, # TODO: merge into args_chunk_spec - ) - - # Use fake tensor for shape propagation - # Since model itself may have been materialized, we need to use - # `allow_non_fake_inputs` - fake_mode = FakeTensorMode(allow_non_fake_inputs=True) - # In reality, the fake input should be created from shape info (potentially - # broadcast from Rank 0) - fake_args_split = fx.node.map_aggregate( - args_split, lambda a: fake_mode.from_tensor(a) - ) - - # Use 1st chunk of args for shape propagation - chunk0 = fake_args_split[0] - - sp = PiPPyShapeProp(gm) - sp.propagate(*chunk0) - - # Prepare output chunk/reduce spec for merging/reducing final outputs - output_chunk_spec = ( - output_chunk_spec - if output_chunk_spec - else gen_output_chunk_spec(pipe.loss_spec, loss_reducer) - ) - - # Create pipeline stage based on schedule - if schedule == "1F1B": - return PipelineStage1F1B( - pipe, - stage_index, - num_stages, - num_chunks, - device, - group=group, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - ) - else: - return PipelineStage( - pipe, - stage_index, - num_stages, - num_chunks, - device, - group=group, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - ) diff --git a/pippy/debug.py b/pippy/debug.py index c393581a9..4e96cf7d1 100644 --- a/pippy/debug.py +++ b/pippy/debug.py @@ -4,19 +4,19 @@ import torch -import pippy.fx - PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "OFF") if PIPPY_VERBOSITY == "DEBUG": - logging.getLogger().setLevel(logging.DEBUG) + logging.getLogger("pippy").setLevel(logging.DEBUG) elif PIPPY_VERBOSITY == "INFO": - logging.getLogger().setLevel(logging.INFO) + logging.getLogger("pippy").setLevel(logging.INFO) elif PIPPY_VERBOSITY == "OFF": pass else: - print(f"Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") + print(f"[PiPPy] Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") + +print(f"[PiPPy] Setting logging level to: {PIPPY_VERBOSITY}") def friendly_debug_info(v): @@ -27,4 +27,4 @@ def friendly_debug_info(v): def map_debug_info(a): - return pippy.fx.node.map_aggregate(a, friendly_debug_info) + return torch.fx.node.map_aggregate(a, friendly_debug_info) diff --git a/pippy/fx/OVERVIEW.md b/pippy/fx/OVERVIEW.md deleted file mode 100644 index f2995eb7a..000000000 --- a/pippy/fx/OVERVIEW.md +++ /dev/null @@ -1,134 +0,0 @@ -# FX Technical Overview (WIP) - -FX is a toolkit for pass writers to facilitate Python-to-Python transformation of `nn.Module` instances. This toolkit aims to support a subset of Python language semantics—rather than the whole Python language—to facilitate ease of implementation of transforms. Currently, this feature is under a Beta release and its API may change. - -## Table of Contents - - - -- [Introduction](#introduction) - - [Motivation](#motivation) - - [Use Cases](#use-cases) - - [Technical Details](#technical-details) -- [Internal Structure](#internal-structure) - - [Graph](#graph) - - [GraphModule](#graphmodule) -- [Symbolic Tracing](#symbolic-tracing) - - [Tracer](#tracer) - - [Proxy](#proxy) -- [The FX IR](#the-fx-ir) -- [Transformation and Codegen](#transformation-and-codegen) - - - -# Introduction - -## Motivation ## - -TODO - -## Use Cases ## - -FX should be used by pass writers to provide functionality for capturing and constructing nn.Module code in a structured way. We do not expect end users to utilize FX directly. A useful property of framing FX in this way is that passes can be seen as functions of the form `pass(in_mod : nn.Module) -> nn.Module`. This means we can create composable pipelines of transformations. - -![An image of a sample nn.Module transformation pipeline that starts with a Quantize transformation, which is then composed with a Split transformation, then a Lower to Accelerator transformation](https://i.imgur.com/TzFIYMi.png "nn.Module transformation pipeline") - -In this example pipeline, we have a Quantize transformation, which is then composed with a Split transformation, then a Lower to Accelerator transformation. Finally, the transformed Modules are compiled with TorchScript for deployment. This last point emphasizes that not only should FX transforms be composable with each other, but their products are composable with other systems like TorchScript compilation or tracing. - -By using `nn.Module` as the interface between passes, FX transforms are interoperable with each other, and the resulting model can be used anywhere an `nn.Module` can be used. - -## Technical Details ## - -The following sections will walk us through the components that transform from original `torch.nn.Module` to FX IR and finally to generated Python code and a GraphModule instance: - -FX’s front-end makes use of the dynamic nature of Python to intercept call-sites for various entities (PyTorch operators, Module invocations, and Tensor method invocations). This functionality is exposed through an API called `torch.fx.symbolic_trace`. We can see how this works by way of an example: - -```python -import torch - -class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter( - torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - -from torch.fx import symbolic_trace -module = MyModule() -symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) - -input = torch.rand(3, 4) -torch.testing.assert_allclose(symbolic_traced(input), module(input)) -``` - -Here, we set up a simple Module that exercises different language features: fetching a parameter, applying an arithmetic operator, applying a submodule (linear), and applying a Tensor method. `symbolic_trace` returns an instance of GraphModule, which is in itself a subclass of `nn.Module`. We can see that the `symbolic_traced` instance runs and returns the same result as the original module instance module. - -# Internal Structure - -## [Graph](https://pytorch.org/docs/master/fx.html#torch.fx.Graph) ## -TODO - -## [GraphModule](https://pytorch.org/docs/master/fx.html#torch.fx.GraphModule) ## -TODO - -# Symbolic Tracing - -## [Tracer](https://pytorch.org/docs/master/fx.html#torch.fx.Tracer) ## - -`Tracer` is the class that implements the symbolic tracing functionality of `torch.fx.symbolic_trace`. A call to `symbolic_trace(m)` is equivalent to `Tracer().trace(m)`. Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described in the docstrings of the methods on the class. - -In the default implementation of `Tracer().trace`, the tracer first creates Proxy objects for all arguments in the `forward` function. (This happens in the call to `create_args_for_root`.) Next, the `forward` function is called with the new Proxy arguments. As the Proxies flow through the program, they record all the operations (`torch` function calls, method calls, and operators) that they touch into the growing FX Graph as Nodes. - -## Proxy ## - -Proxy objects are Node wrappers used by the Tracer to record operations seen during symbolic tracing. The mechanism through which Proxy objects record computation is [`__torch_function__`](https://pytorch.org/docs/stable/notes/extending.html#extending-torch). If any custom Python type defines a method named `__torch_function__`, PyTorch will invoke that `__torch_function__` implementation when an instance of that custom type is passed to a function in the `torch` namespace. In FX, when operations on Proxy are dispatched to the `__torch_function__` handler, the `__torch_function__` handler records the operation in the Graph as a Node. The Node that was recorded in the Graph is then itself wrapped in a Proxy, facilitating further application of ops on that value. - -Consider the following example: - -```python - class M(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - m = M() - traced = symbolic_trace(m) -``` - -During the call to `symbolic_trace`, the parameter `x` is transformed into a Proxy object and the corresponding Node (a Node with op = “placeholder” and target = “x”) is added to the Graph. Then, the Module is run with Proxies as inputs, and recording happens via the `__torch_function__` dispatch path. - -If you're doing graph transforms, you can wrap your own Proxy method around a raw Node so that you can use the overloaded operators to add additional things to a Graph. - -# The FX IR - -Symbolic tracing captures an intermediate representation (IR), which is represented as a doubly-linked list of Nodes. - -Node is the data structure that represents individual operations within a Graph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include Nodes that specify function inputs and outputs). Each Node has a function specified by its `op` property. The Node semantics for each value of `op` are as follows: - -- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. -- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care -- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention -- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument*. -- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* -- `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement in the Graph printout. - -To facilitate easier analysis of data dependencies, Nodes have read-only properties `input_nodes` and `users`, which specify which Nodes in the Graph are used by this Node and which Nodes use this Node, respectively. Although Nodes are represented as a doubly-linked list, the use-def relationships form an acyclic graph and can be traversed as such. - -# Transformation and Codegen - -An invocation of `symbolic_traced` above requires a valid `forward()` method to be defined on the Module instance. How does this work? GraphModule actually generates valid Python source code based on the IR it is instantiated with. This can be seen by accessing the code attribute on the GraphModule: `print(symbolic_traced.code)`. - -After symbolic tracing, the code given under [Technical Details](#technical-details) is represented as follows: - -```python -def forward(self, x): - param = self.param - add_1 = x + param; x = param = None - linear_1 = self.linear(add_1); add_1 = None - clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None - return clamp_1 -``` - -This is the core of why FX is a Python-to-Python translation toolkit. Outside users can treat the results of FX transformations as they would any other `nn.Module` instance. diff --git a/pippy/fx/__init__.py b/pippy/fx/__init__.py deleted file mode 100644 index e52cc2619..000000000 --- a/pippy/fx/__init__.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -r''' -FX is a toolkit for developers to use to transform ``nn.Module`` -instances. FX consists of three main components: a **symbolic tracer,** -an **intermediate representation**, and **Python code generation**. A -demonstration of these components in action: - -:: - - import torch - # Simple module for demonstration - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - module = MyModule() - - from pippy.fx import symbolic_trace - # Symbolic tracing frontend - captures the semantics of the module - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(module) - - # High-level intermediate representation (IR) - Graph representation - print(symbolic_traced.graph) - """ - graph(): - %x : [#users=1] = placeholder[target=x] - %param : [#users=1] = get_attr[target=param] - %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) - %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) - %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) - return clamp - """ - - # Code generation - valid Python code - print(symbolic_traced.code) - """ - def forward(self, x): - param = self.param - add = x + param; x = param = None - linear = self.linear(add); add = None - clamp = linear.clamp(min = 0.0, max = 1.0); linear = None - return clamp - """ - -The **symbolic tracer** performs "symbolic execution" of the Python -code. It feeds fake values, called Proxies, through the code. Operations -on theses Proxies are recorded. More information about symbolic tracing -can be found in the :func:`symbolic_trace` and :class:`Tracer` -documentation. - -The **intermediate representation** is the container for the operations -that were recorded during symbolic tracing. It consists of a list of -Nodes that represent function inputs, callsites (to functions, methods, -or :class:`torch.nn.Module` instances), and return values. More information -about the IR can be found in the documentation for :class:`Graph`. The -IR is the format on which transformations are applied. - -**Python code generation** is what makes FX a Python-to-Python (or -Module-to-Module) transformation toolkit. For each Graph IR, we can -create valid Python code matching the Graph's semantics. This -functionality is wrapped up in :class:`GraphModule`, which is a -:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a -``forward`` method generated from the Graph. - -Taken together, this pipeline of components (symbolic tracing -> -intermediate representation -> transforms -> Python code generation) -constitutes the Python-to-Python transformation pipeline of FX. In -addition, these components can be used separately. For example, -symbolic tracing can be used in isolation to capture a form of -the code for analysis (and not transformation) purposes. Code -generation can be used for programmatically generating models, for -example from a config file. There are many uses for FX! - -Several example transformations can be found at the -`examples `__ -repository. -''' - -from .graph_module import GraphModule -from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta -from .graph import Graph, CodeGen -from .node import Node, map_arg -from .proxy import Proxy -from .interpreter import Interpreter as Interpreter, Transformer as Transformer -from .subgraph_rewriter import replace_pattern diff --git a/pippy/fx/__init__.pyi b/pippy/fx/__init__.pyi deleted file mode 100644 index 2faf3b021..000000000 --- a/pippy/fx/__init__.pyi +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .graph import Graph as Graph -from .graph_module import GraphModule as GraphModule -from .node import Node as Node, map_arg as map_arg -from .proxy import Proxy as Proxy -from ._symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap -from .interpreter import Interpreter as Interpreter, Transformer as Transformer -from .subgraph_rewriter import replace_pattern as replace_pattern diff --git a/pippy/fx/_compatibility.py b/pippy/fx/_compatibility.py deleted file mode 100644 index 559232ce2..000000000 --- a/pippy/fx/_compatibility.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Any, Dict -import textwrap - -_BACK_COMPAT_OBJECTS : Dict[Any, None] = {} -_MARKED_WITH_COMATIBLITY : Dict[Any, None] = {} - -def compatibility(is_backward_compatible : bool): - if is_backward_compatible: - - def mark_back_compat(fn): - docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') - docstring += """ -.. note:: - Backwards-compatibility for this API is guaranteed. -""" - fn.__doc__ = docstring - _BACK_COMPAT_OBJECTS.setdefault(fn) - _MARKED_WITH_COMATIBLITY.setdefault(fn) - return fn - - return mark_back_compat - else: - - def mark_not_back_compat(fn): - docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') - docstring += """ -.. warning:: - This API is experimental and is *NOT* backward-compatible. -""" - fn.__doc__ = docstring - _MARKED_WITH_COMATIBLITY.setdefault(fn) - return fn - - return mark_not_back_compat diff --git a/pippy/fx/_pytree.py b/pippy/fx/_pytree.py deleted file mode 100644 index be8a61af2..000000000 --- a/pippy/fx/_pytree.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Callable, Any, Tuple, List, Dict, Type, NamedTuple -from torch.utils._pytree import PyTree, TreeSpec, LeafSpec -from collections import namedtuple - -FlattenFuncSpec = Callable[[PyTree, TreeSpec], List] - -SUPPORTED_NODES: Dict[Type[Any], Any] = {} -def register_pytree_flatten_spec(typ: Any, flatten_fn_spec: FlattenFuncSpec) -> None: - SUPPORTED_NODES[typ] = flatten_fn_spec - -def tree_flatten_spec(pytree: PyTree, spec: TreeSpec) -> List[Any]: - if isinstance(spec, LeafSpec): - return [pytree] - if spec.type not in SUPPORTED_NODES: - raise RuntimeError( - f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with" - "pippy.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make" - "sure that any custom pytrees have been registered before loading it.") - flatten_fn_spec = SUPPORTED_NODES[spec.type] - child_pytrees = flatten_fn_spec(pytree, spec) - result = [] - for child, child_spec in zip(child_pytrees, spec.children_specs): - flat = tree_flatten_spec(child, child_spec) - result += flat - return result - -def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]: - return list([d[k] for k in spec.context]) - -def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]: - return [d[i] for i in range(len(spec.children_specs))] - -def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]: - return [d[i] for i in range(len(spec.children_specs))] - -def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]: - return [d[i] for i in range(len(spec.children_specs))] - -register_pytree_flatten_spec(dict, _dict_flatten_spec) -register_pytree_flatten_spec(list, _list_flatten_spec) -register_pytree_flatten_spec(tuple, _tuple_flatten_spec) -register_pytree_flatten_spec(namedtuple, _tuple_flatten_spec) diff --git a/pippy/fx/_symbolic_trace.py b/pippy/fx/_symbolic_trace.py deleted file mode 100644 index 00937803a..000000000 --- a/pippy/fx/_symbolic_trace.py +++ /dev/null @@ -1,1080 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import builtins -import copy -import functools -import inspect -import math -import os -import warnings -from itertools import chain -from types import CodeType, FunctionType, ModuleType -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Set, - Tuple, - Type, - Union, -) - -import torch -import torch.utils._pytree as pytree -from torch._C import ScriptObject # type: ignore[attr-defined] - -from ._compatibility import compatibility -from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph -from .graph_module import GraphModule -from .node import Argument, base_types, map_aggregate # pylint: disable=unused-import -from .proxy import ParameterProxy, Proxy, TracerBase - -HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS - -# These need to run in global scope to handle nested calls correctly -_orig_module_call: Callable = torch.nn.Module.__call__ -_orig_module_getattr: Callable = torch.nn.Module.__getattr__ - -_proxyable_classes: Dict[Type, None] = {} - -_is_fx_tracing_flag = False - - -def is_fx_tracing(): - return _is_fx_tracing_flag - - -@compatibility(is_backward_compatible=True) -class ProxyableClassMeta(type): - """ - ProxyableClassMeta allows you to make construction of a given Python class - symbolically traceable. For example:: - - import torch - import pippy.fx - - class TensorPair(metaclass=pippy.fx.ProxyableClassMeta): - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): - s = x.add(TensorPair(y, y)) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - y = torch.randn(5, 3) - ref_out = use_tensor_pair_ctor(x, y) - - traced = pippy.fx.symbolic_trace(use_tensor_pair_ctor) - print(traced.code) - ''' - def forward(self, x : __main___TensorPair, y : torch.Tensor): - tensor_pair = __main___TensorPair(y, y); y = None - add = x.add(tensor_pair); tensor_pair = None - mul = add.mul(x); add = x = None - return mul - ''' - - From this example, we can see that contruction of a class (``TensorPair``) - defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic - tracing. - """ - - def __init__(cls, name, bases, attrs): - _proxyable_classes.setdefault(cls) - super().__init__(name, bases, attrs) - - def __call__(cls, *args, **kwargs): - instance = cls.__new__(cls) # type: ignore[call-overload] - - found_proxies = [] - - def check_proxy(a): - if isinstance(a, Proxy): - found_proxies.append(a) - - map_aggregate(args, check_proxy) - map_aggregate(kwargs, check_proxy) - - if len(found_proxies) != 0: - tracer = found_proxies[0].tracer - return tracer.create_proxy("call_function", cls, args, kwargs) - else: - cls.__init__(instance, *args, **kwargs) # type: ignore[misc] - return instance - - -def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: - co = fn.__code__ - co_flags = co.co_flags & ~HAS_VARSTUFF - co_args: tuple - if hasattr(co, "co_posonlyargcount"): - co_args = ( - nargs, - 0, - 0, - co.co_nlocals, - co.co_stacksize, - co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_freevars, - co.co_cellvars, - ) - else: - co_args = ( - nargs, - 0, - co.co_nlocals, - co.co_stacksize, - co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_freevars, - co.co_cellvars, - ) - new_code = CodeType(*co_args) # type: ignore[arg-type] - return FunctionType( - new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ - ) - - # we need to insert placeholder nodes for *args and **kwargs - # we can't call this function normally, otherwise it would try to unpack them - # instead, let's make python think that args and kwargs are normal variables - - -@compatibility(is_backward_compatible=False) -class PHBase(object): - """ - Object representing an input placeholder to `concrete_args` - """ - - def __repr__(self): - return "PH" - - -PH = PHBase() - - -@compatibility(is_backward_compatible=True) -class Tracer(TracerBase): - # Reference: https://github.com/pytorch/pytorch/issues/54354 - # The first line of this docstring overrides the one Sphinx generates for the - # documentation. We need it so that Sphinx doesn't leak `math`s path from the - # build environment (e.g. ` None: - # This method's signature is overridden by the first line of this class' - # docstring. If this method's signature is modified, the signature that - # overrides it also should be modified accordingly. - - """ - Construct a Tracer object. - - Args: - - autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, - Python modules whose functions should be wrapped automatically - without needing to use fx.wrap(). Backward-compatibility for - this parameter is guaranteed. - - autowrap_function (Tuple[Callable, ...]): defaults to `()`, - Python functions that should be wrapped automatically without - needing to use fx.wrap(). Backward compabilibility for this - parameter is guaranteed. - - param_shapes_constant (bool): When this flag is set, calls to shape, - size and a few other shape like attributes of a module's parameter - will be evaluted directly, rather than returning a new Proxy value - for an attribute access. Backward compatibility for this parameter - is guaranteed. - """ - - super().__init__() - - # Functions we will eagerly wrap when we see them while tracing - # this captures both `math.sqrt()` and `from math import sqrt` automatically - self._autowrap_function_ids: Set[int] = { - id(value) - for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) - if not name.startswith("_") and callable(value) - } - self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) - - # Python modules to apply autowrap to at the start, in addition to - # modules we see while tracing - self._autowrap_search: List[ModuleType] = list(autowrap_modules) - self.param_shapes_constant = param_shapes_constant - - self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None - - @compatibility(is_backward_compatible=True) - def create_arg(self, a: Any) -> "Argument": - """ - A method to specify the behavior of tracing when preparing values to - be used as arguments to nodes in the ``Graph``. - - By default, the behavior includes: - - #. Iterate through collection types (e.g. tuple, list, dict) and recursively - call ``create_args`` on the elements. - #. Given a Proxy object, return a reference to the underlying IR ``Node`` - #. Given a non-Proxy Tensor object, emit IR for various cases: - - * For a Parameter, emit a ``get_attr`` node referring to that Parameter - * For a non-Parameter Tensor, store the Tensor away in a special - attribute referring to that attribute. - - This method can be overridden to support more types. - - Args: - - a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. - - - Returns: - - The value ``a`` converted into the appropriate ``Argument`` - """ - # The base tracer is used to construct Graphs when there is no associated - # module hierarchy, so it can never create parameter references. - # The default tracer adds the ability to refer to parameters when - # tracing modules. - if isinstance(a, torch.nn.Parameter): - for n, p in self.root.named_parameters(): - if a is p: - return self.create_node("get_attr", n, (), {}) - raise NameError("parameter is not a member of this module") - elif isinstance(a, torch.Tensor): - for n_, p_ in self.root.named_buffers(): - if a is p_: - return self.create_node("get_attr", n_, (), {}) - elif isinstance(a, torch.nn.Module): - for n_, p_ in self.root.named_modules(): - if a is p_: - return self.create_node("get_attr", n_, (), {}) - # For NamedTuple instances that appear literally as args, we emit - # a node to construct the NamedTuple and use that Node as the argument. - if isinstance(a, tuple) and hasattr(a, "_fields"): - args = tuple(self.create_arg(elem) for elem in a) - return self.create_node("call_function", a.__class__, args, {}) - - # Tensors do not have a reliable string repr() from which they can be - # constructed (and we probably don't want to rely on that, either), so - # for any constant Tensor values we encounter, first search for if they - # are an attribute of some module in the module hierarchy. If so, emit - # a get_attr to retrieve that tensor. Otherwise, we'll store away the - # tensor value into a special attribute on the Module s.t. we can - # retrieve it with a get_attr. - if isinstance(a, (torch.Tensor, ScriptObject)): - qualname: Optional[str] = self.tensor_attrs.get(a) - - # Tensor was not found in the Module hierarchy, stow it away in a - # special attribute and set the qualname to refer to that - if not qualname: - i = 0 - while True: - qualname = f"_tensor_constant{i}" - if not hasattr(self.root, qualname): - break - i += 1 - self.tensor_attrs[a] = qualname - setattr(self.root, qualname, a) - - return self.create_node("get_attr", qualname, (), {}) - - if type(a) in _proxyable_classes: - # This is an instance of a proxyable class for which we did not - # witness its construction. Intern this as a constant attribute - - # TODO: binary search - i = 0 - while True: - qualname = f"_{a.__class__.__name__}_constant_{i}" - if not hasattr(self.root, qualname): - break - i += 1 - setattr(self.root, qualname, a) - - return self.create_node("get_attr", qualname, (), {}) - - return super().create_arg(a) - - @compatibility(is_backward_compatible=True) - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: - """ - A method to specify whether a given ``nn.Module`` is a "leaf" module. - - Leaf modules are the atomic units that appear in - the IR, referenced by ``call_module`` calls. By default, - Modules in the PyTorch standard library namespace (torch.nn) - are leaf modules. All other modules are traced through and - their constituent ops are recorded, unless specified otherwise - via this parameter. - - Args: - - m (Module): The module being queried about - module_qualified_name (str): The path to root of this module. For example, - if you have a module hierarchy where submodule ``foo`` contains - submodule ``bar``, which contains submodule ``baz``, that module will - appear with the qualified name ``foo.bar.baz`` here. - """ - return ( - (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) - and not isinstance(m, torch.nn.Sequential) - ) - - @compatibility(is_backward_compatible=True) - def path_of_module(self, mod: torch.nn.Module) -> str: - """ - Helper method to find the qualified name of ``mod`` in the Module hierarchy - of ``root``. For example, if ``root`` has a submodule named ``foo``, which has - a submodule named ``bar``, passing ``bar`` into this function will return - the string "foo.bar". - - Args: - - mod (str): The ``Module`` to retrieve the qualified name for. - """ - # Prefer the O(1) algorithm - if self.submodule_paths: - path = self.submodule_paths.get(mod) - if path is None: - raise NameError("module is not installed as a submodule") - assert isinstance(path, str) - return path - # O(N^2) fallback in the case that we didn't store the submodule - # paths. - else: - for n, p in self.root.named_modules(): - if mod is p: - return n - raise NameError("module is not installed as a submodule") - - @compatibility(is_backward_compatible=True) - def call_module( - self, - m: torch.nn.Module, - forward: Callable[..., Any], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - ) -> Any: - """ - Method that specifies the behavior of this ``Tracer`` when it encounters - a call to an ``nn.Module`` instance. - - By default, the behavior is to check if the called module is a leaf module - via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to - ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through - the operations in its ``forward`` function. - - This method can be overridden to--for example--create nested traced - GraphModules, or any other behavior you would want while tracing across - ``Module`` boundaries. - - Args: - - m (Module): The module for which a call is being emitted - forward (Callable): The forward() method of the ``Module`` to be invoked - args (Tuple): args of the module callsite - kwargs (Dict): kwargs of the module callsite - - Return: - - The return value from the Module call. In the case that a ``call_module`` - node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever - value was returned from the ``Module`` invocation. - """ - module_qualified_name = self.path_of_module(m) - if not self.is_leaf_module(m, module_qualified_name): - return forward(*args, **kwargs) - return self.create_proxy("call_module", module_qualified_name, args, kwargs) - - @compatibility(is_backward_compatible=False) - def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): - """ - Method that specifies the behavior of this ``Tracer`` when we call getattr - on a call to an ``nn.Module`` instance. - - By default, the behavior is to return a proxy value for the attribute. It - also stores the proxy value in the ``parameter_proxy_cache``, so that future - calls will reuse the proxy rather than creating a new one. - - This method can be overridden to --for example-- not return proxies when - querying parameters. - - Args: - - attr (str): The name of the attribute being queried - attr_val (Any): The value of the attribute - parametr_proxy_cache (Dict[str, Any]): A cache of attr names to proxies - - Return: - - The return value from the getattr call. - """ - def maybe_get_proxy_for_attr( - attr_val, collection_to_search, parameter_proxy_cache - ): - for n, p in collection_to_search: - if attr_val is p: - if n not in parameter_proxy_cache: - kwargs = {} - if ( - "proxy_factory_fn" - in inspect.signature(self.create_proxy).parameters - ): - kwargs["proxy_factory_fn"] = ( - None - if not self.param_shapes_constant - else lambda node: ParameterProxy( - self, node, n, attr_val - ) - ) - val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] - parameter_proxy_cache[n] = val_proxy - return parameter_proxy_cache[n] - return None - - if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr( - attr_val, self.root.named_parameters(), parameter_proxy_cache - ) - if maybe_parameter_proxy is not None: - return maybe_parameter_proxy - - if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): - maybe_buffer_proxy = maybe_get_proxy_for_attr( - attr_val, self.root.named_buffers(), parameter_proxy_cache - ) - if maybe_buffer_proxy is not None: - return maybe_buffer_proxy - - return attr_val - - # This method will be refactored - @compatibility(is_backward_compatible=False) - def create_args_for_root(self, root_fn, is_module, concrete_args=None): - """ - Create ``placeholder`` nodes corresponding to the signature of the ``root`` - Module. This method introspects root's signature and emits those - nodes accordingly, also supporting ``*args`` and ``**kwargs``. - """ - # In some cases, a function or method has been decorated with a wrapper - # defined via ``functools.wraps``. In this case, the outer code object - # will likely not contain the actual parameters we care about, so unwrap - # the function to get to the innermost callable. - fn_for_analysis = inspect.unwrap(root_fn) - co = fn_for_analysis.__code__ - total_args = co.co_argcount + co.co_kwonlyargcount - orig_args = list(co.co_varnames) - names_iter = iter(co.co_varnames) - args: List[Any] = [] - skip_arg_idx = 0 - if is_module: - if total_args == 0: - raise RuntimeError( - "``self`` argument cannot be part of *args expansion!" - ) - skip_arg_idx = 1 - next(names_iter) # skip self - args.append(self.root) - - sig = inspect.signature(fn_for_analysis) - - def proxy_placeholder(name: str): - if concrete_args is not None and name in concrete_args: - cnt = 0 - - def replace_ph(x): - nonlocal cnt - cnt += 1 - param = sig.parameters[name] - default = ( - () - if param.default is inspect.Parameter.empty - else (param.default,) - ) - out = self.create_proxy( - "placeholder", f"{name}_{str(cnt)}", default, {} - ) - if x == PH: - return out - # Union[int, bool] == bool in Python <= 3.6 - if ( - type(x) == bool - or type(x) in base_types - and type(x) != torch.Tensor - ): - torch._assert( - out == x, - f"{name} has been specialized to have value {x} but got another value", - ) - elif type(x) == type(None): - args = ( - out, - f"{name} has been specialized to have value None but got another value", - ) - self.create_proxy("call_function", _assert_is_none, args, {}) - else: - warnings.warn( - f"Was not able to add assertion to guarantee correct input {name} to " - f"specialized function. It is up to the user to make sure that your inputs match the " - f"inputs you specialized the function with." - ) - - return x - - return pytree.tree_map(replace_ph, concrete_args[name]) - if name[0] == "*": - default = () - else: - param = sig.parameters[name] - default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] - return self.create_proxy( - "placeholder", - name, - default, - {}, - type_expr=fn_for_analysis.__annotations__.get(name, None), - ) - - arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] - if isinstance(concrete_args, tuple): - if len(arg_names) != len(concrete_args): - raise RuntimeError( - f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" - ) - concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} - args.extend(proxy_placeholder(names) for names in arg_names) - - if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: - # TODO: type annotations for *args and **kwargs - if co.co_flags & inspect.CO_VARARGS: - args.append(proxy_placeholder("*" + next(names_iter))) - if co.co_flags & inspect.CO_VARKEYWORDS: - args.append(proxy_placeholder("**" + next(names_iter))) - root_fn = _patch_function(root_fn, len(args)) - - flat_args, in_spec = pytree.tree_flatten(tuple(args)) - if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs): - # In the case that we have pytree-flattened inputs in - # `concrete_args`, generate a flattening wrapper around the - # original root function and return that. - self.graph._codegen = _PyTreeCodeGen( - _PyTreeInfo(orig_args[:total_args], in_spec, None) - ) - - def flatten_fn(*args): - tree_args = pytree.tree_unflatten(list(args), in_spec) - tree_out = root_fn(*tree_args) - out_args, out_spec = pytree.tree_flatten(tree_out) - assert isinstance(self.graph._codegen, _PyTreeCodeGen) - self.graph._codegen.pytree_info = ( - self.graph._codegen.pytree_info._replace(out_spec=out_spec) - ) - return out_args - - return flatten_fn, flat_args - return root_fn, args - - @compatibility(is_backward_compatible=True) - def trace( - self, - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, - ) -> Graph: - """ - Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` - can either be an ``nn.Module`` instance or a Python callable. - - Note that after this call, ``self.root`` may be different from the ``root`` passed - in here. For example, when a free function is passed to ``trace()``, we will - create an ``nn.Module`` instance to use as the root and add embedded constants - to. - - - Args: - - root (Union[Module, Callable]): Either a ``Module`` or a function to be - traced through. Backwards-compatibility for this parameter is - guaranteed. - concrete_args (Optional[Dict[str, any]]): Concrete arguments that should - not be treated as Proxies. This parameter is experimental and - its backwards-compatibility is *NOT* guaranteed. - - Returns: - - A ``Graph`` representing the semantics of the passed-in ``root``. - """ - global _is_fx_tracing_flag - old_is_fx_tracing_flag = _is_fx_tracing_flag - _is_fx_tracing_flag = True - try: - if isinstance(root, torch.nn.Module): - self.root = root - - assert hasattr( - type(root), self.traced_func_name - ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" - - fn = getattr(type(root), self.traced_func_name) - self.submodule_paths = {mod: name for name, mod in root.named_modules()} - else: - self.root = torch.nn.Module() - fn = root - - tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None) - self.graph = Graph(tracer_cls=tracer_cls) - - # When we encounter a Tensor value that's not a parameter, we look if it - # is some other attribute on the model. Construct a dict mapping Tensor - # values to the qualified name here for efficiency. This is used downstream - # in create_arg - self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} - - def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): - for k, v in m.__dict__.items(): - if isinstance(v, (torch.Tensor, ScriptObject)): - self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) - for k, v in m.named_children(): - collect_tensor_attrs(v, prefix_atoms + [k]) - - collect_tensor_attrs(self.root, []) - - assert isinstance(fn, FunctionType) - - fn_globals = fn.__globals__ # run before it gets patched - fn, args = self.create_args_for_root( - fn, isinstance(root, torch.nn.Module), concrete_args - ) - - parameter_proxy_cache: Dict[ - str, Proxy - ] = {} # Reduce number of get_attr calls - - # Method dispatch on parameters is not recorded unless it's directly used. - # Thus, we need to insert a proxy when __getattr__ requests a parameter. - @functools.wraps(_orig_module_getattr) - def module_getattr_wrapper(mod, attr): - attr_val = _orig_module_getattr(mod, attr) - return self.getattr(attr, attr_val, parameter_proxy_cache) - - @functools.wraps(_orig_module_call) - def module_call_wrapper(mod, *args, **kwargs): - def forward(*args, **kwargs): - return _orig_module_call(mod, *args, **kwargs) - - _autowrap_check( - patcher, - getattr(getattr(mod, "forward", mod), "__globals__", {}), - self._autowrap_function_ids, - ) - return self.call_module(mod, forward, args, kwargs) - - with _Patcher() as patcher: - # allow duplicate patches to support the case of nested calls - patcher.patch_method( - torch.nn.Module, - "__getattr__", - module_getattr_wrapper, - deduplicate=False, - ) - patcher.patch_method( - torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False - ) - _patch_wrapped_functions(patcher) - _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) - for module in self._autowrap_search: - _autowrap_check( - patcher, module.__dict__, self._autowrap_function_ids - ) - self.create_node( - "output", - "output", - (self.create_arg(fn(*args)),), - {}, - type_expr=fn.__annotations__.get("return", None), - ) - - self.submodule_paths = None - finally: - _is_fx_tracing_flag = old_is_fx_tracing_flag - return self.graph - - def __deepcopy__(self, memo): - # _autowrap_search contains modules, which cannot be deepcopied. - new_tracer = Tracer.__new__(Tracer) - - for k, v in self.__dict__.items(): - if k in {'_autowrap_search'}: - new_obj = copy.copy(v) - else: - new_obj = copy.deepcopy(v, memo) - - new_tracer.__dict__[k] = new_obj - - return new_tracer - - -# List of pairs of (global dict, function name) functions -# to patch for the purposes of the wrap() API. -_wrapped_fns_to_patch: List[Tuple[dict, str]] = [] - -# List of methods on classes to wrap (class type, function name) -# this currently only works for Tensor.* methods that aren't traced properly -_wrapped_methods_to_patch: List[Tuple[type, str]] = [] - -if os.environ.get("FX_PATCH_GETITEM") == "1": - # This change is needed to trace models like PositionalEmbedding from BERT: - # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py - # but causes issues in quantization documented here: - # https://github.com/pytorch/pytorch/issues/50710 - # once that is fixed we can make this the default behavior. - _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) - - -def _find_proxy(*objects_to_search): - """ - Recursively search a data structure for a Proxy() and return it, - return None if not found. - """ - proxy = None - - def find_proxy(x): - nonlocal proxy - if isinstance(x, Proxy): - proxy = x - - map_aggregate(objects_to_search, find_proxy) - return proxy - - -def _create_wrapped_func(orig_fn): - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - """ - Given an closed-over ``orig_function`` to invoke, search the args and kwargs for - a Proxy object. If there is one, emit a ``call_function`` node to preserve the - call to this leaf function directly. Otherwise, just return the results of - this function call, as this function is not being traced. - """ - proxy = _find_proxy(args, kwargs) - if proxy is not None: - return_proxy = proxy.tracer.create_proxy( - "call_function", orig_fn, args, kwargs - ) - return_proxy.node.meta["is_wrapped"] = True - return return_proxy - return orig_fn(*args, **kwargs) - - return wrapped - - -def _create_wrapped_method(cls, name): - orig_fn = getattr(cls, name) - - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - """ - Search the args and kwargs for a Proxy object. If there is one, - emit a ``call_method`` node to preserve the call to this method - directly. Otherwise, just return the results of this function - call, as this function is not being traced. - """ - proxy = _find_proxy(args, kwargs) - if proxy is not None: - return proxy.tracer.create_proxy("call_method", name, args, kwargs) - return orig_fn(*args, **kwargs) - - return wrapped - - -class _PatchedFn(NamedTuple): - frame_dict: Any - fn_name: str - orig_fn: Any - - def revert(self): - raise NotImplementedError() - - -class _PatchedFnSetItem(_PatchedFn): - def revert(self): - self.frame_dict[self.fn_name] = self.orig_fn - - -class _PatchedFnDel(_PatchedFn): - def revert(self): - del self.frame_dict[self.fn_name] - - -class _PatchedFnSetAttr(_PatchedFn): - def revert(self): - setattr(self.frame_dict, self.fn_name, self.orig_fn) - - -class _Patcher(object): - def __init__(self): - super(_Patcher, self).__init__() - self.patches_made: List[_PatchedFn] = [] - self.visited: Set[int] = set() - - def patch( - self, - frame_dict: Dict[str, Any], - name: str, - new_fn: Callable, - deduplicate: bool = True, - ): - """ - Replace frame_dict[name] with new_fn until we exit the context manager. - """ - new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] - if name not in frame_dict and hasattr(builtins, name): - self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) - elif getattr(frame_dict[name], "__fx_already_patched", False): - return # already patched, no need to do it again - else: - self.patches_made.append( - _PatchedFnSetItem(frame_dict, name, frame_dict[name]) - ) - frame_dict[name] = new_fn - - def patch_method( - self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True - ): - """ - Replace object_or_dict.name with new_fn until we exit the context manager. - """ - new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] - orig_fn = getattr(cls, name) - if getattr(orig_fn, "__fx_already_patched", False): - return # already patched, no need to do it again - self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) - setattr(cls, name, new_fn) - - def visit_once(self, thing: Any): - """Return True on the first call to with thing, otherwise false""" - idx = id(thing) - if idx in self.visited: - return False - self.visited.add(idx) - return True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Undo all the changes made via self.patch() and self.patch_method() - """ - while self.patches_made: - # unpatch in reverse order to handle duplicates correctly - self.patches_made.pop().revert() - self.visited.clear() - - -def _patch_wrapped_functions(patcher: _Patcher): - """ - Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap - the listed global functions in the `_create_wrapped_func` wrapper. - """ - for frame_dict, name in _wrapped_fns_to_patch: - if name not in frame_dict and hasattr(builtins, name): - orig_fn = getattr(builtins, name) - else: - orig_fn = frame_dict[name] - patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) - - for cls, name in _wrapped_methods_to_patch: - patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) - - -def _autowrap_check( - patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] -): - """ - Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. - This method searches a scope for them and patches them if found. - """ - if patcher.visit_once(frame_dict): - for name, value in frame_dict.items(): - if ( - not name.startswith("_") - and callable(value) - and id(value) in function_ids - ): - patcher.patch(frame_dict, name, _create_wrapped_func(value)) - - -@compatibility(is_backward_compatible=True) -def wrap(fn_or_name: Union[str, Callable]): - """ - This function can be called at module-level scope to register fn_or_name as a "leaf function". - A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being - traced through:: - - # foo/bar/baz.py - def my_custom_function(x, y): - return x * x + y * y - - pippy.fx.wrap('my_custom_function') - - def fn_to_be_traced(x, y): - # When symbolic tracing, the below call to my_custom_function will be inserted into - # the graph rather than tracing it. - return my_custom_function(x, y) - - This function can also equivalently be used as a decorator:: - - # foo/bar/baz.py - @pippy.fx.wrap - def my_custom_function(x, y): - return x * x + y * y - - A wrapped function can be thought of a "leaf function", analogous to the concept of - "leaf modules", that is, they are functions that are left as calls in the FX trace - rather than traced through. - - Args: - - fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the - graph when it's called - """ - if not callable(fn_or_name) and not isinstance(fn_or_name, str): - raise RuntimeError( - "Unsupported type for global function! Must be either a callable or " - "string name" - ) - - if callable(fn_or_name): - assert not isinstance(fn_or_name, str) # to make mypy happy - fn_name = fn_or_name.__name__ - else: - assert isinstance( - fn_or_name, str - ), "fn_or_name must be a global function or string name" - fn_name = fn_or_name - - currentframe = inspect.currentframe() - assert currentframe is not None - f = currentframe.f_back - assert f is not None - if f.f_code.co_name != "": - raise NotImplementedError("wrap must be called at the top level of a module") - - # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search - # semantics would be slightly different, but would add support `from x import wrapped_function` - _wrapped_fns_to_patch.append((f.f_globals, fn_name)) - return fn_or_name - - -@compatibility(is_backward_compatible=True) -def symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, -) -> GraphModule: - """ - Symbolic tracing API - - Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` - constructed by recording operations seen while tracing through ``root``. - - ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. - - For example:: - - def f(a, b): - if b == True: - return a - else: - return a*2 - - FX can typically not trace through this due to the presence of control - flow. However, we can use `concrete_args` to specialize on the value of - `b` to trace through this. - - f = fx.symbolic_trace(f, concrete_args={'b': False}) - assert f(3, False) == 6 - - Note that although you can still pass in different values of `b`, they will be ignored. - - We can also use `concrete_args` to eliminate data-structure handling from - our function. This will use pytrees to flatten your input. To avoid - overspecializing, pass in `fx.PH` for values that shouldn't be - specialized. For example:: - - def f(x): - out = 0 - for v in x.values(): - out += v - return out - f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) - assert f({'a': 1, 'b': 2, 'c': 4}) == 7 - - - Args: - root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted - into a Graph representation. - concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized - - Returns: - GraphModule: a Module created from the recorded operations from ``root``. - """ - tracer = Tracer() - graph = tracer.trace(root, concrete_args) - name = ( - root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - ) - return GraphModule(tracer.root, graph, name) - - -@wrap -def _assert_is_none(value, msg): - assert value is None, msg diff --git a/pippy/fx/annotate.py b/pippy/fx/annotate.py deleted file mode 100644 index 906b9b811..000000000 --- a/pippy/fx/annotate.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.proxy import Proxy -from ._compatibility import compatibility - -@compatibility(is_backward_compatible=False) -def annotate(val, type): - # val could be either a regular value (not tracing) - # or fx.Proxy (tracing) - if isinstance(val, Proxy): - if val.node.type: - raise RuntimeError(f"Tried to annotate a value that already had a type on it!" - f" Existing type is {val.node.type} " - f"and new type is {type}. " - f"This could happen if you tried to annotate a function parameter " - f"value (in which case you should use the type slot " - f"on the function signature) or you called " - f"annotate on the same value twice") - else: - val.node.type = type - return val - else: - return val diff --git a/pippy/fx/experimental/__init__.py b/pippy/fx/experimental/__init__.py deleted file mode 100644 index f2661b8c6..000000000 --- a/pippy/fx/experimental/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates diff --git a/pippy/fx/experimental/accelerator_partitioner.py b/pippy/fx/experimental/accelerator_partitioner.py deleted file mode 100644 index a3254cb45..000000000 --- a/pippy/fx/experimental/accelerator_partitioner.py +++ /dev/null @@ -1,1083 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import operator -from typing import Dict, List, Set, NamedTuple, Tuple - -import torch -from pippy.fx.passes.graph_manipulation import get_size_of_all_nodes -from pippy.fx.experimental.partitioner_utils import ( - Partition, - Device, - PartitionerConfig, - get_partition_to_latency_mapping, - get_latency_of_partitioned_graph, - NodeLatency, - get_extra_size_of, - PartitionMode, -) -from pippy.fx.graph_module import GraphModule -from pippy.fx.node import Node, map_arg -from pippy.fx.passes.split_module import split_module - - -class DAGNode: - """DAGNode class maintains useful information for a partition (submodule), - and its input submodules and output submodules. - """ - - def __init__( - self, - submodule_node: Node, - input_nodes: List[Node], - output_nodes: List[Node], - logical_device_ids: List[int], - size_bytes: int, - ) -> None: - self.submodule_node: Node = submodule_node - self.input_nodes: List[Node] = input_nodes - self.output_nodes: List[Node] = output_nodes - self.logical_device_ids: List[int] = logical_device_ids - self.size_bytes = size_bytes - - def __str__(self) -> str: - return str(self.submodule_node) - - -class DAG: - """DAG class contains all the DAG nodes""" - - def __init__(self) -> None: - self.nodes: List[DAGNode] = [] - - def create_node( - self, - submodule_node: Node, - input_nodes: List[Node], - output_nodes: List[Node], - logical_devices: List[int], - size_bytes: int, - ) -> None: - node = DAGNode( - submodule_node, input_nodes, output_nodes, logical_devices, size_bytes - ) - self.nodes.append(node) - - -class PartitionResult(NamedTuple): - """NameTuple used for returning DAG and a new fx module""" - - dag: DAG - module_with_submodules: GraphModule - - -"""Followings are some helper functions for partition manipulation""" - - -def reset_partition_device(partitions): - for partition in partitions: - partition.logical_device_ids = [] - - -def combine_two_partitions( - partition_0: Partition, partition_1: Partition, partitions: List[Partition] -) -> None: - """Given a list of partitions and its two partitions, - combine these two partitions into a new one appending to the partitions - and remove the previous two partitions from the list of partitions - """ - partition = Partition(len(partitions)) - partition.nodes = partition_0.nodes.union(partition_1.nodes) - partition.recalculate_mem_size() - partitions.append(partition) - partitions.remove(partition_0) - partitions.remove(partition_1) - reorganize_partitions(partitions) - return - - -def set_parents_and_children(partitions: List[Partition]) -> None: - """Given a list of partitions, mark parents and children for each partition""" - # Go through all nodes in a partition. - # If a node's user is in other partition, - # then the other partition is this partition's children. - # This partition is the other partition's parent - for partition in partitions: - partition.children = set() - partition.parents = set() - for partition in partitions: - for node in partition.nodes: - # For each node in the current partition, find its users - users = node.users - for n in users: - # Find which the partition the user node belongs to. - # Note that if the node itself is also belongs to that partition, - # that partition is not the child of the current partition - for p in partitions: - if p != partition and n in p.nodes and node not in p.nodes: - partition.children.add(p) - p.parents.add(partition) - return - - -def reorganize_partitions(partitions: List[Partition]) -> None: - """Given a list of partitions, reorganzie partiton id, - its parents and its children for each partition - """ - # Rearrange partition ids - for i, partition in enumerate(partitions): - partition.partition_id = i - set_parents_and_children(partitions) - return - - -def get_bfs_level_partition(partitions: List[Partition]) -> None: - """Given a list of partitions, - mark the bfs level for each partition - """ - current_level: Set[Partition] = set() - visited: Set[Partition] = set() - for partition in partitions: - # If a partition has no parent, it should be in root level - if len(partition.parents) == 0: - current_level.add(partition) - next_level: Set[Partition] = set() - level = 0 - # bfs - while current_level: - partition = current_level.pop() - partition.bfs_level = level - visited.add(partition) - children = partition.children - for child in children: - if child not in next_level: - next_level.add(child) - if not current_level: - current_level = next_level.copy() - next_level = set() - level += 1 - return - - -def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]: - """Given a list of partitions,return node to partition mapping""" - node_to_partition: Dict[Node, int] = {} - for partition in partitions: - for node in partition.nodes: - node_to_partition[node] = partition.partition_id - return node_to_partition - - -def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]: - """Get a mapping from device logical ID to Device object.""" - logical_id_to_device: Dict[int, Device] = {} - for d in devices: - logical_id_to_device[d.logical_id] = d - return logical_id_to_device - - -def get_device_partition_stats( - partitions: List[Partition], devices: List[Device] -) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]: - """Given a list of partitions and a list of devices, returns: - 1. A mapping from device to partitions on it; - 2. A mapping from device to its remaining memory size; - 3. A list of partitions that do not have a device. - """ - # logical id to device - logical_id_to_device = get_logical_id_to_device(devices) - # Track partitions on device - device_to_partitions: Dict[Device, List[Partition]] = {} - # Track device's left mem size - device_to_left_mem_bytes: Dict[Device, int] = {} - for d in devices: - device_to_partitions[d] = [] - device_to_left_mem_bytes[d] = d.available_mem_bytes - - # Deal with the partitions that already have a device - # and also collect all partitions without a device (no_device_partitions) - no_device_partitions = [] - for partition in partitions: - if partition.logical_device_ids != []: - for logical_id in partition.logical_device_ids: - device = logical_id_to_device[logical_id] - device_to_partitions[device].append(partition) - device_to_left_mem_bytes[device] -= partition.used_mem_bytes - else: - no_device_partitions.append(partition) - - return ( - device_to_partitions, - device_to_left_mem_bytes, - no_device_partitions, - ) - - -def get_device_to_partitions_mapping( - partitions: List[Partition], devices: List[Device] -): - """Given a list of partitions and a list of devices, - map each partition into a device. - """ - - def calculate_extra_mem_bytes_needed_for( - partition: Partition, partitions: List[Partition] - ): - all_nodes: Set[Node] = set() - for p in partitions: - all_nodes = all_nodes.union(p.nodes) - if len(all_nodes) == 0: - return partition.used_mem_bytes - all_nodes = all_nodes.union(partition.nodes) - extra_size_needed = 0 - for node in partition.nodes: - extra_size_needed += get_extra_size_of(node, all_nodes) - return extra_size_needed - - def find_device_for(partition: Partition): - """Given a partition, find a logical device for the partition - The algorithm is to put the partition on the device - that has just enough mem left for that partition. - device_to_left_mem_bytes is a dictionary between device and its left mem size - sorted by its left mem size - """ - for d in device_to_left_mem_bytes: - extra_size_needed = calculate_extra_mem_bytes_needed_for( - partition, device_to_partitions[d] - ) - if extra_size_needed < device_to_left_mem_bytes[d]: - device_to_partitions[d].append(partition) - partition.logical_device_ids.append(d.logical_id) - device_to_left_mem_bytes[d] -= extra_size_needed - return True - return False - - ( - device_to_partitions, - device_to_left_mem_bytes, - no_device_partitions, - ) = get_device_partition_stats(partitions, devices) - - # Find devices for all the partitions without a device - found_device = True - for partition in no_device_partitions: - device_to_left_mem_bytes = { - d: left_mem_bytes - for d, left_mem_bytes in sorted( - device_to_left_mem_bytes.items(), key=lambda item: item[1] - ) - } - found_device = find_device_for(partition) - if not found_device: - break - return found_device - - -def check_dependency(partition): - """Given a partition,check if there is a circular dependency on - this partition using bfs - """ - visited: Set[Partition] = set([partition]) - queue: List[Partition] = [partition] - while queue: - p = queue.pop(0) - for child in p.children: - if child == partition: - return True - else: - if child not in visited: - visited.add(child) - queue.append(child) - return False - - -class Partitioner: - """A fx module may not fit into one device. - Partitioner class helps partition one fx module into submodules (partitions), - so that the submodules can be executed crossing different accelerators. - The main function of this class is self.partition_graph. - It partitions the fx module based on the scheme specified in partition_config - A DAG structure is returned - along with a new fx module with submodule nodes. - """ - - def __init__(self) -> None: - self.partitions: List[Partition] = [] - self.node_to_partition: Dict[Node, int] = {} - self.devices: List[Device] = [] - - def partition_graph( - self, - fx_module: GraphModule, - torch_module: torch.nn.Module, - partitioner_config: PartitionerConfig, - ) -> PartitionResult: - """Given the fx module, torch module and partitioner_config, - find the partitions, do the partitions, - and then return a DAG and a new fx module with submodule nodes (partitions) - """ - self.graph_module = fx_module - self.torch_module = torch_module - self.devices = partitioner_config.devices - if len(self.devices) == 0: - raise RuntimeError("No devices") - # Tag the size in bytes to all nodes in the graph_module. - get_size_of_all_nodes(self.graph_module) - # Check if there are op nodes in the fx module - nodes = self.graph_module.graph.nodes - if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes): - raise RuntimeError("No Partition since no operations in the module") - # Calculate total size of the fx module - total_size_of_graph = 0 - for node in nodes: - if node.op == "output": - break - total_size_of_graph += node.size_bytes.total_size - # Find the device with the max mem size - device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) - # AOT based partition - if partitioner_config.mode == PartitionMode.aot_based: - self.aot_based_partition( - partitioner_config.node_to_partition_mapping, - partitioner_config.partition_to_logical_device_mapping, - ) - # Single partition if the whole module can be fit into one device - elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: - self.find_single_partition( - total_size_of_graph, logical_device_id=device_with_max_mem.logical_id - ) - elif total_size_of_graph > sum([d.available_mem_bytes for d in self.devices]): - raise RuntimeError("Devices have no enough memory for the module") - else: - # Sparse nn based partition - if partitioner_config.mode == PartitionMode.sparse_nn: - available_mem_bytes = self.devices[0].available_mem_bytes - if not all( - device.available_mem_bytes == available_mem_bytes - for device in self.devices - ): - raise RuntimeError("All devices must have same memory size!") - # sparse_nn_partition only support same memory size - # TODO: add different size support for sparse_nn_partition - self.sparse_nn_partition(available_mem_bytes) - # Cost aware partition - elif partitioner_config.mode == PartitionMode.cost_aware: - self.cost_aware_partition( - partitioner_config.transfer_rate_bytes_per_sec, - partitioner_config.node_to_latency_mapping, - ) - # KL based partition - elif partitioner_config.mode == PartitionMode.kl_based: - self.kl_based_partition( - partitioner_config.transfer_rate_bytes_per_sec, - partitioner_config.node_to_latency_mapping, - ) - else: - self.size_based_partition() - - # Saturate host if possible. - if partitioner_config.saturate_host: - self.saturate_host() - - # Partition the graph module based on the partition assignment. - module_with_submodules = self.do_partition() - - # The DAG contains DAGNodes with info of each partition's input nodes, output nodes - # and how partitions are connected. - dag = self.dump_dag(module_with_submodules) - ret = PartitionResult(dag, module_with_submodules) - return ret - - def find_single_partition( - self, total_size_of_graph, logical_device_id: int = 0 - ) -> None: - """Fit the whole fx module into one device""" - partition_0 = self.create_partition() - for node in self.graph_module.graph.nodes: - if node.op == "output": - # Skip the output node, but there can - # be nodes after the output in certain cases. - continue - partition_0.nodes.add(node) - partition_0.used_mem_bytes = total_size_of_graph - partition_0.logical_device_ids = [logical_device_id] - # Get the node to partition mapping - self.node_to_partition = get_node_to_partition_mapping(self.partitions) - return - - def size_based_partition(self) -> None: - """This method is to partition the fx module based on memory size. - It uses greedy approach. The result may not be the best. - The basic idea is: - Step 1: - Find a device which has enough memory to fit the current node, create a empty partition - with the size of that device. - Then keep adding the following nodes into the partition until the partition is full. - Step 2: - Repeat Step 1 until no device left - Step 3: - If some nodes are left, create a partition for each left node (single node partition). - and then try to map those partitions into logical devices with enough mem left. - """ - - def find_device_based_on_size(node) -> Device: - """Given a node, this function is to find a logical device - that could fit the node. - """ - mem_size_needed = get_extra_size_of(node, set()) - device = Device("", -1, -1) - for d in self.devices: - if ( - d not in occupied_devices - and d.available_mem_bytes >= mem_size_needed - ): - device = d - break - if device.available_mem_bytes < 0: - raise RuntimeError(str(node) + "is too large to fit any device") - occupied_devices.append(device) - return device - - # Track partition and its left mem size - partition_to_left_mem_bytes: Dict[Partition, int] = {} - # Track all the devices that have been used - occupied_devices: List[Device] = [] - partition = self.create_partition() - for node in self.graph_module.graph.nodes: - if node.op in {"call_module", "call_method", "call_function"}: - # Check if there are devices left - if len(self.partitions) <= len(self.devices): - total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) - # Check if the current partition is the very first partition - if partition.used_mem_bytes == 0: - # Find a device to fit the first node, return available mem size - device = find_device_based_on_size(node) - occupied_devices.append(device) - # Update partition and its left mem size - partition_to_left_mem_bytes[ - partition - ] = device.available_mem_bytes - # Update available mem for the current partition - partition.logical_device_ids.append(device.logical_id) - else: - # The current partition is not the first partition - # Check if the current node can fit into current partition - if ( - partition_to_left_mem_bytes[partition] - < total_size_of_input_nodes - ): - # Check if no device is left - if len(self.partitions) == len(self.devices): - # No device is left - # Put the previous partitions into a list (non_single_node_partitions) - non_single_node_partitions = self.partitions[:] - # Create the first single node partition for the current node - self.create_single_node_partition(node) - continue - # Some devices are still left - # Create a new partition with a mem size that is enough for the current node - device = find_device_based_on_size(node) - partition = self.create_partition() - total_size_of_input_nodes = get_extra_size_of( - node, partition.nodes - ) - partition_to_left_mem_bytes[ - partition - ] = device.available_mem_bytes - partition.logical_device_ids.append(device.logical_id) - partition.add_node(node) - partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes - # Create single node partitions if no device is left - else: - self.create_single_node_partition(node) - reorganize_partitions(self.partitions) - # Get the node to partition mapping - self.node_to_partition = get_node_to_partition_mapping(self.partitions) - # Mapping all partitions into device - found_partition_to_device_mapping = get_device_to_partitions_mapping( - self.partitions, self.devices - ) - if not found_partition_to_device_mapping: - raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") - return - - def saturate_host(self) -> None: - """Saturate host by assigning replicates to unused devices with enough memory. - It uses a greedy approach to find a next available set of devices to place all split - partitions: For each used device, it searches for an idle device with minimal memory - size that can hold all the partition located on that device; If the search is successful - for all used devices, it then assigns the new devices' logical ID to the corresponding - partition. - """ - ( - device_to_partitions, - device_to_left_mem_bytes, - no_device_partitions, - ) = get_device_partition_stats(self.partitions, self.devices) - - assert ( - len(no_device_partitions) == 0 - ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" - - # Devices that hold partitions - used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] - # Track replicates of the assigned devices - replicated_device_to_used_device: Dict[Device, Device] = {} - - while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( - self.devices - ): - # Success flag for this round - success = True - # Devices that have not been assigned - idle_devices = [ - d - for d in self.devices - if d not in used_devices and d not in replicated_device_to_used_device - ] - # Temporary mapping from replicated device to original device - temp_replicate_mapping = {} - - # Find a new device to replicate all partitions on an used device - for used_device in used_devices: - # Idle devices that have enough memory - available_devices = [ - d - for d in idle_devices - if d.available_mem_bytes - >= used_device.available_mem_bytes - - device_to_left_mem_bytes[used_device] - ] - if len(available_devices) == 0: - success = False - break - new_device = min(available_devices, key=lambda d: d.available_mem_bytes) - idle_devices.remove(new_device) - temp_replicate_mapping[new_device] = used_device - - if not success: - break - replicated_device_to_used_device.update(temp_replicate_mapping) - - # Update logical device IDs assigned to the partitions - for ( - replicate_device, - original_device, - ) in replicated_device_to_used_device.items(): - logical_id = replicate_device.logical_id - for partition in device_to_partitions[original_device]: - partition.logical_device_ids.append(logical_id) - for p in self.partitions: - print(p.logical_device_ids) - - def do_partition(self) -> GraphModule: - """Return a new fx module with submodule nodes (partitions).""" - module_with_submodules = split_module( - self.graph_module, - self.torch_module, - lambda node: self.node_to_partition[node], - ) - return module_with_submodules - - def dump_dag(self, module_with_submodules: GraphModule) -> DAG: - """Return the dag structure and the new fx module with submodules.""" - dag = DAG() - for node in module_with_submodules.graph.nodes: - if node.op == "output": - break - if node.op in {"placeholder", "get_attr"}: - continue - if node.target == operator.__getitem__: - continue - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # When a node has two or more output nodes, - # it outputs its result to 'getitem' nodes. - # Those 'getitem' nodes are the output node for this node. - # Otherwise, the output node is this node itself. - if len(node.users) > 1: - output_nodes = list(node.users) - else: - output_nodes = [node] - partition_id = int(node.name.rsplit("_", 1)[-1]) - device_ids = self.partitions[partition_id].logical_device_ids - size_bytes = self.partitions[partition_id].used_mem_bytes - dag.create_node( - node, list(input_nodes), output_nodes, device_ids, size_bytes - ) - return dag - - def create_partition(self) -> Partition: - """Create a partition and append it to self.partitions.""" - partition_id = len(self.partitions) - partition = Partition(partition_id) - self.partitions.append(partition) - return partition - - def create_single_node_partition(self, node): - """Create a partition for a single node""" - partition = self.create_partition() - partition.add_node(node) - return - - def sparse_nn_partition(self, available_mem_bytes: int) -> None: - """This method partition a sparse nn module. - It is size based partition but different from size_based_partition, - it only works when all the devices have same memory size (available_mem_bytes). - In the future, devices with different mem sizes will be supported like size_based_partition. - It first traverse all the nodes and do the partitions based on the same memory size. - If the current partition has no enough memory left for a new op node - (call_module, call_method, call_function), a new partition is created. - When crossing the boundary between non-embedding nodes and embedding nodes, - a new partition is created regardlessly. - For example, if the current node is a non-embedding node but the next node is an - embedding node, a new partition is created for the next node. - After the partition, the partitions are combined as much as possible. - The rule is that a non-embedding partition only - combines with another non-embedding one. - So as the embedding partitions. - """ - - def combine_partitions_based_on_size( - partitions: List[Partition], available_mem_bytes: int - ) -> None: - """Combining small partitions together to keep as less partitions as possible. - Here is an example of the algorithm to do this: - Assume some partitions, we first sort them based on partiiton used memory size. - [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] - The available memory is 10. - step 1: self.find_partition_to_combine_based_on_size() - First, mark bfs level for each partition - Second, look the smallest partition, partition_4: 10 - 1 = 9 - It means any partition has a used memory equal or less than 9 could combine this partition - We go from the largest and selection partition_0. - Check the bfs level for two partitions, if the level difference is less than 2, - it can be combined. - step 2: repeat step 1 until no partitions can be combined - """ - find_combination = True - while find_combination: - # Sort partitions based on memory size - sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) - # Mark bfs level - get_bfs_level_partition(self.partitions) - find_combination, partitions = find_partition_to_combine_based_on_size( - sorted_partitions, available_mem_bytes, partitions - ) - return - - def calculate_mem_bytes_needed(p1, p2): - """Given two partitions, calculate how many mem bytes - are needed if two partitions are combined - """ - nodes = p1.nodes.union(p2.nodes) - mem_bytes_needed = 0 - for node in nodes: - mem_bytes_needed += get_extra_size_of(node, nodes) - return mem_bytes_needed - - def find_partition_to_combine_based_on_size( - sorted_partitions: List[Partition], - available_mem_bytes: int, - partitions: List[Partition], - ) -> Tuple[bool, List[Partition]]: - """step 1 in combine_partition_based_on_size()""" - find_combination = False - smallest_partition = sorted_partitions.pop(0) - for p in sorted_partitions[::-1]: - if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: - # Calculate how many bytes needed if combined - mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) - if mem_bytes_needed <= available_mem_bytes: - combine_two_partitions(p, smallest_partition, self.partitions) - partitions.remove(smallest_partition) - partitions.remove(p) - partitions.append(self.partitions[-1]) - find_combination = True - break - return find_combination, partitions - - def reset_partition_in_sparse_nn(partition, new_partition=True): - """If crossing the boudary between non-embedding nodes and - embedding nodes, create a new partition - """ - if in_embedding_region: - embedding_partitions.append(partition) - else: - non_embedding_partitions.append(partition) - if new_partition: - partition = self.create_partition() - partition.left_mem_bytes = available_mem_bytes - return partition - return None - - def is_embedding_node(node: Node) -> bool: - """Check if a node is an embedding node""" - if node.op == "call_module": - submodule = self.graph_module - for atom in str(node.target).split("."): - if not hasattr(submodule, atom): - raise RuntimeError( - f"Module {submodule} has no attribute {atom}" - ) - submodule = getattr(submodule, atom) - if "Embedding" in str(submodule): - return True - return False - - # Track embedding partitons and non-embedding partitions separately - embedding_partitions: List[Partition] = [] - non_embedding_partitions: List[Partition] = [] - # A Flag to check the boundary - in_embedding_region: bool = False - partition = self.create_partition() - for node in self.graph_module.graph.nodes: - if node.op in {"call_module", "call_method", "call_function"}: - # Check if crossing the boundary between embedding nodes and non embedding nodes - if is_embedding_node(node) != in_embedding_region: - # Crossing the boundary - # Check if the current partition is an empty partition - if partition.used_mem_bytes != 0: - # The current partition isn't an empty partition. Create a new one. - partition = reset_partition_in_sparse_nn(partition) - in_embedding_region = not in_embedding_region - total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) - if ( - total_size_of_input_nodes + partition.used_mem_bytes - > available_mem_bytes - ): - partition = reset_partition_in_sparse_nn(partition) - total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) - if total_size_of_input_nodes > available_mem_bytes: - raise RuntimeError( - node.target + "is too large to fit into a device" - ) - partition.add_node(node) - reset_partition_in_sparse_nn(partition, new_partition=False) - # Set parents and children for partitions - set_parents_and_children(self.partitions) - # Combining non-embedding partitions - combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) - # Combining embedding partitions - combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) - total_size_of_non_embedding_partitions = 0 - for partition in non_embedding_partitions: - total_size_of_non_embedding_partitions += partition.used_mem_bytes - # Check if devices are enough for all partitions - if len(embedding_partitions) > len(self.devices): - msg = ( - "Need " - + str(len(embedding_partitions)) - + " devices, but only " - + str(len(self.devices)) - + " provided" - ) - raise RuntimeError(msg) - occupied_devices = [] - for i, partition in enumerate(embedding_partitions): - # Check if all non-embedding partitions can fit into embedding partition devices - if ( - total_size_of_non_embedding_partitions + partition.used_mem_bytes - > available_mem_bytes - ): - raise RuntimeError( - "partition_" - + str(partition.partition_id) - + "(embedding partition) and non embedding partitions can not fit into one device" - ) - else: - # Add logical device to the partition - partition.logical_device_ids = [self.devices[i].logical_id] - occupied_devices.append(self.devices[i].logical_id) - # Add logical devices to the non_embedding_partitions - for partition in non_embedding_partitions: - partition.logical_device_ids = occupied_devices - # Get the node to partition mapping - self.node_to_partition = get_node_to_partition_mapping(self.partitions) - return - - def cost_aware_partition( - self, - transfer_rate_bytes_per_sec: float, - node_to_latency_mapping: Dict[Node, NodeLatency], - ) -> None: - """This method is to partition the fx module based on the cost. - The cost is the total latency of running the whole fx module. - In partitioner_utils.py, the cost model is built. - The cost aware partition algorithm is: - #1. At every begining, each node is a partition. - Then we map all the partitions to the devices - and calculate the cost - #2. Then try to pre-combine any two of the partitions if the two - partitions can be combined. - (the bfs level is less than 2 or two partitions are connected and - can find partition to device mapping) - See if any partition pair could reduce the current cost. - Choose the pair that shows the minimum cost and then combine them - #3. Repeat #2 until the cost cannot be reduced. - """ - - def try_combining_partitions(p0_index, p1_index, partitions) -> float: - """Given two partitions and a list of partitions, combine these two partitions - and see what is the cost of the modified partition list - """ - p0 = partitions[p0_index] - p1 = partitions[p1_index] - """If two partitions' bfs level are less than 2 or two partitions are connected to each other, - then they can be combined - """ - if ( - (abs(p0.bfs_level - p1.bfs_level) <= 1) - or (p0 in p1.parents) - or p0 in (p1.children) - ): - combine_two_partitions(p0, p1, partitions) - # Check if a circular dependency exists after combining - if check_dependency(partitions[-1]): - return float("inf") - # Check if the modified partition list can be mapped to devices after combination - reset_partition_device(partitions) - found_deivce = get_device_to_partitions_mapping( - partitions, self.devices - ) - if not found_deivce: - return float("inf") - # Calculate the new cost - partition_to_latency_mapping = get_partition_to_latency_mapping( - partitions, node_to_latency_mapping - ) - cost = get_latency_of_partitioned_graph( - partitions, - partition_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - return cost - # If two partition can not be combined, the cost is inf - return float("inf") - - def search_combination( - transfer_rate_bytes_per_sec, node_to_latency_mapping - ) -> bool: - """Given transfer rate between partitions and each node's latency, - find two partitions to combine so the cost of the partitions can - be reduced. - The algorithm is : - 1. Go through all the partition pairs and see - if any pair of partitions can be combined. - 2. Calculate the cost after the combination. - 3. Select the minimum cost and combine its cooresponding partition pair. - """ - partition_to_latency_mapping = get_partition_to_latency_mapping( - self.partitions, node_to_latency_mapping - ) - cost = get_latency_of_partitioned_graph( - self.partitions, - partition_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - if len(self.partitions) == 1: - return False - partition_pair: List[int] = [] - for i in range(len(self.partitions) - 1): - for j in range(i + 1, len(self.partitions)): - # Try to combine the partition pair - # and see the new cost after combination - new_cost = try_combining_partitions(i, j, self.partitions[:]) - if new_cost <= cost: - partition_pair = [i, j] - cost = new_cost - reorganize_partitions(self.partitions) - # If a partition pair is found, combine them - if len(partition_pair) != 0: - p0 = self.partitions[partition_pair[0]] - p1 = self.partitions[partition_pair[1]] - combine_two_partitions(p0, p1, self.partitions) - get_bfs_level_partition(self.partitions) - reset_partition_device(self.partitions) - get_device_to_partitions_mapping(self.partitions, self.devices) - return len(partition_pair) != 0 - - for node in self.graph_module.graph.nodes: - if node.op not in {"placeholder", "get_attr", "output"}: - self.create_single_node_partition(node) - # Set up parent partitions and children partitions for each partition - set_parents_and_children(self.partitions) - # Get bfs level for each partition - get_bfs_level_partition(self.partitions) - find_combination = True - while find_combination: - # Search for a pair partition to generate the minimum new cost, - # then combine them - find_combination = search_combination( - transfer_rate_bytes_per_sec, node_to_latency_mapping - ) - # Make sure all partitions are set up correctly - reorganize_partitions(self.partitions) - # Set up node to partition mapping - self.node_to_partition = get_node_to_partition_mapping(self.partitions) - return - - def kl_based_partition( - self, - transfer_rate_bytes_per_sec: float, - node_to_latency_mapping: Dict[Node, NodeLatency], - ) -> None: - """This function is a cost aware partition based - on Kernighan-Lin algorithm. - First, the graph is partitioned using size_based_partition. - Then, each node is swapped with any other node in a different - partition, and at the same time, the cost is estimated after - the swapping. - For example, we have nodes n0, n1, n2, n3 and n4. - Using size_based_partition, n0 and n1 are in Partition p0. - n2, n3 and n4 in Partition p1. The current cost is esimated. - We first tried using n0 to swap with n2 from the other partiton. - Then we see that swapping n0 and n2 shows a lower cost - than the current cost and it is the minimum among other pairs like - (n0, None)(This means moving n0 to Partition without swapping other nodes), - (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost - as the current cost. - Then We repeat this process for all the other nodes until all swapping pairs - are tried. - """ - - def swap_nodes(n0, n1, p0, p1): - # Either n0 or n1 could be None - # That means we simply move the node - # to another partition - if n0 is not None: - p0.remove_node(n0) - p1.add_node(n0) - if n1 is not None: - p0.add_node(n1) - p1.remove_node(n1) - - def try_swap_nodes( - n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec - ): - cost = float("inf") - swap_nodes(n0, n1, p0, p1) - # Reorganize partitions after swapping - reorganize_partitions(self.partitions) - # Check if there is a circular dependency after swapping - if (not check_dependency(p0)) and (not check_dependency(p1)): - reset_partition_device(self.partitions) - partition_to_latency_mapping = get_partition_to_latency_mapping( - self.partitions, node_to_latency_mapping - ) - # Check if all partitions can be mapped to logical devices after swapping - found_device = get_device_to_partitions_mapping( - self.partitions, self.devices - ) - if not found_device: - cost = float("inf") - else: - cost = get_latency_of_partitioned_graph( - self.partitions, - partition_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - # Swap back and reset all partitions back to original - swap_nodes(n1, n0, p0, p1) - reorganize_partitions(self.partitions) - reset_partition_device(self.partitions) - get_device_to_partitions_mapping(self.partitions, self.devices) - return cost - - def swap_node_to_partition( - node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec - ): - """This function helps to swap one node from partition p0 - with all the nodes in another partition p1 - """ - p1_nodes = list(p1.nodes) + [None] - min_cost = float("inf") - node_pair: List[Node] = [] - for n1 in p1_nodes: - # Ignore the node if it is not a op node - if n1 is not None and n1.op in {"placeholder", "get_attr"}: - continue - # Try swapping node in p0 with n1 in p1 - cost = try_swap_nodes( - node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec - ) - if cost < min_cost: - node_pair = [node, n1] - min_cost = cost - return cost, node_pair - - # First use size_base_partition - self.size_based_partition() - partition_to_latency_mapping = get_partition_to_latency_mapping( - self.partitions, node_to_latency_mapping - ) - # Calculate the cost of the partitions - cost = get_latency_of_partitioned_graph( - self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec - ) - # Keep tracking the node pair that shows the better cost - node_pair: List[Node] = [] - # Keep tracking the partition pair of node pair - partition_pair: List[Partition] = [] - # Collect all the op nodes from the graph - op_nodes = [] - for n in self.graph_module.graph.nodes: - if n.op not in {"placeholder", "get_attr", "output"}: - op_nodes.append(n) - for node in op_nodes: - # Find which partition the current node belongs - p0_index = self.node_to_partition[node] - p0 = self.partitions[p0_index] - # Go through all the other partitions to swap - # with other nodes from those partitions - for p1_index, _ in enumerate(self.partitions): - if p0_index != p1_index: - p1 = self.partitions[p1_index] - new_cost, new_node_pair = swap_node_to_partition( - node, - p0, - p1, - node_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - # Update the cost - # Track the swapped node pair and their partitions - if new_cost < cost: - cost = new_cost - node_pair = new_node_pair - partition_pair = [p0, p1] - # Do the swapping after trying all the nodes from a partition - if len(node_pair) != 0: - swap_nodes( - node_pair[0], node_pair[1], partition_pair[0], partition_pair[1] - ) - reorganize_partitions(self.partitions) - get_device_to_partitions_mapping(self.partitions, self.devices) - reorganize_partitions(self.partitions) - # Mapping the device to the partition - get_device_to_partitions_mapping(self.partitions, self.devices) - return - - def aot_based_partition( - self, node_to_partition_mapping, partition_to_logical_device_mapping - ): - """This function helps to rebuild the partitions given the nodes and its - corresponding partition id - """ - partition_id_to_partition_mapping: Dict[int, Partition] = {} - self.node_to_partition = node_to_partition_mapping - for node in self.node_to_partition: - partition_id = self.node_to_partition[node] - # If the requested partition has not been created, create the partition - if partition_id not in partition_id_to_partition_mapping: - partition = Partition(partition_id) - self.partitions.append(partition) - partition_id_to_partition_mapping[partition_id] = partition - partition.logical_device_ids = partition_to_logical_device_mapping[ - partition_id - ] - else: - partition = partition_id_to_partition_mapping[ - self.node_to_partition[node] - ] - # Add the current node into the partition - partition.add_node(node) diff --git a/pippy/fx/experimental/const_fold.py b/pippy/fx/experimental/const_fold.py deleted file mode 100644 index f0cf5433e..000000000 --- a/pippy/fx/experimental/const_fold.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import re -from typing import Callable, Dict, Optional, Set, Union - -import torch -import pippy.fx -from pippy.fx.node import map_arg -from pippy.fx.passes.split_module import split_module - - -class FoldedGraphModule(pippy.fx.GraphModule): - """ - FoldedGraphModule is a GraphModule which also contains another - `const_subgraph_module` representing a subgraph which has all const attr - inputs and which can be run once before running the main standard - `graph`. The `const_output_names` are the ordered list names of attrs which - represent what each respective output from the const_subgraph should be set - on which attrs. - """ - - def __init__( - self, - root: torch.nn.Module, - graph: pippy.fx.Graph, - const_subgraph: Optional[pippy.fx.Graph] = None, - fx_const_folded_attrs_name: str = None, - device_for_folded_attrs: str = "cuda", - ): - # In init, we set graph's owning module to root which will make graph's - # owning module be None because graph already have a owning module. We - # need owning module to run DCE. To work around we set the number of - # graph's owners to 0. - graph._owners = 0 - super().__init__(root, graph) - self.const_subgraph_module = ( - None - if const_subgraph is None - else pippy.fx.GraphModule(root, const_subgraph) - ) - self.has_folding_been_run = False - self.fx_const_folded_attrs_name = fx_const_folded_attrs_name - self.device_for_folded_attrs = device_for_folded_attrs - - def __call__(self, *args, **kwargs): - if not self.has_folding_been_run: - self.run_folding() - return super().__call__(*args) - - def run_folding(self): - # If there's no const subgraph module or attr output names to use, return - # early as there is no const folding to perform. - if ( - self.const_subgraph_module is None - or self.fx_const_folded_attrs_name is None - ): - return - - assert not self.has_folding_been_run - self.has_folding_been_run = True - - # Actually run const folding subgraph. Note that single attr const fold - # subgraphs output a single Tensor while multiple outputs are returned as - # Tuple[Tensor,]. - folded_attrs = self.const_subgraph_module() - - def _create_param(i): - return torch.nn.Parameter( - i - if not isinstance(i, int) - else torch.Tensor([i]).to(device=self.device_for_folded_attrs), - requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, - ) - - params = ( - torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) - if isinstance(folded_attrs, tuple) - else _create_param(folded_attrs) - ) - setattr(self, self.fx_const_folded_attrs_name, params) - - -def _inline_module(gm: pippy.fx.GraphModule, inline_mod_name: str): - """ - Given `gm` and some graph module which is called with target name `inline_mod_name`, - this helper will inline all of the nodes from that called graph module into `gm`. - """ - # Fetch the inner graph module that we want to inline inside `gm`. - inline_mod = dict(gm.named_modules())[inline_mod_name] - assert isinstance(inline_mod, pippy.fx.GraphModule) - call_mod_node_to_replace = None - for node in gm.graph.nodes: - if node.op == "call_module" and node.target == inline_mod_name: - call_mod_node_to_replace = node - break - assert call_mod_node_to_replace is not None - - # Now actually do the swap. Note that we have to keep track of new nodes that are - # copied into `gm` -- we do this via replacement_mapping. - call_mod_args = call_mod_node_to_replace.args - replacement_mapping: Dict[pippy.fx.Node, pippy.fx.Node] = {} - ph_count = 0 - - def replacement_fn(node): - new_node = replacement_mapping[node] - new_node.meta = node.meta.copy() - return new_node - - for inline_node in inline_mod.graph.nodes: - if inline_node.op == "placeholder": - replacement_mapping[inline_node] = call_mod_args[ph_count] - ph_count += 1 - continue - - if inline_node.op == "output": - outputs = inline_node.args[0] - output_replacements = map_arg(outputs, replacement_fn) - call_mod_node_to_replace.replace_all_uses_with(output_replacements) - continue - - with gm.graph.inserting_before(call_mod_node_to_replace): - new_node = gm.graph.node_copy(inline_node, replacement_fn) - replacement_mapping[inline_node] = new_node - - gm.graph.eliminate_dead_code() - - -def get_unique_attr_name_in_module(mod_traced: pippy.fx.GraphModule, name: str) -> str: - """ - Make sure the name is unique (in a module) and can represents an attr. - """ - # Delete all characters that are illegal in a Python identifier. - name = re.sub("[^0-9a-zA-Z_]+", "_", name) - if name[0].isdigit(): - name = f"_{name}" - # Now make sure it is in fact unique to the module by incrementing suffix value. - while hasattr(mod_traced, name): - match = re.match(r"(.*)_(\d+)$", name) - if match is None: - name = name + "_1" - else: - base, num = match.group(1, 2) - name = f"{base}_{int(num) + 1}" - - return name - - -def split_const_subgraphs( - module: Union[torch.nn.Module, pippy.fx.GraphModule], - skip_folding_node_fn: Optional[Callable[[pippy.fx.Node], bool]] = None, - device_for_folded_attrs: str = "cpu", -) -> FoldedGraphModule: - """ - Looks through `module` for any nodes that have all constant attribute inputs - and separates them out into their own constant subgraph, and returns a - FoldedGraphModule which runs that constant subgraph on the first run to set - attributes on the module prior to running the non-constant portion of the - graph. - """ - if not isinstance(module, pippy.fx.GraphModule): - mod_traced = pippy.fx.symbolic_trace(module) - else: - mod_traced = module - - # Build up a list of const_nodes, defined as nodes that are themselves - # get_attrs, or have all get_attr or other constant node inputs. - const_nodes: Set[pippy.fx.Node] = set() - found_const_folding = False - for node in mod_traced.graph.nodes: - # Skip over placeholders/outputs because they can't be const folded and - # we don't want to add tags to them. - if node.op in {"placeholder", "output"}: - continue - - # If the node itself is constant, or all of its inputs are constant, - # then tag it as constant. - if node.op != "get_attr" and not set(node.all_input_nodes).issubset( - const_nodes - ): - continue - - # If provided skip folding function says to skip, then skip. - if skip_folding_node_fn and skip_folding_node_fn(node): - continue - - # Skip folding side-effectful functions - if node.is_impure(): - continue - - # Must be a constant foldable node at this point. - const_nodes.add(node) - if node.op != "get_attr": - found_const_folding = True - - # If we did not find any const folding then return early without a const fold subgraph. - if not found_const_folding: - return FoldedGraphModule(mod_traced, mod_traced.graph) - - # Partition the module into two: submod_0 for constant folding subgraph, and - # submod_1 for the rest. - def mod_partition(node: pippy.fx.Node): - return 0 if node in const_nodes else 1 - - split = split_module(mod_traced, module, mod_partition) - - const_gm, non_const_gm = split.submod_0, split.submod_1 - const_mod_name, non_const_mod_name = "submod_0", "submod_1" - - # The module that a call_module node refers to gets copied to submodules during split. - # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to - # attach inlined modules to `split` as it's the owning module now. - for node in non_const_gm.graph.nodes: - if node.op == "call_module": - setattr(split, node.target, getattr(non_const_gm, node.target)) - for node in const_gm.graph.nodes: - if node.op == "call_module": - setattr(split, node.target, getattr(const_gm, node.target)) - - # split_module currently does not use get_attrs for attrs. Instead it passes - # them in as args from the parent module, which used get_attrs. Here we set - # them as get_attrs inside const_gm, allowing for running folding without - # somehow a priori knowing the attrs that should be passed as args. We can - # unconditionally do this for all placeholders because we know all - # placeholders to const_gm must be constants accessible via get_attr. - call_const_gm_args = None - for node in split.graph.nodes: - if node.op == "call_module": - if node.target == const_mod_name: - call_const_gm_args = node.args - break - assert call_const_gm_args is not None - - # Here we do the actual replacement of placeholders to get_attrs. Note that here we - # set the const_gm.graph into a new root_const_gm with split as the root module, - # because we are fetching attributes directly from the root module, instead of - # fetching them from const_gm. Example: The const_gm must have some format like: - # graph(): - # %inp : [#users=1] = placeholder[target=const_inp] - # %add : [#users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) - # return add - # We replace that with the following, which does not have any placeholders: - # graph(): - # %inp_1 : [#users=1] = get_attr[target=const_inp] - # %add : [#users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) - # return add - root_const_gm = pippy.fx.GraphModule(split, const_gm.graph) - for node in root_const_gm.graph.nodes: - if node.op == "output": - multiple_outputs = isinstance(node.args[0], tuple) - continue - if node.op != "placeholder": - continue - in_node = next(n for n in call_const_gm_args if n.name == node.target) - assert in_node.op == "get_attr" - with root_const_gm.graph.inserting_before(node): - new_node = root_const_gm.graph.get_attr(in_node.target) - new_node.meta = node.meta.copy() - node.replace_all_uses_with(new_node) - root_const_gm.graph.erase_node(node) - assert "multiple_outputs" in locals() - - # Now find the call to const_gm inside split, and replace it with a getattr to the - # folded tensor(s) that result from constant folding. Note that we don't need to - # worry about whether this is one or more tensors because the original graph - # correctly uses getitem to extract individual tensors if there are multiple folded. - fx_const_folded_attrs_name = get_unique_attr_name_in_module( - split, "_FX_CONST_FOLDED_ATTRS" - ) - setattr( - split, - fx_const_folded_attrs_name, - torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), - ) - for node in split.graph.nodes: - if node.op == "call_module" and node.target == const_mod_name: - with node.graph.inserting_before(node): - folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) - folded_attrs.meta = node.meta.copy() - node.replace_all_uses_with(folded_attrs) - break - - split.graph.eliminate_dead_code() - - # Finally, inline the non-constant submod into the split submod. This is so that the - # original caller who may have passed in a graph module will get back out a graph - # module whose graph is traced to the same granularity. - _inline_module(split, non_const_mod_name) - - return FoldedGraphModule( - split, - split.graph, - root_const_gm.graph, - fx_const_folded_attrs_name, - device_for_folded_attrs, - ) diff --git a/pippy/fx/experimental/debug.py b/pippy/fx/experimental/debug.py deleted file mode 100644 index 916c605f8..000000000 --- a/pippy/fx/experimental/debug.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import pippy.fx as fx - -def set_trace(gm: fx.GraphModule) -> fx.GraphModule: - """ - Sets a breakpoint in `gm`'s generated python code. It drops into pdb when - `gm` gets run. - - Args: - gm: graph module to insert breakpoint. It is then recompiled for it to - take effect. - - Returns: - the `gm` with breakpoint inserted. - """ - def insert_pdb(body): - return ["import pdb; pdb.set_trace()\n", *body] - - with gm.graph.on_generate_code( - make_transformer=lambda cur_transform: ( - # new code transformer to register - lambda body: ( - insert_pdb( - cur_transform(body) if cur_transform - else body - ) - ) - ) - ): - gm.recompile() - - return gm diff --git a/pippy/fx/experimental/graph_gradual_typechecker.py b/pippy/fx/experimental/graph_gradual_typechecker.py deleted file mode 100644 index a3fe8cf37..000000000 --- a/pippy/fx/experimental/graph_gradual_typechecker.py +++ /dev/null @@ -1,927 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import itertools -import operator -from functools import reduce -from typing import Callable, Dict - -import torch -from torch.nn.modules.batchnorm import BatchNorm2d -from torch.nn.modules.conv import Conv2d - -import pippy -from pippy.fx.experimental.refinement_types import Equality -from pippy.fx.experimental.unification import Var # type: ignore[attr-defined] -from pippy.fx.node import Target, Node -from pippy.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise - -try: - import sympy # type: ignore[import] - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False - -_INFERENCE_RULES: Dict[Target, Callable] = {} -_REFINEMENT_RULES: Dict[Target, Callable] = {} -_RULES: Dict[Target, Callable] = {} - - -def expand_to_tensor_dim(t, n): - """ - Expand a type to the desired tensor dimension if possible - Raise an error otherwise. - - t is the given type - - n is a number of dimensions to expand to - """ - if t == Dyn: - dims = [Dyn] * n - return TensorType(tuple(dims)) - elif isinstance(t, TensorType): - if len(t.__args__) != n: - raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}') - return t - else: - raise TypeError(f'Cannot match the type {t}') - - -def broadcast_types(t1, t2): - """ - Applies broadcasting to both given types such that they - become consistent with eachother and returns two new - resulting types - """ - - # if either type is Dyn, do nothing since the types are already consistent - if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): - return t1, t2 - - if isinstance(t1, TensorType) and isinstance(t2, TensorType): - s1 = len(t1.__args__) - s2 = len(t2.__args__) - - new_t1 = list(t1.__args__) - new_t2 = list(t2.__args__) - - # We make the types the same length which is the first requirement - # for consistency - if s1 > s2: - for i in range(s1 - s2): - new_t2.insert(0, 1) - - elif s2 > s1: - for i in range(s2 - s1): - new_t1.insert(0, 1) - - # we replace occurrences of "1" with each tensor with - # the corresponding type from the other tensor - for i, (x, y) in enumerate(zip(new_t1, new_t2)): - if x == 1: - new_t1[i] = y - elif y == 1: - new_t2[i] = x - - # at this point our tensors should be consistent - # and we can apply the element-wise operation and find the right dimension - # for the output of the operation - (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) - return (t1, t2) - else: - raise TypeError(f'Cannot broadcast types {t1} and {t2}') - -def register_inference_rule(call_target): - def register(fn): - if call_target in _INFERENCE_RULES: - raise RuntimeError(f'Inference rule already registered for {call_target}!') - _INFERENCE_RULES[call_target] = fn - return fn - return register - -def register_refinement_rule(call_target): - def register(fn): - if call_target in _REFINEMENT_RULES: - raise RuntimeError(f'Refinement rule already registered for {call_target}!') - _REFINEMENT_RULES[call_target] = fn - return fn - return register - -def register_algebraic_expressions_inference_rule(call_target): - def register(fn): - if call_target in _RULES: - raise RuntimeError(f'Rule already registered for {call_target}!') - _RULES[call_target] = fn - return fn - return register - -@register_inference_rule(torch.add) -@register_inference_rule(operator.add) -def add_inference_rule(n: Node): - """ - Apply the addition inference rule. This includes: - - scalar addition - - broadcasting semantics - - Note that we always return the least precise type between - the operands (after applying broadcasting) to be the final type of the operation - - Note that we do not modify the operand types themselves after applying broadcasting - to them. We only use them to calculate the final type - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], Node) - t1 = n.args[0].type - t2 = n.args[1].type - - # handle scalar addition - if t1 == int and isinstance(t2, TensorType): - n.type = t2 - return n.type - - # handle scalar addition - elif t2 == int and isinstance(t1, TensorType): - n.type = t1 - return n.type - - # we bring the new types to the point where - # we can check for consistency - # any inconsistency would not have been caused - # by broadcasting at this point - (new_t1, new_t2) = broadcast_types(t1, t2) - - if new_t1 != t1 or new_t2 != t2: - n.meta['broadcast'] = True - n.meta[str(n.args[0])] = new_t1 - n.meta[str(n.args[1])] = new_t2 - - else: - n.meta['broadcast'] = False - - new_t1 = t1 if not n.meta['broadcast'] else new_t1 - new_t2 = t2 if not n.meta['broadcast'] else new_t2 - - # we check for consistency between the new types - if is_consistent(new_t1, new_t2): - # we return the less precise type because - # broadcasting may have happened - # for operands with shape [1,2,Dyn] and [1,2,1] - # we have to assign the node [1,2,Dyn] - if is_more_precise(new_t1, new_t2): - n.type = new_t2 - else: - n.type = new_t1 - return n.type - else: - raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' - f' Types should match ') - -@register_inference_rule(getattr) -def get_attr_inference_rule(n: Node, traced): - """ - The current getattr rule only handles the shape attribute - Can be extended to other attributes - The most representitive type we have is "Dyn" but the system - can be extended with more types, such as a type to represent shapes - """ - attr_node = n.args[0] - attr_name = n.args[1] - - if attr_name == "shape": - n.type = Dyn - else: - raise TypeError("Not yet implelemted") - - # TODO. We leave it like this till we add a type to represent tensor sizes - return n.type - -@register_inference_rule(torch.transpose) -def transpose_inference_rule(n: Node): - """ - We check that dimentions for the transpose operations - are within range of the tensor type of the node - """ - if n.target == torch.transpose: - assert isinstance(n.args[0], Node) - t = n.args[0].type - - assert isinstance(n.args[1], int) - assert isinstance(n.args[2], int) - dim1, dim2 = n.args[1], n.args[2] - - if t == Dyn: - n.type = Dyn - return n.type - - elif isinstance(t, TensorType): - if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): - new_type = list(t.__args__) - new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] - final = TensorType(new_type) - n.type = get_greatest_upper_bound(n.type, final) - return n.type - else: - raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') - else: - raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') - - -@register_inference_rule(torch.reshape) -def reshape_inference_rule(n: Node): - """ - Without dynamism, the rule checks that the - product of the elements of the argument tensor - type is equal to the product of the elements - of the required shape. We gradualize this rule - by adding a case to handle fully dynamic input - as well as input where some of the tensor dimensions - are unknown. In this case we check for divisibility - """ - assert isinstance(n.args[0], Node) - t1 = n.args[0].type - - assert isinstance(n.args[1], list) - t2 = n.args[1] - t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) - - # if we do not know the original tensor dimension, - # we return the required dimension - if t1 == Dyn: - n.type = t2_type - return t2_type - - # if any of the dimensions are unknown, - # we check for divisibility - elif isinstance(t1, TensorType): - assert isinstance(t1, TensorType) - a = [e if e != Dyn else 1 for e in t1.__args__] - p1 = reduce(lambda x, y: x * y, a) - p2 = reduce(lambda x, y: x * y, t2) - if p1 % p2 == 0 or p2 % p1 == 0: - n.type = t2_type - return t2_type - else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') - else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') - -@register_inference_rule(BatchNorm2d) -def bn2d_inference_rule(n: Node, module_instance): - """ - Given a BatchNorm2D instance and a node check the following conditions: - - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) - - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - - t is consistent with t' - - x_2 is consistent with the module's num_features - - x_2' is consistent with the module's num_features - output type: the more precise type of t and t' - """ - assert isinstance(n.args[0], Node) - n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) - arg_type = n.args[0].type - n.type = expand_to_tensor_dim(n.type, 4) - - # we check the conditions on the incoming argument - # and any existing annotation - # we also check for consistency between both annotations - if is_consistent(arg_type.__args__[1], module_instance.num_features) and \ - is_consistent(n.type.__args__[1], module_instance.num_features) and \ - is_consistent(arg_type, n.type): - - # we choose the more precise type - # to be the node type - # so if an incoming argument has more type information - # we set this node's type to be the argument type - n.type = get_greatest_upper_bound(arg_type, n.type) - return n.type - else: - raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}') - - -def calculate_out_dimension(d_in, module_instance, index): - """ - For calculating h_in and w_out according to the conv2D documentation - """ - padding = (module_instance.padding, module_instance.padding) \ - if isinstance(module_instance.padding, int) else module_instance.padding - kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \ - if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size - stride = (module_instance.stride, module_instance.stride) \ - if isinstance(module_instance.stride, int) else module_instance.stride - dilation = (module_instance.dilation, module_instance.dilation) \ - if isinstance(module_instance.dilation, int) else module_instance.dilation - - DIMENSION_TYPES = (int, sympy.Symbol) if HAS_SYMPY else (int,) - - if d_in == Dyn: - return Dyn - - elif isinstance(d_in, DIMENSION_TYPES): - n = d_in + 2 * padding[index] - \ - dilation[index] * \ - (kernel_size[index] - 1) - 1 - - return (n // stride[0]) + 1 - - else: - raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}') - - -def get_greatest_upper_bound(type1, type2): - """ - Get the most precise type that's consistent with the given types - """ - if type1 == Dyn: - return type2 - elif type2 == Dyn: - return type1 - elif isinstance(type1, TensorType) and isinstance(type2, TensorType): - if not is_consistent(type1, type2): - raise TypeError(f'Inconsistent types {type1}, {type2}') - gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)] - return TensorType(tuple(gub)) - - -@register_inference_rule(Conv2d) -def conv2d_inference_rule(n: Node, module_instance): - """ - Given a Conv2D instance and a node check the following conditions: - - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) - - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - - x_2 is consistent with the module's in_channels - - let o = (x_1, out_channels, H_out, W_out) - then the output is the greatest upper bound of o and the existing node type t'. - """ - assert isinstance(n.args[0], Node) - n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) - arg_type = n.args[0].type - curr_node_type = expand_to_tensor_dim(n.type, 4) - - if is_consistent(arg_type.__args__[1], module_instance.in_channels): - w_in = arg_type.__args__[3] - h_in = arg_type.__args__[2] - h_out = calculate_out_dimension(h_in, module_instance, 0) - w_out = calculate_out_dimension(w_in, module_instance, 1) - new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) - gub = get_greatest_upper_bound(new_type, curr_node_type) - n.type = gub - return n.type - else: - raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}') - - -@register_inference_rule(torch.nn.ReLU) -def relu_inference_rule(n: Node, module_instance): - """ - Input and output shapes should be equal. - """ - assert isinstance(n.args[0], Node) - - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - - if isinstance(n.args[0].type, TensorType): - n.type = get_greatest_upper_bound(n.args[0].type, n.type) - return n.type - - -def maxpool2d_check(typ, module_instance): - """ - Applies the maxpool2d shape information to the input - this affects the last two dimensions - """ - new_type_list = list(typ.__args__) - if len(new_type_list) == 4 or len(new_type_list) == 3: - w_in = new_type_list[-1] - h_in = new_type_list[-2] - - h_out = calculate_out_dimension(h_in, module_instance, 0) - w_out = calculate_out_dimension(w_in, module_instance, 1) - - new_type_list[-1] = w_out - new_type_list[-2] = h_out - return TensorType(tuple(new_type_list)) - - else: - raise TypeError(f'Wrong size {typ} for {module_instance}') - - -@register_inference_rule(torch.nn.MaxPool2d) -def maxpool2d_inference_rule(n: Node, module_instance): - """ - Given a MaxPool2D instance and a node check the following conditions: - - Input size matches size 3 or 4 - - Current node type is consistent with the output type we will calculate - - Input size matches output size and the last two dimensions of the output - are w_out and h_out. The remaining dimensions are the same as the input - - Our final result is the greatest upper bound of the output we calculate - and the current node type. - """ - assert isinstance(n.args[0], Node) - - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - if isinstance(n.args[0].type, TensorType): - output = maxpool2d_check(n.args[0].type, module_instance) - n.type = get_greatest_upper_bound(output, n.type) - return n.type - - - -def linear_check(tensor_type, module_instance): - """ - Checks that an input tensor type satisfies the conditions for linear operation - and returns the output type based on in and out features given by module_instance - """ - if len(tensor_type.__args__) >= 2: - if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): - new_type_args = list(tensor_type.__args__) - new_type_args[-1] = module_instance.out_features - return TensorType(tuple(new_type_args)) - else: - raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}') - else: - raise TypeError(f'Type {tensor_type} must have rank 2 or more.') - - -@register_inference_rule(torch.nn.Linear) -def linear_inference_rule(n: Node, module_instance): - """ - Applies the shape information to the input then gets the greatest upper bound - of the resulting type and the existing type - """ - assert isinstance(n.args[0], Node) - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - if isinstance(n.args[0].type, TensorType): - output_type = linear_check(n.args[0].type, module_instance) - n.type = get_greatest_upper_bound(output_type, n.type) - return n.type - - -def adaptiveavgpool2d_check(tensor_type, module_instance): - output_size = module_instance.output_size - if isinstance(output_size, int): - output_size = [output_size, output_size] - elif isinstance(output_size, tuple): - output_size = list(output_size) - if output_size[0] is None: - output_size[0] = output_size[1] - if output_size[1] is None: - output_size[1] = output_size[0] - - new_type_list = list(tensor_type.__args__) - - if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3: - new_type_list[-1] = output_size[1] - new_type_list[-2] = output_size[0] - - return TensorType(tuple(new_type_list)) - - else: - raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}') - -@register_inference_rule(torch.nn.AdaptiveAvgPool2d) -def adaptiveavgpool2d_inference_rule(n: Node, module_instance): - """ - The input and output sizes should be the same except for the last - two dimensions taken from the input, which represent width and height - """ - assert isinstance(n.args[0], Node) - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - if isinstance(n.args[0].type, TensorType): - output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance) - n.type = get_greatest_upper_bound(n.type, output_type) - return n.type - -def flatten_check(tensor_type, start_dim, end_dim): - l = len(tensor_type.__args__) - - start_dim = l if start_dim == -1 else abs(start_dim) - end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 - - if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: - my_args = list(tensor_type.__args__) - lhs = my_args[0:start_dim] - rhs = my_args[end_dim:] - mid = my_args[start_dim:end_dim] - if Dyn in mid: - mid = [Dyn] - else: - mid = [reduce(lambda x, y: x * y, my_args[start_dim:end_dim])] - new_type_list = lhs + mid + rhs - return TensorType(tuple(new_type_list)) - else: - raise TypeError(f'Incompatable dimentions {start_dim}, {end_dim - 1} in type {tensor_type}') - -@register_inference_rule(torch.flatten) -def flatten_inference_rule(n: Node): - """ - Applies the flatten shape information to the input then gets the - greatest upper bound of the resulting type and the existing type - """ - assert isinstance(n.args[0], Node) - - # set the default start and end dims - start_dim = 1 - end_dim = -1 - - if len(n.args) > 1: - assert isinstance(n.args[1], int) - start_dim = n.args[1] - - if len(n.args) > 2: - assert isinstance(n.args[2], int) - end_dim = n.args[2] - - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - - if isinstance(n.args[0].type, TensorType): - output_type = flatten_check(n.args[0].type, start_dim, end_dim) - n.type = get_greatest_upper_bound(output_type , n.type) - - return n.type - -class GraphTypeChecker: - def __init__(self, env, traced): - self.env = env - self.traced = traced - - def type_check(self): - """ - A gradual type checker for graphs - Effect: every node's field type will be - populated with a type after type-checking is done - """ - graph = self.traced.graph - - # type check every node with gradual type rules - # if any node does not type check return false - for n in graph.nodes: - self.type_check_node(n) - return True - - def type_check_node(self, n: Node): - """ - Type check a given fx node. - Current operations: - - Reshape - - Transpose - - Add - - Relu - - conv2d - - batchnorm2d - - flatten - - maxpool2d - - adaptiveavgpool2d - - linear - """ - if n.type is None: - n.type = Dyn - - if n.op == 'placeholder': - return n.type - - elif n.op == 'get_attr': - t = get_parameter(self.traced, n.target) # type: ignore[arg-type] - if isinstance(t.data, torch.Tensor): - n.type = TensorType(t.data.shape) - return n.type - - elif n.op == 'call_function': - if n.target == getattr: - assert getattr in _INFERENCE_RULES - return _INFERENCE_RULES[n.target](n, self.traced) - - elif n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n) - else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') - - elif n.op == 'call_module': - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _INFERENCE_RULES: - return _INFERENCE_RULES[type(module_instance)](n, module_instance) - else: - raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') - - elif n.op == 'output': - def get_node_type(a): - return a.type - n.type = pippy.fx.node.map_arg(n.args[0], get_node_type) - return n.type - - else: - raise NotImplementedError(f"Method {n.op} not yet implemented") - - -@register_refinement_rule(Conv2d) -def conv_refinement_rule(n: Node): - """ - The equality constraints are between the first dimension of - the input and output - """ - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - res = [Equality(arg_type.__args__[0], n.type.__args__[0])] - return res - - -@register_refinement_rule(torch.nn.Linear) -def linear_refinement_rule(n: Node): - """ - The equality constraints are between the first dimension of - the input and output - """ - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - res = [Equality(arg_type.__args__[0], n.type.__args__[0])] - return res - -@register_refinement_rule(BatchNorm2d) -@register_refinement_rule(torch.nn.ReLU) -def all_eq(n: Node): - """ - For operations where the input shape is equal to the output shape - """ - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - args1 = arg_type.__args__ - args2 = n.type.__args__ - res = [Equality(args1[i], args2[i]) for i in range(len(args1))] - return res - - -@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) -@register_refinement_rule(torch.nn.MaxPool2d) -def first_two_eq(n: Node): - """ - For operations where the first two dimensions of the input and output shape - are equal - """ - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - args1 = arg_type.__args__ - args2 = n.type.__args__ - res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] - return res - - -@register_refinement_rule(torch.add) -@register_refinement_rule(operator.add) -def element_wise_eq(n: Node): - """ - For element-wise operations and handles broadcasting. - Note that after applying broadcasting to the arguments - we are able to determine if certain dimensions have not been broadcast - if they are symbolicallu equal. - - in this case, we can establish equality between those dimensions and the - corresponding output dimensions. - - Note that it takes two iterations for this result. One iteration to establish - equality between certain dimensions of the operands (requiring the whole solver - including unification) and another iteration to establish equality between the operands - and the resulting type, requiring another round of constraint generation and unificaiton. - """ - res = [] - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - arg_type1 = n.args[0].type - arg_type2 = n.args[1].type - if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): - args1, args2 = broadcast_types(arg_type1, arg_type2) - # by this point, we know that args1 and args2 are the same size. - a1 = args1.__args__ - a2 = args2.__args__ - a3 = n.type.__args__ - - # we would be here in the second iteration where we establish equality - # between operand type dimensions and the resulting type dimensions - r = [] - for x, y, z in zip(a1, a2, a3): - if x == y: - r.append(Equality(x, z)) - res = r - return res - - -@register_refinement_rule(torch.flatten) -def flatten_refinement_rule(n: Node): - """ - Generates equality constraints between the dimensions of the input and output - that will not be involved in the flatten operation - """ - assert isinstance(n.args[0], Node) - - eq_const = [] - - start_dim = 1 - end_dim = -1 - - if len(n.args) > 1: - assert isinstance(n.args[1], int) - start_dim = n.args[1] - - if len(n.args) > 2: - assert isinstance(n.args[2], int) - end_dim = n.args[2] - - if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): - l = len(n.type.__args__) - arg_type = n.args[0].type - start_dim = l if start_dim == -1 else start_dim - end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 - - for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): - eq_const.append(Equality(t1, t2)) - - for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): - eq_const.append(Equality(t1, t2)) - return eq_const - - -@register_algebraic_expressions_inference_rule(Conv2d) -def conv_rule(n: Node, module_instance): - """ - Represents the outout in terms of an algrbraic expression w.r.t - the input when possible - """ - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - w_in = arg_type.__args__[3] - h_in = arg_type.__args__[2] - h_out = calculate_out_dimension(h_in, module_instance, 0) - w_out = calculate_out_dimension(w_in, module_instance, 1) - new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) - n.type = new_type - return new_type - -class Refine: - """ - Symbolic shape inference. - Generates constraints over type variables. - Currently all constraints are equality constraints. - """ - def __init__(self, traced): - self.constraints = [] - self.traced = traced - self.symbol_iter = itertools.count(start=0, step=1) - - def refine(self): - """ - Generates constraints for - every node in the graph based on - the operation. - """ - graph = self.traced.graph - for n in graph.nodes: - self.refine_node(n) - return True - - def symbolic_relations(self): - """ - Infers algebraic relations - """ - graph = self.traced.graph - for n in graph.nodes: - self.infer_symbolic_relations(n) - return True - - def replace_dyn_with_fresh_var(self, typ): - """ - Replace all unknown types with fresh type variables. - """ - if typ == Dyn: - new_symbol = Var(next(self.symbol_iter)) - return new_symbol - elif isinstance(typ, TensorType): - new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] - return TensorType(tuple(new_args)) - elif isinstance(typ, list): - return [self.replace_dyn_with_fresh_var(t) for t in typ] - elif isinstance(typ, tuple): - return (self.replace_dyn_with_fresh_var(t) for t in typ) - else: - return typ - - - def convert_to_sympy_symbols(self, typ): - """ - Replace all unknown types with fresh type variables. - """ - if HAS_SYMPY: - if isinstance(typ, Var): - return sympy.symbols(str(typ)) - elif isinstance(typ, TensorType): - new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] - return TensorType(tuple(new_args)) - elif isinstance(typ, list): - return [self.convert_to_sympy_symbols(t) for t in typ] - elif isinstance(typ, tuple): - return (self.convert_to_sympy_symbols(t) for t in typ) - else: - return typ - else: - return typ - - def refine_node(self, n: Node): - """ - Returns a list of equality constraints for - call_module and call_function nodes. - Models the relation between input and output dimensions - using constraints in case they are both tensors. - All operations used in resnet50 are defined. - """ - if n.type is None: - n.type = Dyn - - n.type = self.replace_dyn_with_fresh_var(n.type) - - if n.op == 'call_function': - if n.target in _REFINEMENT_RULES: - self.constraints += _REFINEMENT_RULES[n.target](n) - else: - pass - - if n.op == 'call_module': - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _REFINEMENT_RULES: - self.constraints += _REFINEMENT_RULES[type(module_instance)](n) - else: - pass - - if n.op == 'output': - def get_node_type(a): - return a.type - n.type = pippy.fx.node.map_arg(n.args[0], get_node_type) - return n.type - - else: - pass - - def infer_symbolic_relations(self, n: Node): - if HAS_SYMPY: - n.type = self.convert_to_sympy_symbols(n.type) - if n.op == 'call_function': - if n.target in _RULES: - return _RULES[n.target](n) - else: - pass - - if n.op == 'call_module': - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _RULES: - return _RULES[type(module_instance)](n, module_instance) - else: - pass - - if n.op == 'output': - def get_node_type(a): - return a.type - n.type = pippy.fx.node.map_arg(n.args[0], get_node_type) - return n.type - - else: - pass - else: - pass - -def get_parameter(traced, target: str): - """ - Returns the parameter given by ``target`` if it exists, - otherwise throws an error. - - See the docstring for ``get_submodule`` for a more detailed - explanation of this method's functionality as well as how to - correctly specify ``target``. - - Args: - target: The fully-qualified string name of the Parameter - to look for. (See ``get_submodule`` for how to specify a - fully-qualified string.) - - Returns: - torch.nn.Parameter: The Parameter referenced by ``target`` - - Raises: - AttributeError: If the target string references an invalid - path or resolves to something that is not an - ``nn.Parameter`` - """ - module_path, _, param_name = target.rpartition(".") - - mod: torch.nn.Module = traced.get_submodule(module_path) - - if not hasattr(mod, param_name): - raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") - - param: torch.nn.Parameter = getattr(mod, param_name) - - return param diff --git a/pippy/fx/experimental/merge_matmul.py b/pippy/fx/experimental/merge_matmul.py deleted file mode 100644 index f53ea9c9f..000000000 --- a/pippy/fx/experimental/merge_matmul.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import itertools -import operator -from typing import Dict, List - -import torch - -import pippy -import pippy.fx -from pippy.fx._symbolic_trace import symbolic_trace -from pippy.fx.node import Node -from pippy.fx.passes.tools_common import legalize_graph - - -def split_result_tensors(result: torch.Tensor, inputs: List[torch.Tensor]) -> List[torch.Tensor]: - """ - A free function for use in the merge_matmul graph transformation below that - splits the output from a merged matmul into the individual results for each - input tensor. - - Arguments: - result: The merged matmul result tensor. - inputs: The list of inputs that were merged into one for the matmul. - - Returns: - List of matmul results for each input tensor. - """ - # When fx tracer is running, x.shape[0] will be pippy.fx.Attribute but we - # need an int even when tracing - if isinstance(result, pippy.fx.Proxy): - splits = [0] * len(inputs) - else: - splits = [x.shape[0] for x in inputs] - - return torch.split(result, splits) - - -def may_depend_on(a: Node, b: Node, search_depth: int = 6): - """ - Determine if one node depends on another in a pippy.fx.Graph. - - Arguments: - a: The node that may have a dependency on b. - b: The node that a may have a dependency on. - search_depth: In the case of an indirect dependency, this function - searches upto this many nodes away in search of a - data dependency. If none is found, the function - makes the conservative assumption that there is a - dependency. - - Returns: - True if a may depend on b, False if it definitely does not. - """ - # Equivalence is defined as dependence. - if a == b: - return True - - # If a has no inputs, it cannot depend on b. - if len(a.all_input_nodes) == 0: - return False - - # If the search depth has been exhausted and no conclusion has been - # reached, assume that there is a data dependency. - if search_depth == 0: - return True - - # Recursively check all inputs of a. - for inp in a.all_input_nodes: - if may_depend_on(inp, b, search_depth - 1): - return True - - return False - - -def are_nodes_independent(nodes: List[Node]): - """ - Check if all of the given nodes are pairwise-data independent. - - Arguments: - nodes: The nodes to check for data dependencies. - - Returns: - True if any pair in nodes has a data dependency. - """ - # For each pair in nodes: - for i, j in itertools.combinations(nodes, 2): - if may_depend_on(i, j) or may_depend_on(j, i): - return False - - return True - - -def merge_matmul(in_mod: torch.nn.Module): - """ - A graph transformation that merges matrix multiplication operations that share the same right-hand - side operand into one large matrix multiplication. - ____ _________ _________ - ---- | | | | M| A * C | - M| A | T| B | * K| C | = |---------| - ---- , | | | | T| B * C | - K ---- --------- --------- - K R R - """ - gm = symbolic_trace(in_mod) - - rhs_users: Dict[Node, List[Node]] = {} - lhs_users: Dict[Node, List[Node]] = {} - - # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to - # the matmul of which they are the LHS/RHS. - for node in gm.graph.nodes: - if node.op != "call_function" or node.target is not torch.matmul: - continue - - lhs, rhs = node.args - - # TODO: Properly handle aliasing caused by get_attr. For now, - # use the attribute name as the operand if the node is a - # get_attr. - lhs = lhs.target if lhs.op == "get_attr" else lhs - rhs = rhs.target if rhs.op == "get_attr" else rhs - - lhs_users.setdefault(lhs, []).append(node) - rhs_users.setdefault(rhs, []).append(node) - - for rhs, mms in rhs_users.items(): - # There must be at least matmuls for a merge to make sense. - if len(mms) < 2: - continue - - # All matmuls must not depend on each other directly or indirectly - # in order for the merge to be possible. - if not are_nodes_independent(mms): - continue - - lhs_vals = [mm.args[0] for mm in mms] - - # Merge the matmul. - # Collect a list of LHS operands and the single RHS operand. - lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] - rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs - - # Concatenate all the LHS operands. - merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) - - # Multiply the concatenated LHS operands with the one RHS. This will produce - # the same results as all the individual matmuls involving rhs in the original graph, - # but they will all be concatenated together. - merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) - - # Split the result of the merged matmul using the shapes of the LHS operands - # to ascertain how large each chunk should be. - merge_mm_split = gm.graph.call_function( - split_result_tensors, (merge_mm, lhs), {} - ) - merge_mm_res = [ - gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) - for out in range(len(lhs)) - ] - - # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. - for old, new in zip(mms, merge_mm_res): - old.replace_all_uses_with(new) - gm.graph.erase_node(old) - - # All of the new nodes created above were inserted at the end, so we need to sort - # the nodes topologically to make sure all definitions precede uses. - legalize_graph(gm) - - gm.recompile() - gm.graph.lint() - return gm diff --git a/pippy/fx/experimental/meta_tracer.py b/pippy/fx/experimental/meta_tracer.py deleted file mode 100644 index 3f5caa599..000000000 --- a/pippy/fx/experimental/meta_tracer.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import pippy.fx -import warnings -import functools -import builtins - -from typing import Any, Callable, Dict, Optional, Union - -def embedding_override(self, input): - return torch.empty(*input.shape, self.weight.shape[-1], device='meta') - - -def nn_layernorm_override(self, input): - return input - - -def torch_relu_override(x): - return x - - -def torch_nn_relu_override(self, x): - return x - - -def functional_relu_override(x, inplace=False): - assert not inplace, 'dont support inplace functional.relu for metatensor analysis' - return x - - -def torch_where_override(condition, x, y): - # torch.where returns the broadcasted tensor of condition, x, and y, - # so hack it by using addition - return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta') - - -def torch_abs_override(input, *, out=None): - assert out is None, 'Dont support in-place abs for MetaTensor analysis' - return input - -manual_meta_overrides : Dict[Callable, Callable] = { - torch.nn.Embedding: embedding_override, - torch.nn.LayerNorm: nn_layernorm_override, - torch.relu: torch_relu_override, - torch.nn.functional.relu: functional_relu_override, - torch.nn.ReLU: torch_nn_relu_override, - torch.where: torch_where_override, - torch.abs: torch_abs_override, -} - -def gen_constructor_wrapper(target): - @functools.wraps(target) - def wrapper(*args, **kwargs): - proxy = None - - def check_has_proxy(v): - if isinstance(v, pippy.fx.Proxy): - nonlocal proxy - proxy = v - pippy.fx.node.map_aggregate(args, check_has_proxy) - pippy.fx.node.map_aggregate(kwargs, check_has_proxy) - - if proxy is not None: - return proxy.tracer.create_proxy('call_function', target, args, kwargs) - else: - return target(*args, **kwargs) - return wrapper, target - -class MetaProxy(pippy.fx.Proxy): - def install_tensor_meta(self, tensor_meta): - self._tensor_meta = tensor_meta - - def size(self, dim=None): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: - return self._tensor_meta.size(*[dim] if dim else []) - return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {}) - - def dim(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: - return self._tensor_meta.dim() - return self.tracer.create_proxy('call_method', 'dim', (self,), {}) - - @property - def shape(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: - return self._tensor_meta.shape - return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {}) - - @property - def dtype(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: - return self._tensor_meta.dtype - return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {}) - - @property - def device(self): - # Hack so we can track when devices are used. During meta-tensor propagation, - # replace these values with a constant 'meta' - return MetaDeviceAttribute(self, 'device') - - def __getattr__(self, k): - if k == '_tensor_meta': - return self.__getattribute__(k) - # note: not added to the graph yet, if this is a method call - # we peephole optimize to the method invocation - return MetaAttribute(self, k) - -class MetaAttribute(MetaProxy): - def __init__(self, root, attr: str): - - self.root = root - self.attr = attr - self.tracer = root.tracer - self._node = None - - @property - def node(self): - # the node for attributes is added lazily, since most will just be method calls - # which do not rely on the getitem call - if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node - return self._node - - def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) - -class MetaDeviceAttribute(MetaAttribute): - pass - -def proxys_to_metas(v): - if isinstance(v, MetaDeviceAttribute): - return 'meta' - if isinstance(v, pippy.fx.Proxy): - assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}' - assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta' - return v._tensor_meta - return v - -class MetaTracer(pippy.fx.Tracer): - allow_insert_stateless_mods : bool = True - - _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye'] - - def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): - rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) - - if kind == 'placeholder' and target in self.meta_args: - rv.install_tensor_meta(self.meta_args[target]) - return rv - - if target in self.orig_fns: - # NOTE: tensor constructors in PyTorch define the `device` argument as - # *kwargs-only*. That is why this works. If you add methods to - # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, - # this will break and you will likely see issues where we cannot infer - # the size of the output. - if 'device' in kwargs: - kwargs['device'] = 'meta' - - try: - args_metas = pippy.fx.node.map_aggregate(args, proxys_to_metas) - kwargs_metas = pippy.fx.node.map_aggregate(kwargs, proxys_to_metas) - - if kind == 'call_function': - meta_target = manual_meta_overrides.get(target, target) - meta_out = meta_target(*args_metas, **kwargs_metas) - elif kind == 'call_method': - meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) - elif kind == 'call_module': - assert hasattr(self, 'orig_forward') - self._disable_module_getattr = True - try: - mod = self.root.get_submodule(target) - mod_type = type(mod) - if mod_type in manual_meta_overrides: - meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) - else: - meta_out = self.orig_forward(*args_metas, **kwargs_metas) - finally: - self._disable_module_getattr = False - elif kind == 'get_attr': - self._disable_module_getattr = True - try: - attr_itr = self.root - atoms = target.split('.') - for atom in atoms: - attr_itr = getattr(attr_itr, atom) - assert isinstance(attr_itr, torch.Tensor) - meta_out = attr_itr.to(device='meta') - finally: - self._disable_module_getattr = False - else: - return rv - - # TODO - assert isinstance(rv, pippy.fx.Proxy), 'Dont support composite output yet' - rv.install_tensor_meta(meta_out) - except Exception as e: - warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}') - - return rv - - def getattr(self, attr, attr_val, parameter_proxy_cache): - if getattr(self, '_disable_module_getattr', False): - return attr_val - else: - return super().getattr(attr, attr_val, parameter_proxy_cache) - - def call_module(self, m, forward, args, kwargs): - self.orig_forward = forward - return super().call_module(m, forward, args, kwargs) - - def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: - """ - Helper method which tries to insert a module that was not declared as submodule. - """ - idx = 0 - mod_name = mod.__class__.__name__.lower() - path = f"{mod_name}_{idx}" - while hasattr(self.root, path): - path = f"{mod_name}_{idx}" - idx += 1 - - self.root.add_module(path, mod) - return path - - def path_of_module(self, mod: torch.nn.Module) -> str: - try: - return super().path_of_module(mod) - except NameError as e: - if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0: - path = self._insert_module_as_submodule(mod) - self.prev_module = path - return path - raise - - def proxy(self, node): - return MetaProxy(node, self) - - def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): - assert isinstance(meta_args, dict) - self.meta_args = meta_args - - self.patched_torch_methods = { - target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH - } - self.orig_fns = set() - - for name, (wrapper, orig) in self.patched_torch_methods.items(): - setattr(torch, name, wrapper) - self.orig_fns.add(orig) - - try: - graph = super().trace(root, concrete_args) - graph._tracer_extras = {'meta_args': meta_args} - return graph - finally: - for name, (_, orig) in self.patched_torch_methods.items(): - setattr(torch, name, orig) - - -def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], - meta_args : Dict[str, torch.Tensor] = None, - concrete_args: Optional[Dict[str, Any]] = None) -> pippy.fx.GraphModule: - tracer = MetaTracer() - graph = tracer.trace(root, meta_args, concrete_args) - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - gm = pippy.fx.GraphModule(tracer.root, graph, name) - return gm diff --git a/pippy/fx/experimental/migrate_gradual_types/__init__.py b/pippy/fx/experimental/migrate_gradual_types/__init__.py deleted file mode 100644 index f2661b8c6..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates diff --git a/pippy/fx/experimental/migrate_gradual_types/constraint.py b/pippy/fx/experimental/migrate_gradual_types/constraint.py deleted file mode 100644 index 9188e8346..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/constraint.py +++ /dev/null @@ -1,559 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# -*- coding: utf-8 -*- -from pippy.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ - op_mod, op_gt, op_lt, op_neq, op_eq -from pippy.fx.tensor_type import TensorType, Dyn - - -class Constraint: - pass - - -class Conj(Constraint): - def __init__(self, conjuncts): - """ - :param conjuncts: Conjuction of constraints - """ - self.conjucts = conjuncts - - def __eq__(self, other): - if isinstance(other, Conj): - return self.conjucts == other.conjucts and self.conjucts == other.conjucts - else: - return False - - def __repr__(self): - return f'And({self.conjucts})' - - -class Disj(Constraint): - def __init__(self, disjuncts): - """ - :param disjuncts: Disjunction of constraints - """ - self.disjuncts = disjuncts - - def __eq__(self, other): - if isinstance(other, Disj): - return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts - else: - return False - - def __repr__(self): - return f'Or({self.disjuncts})' - - -class Prod(Constraint): - def __init__(self, products): - """ - :param products: lists of dimensions to multiply - """ - self.products = products - - def __eq__(self, other): - if isinstance(other, Prod): - return self.products == other.products and self.products == other.products - else: - return False - - def __repr__(self): - return f'Product({self.products})' - - -class T(Constraint): - """ - True - """ - def __init__(self): - pass - - def __eq__(self, other): - return isinstance(other, T) - - def __repr__(self): - return 'True' - -class F(Constraint): - """ - False - """ - def __init__(self): - pass - - def __eq__(self, other): - return isinstance(other, F) - - def __repr__(self): - return 'False' - - -class BinaryConstraint(Constraint): - """ - Represents all binary operations - """ - def __init__(self, lhs, rhs, op): - """ - :param lhs: lhs of the constraint - :param rhs: rhs of the constraint - :param op: string reprsenting the operation - """ - self.lhs = lhs - self.rhs = rhs - self.op = op - - def __eq__(self, other): - if isinstance(other, BinaryConstraint): - return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op - else: - return False - - def __repr__(self): - return f'({self.lhs} {self.op} {self.rhs})' - - -class BinConstraintT(BinaryConstraint): - """ - Binary constraints about tensors - """ - def __init__(self, lhs, rhs, op): - assert (isinstance(lhs, TVar) or isinstance(lhs, TensorType) or isinstance(lhs, int) or lhs == Dyn) and \ - (isinstance(rhs, TVar) or isinstance(rhs, TensorType) or isinstance(rhs, int) or rhs == Dyn) - super().__init__(lhs, rhs, op) - - def __eq__(self, other): - return super().__eq__(other) - - -class BinConstraintD(BinaryConstraint): - """ - Binary constraints about dimensions - """ - def __init__(self, lhs, rhs, op): - assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) - assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) - - super().__init__(lhs, rhs, op) - - def __eq__(self, other): - return super().__eq__(other) - - - -class TGreatestUpperBound(Constraint): - """ - Greatest Upper bound for tensors with dynamic type - """ - def __init__(self, res, rhs1, rhs2): - """ - :param res: tensor variable that stores the result of the outout - :param rhs1: tensor or tensor variable - :param rhs2: tensor or tensor variabke - """ - self.res = res - self.rhs1 = rhs1 - self.rhs2 = rhs2 - - def __repr__(self): - return f'{self.res} = {self.rhs1}⊔*{self.rhs2}' - - def __eq__(self, other): - if isinstance(other, TGreatestUpperBound): - return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 - else: - return False - - -class DGreatestUpperBound(Constraint): - """ - Greatest Upper bound for dimensions - """ - def __init__(self, res, rhs1, rhs2): - """ - :param res: Dimension variable to store the result - :param rhs1: dimension variable 1 - :param rhs2: dimension variable 2 - """ - assert is_dim(res) - assert is_dim(rhs1) - assert is_dim(rhs2) - - self.res = res - self.rhs1 = rhs1 - self.rhs2 = rhs2 - - def __repr__(self): - return f'{self.res} = {self.rhs1}⊔{self.rhs2}' - - def __eq__(self, other): - if isinstance(other, DGreatestUpperBound): - return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 - else: - return False - - -class CanReshape(Constraint): - """ - can_reshape constraint - """ - def __init__(self, src, target): - """ - :param src: tensor variable - :param target: tensor - """ - self.src = src - self.target = target - - def __repr__(self): - return f'can-reshape({self.src}, {self.target})' - - def __eq__(self, other): - if isinstance(other, CanReshape): - return self.src == other.src and self.target == other.target - else: - return False - - -class IndexSelect(Constraint): - - def __init__(self, tensor_size, input_var, dim_replace, index, output): - """ - Args: - input_var: input to index_select - tensor_size: tensor size we are considering - dim_replace: the dimension of the output at "index" - index: location of the dimensions to replace in the input - outut: variable to store the result - """ - assert isinstance(input_var, TVar) - assert isinstance(output, TVar) - assert isinstance(dim_replace, DVar) or dim_replace == Dyn - assert isinstance(index, int) - - self.input_var = input_var - self.tensor_size = tensor_size - self.dim_replace = dim_replace - self.index = index - self.output = output - - def __repr__(self): - - return f' {self.output} = ' \ - f'IndexSelect({self.input_var}, ' \ - f'tensor_size: {self.tensor_size}, ' \ - f'{self.dim_replace}, ' \ - f'{self.index})' - - def __eq__(self, other): - if isinstance(other, IndexSelect): - return self.tensor_size == other.tensor_size and \ - self.dim_replace == other.dim_replace and \ - self.index == other.index and \ - self.output == other.output and \ - self.input_var == other.input_var - else: - return False - - -class Transpose(Constraint): - - def __init__(self, tensor_size, input_var, index1, index2, output): - """ - Args: - tensor_size: current tensor size - input_var: variable to hold input - index1: dimension 1 - index2: dimension 2 - output: output that stores result - """ - assert isinstance(input_var, TVar) - assert isinstance(output, TVar) - assert isinstance(index1, int) - assert isinstance(index2, int) - - self.input_var = input_var - self.tensor_size = tensor_size - self.index1 = index1 - self.index2 = index2 - self.output = output - - def __repr__(self): - - return f' {self.output} = ' \ - f'Transpose({self.input_var}, ' \ - f'tensor_size: {self.tensor_size}, ' \ - f'{self.index1}, ' \ - f'{self.index2})' - - def __eq__(self, other): - if isinstance(other, Transpose): - return self.tensor_size == other.tensor_size and \ - self.index1 == other.index1 and \ - self.index2 == other.index2 and \ - self.output == other.output and \ - self.input_var == other.input_var - else: - return False - - -class GetItem(Constraint): - - def __init__(self, tensor_size, index, res, input_var): - """ - Constraint for getting item given a tensor size - :param tensor_size: actual number - :param index: actual number representing the index - :param res: dimension variable to carry the item we get - :param input_var: a tensor variable from which we will get item - """ - assert isinstance(res, DVar) - - self.res = res - self.tensor_size = tensor_size - self.index = index - self.input_var = input_var - - def __repr__(self): - return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})' - - def __eq__(self, other): - if isinstance(other, GetItem): - return self.res == other.res and \ - self.tensor_size == other.tensor_size and \ - self.index == other.index and \ - self.input_var == other.input_var - else: - return False - -class GetItemTensor(Constraint): - - def __init__(self, tensor_size, index_tuple, res, input_var): - """ - Constraint for getting item given a tensor size - However, when the argument is a tuple, we will - expect a tensor - :param tensor_size: actual number representing the rank - :param index_tuple: tuple for indexing - :param res: tensor variable to carry the item we get - :param input_var: a tensor variable from which we will get item - """ - assert isinstance(res, TVar) - - self.res = res - self.tensor_size = tensor_size - self.index_tuple = index_tuple - self.input_var = input_var - - def __repr__(self): - return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})' - - def __eq__(self, other): - if isinstance(other, GetItemTensor): - return self.res == other.res and \ - self.tensor_size == other.tensor_size and \ - self.index_tuple == other.index_tuple and \ - self.input_var == other.input_var - else: - return False - -class CalcConv(Constraint): - - def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars): - """ - :param conv_result: the convolution result - :param input_var: input to convolution - :param c_out: output chanel type - :param kernel: kernel tuple - """ - self.conv_result = conv_result - self.input_var = input_var - self.c_out = c_out - self.kernel = kernel - self.padding = padding - self.stride = stride - self.dilation = dilation - self.matching_constraint = matching_constraint_vars - - def __repr__(self): - return f'{self.conv_result} =' \ - f' calc-conv({self.input_var},' \ - f' {self.c_out}, {self.kernel}, ' \ - f'{self.padding}, {self.stride},' \ - f' {self.dilation})' - - def __eq__(self, other): - if isinstance(other, CalcConv): - return self.conv_result == other.conv_result and self.input_var == other.input_var and \ - self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \ - and self.stride == other.stride and self.dilation == other.dilation \ - and self.matching_constraint == other.matching_constraint - else: - return False - - -class CalcMaxPool(Constraint): - - def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars): - """ - :param maxpool_result: the result of maxpool - :param input_var: input to convolution - :param kernel: kernel tuple - """ - self.maxpool_result = maxpool_result - self.input_var = input_var - self.kernel = kernel - self.padding = padding - self.stride = stride - self.dilation = dilation - self.matching_constraint = matching_constraint_vars - - def __repr__(self): - return f'{self.maxpool_result} =' \ - f' calc-maxpool({self.input_var},' \ - f' {self.kernel}, ' \ - f'{self.padding}, {self.stride},' \ - f' {self.dilation})' - - def __eq__(self, other): - if isinstance(other, CalcMaxPool): - return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \ - and self.kernel == other.kernel and self.padding == other.padding \ - and self.stride == other.stride and self.dilation == other.dilation \ - and self.matching_constraint == other.matching_constraint - else: - return False - - -class ApplyBroadcasting(Constraint): - def __init__(self, res1, res2, input1, input2): - """ - :param res1: resulting tensor 1 - :param res2: resulting tensor 2 - :param input1: tensor variable 1 - :param input2: tensor variable 2 - """ - self.res1 = res1 - self.res2 = res2 - self.input1 = input1 - self.input2 = input2 - - def __eq__(self, other): - if isinstance(other, ApplyBroadcasting): - return self.res1 == other.res1 \ - and self.res2 == other.res2 \ - and self.input1 == other.input1 \ - and self.input2 == other.input2 - else: - return False - - def __repr__(self): - return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})' - - -class CalcProduct(Constraint): - """ - Given correct dimensions, calculate the product for flatten accounting for Dyn - """ - def __init__(self, start, end, flattened, dims_to_flatten): - """ - :param start: start index - :param end: end index - :param theta: variable to store the product - :param dims_to_flatten: the type which we will flatten - """ - assert isinstance(dims_to_flatten, list) - assert isinstance(flattened, TVar) - assert isinstance(start, int) - assert isinstance(end, int) - - self.start = start - self.end = end - self.dims_to_flatten = dims_to_flatten - self.flattened = flattened - - def __eq__(self, other): - if isinstance(other, CalcProduct): - return self.start == other.start and self.end == other.end and \ - self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened - - else: - return False - - def __repr__(self): - return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})' - - -class TVar: - """ - Tensor variable with no tensor constructor - """ - def __init__(self, tvar): - """ - :param tvar: tensor variable - """ - self.tvar = tvar - - def __repr__(self): - return f'TV({self.tvar})' - - def __eq__(self, other): - if isinstance(other, TVar): - return self.tvar == other.tvar - else: - return False - - -class DVar: - """ - Dimension variable - """ - def __init__(self, c): - """ - :param c: character or number - """ - self.c = c - - def __repr__(self): - return f'DV({self.c})' - - def __eq__(self, other): - if isinstance(other, DVar): - return self.c == other.c - else: - return False - - -class BVar: - """ - Boolean variable - """ - def __init__(self, c): - """ - :param c: character or number - """ - self.c = c - - def __repr__(self): - return f'BV({self.c})' - - def __eq__(self, other): - if isinstance(other, BVar): - return self.c == other.c - else: - return False - - -def is_algebraic_expression(constraint): - if isinstance(constraint, BinConstraintD): - return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] - else: - return isinstance(constraint, Prod) - - -def is_bool_expr(constraint): - if isinstance(constraint, BinConstraintD): - return constraint.op in [op_gt, op_lt, op_neq, op_eq] - else: - return isinstance(constraint, BVar) or isinstance(constraint, Conj) or isinstance(constraint, Disj) - -def is_dim(d): - return isinstance(d, DVar) or isinstance(d, int) or d == Dyn diff --git a/pippy/fx/experimental/migrate_gradual_types/constraint_generator.py b/pippy/fx/experimental/migrate_gradual_types/constraint_generator.py deleted file mode 100644 index b47f0160f..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/constraint_generator.py +++ /dev/null @@ -1,1282 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import operator -import warnings -from typing import Callable, Dict, Iterable - -from pippy.fx._symbolic_trace import _assert_is_none -from pippy.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \ - Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \ - TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound -from pippy.fx.experimental.migrate_gradual_types.operation import \ - op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul -from pippy.fx.node import Target, Node -from pippy.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \ - gen_bvar - -from pippy.fx.tensor_type import Dyn, TensorType -from torch.nn.modules.conv import Conv2d -from torch.nn.modules.batchnorm import BatchNorm2d - -_INFERENCE_RULES: Dict[Target, Callable] = {} - -MAX_TENSOR_RANK = 4 - -def register_inference_rule(call_target): - def register(fn): - if call_target in _INFERENCE_RULES: - raise RuntimeError(f'Inference rule already registered for {call_target}!') - _INFERENCE_RULES[call_target] = fn - return fn - return register - - -def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): - d, counter = gen_tensor_dims(n, counter) - c1 = BinConstraintT(input, TensorType(d), op_eq) - start_dim = n if start_dim == -1 else abs(start_dim) - end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 - c2 = CalcProduct(start_dim, end_dim, flattened, d) - nat_constraints = gen_nat_constraints(d) - return Conj([c1, c2, *nat_constraints]), counter - - -@register_inference_rule(getattr) -def get_attr_inference_rule(n: Node, symbols, constraints, counter): - """ - If the attribute is "device" then the tensor shape is preserved - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], str) - output, counter = gen_tvar(counter) - symbols[n] = output - - input = symbols[n.args[0]] - attr = n.args[1] - - if attr == 'device': - return [BinConstraintT(input, output, op_eq)], counter - else: - raise NotImplementedError('Not yet implemented') - -@register_inference_rule(torch.bmm) -def bmm_inference_rule(n: Node, symbols, constraints, counter): - """ - Constraints that match the input to a size 3 tensor - and switch the dimensions according to the rules - of batch multiplication - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], Node) - - bmm_output, counter = gen_tvar(counter) - symbols[n] = bmm_output - - bmm_input1 = symbols[n.args[0]] - bmm_input2 = symbols[n.args[1]] - - dims_input1, counter = gen_tensor_dims(3, counter) - dims_input2, counter = gen_tensor_dims(3, counter) - - inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), - BinConstraintT(bmm_input2, Dyn, op_eq), - BinConstraintT(bmm_output, Dyn, op_eq)]) - - input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), - BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), - BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)]) - - input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq), - BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), - BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)]) - - consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)] - - batch_size, counter = gen_dvar(counter) - - inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), - BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), - BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), - *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])]) - - return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter - - -@register_inference_rule("index_select") -def index_select_inference_rule(n: Node, symbols, constraints, counter): - """ - We constrain the second argument to a vector or Dyn. - The output replaces the input with the shape of the vector - at the position given by the index (first argument) - """ - # print(n.args) - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], int) - assert isinstance(n.args[2], Node) - - - - index_select, counter = gen_tvar(counter) - symbols[n] = index_select - - dims, counter = gen_tensor_dims(1, counter) - - # equality constraint - is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) - is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) - - c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) - for i in range(MAX_TENSOR_RANK)])]) - c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) - for i in range(MAX_TENSOR_RANK)])]) - - return [Disj([c2, c3])], counter - - -@register_inference_rule("expand") -def expand_inference_rule(n: Node, symbols, constraints, counter): - """ - We generate the exact constraints as we do for tensor additions but we constraint - the rank of this expression to be equal to len(n.args[1:]) so that only - those cases get considered for the output - """ - assert isinstance(n.args[0], Node) - - # define the output for expand - expand, counter = gen_tvar(counter) - symbols[n] = expand - - # since we do not have two nodes here, we will construct an argument variable - e1 = symbols[n.args[0]] - e2, counter = gen_tvar(counter) - - e2_nat_constraints = [] - for arg in n.args[1:]: - assert isinstance(arg, Node) or isinstance(arg, int) - if isinstance(arg, Node): - assert isinstance(symbols[arg], DVar) - e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) - - e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) - - constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) - - # constraint the output size - dims, counter = gen_tensor_dims(len(n.args[1:]), counter) - nat_constraints = gen_nat_constraints(dims) - c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] - constraints += c - - return constraints, counter - - -@register_inference_rule(torch.nn.functional.gelu) -@register_inference_rule(torch.nn.functional.dropout) -@register_inference_rule(torch.nn.functional.softmax) -@register_inference_rule("detach") -@register_inference_rule("to") -@register_inference_rule("int") -@register_inference_rule("long") -@register_inference_rule("contiguous") -@register_inference_rule(torch.ones) -@register_inference_rule(torch.zeros) -def equality_inference_rule(n: Node, symbols, constraints, counter): - """ - We generate the constraint: input = output - """ - output, counter = gen_tvar(counter) - symbols[n] = output - - if isinstance(n.args[0], Node): - input = symbols[n.args[0]] - if isinstance(input, TVar): - return [BinConstraintT(input, output, op_eq)], counter - - # then we have dimension variables - else: - for arg in n.args: - assert isinstance(symbols[arg], DVar) - my_size = [symbols[arg] for arg in n.args] - return [BinConstraintT(output, TensorType(my_size), op_eq)], counter - - elif isinstance(n.args[0], tuple): - # then the tuple is the size - assert len(n.args[0]) <= 4 - my_size = [symbols[arg] for arg in n.args[0]] - return [BinConstraintT(output, TensorType(my_size), op_eq)], counter - else: - raise NotImplementedError('Method not yet implemented') - - -@register_inference_rule("transpose") -def transpose_inference_rule(n: Node, symbols, constraints, counter): - """ - Can be considered as a sequence of two index selects, so we generate constraints accordingly - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], int) - assert isinstance(n.args[2], int) - - output, counter = gen_tvar(counter) - symbols[n] = output - - from_arg = symbols[n.args[0]] - assert isinstance(from_arg, TVar) - - # input and output are dyn - is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]) - - # or input is a tensor and we actually do the replacement - c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)]) - - return [Disj([is_dyn, c3])], counter - - -@register_inference_rule("type_as") -def type_inference_rule(n: Node, symbols, constraints, counter): - """ - We generate the constraint: input = output - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], Node) - - output, counter = gen_tvar(counter) - symbols[n] = output - - from_arg = symbols[n.args[0]] - to_arg = symbols[n.args[1]] - - assert isinstance(from_arg, TVar) - assert isinstance(to_arg, TVar) - - return [BinConstraintT(from_arg, to_arg, op_consistency), - BinConstraintT(output, to_arg, op_eq)], counter - -@register_inference_rule("masked_fill_") -def masked_fill_inference_rule(n: Node, symbols, constraints, counter): - """ - Similar to addition. For now we implemenent the constraints when - the argument is a boolean tensor. There is also a case for when - it is a condition. We will leave this out for now. - """ - - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], Node) - - # We will retrieve the type variables from the symbol table - # and confirm they are tensor variables - - e1 = symbols[n.args[0]] - e2 = symbols[n.args[1]] - - if isinstance(e1, TVar) and isinstance(e2, TVar): - masked_fill_tensor, counter = gen_tvar(counter) - symbols[n] = masked_fill_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor) - else: - raise NotImplementedError('Not yet implemented') - - -@register_inference_rule(torch.nn.functional.embedding) -def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - embedding_dim_weights = symbols[n.args[1]] - - # will treat this as a static shape. So we will not use matching. - weight_dims, counter = gen_tensor_dims(2, counter) - equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq) - embedding_dim = weight_dims[1] - constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) - return [equality_constraint] + constraints, counter - - -@register_inference_rule(torch.nn.modules.sparse.Embedding) -def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): - """ - The output shape differs from the input shape in the last dimension - """ - assert isinstance(n.args[0], Node) - return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter) - - -def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): - - embedding_output, counter = gen_tvar(counter) - symbols[n] = embedding_output - embedding_input = symbols[n.args[0]] - - input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) - output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) - - c1 = Conj([input_dyn, output_dyn]) - c2 = [] - - for i in range(1, MAX_TENSOR_RANK): - new_dims, counter = gen_tensor_dims(i, counter) - nat_constraints = gen_nat_constraints(new_dims) - - # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases - c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), - BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + - nat_constraints) - c2.append(c_tensor_i) - - return [Disj([c1, Disj(c2)])], counter - - -@register_inference_rule(torch.tensor) -def tensor_inference_rule(n: Node, symbols, constraints, counter): - """ - If the tensor is a scalar, we will skip it since we - do not support scalars yet. We will add support in the future - if it's needed. For our examples so far, scalars are not needed. - """ - return [], counter - - -@register_inference_rule("reshape") -@register_inference_rule("view") -def view_inference_rule(n: Node, symbols, constraints, counter): - """ - Similar to reshape but with an extra condition on the strides - """ - assert isinstance(n.args[0], Node) - - # generate the new variable - my_view, counter = gen_tvar(counter) - symbols[n] = my_view - - - src_var = symbols[n.args[0]] - t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape - t2_type = [] - num_constraints = [] - - for t in t2: - if t == -1: - var, counter = gen_dvar(counter) - t2_type.append(var) - num_constraints.append(BinConstraintD(var, Dyn, op_neq)) - - else: - num_constraints.append(BinConstraintD(t, Dyn, op_neq)) - t2_type.append(t) - - t2_type = TensorType(t2_type) # type: ignore[assignment] - - c1 = BinConstraintT(my_view, t2_type, op_eq) - c2 = CanReshape(src_var, t2_type) - - # TODO: add the extra check mentioned here: - # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view - - return [c1, c2] + num_constraints, counter # type: ignore[operator] - - -@register_inference_rule("size") -def size_inference_rule(n: Node, symbols, constraints, counter): - """ - The constraint is just lhs = rhs. - Ex: size = input_ids.size() - """ - - - if len(n.args) == 1: - # generate the new variable - size, counter = gen_tvar(counter) - symbols[n] = size - input = symbols[n.args[0]] - c = BinConstraintT(input, size, op_eq) - return [c], counter - - elif len(n.args) == 2: - # TODO: review this rule; should input = dyn; output = dyn be included here? - if isinstance(n.args[1], int): - # generate the new variable - size_index, counter = gen_dvar(counter) - symbols[n] = size_index - input = symbols[n.args[0]] - c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)] - c3 = BinConstraintD(0, size_index, op_leq) - - input_dyn = BinConstraintT(input, Dyn, op_eq) - output_dyn = BinConstraintD(size_index, Dyn, op_eq) - c1 = Conj([input_dyn, output_dyn]) - - return [Disj([c1, Conj([Disj(c2), c3])])], counter - - else: - raise NotImplementedError - - else: - raise NotImplementedError - - -def range_check(i, n): - """ - Checks if an index i is within range of a size n list - Args: - i: index - n: list size - - Returns: Boolean - """ - if i >= 0: - return T() if i < n else F() - else: - return T() if i >= n else F() - - -@register_inference_rule(torch.cumsum) -def cumsum_inference_rule(n: Node, symbols, constraints, counter): - """ - Input and output shapes should be equal - We should verify that the index is valid - """ - assert isinstance(n.args[0], Node) - arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] - assert isinstance(arg_1, int) - - output, counter = gen_tvar(counter) - symbols[n] = output - input = symbols[n.args[0]] - - input_dyn = BinConstraintT(input, Dyn, op_eq) - output_dyn = BinConstraintT(output, Dyn, op_eq) - c1 = Conj([input_dyn, output_dyn]) - c2 = [] - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims, counter = gen_tensor_dims(i, counter) - - nat_constraints = gen_nat_constraints(new_dims) - - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), - BinConstraintT(output, TensorType(new_dims), op_eq)] + - [range_check(arg_1, i)] + nat_constraints) - - c2.append(c_tensor_i) - dyn_or_tensor = Disj([c1, Disj(c2)]) - return [dyn_or_tensor], counter - - -@register_inference_rule(_assert_is_none) -def assert_inference_rule(n: Node, symbols, constraints, counter): - assert len(n.users) == 0 - return [], counter - - -@register_inference_rule(operator.getitem) -def getitem_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - # dimension output case - if isinstance(n.args[1], int): - # create and store the new dimension variable - get_item_output, counter = gen_dvar(counter) - symbols[n] = get_item_output - - # retreive arg variables - get_item_arg = symbols[n.args[0]] - assert isinstance(get_item_arg, TVar) - - - # if the input is dynamic, we accept any index and return - # a dynamic dimension as output - input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) - output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) - c1 = Conj([input_dyn, output_dyn]) - - # if the input is a tensor, - # generate a getItem constraint which will be expanded based on the - # tensor dimension. - - c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)] - - - # since the output is a dimension, we make sure it's a natural number - # added as a conjunction to the disjuction of c2 - c3 = BinConstraintD(0, get_item_output, op_leq) - return [Disj([c1, Conj([Disj(c2), c3])])], counter - - # tensor output case - elif isinstance(n.args[1], tuple): - # create and store the new tensor variable - get_item_output, counter = gen_tvar(counter) - symbols[n] = get_item_output - - # retreive arg variables - if n.args[0] in symbols: - get_item_arg = symbols[n.args[0]] - assert isinstance(get_item_arg, TVar) - - input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) - output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] - c1 = Conj([input_dyn, output_dyn]) - - c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] - for i in range(MAX_TENSOR_RANK)] - else: - # TODO: we should figure out why there is a key-error here. - return [], counter - - return [Disj([c1, *c2])], counter - - else: - raise RuntimeError('Method not yet implemented') - - -@register_inference_rule(operator.gt) -def gt_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) - assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) - - # We make sure this node will not be used again. We do not - # generate a constraint about that node. Only about the operands. - - e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] - e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] - - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(e1, TVar) and isinstance(e2, TVar): - gt_tensor, counter = gen_tvar(counter) - symbols[n] = gt_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) - - elif isinstance(e1, DVar) and isinstance(e2, DVar): - # This is meant to be used for flow analysis only - gt_constraint = BinConstraintD(e1, e2, op_gt) - - my_gt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) - return [equality_constraint], counter - - else: - raise RuntimeError('Sort Mismatch') - - elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): - if isinstance(e1, DVar): - # This is meant to be used for flow analysis only - gt_constraint = BinConstraintD(e1, e2, op_gt) - - my_gt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) - return [equality_constraint], counter - - elif isinstance(e1, TVar) and isinstance(e2, int): - # then we made the wrong assumption about the argument being a tensor - # so we should fix the assumption - warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.') - - new_e1, counter = gen_dvar(counter) - symbols[n.args[0]] = new_e1 - symbols[n.args[0]] - - gt_constraint = BinConstraintD(new_e1, e2, op_gt) - - my_gt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) - return [equality_constraint], counter - - else: - raise NotImplementedError('Method not yet implemented') - - else: - raise NotImplementedError('Method not yet implemented') - - -@register_inference_rule(operator.eq) -def eq_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) - assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) - - e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] - e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] - - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(e1, TVar) and isinstance(e2, TVar): - eq_tensor, counter = gen_tvar(counter) - symbols[n] = eq_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor) - - elif isinstance(e1, DVar) and isinstance(e2, DVar): - # This is meant to be used for flow analysis only - eq_constraint = BinConstraintD(e1, e2, op_eq) - - my_eq, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) - return [equality_constraint], counter - - else: - raise RuntimeError('Sort Mismatch') - - elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): - if isinstance(e1, DVar): - # This is meant to be used for flow analysis only - eq_constraint = BinConstraintD(e1, e2, op_eq) - - my_eq, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) - return [equality_constraint], counter - else: - raise NotImplementedError('Method not yet implemented') - else: - raise NotImplementedError('Method not yet implemented') - -@register_inference_rule(operator.ne) -def neq_inference_rule(n: Node, symbols, constraints, counter): - """ - Translates to inconsistent in gradual types. - To prove inequality, we should prove that - tensors are either different sizes or - disagree on at least one dimension - - This is a WIP (works when the condition - is false. We are working on making this operation work - when the condition is true as well) - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], tuple) - - # implementing for size 3 and 4 - if len(n.args[1]) == 3: - - assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) - assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) - assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) - - lhs = symbols[n.args[0]] - - b, counter = gen_tensor_dims(4, counter) - input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) - - d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] - d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] - d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] - - # dimensions not equal - my_ne, counter = gen_bvar(counter) - neq_1 = BinConstraintD(d1, b[0], op_neq) - neq_2 = BinConstraintD(d2, b[1], op_neq) - neq_3 = BinConstraintD(d3, b[2], op_neq) - - # dimensions inconsistent - dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) - dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) - dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) - - dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) - - # we are covering size 3 and 4 only for now - ne_constraint = Conj([input_is_size3, dims_inconsistent]) - - my_ne, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) - - elif len(n.args[1]) == 4: - - assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) - assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) - assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) - assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int) - - lhs = symbols[n.args[0]] - - b1, counter = gen_dvar(counter) - b2, counter = gen_dvar(counter) - b3, counter = gen_dvar(counter) - b4, counter = gen_dvar(counter) - - input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) - - d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] - d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] - d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] - d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] - - # dimensions not equal - my_ne, counter = gen_bvar(counter) - neq_1 = BinConstraintD(d1, b1, op_neq) - neq_2 = BinConstraintD(d2, b2, op_neq) - neq_3 = BinConstraintD(d3, b3, op_neq) - neq_4 = BinConstraintD(d4, b4, op_neq) - - # dimensions to inconsistent - dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) - dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) - dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) - dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) - - dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) - - ne_constraint = Conj([input_is_size4, dims_inconsistent]) - - my_ne, counter = gen_bvar(counter) - - equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) - - else: - raise NotImplementedError('Method not yet implemented') - - return [equality_constraint], counter - - -@register_inference_rule(operator.lt) -def lt_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) - assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) - - # We make sure this node will not be used again. We do not - # generate a constraint about that node. Only about the operands. - - e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] - e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] - - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(e1, TVar) and isinstance(e2, TVar): - lt_tensor, counter = gen_tvar(counter) - symbols[n] = lt_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor) - - elif isinstance(e1, DVar) and isinstance(e2, DVar): - # This is meant to be used for flow analysis only - lt_constraint = BinConstraintD(e1, e2, op_lt) - - my_lt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) - return [equality_constraint], counter - - else: - raise RuntimeError('Sort Mismatch') - - elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): - if isinstance(e1, DVar): - # This is meant to be used for flow analysis only - lt_constraint = BinConstraintD(e1, e2, op_lt) - - my_lt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) - return [equality_constraint], counter - else: - raise NotImplementedError('Method not yet implemented') - - else: - raise NotImplementedError('Method not yet implemented') - - -@register_inference_rule(torch.full) -def full_inference_rule(n: Node, symbols, constraints, counter): - full, counter = gen_tvar(counter) - symbols[n] = full - res = [] - - assert isinstance(n.args[0], Iterable) - for arg in n.args[0]: - dim = arg if isinstance(arg, int) else symbols[arg] - res.append(dim) - c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] - return [c], counter - - -# TODO normalize index -@register_inference_rule(torch.arange) -def arange_inference_rule(n: Node, symbols, constraints, counter): - start = 0 - step = 1 - - if len(n.args) == 1: - end = symbols[n.args[0]] - else: - raise NotImplementedError('Not yet implemented') - - # int((end - start) / step) - d1, counter = gen_dvar(counter) - size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) - arange, counter = gen_tvar(counter) - symbols[n] = arange - - # either the a parameter is a number or it is Dyn - c1 = Disj([BinConstraintD(end, Dyn, op_eq), - BinConstraintD(start, Dyn, op_eq), - BinConstraintD(step, Dyn, op_eq)]) - c2 = BinConstraintD(d1, Dyn, op_eq) - both_dyn = Conj([c1, c2]) - - c11 = Conj([BinConstraintD(end, Dyn, op_neq), - BinConstraintD(start, Dyn, op_neq), - BinConstraintD(step, Dyn, op_neq)]) - c22 = BinConstraintD(d1, Dyn, op_neq) - both_numbers = Conj([c11, c22, size_constraint]) - - return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter - -def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): - # additional vars that don't correspond to expressions - e11, counter = gen_tvar(counter) - e22, counter = gen_tvar(counter) - - # generate constraints - c1 = TGreatestUpperBound(output_var, e11, e22) - c2 = ApplyBroadcasting(e11, e22, e1, e2) - c3 = BinConstraintT(e11, e22, op_consistency) - return [c1, c2, c3], counter - - -@register_inference_rule(operator.mul) -@register_inference_rule(torch.ne) -@register_inference_rule("ne") -@register_inference_rule(torch.add) -@register_inference_rule(operator.add) -def broadcasting_inference_rule(n: Node, symbols, constraints, counter): - - op_code = None - if n.target == operator.add or n.target == torch.add: - op_code = op_add - elif n.target == operator.mul: - op_code = op_mul - - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): - my_output, counter = gen_tvar(counter) - symbols[n] = my_output - e1 = symbols[n.args[0]] - e2 = symbols[n.args[1]] - - return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) - else: - raise NotImplementedError('Method not yet implemented') - - elif isinstance(n.args[0], Node) and (isinstance(n.args[1], int) or isinstance(n.args[1], float)): - if isinstance(symbols[n.args[0]], TVar): - my_output, counter = gen_tvar(counter) - symbols[n] = my_output - e1 = symbols[n.args[0]] - return [BinConstraintT(my_output, e1, op_eq)], counter - elif isinstance(symbols[n.args[0]], DVar): - my_output, counter = gen_dvar(counter) - symbols[n] = my_output - e1 = symbols[n.args[0]] - - # we will propagate the runtime value here since this is regular addition - c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), - BinConstraintD(0, my_output, op_leq)]) - return [c], counter - - elif isinstance(n.args[1], Node) and (isinstance(n.args[0], int) or isinstance(n.args[1], float)): - if isinstance(symbols[n.args[1]], TVar): - my_output, counter = gen_tvar(counter) - symbols[n] = my_output - e2 = symbols[n.args[1]] - return [BinConstraintT(my_output, e2, op_eq)], counter - elif isinstance(symbols[n.args[1]], DVar): - my_output, counter = gen_dvar(counter) - symbols[n] = my_output - e2 = symbols[n.args[1]] - - # we will propagate the runtime value here since this is regular addition - c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), - BinConstraintD(0, my_output, op_leq)]) - return [c], counter - - else: - raise NotImplementedError('Method not yet implemented') - - else: - # TODO generate add constraints for scalar addition - raise NotImplementedError('Addition not yet implemented') - - -@register_inference_rule(torch.flatten) -def flatten_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - # generate the new variable - flattened, counter = gen_tvar(counter) - symbols[n] = flattened - - input = symbols[n.args[0]] - - # set the default start and end dims - start_dim = 1 - end_dim = -1 - - if len(n.args) > 1: - assert isinstance(n.args[1], int) - start_dim = n.args[1] - - if len(n.args) > 2: - assert isinstance(n.args[2], int) - end_dim = n.args[2] - - c1 = BinConstraintT(input, Dyn, op_eq) - c2 = BinConstraintT(flattened, Dyn, op_eq) - both_dyn = Conj([c1, c2]) - - const = [] - for i in range(1, MAX_TENSOR_RANK + 1): - c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) - const.append(c) - - return [Disj([both_dyn, *const])], counter - - -@register_inference_rule(torch.nn.functional.layer_norm) -def layer_norm_functional(n: Node, symbols, constraints, counter): - """ - We generate the constraint: input = output - """ - assert isinstance(n.args[0], Node) - return gen_layer_norm_constraints(n, n.args[1], symbols, counter) - - -@register_inference_rule(torch.nn.LayerNorm) -def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): - """ - Input and output shapes should be equal. - Input should be consistent with the normalized_shape - """ - assert isinstance(n.args[0], Node) - return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter) - - -def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): - output, counter = gen_tvar(counter) - symbols[n] = output - input = symbols[n.args[0]] - - input_dyn = BinConstraintT(input, Dyn, op_eq) - output_dyn = BinConstraintT(output, Dyn, op_eq) - - c1 = Conj([input_dyn, output_dyn]) - - c2 = [] - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims_rhs, counter = gen_tensor_dims(i, counter) - nat_constraints = gen_nat_constraints(new_dims_rhs) - - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), - BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + - add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + - nat_constraints) - c2.append(c_tensor_i) - return [Disj([c1, Disj(c2)])], counter - -@register_inference_rule(torch.nn.Dropout) -@register_inference_rule(torch.nn.ReLU) -def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): - """ - Input and output shapes should be equal. - """ - assert isinstance(n.args[0], Node) - output, counter = gen_tvar(counter) - symbols[n] = output - input = symbols[n.args[0]] - assert isinstance(input, TVar) - return [BinConstraintT(input, output, op_eq)], counter - - -@register_inference_rule(torch.nn.Linear) -def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): - """ - Input and output sizes should be the same except for the last dimension - If the input is Dyn, then so should the output - """ - assert isinstance(n.args[0], Node) - return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter) - - -@register_inference_rule("dim") # type: ignore[attr-defined] -def torch_dim_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - my_dim, counter = gen_dvar(counter) - symbols[n] = my_dim - input = symbols[n.args[0]] - - input_dyn = BinConstraintT(input, Dyn, op_eq) - output_dyn = BinConstraintD(my_dim, Dyn, op_eq) - - c1 = [] - - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims_rhs_1, counter = gen_tensor_dims(i, counter) - - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), - BinConstraintD(my_dim, i, op_eq)]) - c1.append(c_tensor_i) - - return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter - - -@register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined] -def torch_linear_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - weight_dims, counter = gen_tensor_dims(2, counter) - equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq) - constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter) - return [equality_constraint] + constraints, counter - - -def linear_constraints(n: Node, in_features, out_features, symbols, counter): - linear_output, counter = gen_tvar(counter) - symbols[n] = linear_output - linear_input = symbols[n.args[0]] - - input_dyn = BinConstraintT(linear_input, Dyn, op_eq) - output_dyn = BinConstraintT(linear_output, Dyn, op_eq) - - c1 = Conj([input_dyn, output_dyn]) - - c2 = [] - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims_rhs_1, counter = gen_tensor_dims(i, counter) - new_dims_rhs_2, counter = gen_tensor_dims(i, counter) - - nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) - - c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), - BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + - add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) + - nat_constraints) - c2.append(c_tensor_i) - return [Disj([c1, Disj(c2)])], counter - -def add_layer_norm_constraints(input_dim, normalized_dim): - """ - The constraints say that the type has te form: [*, 1024, 1024] - while the normalized_dim have the form [1024, 1024] - Args: - input_dim: Input shape of layer norm - normalized_dim: normalized_dim parameter of the module instance - - """ - - # in this case we return false since there's a pattern mismatch - if len(normalized_dim) > len(input_dim): - return [F()] - - else: - constraints = [] - for i, n in zip(reversed(input_dim), reversed(normalized_dim)): - constraints.append(BinConstraintD(i, n, op_consistency)) - return constraints - - -def add_linear_constraints(dims1, dims2, in_features, out_features): - assert len(dims1) == len(dims2) - constraints = [] - for i in range(len(dims1)): - if i == len(dims1) - 1: - constraints.append(BinConstraintD(dims1[i], in_features, op_consistency)) - constraints.append(BinConstraintD(dims2[i], out_features, op_eq)) - else: - constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq)) - - return constraints - - -@register_inference_rule(torch.reshape) -def reshape_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - # generate the new variable - my_reshape, counter = gen_tvar(counter) - symbols[n] = my_reshape - - src_var = symbols[n.args[0]] - t2 = n.args[1] - t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] - c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] - c2 = CanReshape(src_var, t2_type) - - return [c1, c2], counter - - -@register_inference_rule(BatchNorm2d) -def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - # generate the new variable - batchnorm_output, counter = gen_tvar(counter) - symbols[n] = batchnorm_output - batchnorm_input = symbols[n.args[0]] - - # dim vars - d1, counter = gen_dvar(counter) - d2, counter = gen_dvar(counter) - d3, counter = gen_dvar(counter) - d4, counter = gen_dvar(counter) - - nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) - - c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) - c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) - return [c1, c2, *nat_constraints], counter - - -@register_inference_rule(torch.nn.AdaptiveAvgPool2d) -def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - avg_pool, counter = gen_tvar(counter) - - symbols[n] = avg_pool - input_var = symbols[n.args[0]] - - # dim vars - d1, counter = gen_dvar(counter) - d2, counter = gen_dvar(counter) - d3, counter = gen_dvar(counter) - d4, counter = gen_dvar(counter) - nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) - c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) - - return [c1, c2, *nat_constraints], counter - - -@register_inference_rule(Conv2d) -def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - my_conv, counter = gen_tvar(counter) - symbols[n] = my_conv - input_var = symbols[n.args[0]] - - # dim vars - [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) - - # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) - c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - - # c2 = DConsistency(module_instance.in_channels, d2) - c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) - - c3 = CalcConv(my_conv, input_var, - module_instance.out_channels, - module_instance.kernel_size, - module_instance.padding, - module_instance.stride, - module_instance.dilation, [d1, d2, d3, d4]) - - nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) - - return [c1, c2, c3, *nat_constraints], counter - - -@register_inference_rule(torch.nn.MaxPool2d) -def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - maxpool, counter = gen_tvar(counter) - symbols[n] = maxpool - input_var = symbols[n.args[0]] - - # dim vars - [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) - - c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - - c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, - module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) - - nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) - - return [c1, c2, *nat_constraints], counter - - -class ConstraintGenerator: - def __init__(self, traced, graph=None): - self.traced = traced # traced or tracer.root - self.traced_params = dict(self.traced.named_parameters()) - self.constraints = [] - self.symbol_dict = {} - self.graph = traced.graph if hasattr(traced, 'graph') else graph - - - def generate_constraints(self, counter=0): - """ - Iterate through every node and generate constraints - Effect: self.constraints will be populated with the final constraints - """ - graph = self.graph - - all_constraints = [] - - for n in graph.nodes: - (constraints, counter) = self.generate_constraints_node(n, counter) - all_constraints += constraints - - return Conj(all_constraints), counter - - def generate_constraints_node(self, n: Node, counter): - """ - Generate constraints the given node: - Currently supported operations: - - Reshape - - Add - - conv2d - """ - - if n.op == 'placeholder': - x, counter = gen_tvar(counter) - self.symbol_dict[n] = x - - my_type = n.type - - if n.type != Dyn and (not isinstance(n.type, TensorType)): - if n.type == torch.nn.parameter.Parameter: - # since we have a parameter, the shape must be static - assert 'example_value' in n.meta - my_type = TensorType(n.meta['example_value'].size()) - else: - my_type = Dyn - - c1 = BinConstraintT(my_type, x, op_precision) - c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) - return [c1, c2], counter - - elif n.op == 'call_function': - if n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) - else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') - - elif n.op == 'call_module': - - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _INFERENCE_RULES: - return _INFERENCE_RULES[type(module_instance)](n, - module_instance, - self.symbol_dict, - self.constraints, counter) - else: - raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') - - elif n.op == 'call_method': - if n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) - else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') - - elif n.op == 'get_attr': - t = self.traced_params.get(n.target, None) - - if isinstance(t, torch.Tensor): - if len(t.shape) > 0: - res = [] - for t in t.shape: - res.append(t) - attr_type = TensorType(res) - output, counter = gen_tvar(counter) - self.symbol_dict[n] = output - return [BinConstraintT(output, attr_type, op_eq)], counter - else: - # scalar? - return [], counter - else: - return [], counter - - elif n.op == 'output': - return [], counter - - else: - raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/pippy/fx/experimental/migrate_gradual_types/constraint_transformation.py b/pippy/fx/experimental/migrate_gradual_types/constraint_transformation.py deleted file mode 100644 index ec40a41f6..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ /dev/null @@ -1,1041 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# mypy: ignore-errors -import copy -import itertools -from pippy.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK -from pippy.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \ - Transpose -from pippy.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound -from pippy.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound -from pippy.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool -from pippy.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape -from pippy.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect -from pippy.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching -from pippy.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq -from pippy.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod -from pippy.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar -from pippy.fx.tensor_type import TensorType, Dyn -from typing import Callable, Dict, List - -_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} - - -def register_transformation_rule(call_target): - def register(fn): - if call_target in _TRANSFORMATION_RULES: - raise RuntimeError(f'Transformation rule already registered for {call_target}!') - _TRANSFORMATION_RULES[call_target] = fn - return fn - return register - - -def valid_index(index, dims): - """ - Given a list of dimensions, checks if an index is valid in the list - """ - try: - dims[index] - return T() - except IndexError: - return F() - - -@register_transformation_rule(Transpose) -def transform_transpose(constraint, counter): - """ - Similar to a sequence of two index-selects - """ - dims, counter = gen_tensor_dims(constraint.tensor_size, counter) - is_valid_index1 = valid_index(constraint.index1, dims) - is_valid_index2 = valid_index(constraint.index2, dims) - new_dims = copy.deepcopy(dims) - nat_constraints = gen_nat_constraints(dims) - - if is_valid_index1 == T() and is_valid_index2 == T(): - new_dims[constraint.index1] = dims[constraint.index2] - new_dims[constraint.index2] = dims[constraint.index1] - - transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index1, is_valid_index2, - BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) - return transformed_constraint, counter - - -@register_transformation_rule(IndexSelect) -def transform_index_select(constraint, counter): - """ - The constraints consider the given tensor size, checks if the index is valid - and if so, generates a constraint for replacing the input dimension - with the required dimension - """ - dims, counter = gen_tensor_dims(constraint.tensor_size, counter) - is_valid_index = valid_index(constraint.index, dims) - nat_constraints = gen_nat_constraints(dims) - - # if the index is valid then replace the input dimension with the new dimension - # otherwise the dimension will not be replaced and the clause will contain False - if is_valid_index == T(): - new_dims = copy.deepcopy((dims)) - new_dims[constraint.index] = constraint.dim_replace - - transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index, - BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) - - # print(constraints) - return transformed_constraint, counter - - -@register_transformation_rule(GetItem) -def transform_get_item(constraint, counter): - """ - generate an equality of the form: - t = [a1, ..., an] - then generate constraints that check if the given index is valid - given this particular tensor size. - If the index is valid, generate a constraint to get the item - Note that we already handled the Dyn input case in the previous - step. - Args: - constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) - counter: variable tracking - Returns: simplified constraints for GetItem - - """ - dims, counter = gen_tensor_dims(constraint.tensor_size, counter) - nat_constraints = gen_nat_constraints(dims) - - - is_valid_index = valid_index(constraint.index, dims) - - all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index] - - # if the index is valid, we generate a constraint for getting an item - # otherwise this clause will have been UNSAT due to the wrong index - if is_valid_index == T(): - all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq)) - - return Conj(all_constraints), counter - -def valid_index_tensor(index, dims): - """ - if the slice instances exceed the length of the dimensions - then this is a type error so we return False - """ - slice_count = 0 - for s in index: - if isinstance(s, slice): - slice_count += 1 - if slice_count > len(dims): - return F() - else: - return T() - -@register_transformation_rule(GetItemTensor) -def transform_get_item_tensor(constraint, counter): - """ - When the index is a tuple, then the output will be a tensor - TODO: we have to check if this is the case for all HF models - - The cases we are covrering here are a tuple with one of: - - slice with default argument - - None - - None appends 1 to the input tensor dimensions - so each occurrence of 'None' increases the rank by 1 - - slice with default arguments does not change the rank - """ - assert isinstance(constraint.index_tuple, tuple) - - - # generate a result tensor of the expected size - dims, counter = gen_tensor_dims(constraint.tensor_size, counter) - nat_constraints = gen_nat_constraints(dims) - - # generate a place-holder list of the right rank - # where "slice" does not contribute to the rank and "None" does - none_c = constraint.index_tuple.count(None) - resulting_tensor_dims = (none_c + len(dims)) * [None] - - dim_index = 0 - for i in range(len(constraint.index_tuple)): - - # append 1 to the right location of the resulting tensor - if constraint.index_tuple[i] is None: - resulting_tensor_dims[i] = 1 - - elif constraint.index_tuple[i] == slice(None, None, None): - pass - - else: - raise NotImplementedError('Method not yet implemented') - - # append the remaining dimensions to the right location - dim_index = 0 - for i in range(len(resulting_tensor_dims)): - if resulting_tensor_dims[i] is None: - resulting_tensor_dims[i] = dims[dim_index] - dim_index += 1 - - # check if the index is valid - is_valid_index = valid_index_tensor(constraint.index_tuple, dims) - - # check if the resulting tensor is within bounds - if len(resulting_tensor_dims) > 4: - return F(), counter - - else: - constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), - *nat_constraints, - is_valid_index] - return Conj(constraints), counter - - -@register_transformation_rule(BinConstraintT) -def generate_binconstraint_t(constraint, counter): - """ - Transform binary constraints for tensors - """ - - # precision constraints - if constraint.op == op_precision: - if constraint.lhs == Dyn: - return T(), counter - elif isinstance(constraint.lhs, TensorType): - is_fully_static = all([d != Dyn for d in constraint.lhs.__args__]) - if is_fully_static: - return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter - else: - new_dims = [] - - for _ in range(len(constraint.lhs.__args__)): - dim, counter = gen_dvar(counter) - new_dims.append(dim) - - new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for - new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ - [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ - [BinConstraintD(1, new_dim, op_leq) for - new_dim in new_dims] - return Conj(new_dim_constraints), counter - - # matching - elif constraint.op == op_matching: - assert isinstance(constraint.rhs, TensorType) - d1 = constraint.rhs.__args__[0] - d2 = constraint.rhs.__args__[1] - d3 = constraint.rhs.__args__[2] - d4 = constraint.rhs.__args__[3] - - conj = [BinConstraintT(constraint.lhs, Dyn, op_eq), - BinConstraintD(d1, Dyn, op_eq), - BinConstraintD(d2, Dyn, op_eq), - BinConstraintD(d3, Dyn, op_eq), - BinConstraintD(d4, Dyn, op_eq)] - return Disj([Conj(conj), - BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter - - elif constraint.op == op_consistency: - c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)]) - [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter) - - return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter - - elif constraint.op == op_leq: - assert isinstance(constraint.rhs, int) - disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] - for i in range(1, constraint.rhs + 1): - dims = [] - for j in range(1, i + 1): - dim_var, counter = gen_dvar(counter) - dims.append(dim_var) - disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) - return Disj(disj), counter - else: - return constraint, counter - - -@register_transformation_rule(BinConstraintD) -def generate_binconstraint_d(constraint, counter): - """ - Transform binary constraints for dimensions - """ - if constraint.op == op_precision: - if isinstance(constraint.lhs, int): - return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter - elif constraint.lhs == Dyn: - return T(), counter - - elif constraint.op == op_consistency: - return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq), - BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter - - else: - return constraint, counter - - -@register_transformation_rule(Conj) -def generate_conj(constraint, counter): - """ - Transform conjunctions - """ - new = [] - for c in constraint.conjucts: - new_c, counter = transform_constraint(c, counter) - new.append(new_c) - return Conj(new), counter - - -@register_transformation_rule(Disj) -def generate_disj(constraint, counter): - """ - Transform disjunctions - """ - new = [] - for c in constraint.disjuncts: - new_c, counter = transform_constraint(c, counter) - new.append(new_c) - return Disj(new), counter - - -@register_transformation_rule(TGreatestUpperBound) -def generate_gub(constraint, counter): - """ - Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound - on dimensions - """ - c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq), - BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)]) - - [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) - - return Disj([c1, c2, c3, c4, c5]), counter - - -@register_transformation_rule(DGreatestUpperBound) -def generate_d_gub(constraint, counter): - """ - Transform greatest upper bound for dimensions into equality constraints - """ - c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)]) - c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) - c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) - return Disj([c1, c2, c3]), counter - - -@register_transformation_rule(CalcConv) -def generate_calc_conv(constraint, counter): - d, counter = gen_tensor_dims(4, counter) - conv_result = TensorType([d[0], d[1], d[2], d[3]]) - - # the convolution result is a tensor of size 4 - c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) - - # the second dimension of the output is equal to the output channels - c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)]) - - # the input corresponds to the output in the first dimension of the convolution - c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) - - c4, c5 = calc_last_two_dims(constraint, d) - - leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), - BinConstraintD(0, d[1], op_leq), - BinConstraintD(0, d[2], op_leq), - BinConstraintD(0, d[3], op_leq)]) - - return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter - - -@register_transformation_rule(CalcMaxPool) -def generate_calc_maxpool(constraint, counter): - """ - Transform maxpool constraints - """ - d, counter = gen_tensor_dims(4, counter) - maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) - - # the maxpool result is a tensor of size 4 - c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) - - # the input corresponds to the output in the first and second dimension of maxpool - c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) - c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) - c4, c5 = calc_last_two_dims(constraint, d) - - leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), - BinConstraintD(0, d[1], op_leq), - BinConstraintD(0, d[2], op_leq), - BinConstraintD(0, d[3], op_leq)]) - - return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter - - -@register_transformation_rule(CalcProduct) -def generate_calc_product(constraint, counter): - """ - Transform flatten constraints - """ - start = constraint.start - end = constraint.end - dims = constraint.dims_to_flatten - flattened = constraint.flattened - n = len(constraint.dims_to_flatten) - - # this will be evaluated right here - boundary_check = (0 <= start and start < end and end <= n) - - c_boundary = T() if boundary_check else F() - - lhs = dims[0:start] - rhs = dims[end:] - mid = dims[start:end] - - all_possibilities = generate_all_int_dyn_dim_possibilities(mid) - - all_constraints = [] - - for p in all_possibilities: - p = list(p) - # this tells us there is a dynamic variable - contains_dyn = not(all([constraint.op == op_neq for constraint in p])) - if contains_dyn: - mid_var = [Dyn] - total_constraints = lhs + mid_var + rhs - if len(total_constraints) > 4: - all_constraints.append(F()) - else: - all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p)) - else: - new_var, counter = gen_dvar(counter) - mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)]) - mid_var = [new_var] - total_constraints = lhs + mid_var + rhs - if len(total_constraints) > 4: - all_constraints.append(F()) - else: - all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p)) - - return Conj([Disj(all_constraints), c_boundary]), counter - - -@register_transformation_rule(CanReshape) -def generate_reshape(constraint, counter): - """ - Transform reshape constraints - """ - d, counter = gen_tensor_dims(4, counter) - - d1 = d[0] - d2 = d[1] - d3 = d[2] - d4 = d[3] - - target = constraint.target.__args__ - - is_fully_static = all([d != Dyn for d in target]) - - # dynamic tensor - c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) - c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) - c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) - c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) - c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) - - d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) - d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) - - d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) - d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) - - d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) - d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) - - d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) - d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) - - nat_d1 = BinConstraintD(0, d1, op_leq) - nat_d2 = BinConstraintD(0, d2, op_leq) - nat_d3 = BinConstraintD(0, d3, op_leq) - nat_d4 = BinConstraintD(0, d4, op_leq) - - if is_fully_static: - # size 1 tensor - c3_tensor1 = Disj([d1_eq_dyn, - (Conj([d1_neq_dyn, - BinConstraintD(d1, Prod(target), op_eq)]))]) - all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) - - # size 2 tensor - all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) - - # size 3 tensor - all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) - - # size 4 tensor - all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]) - - return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), - nat_d1, nat_d2, nat_d3, nat_d4]), counter - - # then there must be exactly one occurrence of dyn - else: - new_target = [] - - for n in target: - if n != Dyn: - new_target.append(n) - - # tensor 1 - c3_tensor1 = Disj([d1_eq_dyn, - (Conj([d1_neq_dyn, - is_dim_div_by_target(new_target, d1)]))]) - all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) - - # tensor 2 - c21 = Disj([d1_eq_dyn, d2_eq_dyn]) - c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]) - all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) - - # tensor 3 - c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) - c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))]) - all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) - - # tensor 4 - c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) - c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))]) - all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) - - return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), - nat_d1, nat_d2, nat_d3, nat_d4]), counter - - -@register_transformation_rule(ApplyBroadcasting) -def generate_broadcasting(constraint, counter): - """ - Transform broadcasting constraints - """ - e11, e12 = constraint.res1, constraint.res2 - e1, e2 = constraint.input1, constraint.input2 - - e1_dyn = BinConstraintT(e1, Dyn, op_eq) - e2_dyn = BinConstraintT(e2, Dyn, op_eq) - - # Introduce dimensions - e1_equal_e11 = BinConstraintT(e1, e11, op_eq) - e2_equal_e12 = BinConstraintT(e2, e12, op_eq) - - # dyn possibility - e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) - e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) - - # tensor possibility - # generate dimensions to create tensors of size 1 - final_tensor_1_constraint, _, _, nat_dims_1, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) - - # generate dimensions to create tensors of size 2 - final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ - final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) - - # generate dimensions to create tensors of size 3 - final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ - final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) - - # generate dimensions to create tensors of size 4 - final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ - final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) - - final_result = Disj([ - e1_dyn_constraint, - e2_dyn_constraint, - final_tensor_1_constraint, - final_tensor_2_constraint_no_padding, - final_tensor_2_constraint_padding_arg1, - final_tensor_2_constraint_padding_arg2, - final_tensor_3_constraint_no_padding, - final_tensor_3_constraint_padding_arg1, - final_tensor_3_constraint_padding_arg2, - final_tensor_4_constraint_no_padding, - final_tensor_4_constraint_padding_arg1, - final_tensor_4_constraint_padding_arg2 - ]) - - return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter - - -def transform_constraint(constraint: Constraint, counter: int): - """ - Transforms a constraint into a simpler constraint. - Ex: precision and consistency are transformed to equality - Args: - constraint: constraint to be transformed - counter: for variable tracking - - Returns: Constraint - - """ - if type(constraint) in _TRANSFORMATION_RULES: - return _TRANSFORMATION_RULES[type(constraint)](constraint, counter) - - else: - return constraint, counter - - - - -def calc_last_two_dims(constraint, d: List[DVar]): - """ - Generates constraints for the last two dimensions of a convolution or a maxpool output - Args: - constraint: CalcConv or CalcMaxPool - d: The list of output dimensions - - Returns: Constraints for calculating the last two dimensions of the output - - """ - - assert isinstance(constraint, CalcConv) or isinstance(constraint, CalcMaxPool) - - b3 = constraint.matching_constraint[2] - b4 = constraint.matching_constraint[3] - - b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) - b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) - - d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]) - d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]) - - # transform parameters into tuples incase they are not already - padding = (constraint.padding, constraint.padding) \ - if isinstance(constraint.padding, int) else constraint.padding - kernel = (constraint.kernel, constraint.kernel) \ - if isinstance(constraint.kernel, int) else constraint.kernel - stride = (constraint.stride, constraint.stride) \ - if isinstance(constraint.stride, int) else constraint.stride - dilation = (constraint.dilation, constraint.dilation) \ - if isinstance(constraint.dilation, int) else constraint.dilation - - f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) - f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) - f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div) - f4 = BinConstraintD(f3, 1, op_add) - - c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) - - f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) - f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) - f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div) - f44 = BinConstraintD(f33, 1, op_add) - - c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) - - return c4, c5 - - -def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): - """ - Generate all possibilities of being equal or not equal to dyn for my_list - Args: - my_list: List of tensor dimensions - - Returns: A list of a list of constraints. Each list of constraints corresponds to - one possibility about the values of the dimension variables - """ - # generate all possibilities of being equal or not equal to dyn for my_list - eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))] - neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))] - d_possibilities = [] - - for i in zip(eq_possibilities, neq_possibilities): - d_possibilities.append(list(i)) - all_possibilities = list(itertools.product(*d_possibilities)) - return all_possibilities - - -def is_target_div_by_dim(target: List[int], dim: List[DVar]): - """ - Generate constraints to check if the target dimensions are divisible by the input dimensions - Args: - target: Target dimensions - dim: Input dimensions - - Returns: Constraints to check divisibility - - """ - return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) - - -def is_dim_div_by_target(target: List[int], dim: List[DVar]): - """ - Generate constraints to check if the input dimensions is divisible by the target dimensions - Args: - target: Target dimensions - dim: Input dimensions - - Returns: Constraints to check divisibility - - """ - return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) - - -def gen_all_reshape_possibilities(list_of_dims, target): - """ - Consider all possibilities what the input dimensions could be (number or dynamic) - Then generate the appropriate constraints using multiplication or mod depending on the possibility - The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn - for the input. Target is fixed because at most one dimension could be dyn. - We have different cases for this. - - Args: - list_of_dims: The input list of dimensions - target: The tensor we want to reshape to - - Returns: A disjuncition of transformed reshape constraints - - """ - all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) - - all_constraints = [] - - for p in all_possibilities: - to_multiply = [] - - p = list(p) - - for constraint in p: - assert isinstance(constraint, BinConstraintD) - if constraint.op == op_neq: - to_multiply.append(constraint.lhs) - - if not to_multiply: - all_constraints.append(Conj(p)) - - elif len(to_multiply) < len(list_of_dims): - all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) - else: - all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims), - Prod(target), op_eq)])) - - return Disj(all_constraints) - - -def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): - """ - Apply broadcasting to the 'index' dimension of tensor_input1. - Args: - tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1 - tensor_input2: represents the second input - res1: broadcasted result 1 - res2: broadcasted result 2 - index: the index to broadcast - padding: If padding was used, then tensor_input1[index] does not exist - - Returns: - - """ - if tensor_input1[index] is None: - assert padding - - - if not padding: - # then the inputs are the same length so they all have dimensions at "index" - return Conj([BinConstraintD(tensor_input1[index], 1, op_eq), - BinConstraintD(res1[index], res2[index], op_eq), - BinConstraintD(res2[index], tensor_input2[index], op_eq)]) - - else: - # we don't set the input dimension to 1, since it doesn't exist. - return Conj([BinConstraintD(res1[index], res2[index], op_eq), - BinConstraintD(res2[index], tensor_input2[index], op_eq)]) - - -def apply_padding(e1_var: TVar, - e11: BinConstraintT, - e2: BinConstraintT, - e12: BinConstraintT, - d2: List[DVar], - d11: List[DVar], - d12: List[DVar], - counter: int): - """ - We are considering the possibility where one input has less dimensions than - another input, so we apply padding to the broadcasted results - - Args: - e1_var: Variable representing the first input where padding will be - e11: constraint of the form e11 = Tensortype[d1, ..., dn] - e2: constraint of the form e2 = Tensortype[d1, ..., dn] - e12: constraint of the form e11 = Tensortype[d1, ..., dn] - d2: Tensor variables for the second input - d11: Tensor variables for the broadcasted first input - d12: Tensor variables for the broadcasted second input - counter: variable tracking - - Returns: A new constraint whose goal is to apply padding to the broadcasted result - - """ - - res = [] - - # pad the shorter input with None so we can pass it to the broadcasting helper function - for i in range(1, len(d2)): - - d1, counter = gen_tensor_dims(i, counter) - - nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) - - e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) - - simulate_padding = [None] * (len(d2) - i) - - assert len(simulate_padding + d1) == len(d2) - - broadcast_padding = [] - - # for every padding size, we also consider broadcasting - for j in range((len(d2) - i)): - broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True)) - - # we consider the possibilities for broadcasting for every dimension. Since we already - # padded d1, we do not consider it while broadcasting - all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1, - d2[(len(d2) - i):], - d11[(len(d2) - i):], - d12[(len(d2) - i):]) - # combine all constraints into a conjunction - c = Conj([e1, e11, e2, e12, - *broadcast_padding, - all_broadcasting_possibilities, - *nat_constraints - ]) - res.append(c) - - return Disj(res), counter - - -def no_broadcast_dim_with_index(d1: List[DVar], - d2: List[DVar], - d3: List[DVar], - d4: List[DVar], - i: int): - """ - Args: - d1: inpput 1 - d2: inpput 2 - d3: simulated broadcasting for input 1 - d4: simulated broadcasting for input 2 - i: the rank of the resulting tensor addition - - Returns: Constraints for when no broadcasting occurs - """ - return Conj([ - Disj([ - Conj([BinConstraintD(d1[i], 1, op_eq), - BinConstraintD(d2[i], 1, op_eq)]), - - Conj([BinConstraintD(d1[i], 1, op_neq), - BinConstraintD(d2[i], 1, op_neq)])]), - - BinConstraintD(d1[i], d3[i], op_eq), - BinConstraintD(d2[i], d4[i], op_eq)]) - - - -def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): - """ - Generate lists of DVar to represent tensor dimensions - Args: - num_tensors: the required number of tensors - dim_size: the number of dimensions for each tensor - counter: variable tracking - - Returns: A list of a list of tensor dimensions - - """ - res = [] - - for _ in range(num_tensors): - dims, counter = gen_tensor_dims(dim_size, counter) - res.append(dims) - - return res, counter - - -def create_equality_constraints_for_broadcasting(e1: TVar, - e2: TVar, - e11: TVar, - e12: TVar, - d1: List[DVar], - d2: List[DVar], - d11: List[DVar], - d12: List[DVar]): - """ - Create equality constraints for when no broadcasting occurs - Args: - e1: Input 1 - e2: Input 2 - e11: Broadcasted input 1 - e12: Broadcasted input 2 - d1: Variables that store dimensions for e1 - d2: Variables that store dimensions for e2 - d11: Variables that store dimensions for e11 - d12: Variables that store dimensions for e22 - - Returns: Four equality constraints - - """ - - e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) - e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) - e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) - e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) - return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] - - -def gen_consistency_constraints(constraint: Constraint, counter: int): - """ - Args: - constraint: Consistency constraint on tensors - counter: for variable tracking - - Returns: Equality and consistency constraints on dimensions - - """ - - all_constraints = [] - - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims_rhs_1, counter = gen_tensor_dims(i, counter) - new_dims_rhs_2, counter = gen_tensor_dims(i, counter) - - nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) - - c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), - BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] + - [BinConstraintD(d1, d2, op_consistency) for - d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints) - - all_constraints.append(c_tensor_i) - - return all_constraints, counter - - -def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): - """ - Args: - constraint: Greatest upper bound on tensors - counter: variable tracking - - Returns: A set of equality constraints and DGreatestUpperBound constraints - - """ - - all_constraints = [] - - for i in range(1, MAX_TENSOR_RANK + 1): - c = [] - dims1, counter = gen_tensor_dims(i, counter) - c1tensor = TensorType(dims1) - - dims2, counter = gen_tensor_dims(i, counter) - c2tensor = TensorType(dims2) - - dims3, counter = gen_tensor_dims(i, counter) - c3tensor = TensorType(dims3) - - c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), - BinConstraintT(constraint.rhs2, c2tensor, op_eq), - BinConstraintT(constraint.res, c3tensor, op_eq)] + \ - gen_nat_constraints(dims1 + dims2 + dims3) - - assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) - for i in range(len(c3tensor.__args__)): - c.append(DGreatestUpperBound(c3tensor.__args__[i], - c1tensor.__args__[i], - c2tensor.__args__[i])) - - all_constraints.append(Conj(c)) - return all_constraints, counter - - -def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): - """ - Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. - We look at all combinations for all dimendions in d1 and d2 - Args: - d1: input1 dimensions - d2: input2 dimensions - d11: broadcasted input1 dimensions - d12: broadcasted input2 dimensions - - Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions - - """ - - size = len(d1) - - res2 = [] - - for i in range(size): - t1 = broadcast_dim(d1, d2, d11, d12, i) - t2 = broadcast_dim(d2, d1, d12, d11, i) - t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) - - res2.append(Disj([t1, t2, t3])) - - return Conj(res2) - - -def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int): - """ - Simulates broadcasting on e1 and e2 and returns the results - respectively in e11 and e12. Because of gradual types, - e1 and e2 may not be equal. Similarly, e11 and e12 may not - be equal. e11 and e12 should be guaranteed to be consistent - as they represent the shapes of the tensors to be added after - broadcasting. - Args: - e1: TVar representing the type of input 1 - e2: TVar representing the type of input 2 - e11: TVar representing the representing broadcasted input 1 - e12: TVar representing the representing broadcasted input 2 - i: The rank of the resulting type of addition - counter: for variable tracking - - Returns: Simplified broadcasting constraints - - """ - dims, counter = gen_lists_of_dims(4, i, counter) - [d1, d2, d3, d4] = dims - nat_dims_i = gen_nat_constraints(list(itertools.chain(*dims))) - - initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12, - d1, d2, d3, d4) - - [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints - - # without padding, broadcast all possibilities for tensors of size i - final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints, - generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)]) - - # with padding, broadcast all possibilities for tensors of size i - final_tensor_constraint_padding_arg1, counter = \ - apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter) - - final_tensor_constraint_padding_arg2, counter = \ - apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter) - - return final_tensor_constraint_no_padding, \ - final_tensor_constraint_padding_arg1, \ - final_tensor_constraint_padding_arg2, nat_dims_i, counter diff --git a/pippy/fx/experimental/migrate_gradual_types/operation.py b/pippy/fx/experimental/migrate_gradual_types/operation.py deleted file mode 100644 index ef7c670bf..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/operation.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# -*- coding: utf-8 -*- -op_add = '+' -op_sub = '-' -op_mul = '*' -op_div = '/' -op_eq = '=' -op_neq = '!=' -op_imp = '=>' -op_matching = '⊳' -op_consistency = '~' -op_precision = '⊑' -op_leq = '≤' -op_lt = '<' -op_gt = '>' -op_mod = '%' diff --git a/pippy/fx/experimental/migrate_gradual_types/transform_to_z3.py b/pippy/fx/experimental/migrate_gradual_types/transform_to_z3.py deleted file mode 100644 index dbf1b4d4c..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/transform_to_z3.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr -from pippy.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar -from pippy.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim -from pippy.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator -from pippy.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint -from pippy.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt -from pippy.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod -from pippy.fx.tensor_type import TensorType, Dyn - -try: - import z3 # type: ignore[import] - from pippy.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D - HAS_Z3 = True - - def transform_to_z3(constraint, counter, dimension_dict): - if isinstance(constraint, Conj): - conjuncts = [] - for c in constraint.conjucts: - new_c, counter = transform_to_z3(c, counter, dimension_dict) - conjuncts.append(new_c) - return z3.And(conjuncts), counter - - elif isinstance(constraint, Disj): - disjuncts = [] - for c in constraint.disjuncts: - new_c, counter = transform_to_z3(c, counter, dimension_dict) - disjuncts.append(new_c) - return z3.Or(disjuncts), counter - - elif isinstance(constraint, T): - return True, counter - - elif isinstance(constraint, F): - return False, counter - - elif isinstance(constraint, BinConstraintT): - if constraint.op == op_eq: - lhs, counter = transform_var(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_var(constraint.rhs, counter, dimension_dict) - return (lhs == rhs), counter - - else: - raise NotImplementedError('Method not yet implemented') - - elif isinstance(constraint, BinConstraintD): - if constraint.op == op_eq: - - if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): - transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict) - transformed_lhs = z3.Bool(constraint.lhs.c) - return transformed_lhs == transformed_rhs, counter - - elif is_dim(constraint.lhs) and is_dim(constraint.rhs): - # with dimension tranformations we consider the encoding - lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) - return lhs == rhs, counter - - else: - # then we have an algebraic expression which means that we disregard the - # first element of the encoding - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) - return lhs == rhs, counter - - # The assumption here is that the LHS and RHS must be dimensions - elif constraint.op == op_neq: - assert is_dim(constraint.lhs) - assert is_dim(constraint.rhs) - lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) - if constraint.rhs == Dyn or constraint.lhs == Dyn: - if constraint.rhs == Dyn: - return lhs.arg(0) == 1, counter - elif constraint.lhs == Dyn: - return rhs.arg(0) == 1, counter - - # if one of the instances is a number - elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): - if isinstance(constraint.lhs, int): - return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter - - elif isinstance(constraint.rhs, int): - return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter - - else: - return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), - z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), - z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter - - - elif constraint.op == op_leq: - # if the dimensions are not dyn, this will come into effect - # there would have been another constraint specifying if a given dimension - # is dyn or not - assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) - return lhs <= rhs, counter - - elif constraint.op == op_gt: - assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) - return lhs > rhs, counter - - elif constraint.op == op_lt: - assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) - return lhs < rhs, counter - - else: - raise NotImplementedError('operation not yet implemented') - - else: - raise NotImplementedError('Operation not yet implemented') - - - def transform_var(tensor, counter, dimension_dict): - """ - Transforms tensor variables to a format understood by z3 - Args: - tensor: Tensor variable or a tensor type potentially with variable dimensions - Returns: Transformed variable to a z3 format - - """ - if isinstance(tensor, TensorType): - res = [] - for t in tensor.__args__: - transformed, counter = transform_dimension(t, counter, dimension_dict) - res.append(transformed) - - assert len(res) <= 4 - if len(tensor.__args__) == 1: - return tensor_type.tensor1(res[0]), counter - elif len(tensor.__args__) == 2: - return tensor_type.tensor2(res[0], res[1]), counter - elif len(tensor.__args__) == 3: - return tensor_type.tensor3(res[0], res[1], res[2]), counter - elif len(tensor.__args__) == 4: - return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter - - elif tensor == Dyn: - return z3_dyn, counter - - elif isinstance(tensor, TVar): - return z3.Const(tensor.tvar, tensor_type), counter - - def transform_dimension(dimension, counter, dimension_dict): - """ - Takes a dimension variable or a number and transforms it to a tuple - according to our scheme - Args: - dimension: The dimension to be transformed - counter: variable tracking - - Returns: tuple and the current counter - - """ - if dimension == Dyn: - counter += 1 - return D(0, z3.Int(counter)), counter - elif isinstance(dimension, int): - return D(1, dimension), counter - elif isinstance(dimension, DVar): - if dimension.c in dimension_dict: - return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter - else: - counter += 1 - dimension_dict[dimension.c] = counter - return D(z3.Int(counter), z3.Int(dimension.c)), counter - - - def transform_algebraic_expression(expr, counter, dimension_dict): - """ - Transforms an algebraic expression to z3 format - Args: - expr: An expression is either a dimension variable or an algebraic-expression - - - Returns: the transformed expression - - """ - assert is_algebraic_expression(expr) or is_dim(expr) - - if is_dim(expr): - transformed, counter = transform_dimension(expr, counter, dimension_dict) - return transformed.arg(1), counter - - elif isinstance(expr, Prod): - - dims = [] - for dim in expr.products: - assert is_dim(dim) - d, counter = transform_dimension(dim, counter, dimension_dict) - dims.append(d.arg(1)) - return z3.Product(dims), counter - - elif is_algebraic_expression(expr): - - lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict) - - if expr.op == op_sub: - c = lhs - rhs - - elif expr.op == op_add: - c = lhs + rhs - - elif expr.op == op_div: - c = lhs / rhs - - elif expr.op == op_mul: - c = lhs * rhs - - elif expr.op == op_mod: - c = lhs % rhs - - else: - raise NotImplementedError('operation not yet implemented') - - return c, counter - - else: - raise RuntimeError - - - def transform_all_constraints(traced, counter=0): - """ - Given a trace, generates constraints and transforms them to z3 format - - """ - dimension_dict = {} # type: ignore[var-annotated] - - generator = ConstraintGenerator(traced) - new_constraints, counter = generator.generate_constraints(counter) - - # print(new_constraints.conjucts[0]) - # print(*new_constraints.conjucts, sep='\n') - - # transform precision, matching, consistency till obtaining a fixed point - new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) - # print(new_constraints) - # print(new_constraints.conjucts) - # new_constraints.conjucts = new_constraints.conjucts[:-1] - # print(*new_constraints.conjucts, sep='\n') - - transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) - # print(transformed) - return transformed - - def iterate_till_fixed_point(constraints, counter): - """ - Transform constraints till reaching a fixed point - """ - old_c = None - while old_c != constraints: - old_c = constraints - constraints, counter = transform_constraint(constraints, counter) - return constraints, counter - - def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): - """ - Takes a node and a graph and generates two sets of constraints. - One set constraints the node's constraints and another set - constraints the negation of the node's constraints - Args: - tracer_root: the root for getting the module instances - graph: the graph so far in the tracing process - node: node that represents a conditional - counter: variable tracking - - Returns: Two sets of constraints. One with a conjunction with the - the conditional constraint and the other with a conjunction with - its negation. - - """ - dimension_dict = {} # type: ignore[var-annotated] - - generator = ConstraintGenerator(tracer_root, graph) - new_constraints, counter = generator.generate_constraints(counter) - - condition_constraint = new_constraints.conjucts[-1] - - # we know the constraint is a conjunction where the last constraint is about the conditional - # so remove the last constraint - new_constraints.conjucts = new_constraints.conjucts[:-1] - - # transform precision, matching, consistency till obtaining a fixed point - new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) - - - # since the function returns a list of one element, we get the first element - # we are only interested in the RHS in this case because the LHS just stores - # the result - - # we make sure the constraint is of the form: - # c = b where b is a boolean expression - # and we consider b (constraint.rhs) for transformation - assert isinstance(condition_constraint.lhs, BVar) - assert is_bool_expr(condition_constraint.rhs) - condition_constraint_rhs = condition_constraint.rhs - - # transform the condition constraint - condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter) - - transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) - - transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict) - - negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) - - return z3.And([transformed, transformed_condition_constraint]),\ - z3.And([transformed, negation_transformed_condition_constraint]) - - - def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None): - """ - Given an IR and a node representing a conditional, evaluate the conditional - and its negation - Args: - tracer_root: Tracer root for module instances - node: The node to be evaluated - - Returns: the results of evaluating the condition and the negation with - the rest of the constraints - - """ - - transformed_positive, transformed_negative = \ - transform_all_constraints_trace_time(tracer_root, graph, node, counter) - - s = z3.Solver() - s.add(transformed_positive) - if user_constraints is not None: - s.add(user_constraints) - condition = s.check() - - s = z3.Solver() - s.add(transformed_negative) - if user_constraints is not None: - s.add(user_constraints) - negation = s.check() - return condition, negation - -except ImportError: - HAS_Z3 = False diff --git a/pippy/fx/experimental/migrate_gradual_types/util.py b/pippy/fx/experimental/migrate_gradual_types/util.py deleted file mode 100644 index 89ab32648..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/util.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ - BVar -from pippy.fx.experimental.migrate_gradual_types.operation import op_leq - - -def gen_tvar(curr): - """ - Generate a tensor variable - :param curr: The current counter - :return: a tensor variable and the updated counter - """ - curr += 1 - return TVar(curr), curr - - -def gen_dvar(curr): - """ - Generate a dimension variable - :param curr: the current counter - :return: a dimension variable and an updated counter - """ - curr += 1 - return DVar(curr), curr - -def gen_bvar(curr): - """ - Generate a boolean variable - :param curr: the current counter - :return: a boolean variable and an updated counter - """ - curr += 1 - return BVar(curr), curr - -def gen_tensor_dims(n, curr): - """ - Generate a list of tensor dimensions - :param n: the number of dimensions - :param curr: the current counter - :return: a list of dimension variables and an updated counter - """ - dims = [] - for _ in range(n): - dvar, curr = gen_dvar(curr) - dims.append(dvar) - return dims, curr - - -def gen_nat_constraints(list_of_dims): - """ - Generate natural number constraints for dimensions - """ - return [BinConstraintD(0, d, op_leq) for d in list_of_dims] diff --git a/pippy/fx/experimental/migrate_gradual_types/z3_types.py b/pippy/fx/experimental/migrate_gradual_types/z3_types.py deleted file mode 100644 index 851e4bc89..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/z3_types.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -try: - import z3 # type: ignore[import] - HAS_Z3 = True - # dynamic type - dyn = z3.DeclareSort('Dyn') - dyn_type = z3.Const('dyn', dyn) - - # dimension - dim = z3.Datatype('dim') - dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort())) - dim = dim.create() - - # tensors - tensor_type = z3.Datatype('TensorType') - tensor_type.declare('Dyn', ('dyn', dyn)) - tensor_type.declare('tensor1', ('0', dim)) - tensor_type.declare('tensor2', ('0', dim), ('1', dim)) - tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim)) - tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim)) - tensor_type = tensor_type.create() - - # create dimension - D = dim.dim - - z3_dyn = tensor_type.Dyn(dyn_type) - - -except ImportError: - HAS_Z3 = False diff --git a/pippy/fx/experimental/normalize.py b/pippy/fx/experimental/normalize.py deleted file mode 100644 index c92dbf973..000000000 --- a/pippy/fx/experimental/normalize.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import operator -from typing import Any, Callable, Dict, Tuple, Optional - -import torch -import pippy.fx -import pippy.fx as fx -from pippy.fx import Transformer, Proxy -from pippy.fx.node import Argument, Target, Node, map_aggregate -from pippy.fx.operator_schemas import ( - normalize_module, - normalize_function, - create_type_hint, -) - -from .schema_type_annotation import AnnotateTypesWithSchema - - -class NormalizeArgs(Transformer): - """ - Normalize arguments to Python targets. This means that - `args/kwargs` will be matched up to the module/functional's - signature and rewritten to exclusively kwargs in positional order - if `normalize_to_only_use_kwargs` is true. Also populates default - values. Does not support positional-only parameters or varargs - parameters (*args, **kwargs). - - If the nodes have 'type' metadata, it will use it to disambiguate - overloads. Otherwise, it will throw an error. - - Example usage: - m = torchvision.models.resnet18() - traced = pippy.fx.symbolic_trace(m) - traced = NormalizeArgs(traced).transform() - """ - - def __init__( - self, module: pippy.fx.GraphModule, normalize_to_only_use_kwargs: bool = True - ): - super().__init__(module) - self.node_map: Dict[Proxy, Node] = {} - self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs - - def run_node(self, n: Node) -> Any: - args, kwargs = self.fetch_args_kwargs_from_env(n) - - def get_type(arg): - if isinstance(arg, fx.Node): - return n.meta["type"] if "type" in n.meta else None - return type(arg) - - arg_types = map_aggregate(n.args, get_type) - assert isinstance(arg_types, tuple) - arg_types = tuple([create_type_hint(i) for i in arg_types]) - kwarg_types = {k: get_type(v) for k, v in kwargs.items()} - if n.op == "call_function": - out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) - else: - out = super().run_node(n) - if n.op != "output": - self.node_map[out] = n - out.node.meta = n.meta - out.node.type = n.type - return out - - def call_function( - self, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Any], - arg_types: Optional[Tuple[Any, ...]] = None, - kwarg_types: Optional[Dict[str, Any]] = None, - ): - assert callable(target) - new_args_and_kwargs = normalize_function( - target, - args, # type: ignore[arg-type] - kwargs, - arg_types, # type: ignore[arg-type] - kwarg_types, - self.normalize_to_only_use_kwargs, - ) - if new_args_and_kwargs: - new_args, new_kwargs = new_args_and_kwargs - return self.tracer.create_proxy( - "call_function", target, new_args, new_kwargs - ) - else: - return super().call_function(target, args, kwargs) - - def call_module( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] - ): - assert isinstance(target, str) - new_args_and_kwargs = normalize_module( - self.module, - target, - args, # type: ignore[arg-type] - kwargs, - self.normalize_to_only_use_kwargs, - ) - if new_args_and_kwargs: - new_args, new_kwargs = new_args_and_kwargs - return super().call_module(target, new_args, new_kwargs) - else: - return super().call_module(target, args, kwargs) - - -class NormalizeOperators(AnnotateTypesWithSchema): - """ - Normalize callsites that are different ways of "spelling" the same - invocation into a single, canonical call. Currently supports: - - 1. Normalize operators (e.g. operator.add) to the `torch` ops they - ultimately invoke (e.g. torch.add) when it is possible to statically - reason that - - Example usage: - - m = torchvision.models.resnet18() - - traced = pippy.fx.symbolic_trace(m) - - traced = NormalizeOperators(traced).transform() - """ - - binary_magic_method_remap: Dict[ - Callable[[Any, Any], Any], Callable[[Any, Any], Any] - ] = { - torch.add: operator.add, - torch.mul: operator.mul, - torch.sub: operator.sub, - torch.div: operator.truediv, - torch.floor_divide: operator.floordiv, - torch.remainder: operator.mod, - torch.eq: operator.eq, - torch.ne: operator.ne, - torch.lt: operator.lt, - torch.le: operator.le, - torch.gt: operator.gt, - torch.ge: operator.ge, - } - - def call_function( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] - ): - # Normalize operators according to the magic methods implemented on tensors here: - # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 - - assert callable(target) - - if target in self.binary_magic_method_remap: - if len(args) != 2: - return super().call_function(target, args, kwargs) - lhs, rhs = args - - return super().call_function( - target=self.binary_magic_method_remap[target], - args=(lhs, rhs), - kwargs={}, - ) - - return super().call_function(target, args, kwargs) diff --git a/pippy/fx/experimental/optimization.py b/pippy/fx/experimental/optimization.py deleted file mode 100644 index 2f9eb07d8..000000000 --- a/pippy/fx/experimental/optimization.py +++ /dev/null @@ -1,406 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import pippy.fx as fx -from pippy.fx.node import Argument, Target -from torch.nn.utils.fusion import fuse_conv_bn_eval -from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast -import torch -import torch.nn as nn -import torch.nn.functional as F -from pippy.fx.passes.shape_prop import ShapeProp -import copy -from collections import defaultdict -import torch.utils.mkldnn as th_mkldnn -import operator -import time -import logging -from enum import Enum - -def _parent_name(target : str) -> Tuple[str, str]: - """ - Splits a qualname into parent path and last atom. - For example, `foo.bar.baz` -> (`foo.bar`, `baz`) - """ - *parent, name = target.rsplit('.', 1) - return parent[0] if parent else '', name - -# Works for length 2 patterns with 2 modules -def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): - if len(node.args) == 0: - return False - nodes: Tuple[Any, fx.Node] = (node.args[0], node) - for expected_type, current_node in zip(pattern, nodes): - if not isinstance(current_node, fx.Node): - return False - if current_node.op != 'call_module': - return False - if not isinstance(current_node.target, str): - return False - if current_node.target not in modules: - return False - if type(modules[current_node.target]) is not expected_type: - return False - return True - - -def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): - assert(isinstance(node.target, str)) - parent_name, name = _parent_name(node.target) - modules[node.target] = new_module - setattr(modules[parent_name], name, new_module) - -def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module: - """ - Fuses convolution/BN layers for inference purposes. Will deepcopy your - model by default, but can modify the model inplace as well. - """ - patterns = [(nn.Conv1d, nn.BatchNorm1d), - (nn.Conv2d, nn.BatchNorm2d), - (nn.Conv3d, nn.BatchNorm3d)] - if not inplace: - model = copy.deepcopy(model) - fx_model = fx.symbolic_trace(model) - modules = dict(fx_model.named_modules()) - new_graph = copy.deepcopy(fx_model.graph) - - for pattern in patterns: - for node in new_graph.nodes: - if matches_module_pattern(pattern, node, modules): - if len(node.args[0].users) > 1: # Output of conv is used by other nodes - continue - conv = modules[node.args[0].target] - bn = modules[node.target] - if not bn.track_running_stats: - continue - fused_conv = fuse_conv_bn_eval(conv, bn) - replace_node_module(node.args[0], modules, fused_conv) - node.replace_all_uses_with(node.args[0]) - new_graph.erase_node(node) - return fx.GraphModule(fx_model, new_graph) - -def remove_dropout(model: nn.Module) -> nn.Module: - """ - Removes all dropout layers from the module. - """ - fx_model = fx.symbolic_trace(model) - - class DropoutRemover(fx.Transformer): - def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - if isinstance(self.submodules[target], nn.Dropout): - assert len(args) == 1 - return args[0] - else: - return super().call_module(target, args, kwargs) - return DropoutRemover(fx_model).transform() - -def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): - """ - Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. - """ - new_graph = fx.Graph() - env: Dict[fx.Node, fx.Node] = {} - for input in inputs: - new_node = new_graph.placeholder(input.name) - env[input] = new_node - for node in nodes: - new_node = new_graph.node_copy(node, lambda x: env[x]) - env[node] = new_node - new_graph.output([env[output] for output in outputs]) - new_graph.lint() - return fx.GraphModule(orig_module, new_graph) - -mkldnn_supported = [ - nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, - torch.relu, torch.transpose, torch.sigmoid, - F.relu, F.avg_pool2d, F.adaptive_avg_pool2d -] -# These are operators that may not be convertible into MKLDNN ops (e.g. the -# args are scalar values). Thus, we only include them in the subgraph if their -# arguments are already in MKLDNN. -# TODO: Determine whether this can be removed after type inference. -mkldnn_supported_unknown = [operator.add, operator.mul] -mkldnn_map = { - nn.Conv2d: th_mkldnn.MkldnnConv2d, - nn.Linear: th_mkldnn.MkldnnLinear, - nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a) -} - - -def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): - """ - For each node, if it's a module that can be preconverted into MKLDNN, - then we do so and create a mapping to allow us to convert from the MKLDNN - version of the module to the original. - """ - old_modules: Dict[nn.Module, nn.Module] = {} - for node in nodes: - if node.op == 'call_module': - assert(isinstance(node.target, str)) - cur_module = modules[node.target] - if type(cur_module) in mkldnn_map: - new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) - assert(isinstance(new_module, nn.Module)) - old_modules[new_module] = copy.deepcopy(cur_module) - replace_node_module(node, modules, new_module) - return old_modules - -def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]): - """ - Maps each module that's been changed with `modules_to_mkldnn` back to its - original. - """ - for node in nodes: - if node.op == 'call_module': - assert(isinstance(node.target, str)) - cur_module = modules[node.target] - if cur_module in old_modules: - replace_node_module(node, modules, old_modules[cur_module]) - -class MklSubgraph: - def __init__(self, fx_graph: fx.Graph): - self.fx_graph = fx_graph - self.nodes: List[fx.Node] = [] - self.start_nodes: List[fx.Node] = [] - self.end_nodes: List[fx.Node] = [] - -def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): - """ - This generates a heuristic that can be passed into `optimize_for_inference` that - determines whether a subgraph should be run in MKL by running it with the example_inputs. - - Example usage: - heuristic = gen_mkl_autotuner(example_inputs, iters=10) - fast_model = optimization.optimize_for_inference(model, heuristic) - """ - fx_model = None - old_modules = None - - def use_mkl_heuristic(graph: MklSubgraph) -> bool: - nonlocal fx_model, old_modules - input_nodes = graph.start_nodes - if fx_model is None: - fx_model = graph.fx_graph.owning_module - old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] - ShapeProp(fx_model).propagate(example_inputs) - sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] - output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) - submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) - - def benchmark(f): - for _ in range(warmup): - f() - begin = time.time() - for _ in range(iters): - out = f() - return time.time() - begin - - mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])]) - - reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) - no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) - return mkl_time < no_mkl_time - return use_mkl_heuristic - -def use_mkl_length(graph: MklSubgraph) -> bool: - """ - This is a heuristic that can be passed into `optimize_for_inference` that - determines whether a subgraph should be run in MKL by checking if there - are more than 2 nodes in it - """ - return len(graph.nodes) > 2 - -class UnionFind: - def __init__(self, n): - self.parent: List[Optional[int]] = [None] * n - self.size: List[int] = [0] * n - - def make_set(self, v: int): - self.parent[v] = v - self.size[v] = 1 - - def find(self, v: int) -> int: - par = self.parent[v] - if v == par: - return v - assert(par is not None) - self.parent[v] = self.find(par) - return cast(int, self.parent[v]) - - def join(self, a: int, b: int): - a, b = self.find(a), self.find(b) - if a == b: - return a - if self.size[a] < self.size[b]: - a, b = b, a - self.parent[b] = a - self.size[a] += self.size[b] - -def optimize_for_inference( - model: torch.nn.Module, - pass_config: Optional[Dict[str, Any]] = None, - tracer: Type[fx.Tracer] = fx.Tracer -) -> torch.nn.Module: - """ - Performs a set of optimization passes to optimize a model for the - purposes of inference. Specifically, the passes that are run are: - 1. Conv/BN fusion - 2. Dropout removal - 3. MKL layout optimizations - - The third optimization takes a function `use_mkl_heuristic` that's used - to determine whether a subgraph should be explicity run in MKL layout. - - Note: As FX does not currently handle aliasing, this pass currently - assumes nothing aliases. If that isn't true, use at your own risk. - """ - default_pass_config = { - "conv_bn_fuse": True, - "remove_dropout": True, - "mkldnn_layout_optimize": {'heuristic': use_mkl_length}, - } - if pass_config is None: - pass_config = {} - default_pass_config.update(pass_config) - - if default_pass_config["conv_bn_fuse"]: - model = fuse(model) - if default_pass_config["remove_dropout"]: - model = remove_dropout(model) - if default_pass_config["mkldnn_layout_optimize"] is False: - return model - if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): - raise RuntimeError("mkldnn_layout_optimize config is not a dict") - if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: - raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") - use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] - - cur_tracer = tracer() - fx_graph = cur_tracer.trace(copy.deepcopy(model)) - fx_model = fx.GraphModule(cur_tracer.root, fx_graph) - modules: Dict[str, nn.Module] = dict(model.named_modules()) - - class MklSupport(Enum): - NO = 1 - YES = 2 - UNKNOWN = 3 - - # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. - # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. - # However, if it's in `mkldnn_supported_unknown`, then we only treat it as - # a MKLDNN node if its inputs are MKLDNN nodes. - for node in list(fx_graph.nodes): - supports_mkldnn = MklSupport.NO - if node.op == 'call_module': - cur_module = modules[node.target] - if type(cur_module) in mkldnn_supported: - supports_mkldnn = MklSupport.YES - sample_parameter = next(cur_module.parameters(), None) - if sample_parameter is not None: - assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules" - assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules" - elif node.op == 'call_function': - if node.target in mkldnn_supported: - supports_mkldnn = MklSupport.YES - elif node.target in mkldnn_supported_unknown: - supports_mkldnn = MklSupport.UNKNOWN - - if supports_mkldnn != MklSupport.NO: - if supports_mkldnn == MklSupport.UNKNOWN: - if not any([arg.target == 'to_dense' for arg in node.args]): - continue - with fx_graph.inserting_before(node): - mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) - - node.args = cast(Tuple[fx.node.Argument], mkldnn_args) - - with fx_graph.inserting_after(node): - dense_x = fx_graph.create_node('call_method', 'to_dense', (node,)) - node.replace_all_uses_with(dense_x) - dense_x.args = (node,) - - # Does pre-conversion of all modules into MKLDNN (when possible) - old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) - fx_graph.old_modules = old_modules # type: ignore[attr-defined] - - # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b - for node in fx_graph.nodes: - if node.op == 'call_method' and node.target == 'to_dense': - prv_node = node.args[0] - users = list(node.users) - for user in users: - if user.op == 'call_method' and user.target == 'to_mkldnn': - user.replace_all_uses_with(prv_node) - fx_graph.erase_node(user) - if len(node.users) == 0: - fx_graph.erase_node(node) - - - num_nodes = len(fx_graph.nodes) - uf = UnionFind(num_nodes) - - def get_color(n): - if hasattr(n, 'color'): # Current node is part of a MKL subgraph - return uf.find(n.color) - if hasattr(n, 'start_color'): # Current node is input to MKL subgraph - return uf.find(n.start_color) - return None - - - # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists - # of input nodes (which are only `to_mkldnn` calls), output nodes - # (`to_dense` calls), and intermediate nodes, which are run entirely on - # MKLDNN layout tensors. - # - # Specifically, this code does a flood fill on a directed acyclic graph - # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). - # If every node only had one input, this would be sufficient. However, in - # the case that a node has multiple inputs coming from different start - # nodes (i.e. colors), we need to join these 2 colors into 1. That's done - # using a Disjoint Set Union. - for cur_idx, node in enumerate(fx_graph.nodes): - if node.op == 'call_method' and node.target == 'to_mkldnn': - node.start_color = cur_idx - uf.make_set(cur_idx) - elif node.op == 'call_method' and node.target == 'to_dense': - assert(get_color(node.args[0]) is not None) - node.end_color = get_color(node.args[0]) - else: - cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] - - if len(cur_colors) == 0: - continue - assert(not any(i is None for i in cur_colors)) - cur_colors = sorted(cur_colors) - node.color = cur_colors[0] - for other_color in cur_colors[1:]: - uf.join(cur_colors[0], other_color) - - - mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) - for node in fx_graph.nodes: - if hasattr(node, 'color'): - mkldnn_graphs[uf.find(node.color)].nodes.append(node) - if hasattr(node, 'start_color'): - mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) - if hasattr(node, 'end_color'): - mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) - - - # Now that we have all the subgraphs, we need to decide which MKLDNN - # subgraphs we actually want to keep in MKLDNN. - for graph in mkldnn_graphs.values(): - if not use_mkl_heuristic(graph): - for node in graph.start_nodes + graph.end_nodes: - prv = node.args[0] - node.replace_all_uses_with(prv) - fx_graph.erase_node(node) - reset_modules(graph.nodes, modules, old_modules) - - mkldnn_conversions = 0 - for node in fx_graph.nodes: - if node.target == 'to_mkldnn' or node.target == 'to_dense': - mkldnn_conversions += 1 - - logging.getLogger(__name__).info(f"mkldnn conversions: {mkldnn_conversions}") - fx_graph.lint() - result = fx.GraphModule(model, fx_graph) - return result diff --git a/pippy/fx/experimental/partitioner_utils.py b/pippy/fx/experimental/partitioner_utils.py deleted file mode 100644 index 334ef5d94..000000000 --- a/pippy/fx/experimental/partitioner_utils.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from enum import Enum -from typing import NamedTuple, Dict, List, Set - -from pippy.fx.node import Node, map_arg - - -class Partition: - """Partition class contains all the information about an individual partition. - It also provides necessary methods for manipulation the partition. - """ - - def __init__(self, partition_id: int) -> None: - self.nodes: Set[Node] = set() - self.partition_id = partition_id - self.parents: Set["Partition"] = set() - self.children: Set["Partition"] = set() - self.bfs_level: int = -1 - self.used_mem_bytes: int = 0 - self.logical_device_ids: List[int] = [] - - def __str__(self): - return str(self.partition_id) - - def recalculate_mem_size(self): - self.used_mem_bytes = 0 - for node in self.nodes: - self.used_mem_bytes += get_extra_size_of(node, self.nodes) - - def add_node(self, node): - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Add current node's input nodes if they are placeholder or constants - for n in input_nodes: - if n.op in {"placeholder", "get_attr"}: - self.nodes.add(n) - self.nodes.add(node) - self.recalculate_mem_size() - - def remove_node(self, node): - # Remove a node only if the node is in the partition - if node in self.nodes: - self.nodes.remove(node) - # Collect the node's input nodes - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Check if an input node is a placeholder or get_attr, - # and this input node is not used by some other nodes in this partition, - # the remove this input node - for input_node in input_nodes: - if all( - [n not in self.nodes for n in input_node.users] - ) and input_node.op in {"placeholder", "get_attr"}: - self.nodes.remove(input_node) - self.recalculate_mem_size() - - -class Device(NamedTuple): - name: str - available_mem_bytes: int - logical_id: int - - -class NodeLatency(NamedTuple): - # Latency due to the memory bandwidth - mem_latency_sec: float - # Latency due to the computation - computer_latency_sec: float - - -class PartitionLatency(NamedTuple): - # Sum of all nodes' memory latency on the critical path - mem_latency_sec: float - # Sum of all nodes' compute latency on the critical path - computer_latency_sec: float - # Latency of the critical path - overall_latency_sec: float - - -class PartitionMode(Enum): - size_based = 0 - sparse_nn = 1 - cost_aware = 2 - kl_based = 3 - aot_based = 4 - - -class PartitionerConfig(NamedTuple): - devices: List[Device] - mode: PartitionMode = PartitionMode.size_based - transfer_rate_bytes_per_sec: float = 0.0 - node_to_latency_mapping: Dict[Node, NodeLatency] = {} - node_to_partition_mapping: Dict[Node, int] = {} - partition_to_logical_device_mapping: Dict[int, List[int]] = {} - # Saturate host by replicating partitions to the remaining idle devices. - saturate_host: bool = False - - -def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: - """Given a node and a set of nodes, - this function return the extra size that needed - if this node is included in this set. - """ - # Find all its input nodes - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Calculate total size of related nodes - total_size_of_input_nodes = 0 - for n in input_nodes: - # Make sure this node hasn't been in this set yet - if n not in nodes: - size_bytes = getattr(n, "size_bytes", None) - if size_bytes: - total_size_of_input_nodes += size_bytes.output_size - else: - raise RuntimeError("node has no size_bytes attr") - # Don't forget the op node itself - size_bytes = getattr(node, "size_bytes", None) - if size_bytes: - total_size_of_input_nodes += size_bytes.total_size - else: - raise RuntimeError("node has no size_bytes attr") - return total_size_of_input_nodes - - -def get_latency_of_one_partition( - partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] -) -> PartitionLatency: - """Given a partiton and its nodes' latency, return a PartitionLatency for this partition""" - - def get_top_nodes(partition: Partition) -> List[Node]: - """Given a partition, return a list of nodes on the top bfs level""" - top_nodes: List[Node] = [] - for node in partition.nodes: - # Skip placeholder and get_attr nodes - if node.op in {"placeholder", "get_attr"}: - continue - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # If a node has no input nodes in this partition, - # or its input nodes in this partition are placeholders and get_attrs - # this node is on the top bfs level in this partition - if not any( - [ - n in partition.nodes and n.op not in {"placeholder", "get_attr"} - for n in input_nodes - ] - ): - top_nodes.append(node) - return top_nodes - - def dfs_helper(node: Node, partition_latency) -> PartitionLatency: - """Given a top node of a partition, this function returns - the latency of the critical path in the partition - """ - node_latency = node_to_latency_mapping[node] - # Calculate the current overall latency of the partition - overall_latency_sec = partition_latency.overall_latency_sec + max( - node_latency.computer_latency_sec, node_latency.mem_latency_sec - ) - # Update the mem latency of this path - mem_latency_sec = ( - partition_latency.mem_latency_sec + node_latency.mem_latency_sec - ) - # Update the compute latency of this path - computer_latency_sec = ( - partition_latency.computer_latency_sec + node_latency.computer_latency_sec - ) - # Get all users of this node that are in this partition - users = set(node.users).intersection(partition.nodes) - if users: - max_latency = PartitionLatency( - mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 - ) - for n in users: - # Get new partition latency recursively - new_partition_latency = dfs_helper( - n, - PartitionLatency( - mem_latency_sec, computer_latency_sec, overall_latency_sec - ), - ) - if ( - new_partition_latency.overall_latency_sec - > max_latency.overall_latency_sec - ): - max_latency = new_partition_latency - return max_latency - # If there is no user, the node is at bottom of the partition - return PartitionLatency( - mem_latency_sec, computer_latency_sec, overall_latency_sec - ) - - # Main part starts - # Get all top level nodes of this partition - top_nodes = get_top_nodes(partition) - critical_path_latency = PartitionLatency( - mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 - ) - # Go through all top nodes and find the largest latency (critical pass latency) - for node in top_nodes: - partition_latency = dfs_helper( - node, - PartitionLatency( - mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 - ), - ) - if ( - partition_latency.overall_latency_sec - > critical_path_latency.overall_latency_sec - ): - critical_path_latency = partition_latency - return critical_path_latency - - -def get_partition_to_latency_mapping( - partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] -) -> Dict[Partition, PartitionLatency]: - """Given all the partitions and node_to_latency_mapping dictionary, - return a mapping dictionary of each partition to its overall latency - """ - partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} - # Go through each partition and get its latency - for partition in partitions: - partition_latency = get_latency_of_one_partition( - partition, node_to_latency_mapping - ) - partition_to_latency_mapping[partition] = partition_latency - return partition_to_latency_mapping - - -def get_comm_latency_between( - parent_partition: Partition, - child_partition: Partition, - transfer_rate_bytes_per_sec: float, -): - """Given two partitions (parent and child), - calculate the communication latency between the two. - """ - # If two partitions are on the same device, the comm latency is 0. - if ( - parent_partition.logical_device_ids != [] - and child_partition.logical_device_ids != [] - and parent_partition.logical_device_ids == child_partition.logical_device_ids - ): - return 0.0 - # Keep tracking the communication size between parent and child - comm_size = 0 - # Keep tracking all the counted node - visited_nodes = set() - # Go through all nodes in the child partition - # If a node has input nodes from the parent partition, - # the output size of those input nodes will be counted - # and added to comm_size - for node in child_partition.nodes: - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - for n in input_nodes: - if n in parent_partition.nodes and n not in visited_nodes: - size_bytes = getattr(n, "size_bytes", None) - if size_bytes is not None: - comm_size += size_bytes.output_size - visited_nodes.add(n) - return comm_size / transfer_rate_bytes_per_sec - - -def get_latency_of_partitioned_graph( - partitions: List[Partition], - partition_to_latency_mapping: Dict[Partition, PartitionLatency], - transfer_rate_bytes_per_sec: float, -): - """Given all paritions in a graph, find the critical path among all partitions - and return its latency as the latency of the whole graph - """ - - def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: - """This function helps to recursively get the latency of a path of partitions""" - # Update latency by adding current partition's latency - latency_so_far_sec += partition_to_latency_mapping[ - partition - ].overall_latency_sec - children = partition.children - if partition.children: - max_latency_sec = 0.0 - for child in partition.children: - # Calculate latency between - comm_latency_sec = get_comm_latency_between( - partition, child, transfer_rate_bytes_per_sec - ) - new_latency_sec = dfs_helper( - child, latency_so_far_sec + comm_latency_sec - ) - if new_latency_sec > max_latency_sec: - max_latency_sec = new_latency_sec - return max_latency_sec - return latency_so_far_sec - - def get_top_partitions(partitions: List[Partition]) -> List[Partition]: - """This function is to return all the partitions without parents - as the starting points of all the paths - """ - top_partitions = [] - for partition in partitions: - # If a partition has no parents, then it is a top partition - if len(partition.parents) == 0: - top_partitions.append(partition) - return top_partitions - - top_partitions = get_top_partitions(partitions) - critical_path_latency_sec = 0.0 - for partition in top_partitions: - latency_sec = dfs_helper(partition, 0.0) - if latency_sec > critical_path_latency_sec: - critical_path_latency_sec = latency_sec - return critical_path_latency_sec diff --git a/pippy/fx/experimental/proxy_tensor.py b/pippy/fx/experimental/proxy_tensor.py deleted file mode 100644 index f223d1290..000000000 --- a/pippy/fx/experimental/proxy_tensor.py +++ /dev/null @@ -1,683 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import contextlib -import functools -import inspect -import operator -import weakref -from contextlib import contextmanager, nullcontext -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.utils._pytree as pytree -from torch._dispatch.python import enable_python_dispatcher -from torch._subclasses import FakeTensor -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode - -import pippy -import pippy.fx as fx -from pippy.fx import Proxy -from pippy.fx import Tracer, GraphModule -from pippy.fx.passes.shape_prop import _extract_tensor_metadata -from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat - -__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy"] -aten = torch.ops.aten -prim = torch.ops.prim - -CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {} - -CONSTANT_NUMEL_LIMIT = 1 - - -def fake_signature(fn, nargs): - """FX gets confused by varargs, de-confuse it""" - argnames = ",".join(f"arg{i}" for i in range(nargs)) - return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) - -@contextmanager -def decompose(decomposition_table): - global CURRENT_DECOMPOSITION_TABLE - old_decomposition_table = CURRENT_DECOMPOSITION_TABLE - CURRENT_DECOMPOSITION_TABLE = decomposition_table - try: - yield CURRENT_DECOMPOSITION_TABLE - finally: - CURRENT_DECOMPOSITION_TABLE = old_decomposition_table - -# ensure we cannot collide with other properties -proxy_slot = object() -no_default = object() - -def set_proxy_slot(obj, tracer, proxy): - d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) - assert isinstance(d, weakref.WeakKeyDictionary) - d[tracer] = proxy - -def has_proxy_slot(obj, tracer): - return get_proxy_slot(obj, tracer, False, lambda _: True) - -# the default argument is what to return if the slot is not set. -# the transform argument is handy if you need to extract a subfield from -# the successfully looked up result (but NOT the default.) -def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): - d = obj.__dict__.get(proxy_slot) - if not d: - if default is no_default: - raise KeyError(f"{obj} is not tracked with proxy for {tracer}") - return default - assert isinstance(d, weakref.WeakKeyDictionary) - if tracer not in d: - if default is no_default: - raise KeyError(f"{obj} is not tracked with proxy for {tracer}") - else: - return default - return transform(d[tracer]) - - -def get_proxy_slots(obj): - return obj.__dict__.get(proxy_slot) - - -# Gets the proxy for a tensor, if it exists. -def get_proxy(obj): - res = get_proxy_slots(obj) - if res is None: - return None - vals = tuple(res.values()) - assert len(vals) == 1 - return vals[0] - -def has_proxy(obj): - return get_proxy(obj) is not None - -def set_meta(proxy, val): - if isinstance(val, FakeTensor): - proxy.node.meta['val'] = val - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) - elif isinstance(val, PySymInt): - proxy.node.meta['val'] = val - elif isinstance(val, torch.Tensor): - if not val.is_sparse: - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) - return proxy - -def thunkify(f, *args, **kwargs): - """ - Delays computation of f until it's called again - Also caches the result - """ - return functools.lru_cache(1)(functools.partial(f, *args, **kwargs)) - -def track_tensor(tensor, proxy, *, constant, tracer): - def try_set_proxy_slot(outer_s, proxy_callable, *args): - assert callable(proxy_callable) - if isinstance(outer_s, SymInt): - inner_s = outer_s.get_pyobj() - assert isinstance(inner_s, PySymInt) - - set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, inner_s, *args)) - - # The basic idea is that we need to associate each tensor/SymInt - # with a Proxy. How do we setup this association? We just store - # the proxy on the proxy slot of the object, keyed on the tracer - # (so that if we have multiple tracers at the same time, they - # don't clobber each other.) - for i, s in enumerate(tensor.shape): - try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size(proxy, i), x), i) - - for i, s in enumerate(tensor.stride()): - try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride(proxy, i), x), i) - - try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel(proxy), x)) - try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset(proxy), x)) - set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) - -def track_tensor_tree(inner_res, proxy_res, *, constant, tracer): - def wrap_with_proxy(e, proxy, constant): - if isinstance(e, torch.Tensor): - track_tensor(e, proxy, tracer=tracer, constant=constant) - set_meta(proxy, e) - elif isinstance(e, list): - # example use case: allreduce_ returns ([tensor], work) - for idx, ee in enumerate(e): - wrap_with_proxy(ee, proxy[idx], get_constant(idx)) - - def get_constant(idx): - if constant is None: - return None - else: - return constant[idx] - - # Unfortunately, tree_map cannot directly be used here. As the resulting - # object may be a proxy that represents a tuple, we may need to - # explicitly unwrap the proxy by simulating the flattening operations. - if isinstance(inner_res, tuple) or isinstance(inner_res, list): - for idx, e in enumerate(inner_res): - wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) - elif isinstance(inner_res, torch.Tensor): - wrap_with_proxy(inner_res, proxy_res, constant) - - return inner_res - - -def maybe_disable_fake_tensor_mode(): - # TODO: figure out if this API generally makes sense and bake it into the - # library - mb_fake_mode = _get_current_dispatch_mode() - if isinstance(mb_fake_mode, FakeTensorMode): - return _pop_mode_temporarily() - else: - return nullcontext() - - -@dataclass -class _ProxyTensor: - proxy: Proxy - constant: Optional[torch.Tensor] - - -def fetch_sym_proxy(tracer): - def inner(e): - n = e.get_pyobj() - if n.constant is not None: - return n.constant - else: - # NB: we REQUIRE all symints to be tracked - return get_proxy_slot(n, tracer)() - return inner - - -def fetch_tensor_proxy(tracer): - return lambda t: get_proxy_slot(t, tracer, t) - -HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter) - -def proxy_call(proxy_mode, func, args, kwargs): - def can_handle_tensor(x): - return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) - - # If there are any tensor subclasses, we need to handle those tensor subclasses first - # TODO: we could use types to test this - if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)): - return NotImplemented - - if func in CURRENT_DECOMPOSITION_TABLE: - with proxy_mode: - r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs) - if r is not NotImplemented: - return r - - with proxy_mode: - r = func.decompose(*args, **kwargs) - if r is not NotImplemented: - return r - - tracer = proxy_mode.tracer - f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs)) - - # If there are SymInts, we also should not consider this constant. - # However, fake tensor handling of SymInts is sufficiently broken that - # I couldn't write a test for this case - all_constant = ( - pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) - # TODO: maybe constant SymInts should also be allowed? Not sure if - # this can happen - and pytree.tree_all_only((SymInt, SymFloat), lambda _: False, (args, kwargs)) - ) - - if torch.Tag.data_dependent_output in func.tags: # type: ignore[attr-defined] - # Check if all of the Tensor inputs are constants - if all_constant: - const_args, const_kwargs = pytree.tree_map_only( - _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs) - ) - with maybe_disable_fake_tensor_mode(): - return func(*const_args, **const_kwargs) - raise RuntimeError( - "It appears that you're trying to get value out of a tracing tensor - erroring out! " - "It's likely that this is caused by data-dependent control flow or similar." - ) - proxy_args, proxy_kwargs = pytree.tree_map_only( - (SymInt, SymFloat), - fetch_sym_proxy(proxy_mode.tracer), - pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs)) - ) - - # When we trace through a torch.tensor invocation, you never actually - # see a torch.ops.aten.tensor call. Instead, the way this function is - # implemented internally is that we allocate a plain tensor (this is - # *guaranteed* to be a plain tensor, we disable all modes when doing - # so), and then call at::lift_fresh on it (to give modes a chance to do - # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed - # to be freshly allocated, so we want lift_fresh to be a no-op (directly - # returning the input argument). - # - # Here is the basic problem: when we trace this sequence of executions - # into an FX graph, what happens to this call sequence? Traditionally, - # tensor constants get interned as buffers on the FX GraphModule. But - # this is dangerous. Consider: - # - # x = torch.tensor(1) - # x.add_(2) - # - # Naively, this traces into: - # - # t = self._tensor_constant0 # initialized to torch.tensor(1) - # x = torch.ops.aten.lift_fresh(t) - # x.add_(2) - # - # If lift_fresh returns t directly, the subsequent add_ call will - # modify the tensor constant. Really, the problem is we've violated - # the invariant the the argument to lift is fresh. So what we should - # preserve the invariant by replacing lift_fresh with lift_fresh_copy: - # - # t = self._tensor_constant0 # initialized to torch.tensor(1) - # x = torch.ops.aten.lift_fresh_copy(t) - # x.add_(2) - # - # This is what the overload modification does. - if func is torch.ops.aten.lift_fresh.default: - func = torch.ops.aten.lift_fresh_copy.default - - proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs, - name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__)) - - # This makes DCE marginally less likely to DCE inplace operations. - # It is not strictly necessary - # Kind of a hacky way to test if an op is in-place or not - if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_": - if isinstance(args[0], List): - # e.g., c10d::allreduce_ returns a list of tensors as the first element - # in the output. - for i, a in enumerate(args[0]): - a.proxy = proxy_out[0][i] - else: - args[0].proxy = proxy_out - - out = func(*args, **kwargs) - - # In some circumstances, we will be tracing in a situation where a tensor - # is *statically* known to be a constant (currently, this only happens if - # you run torch.tensor; deterministic factory functions like torch.arange - # don't get this treatment). When the tensor in question is small, it's - # helpful to due constant propagation in case we call item() (in which - # case we can return the constant value that is known, rather than give - # an error.) The logic here tests if constant propagation is possible - # (because all of the inputs are constant). If so, we disable fake tensor - # mode (if it is on) and do true compute on the constant. - # - # It's worth highlighting that we're making a policy decision here. - # There is a potential that the tensor is actually quite large, and we - # don't actually want to run the compute. The tensor being quite large - # is one of the reasons why factory functions don't get this treatment - # (since they can be quite large; if a parameter is initialized to a - # constant value it will be!) Similarly, there is also a potential - # to run an operator that blows up the size of a small tensor; we don't - # protect against this case, but we could force, e.g., only single - # element constant computation by testing the numel of the result before - # propagating const-ness. Similarly, we don't require the constant to - # live on CPU, but we could. - any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) - - constant = None - - # If this is a lift, the input tensor is guaranteed to be a - # constant, so we keep a copy of the original argument along so - # we can query it if we're asked to item() it at some later point - if func is torch.ops.aten.lift_fresh_copy.default and out.numel() <= CONSTANT_NUMEL_LIMIT: - with maybe_disable_fake_tensor_mode(): - constant = args[0].clone() - elif ( - torch.Tag.nondeterministic_seeded not in func.tags # type: ignore[attr-defined] - and all_constant - and any_constant - and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out) - ): - # NB: do NOT include factories as constants - with maybe_disable_fake_tensor_mode(): - const_args, const_kwargs = pytree.tree_map_only( - _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs) - ) - constant = func(*const_args, **const_kwargs) - else: - constant = None - - track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) - return out - - -class PythonKeyTracer(Tracer): - def __init__(self): - super().__init__() - - # In general, we don't want to make modules leaves. In principle, users of - # this tracer might want to override this in order to turn a couple specific - # modules into leaves in the traced graph. - def call_module( - self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] - ) -> Any: - return forward(*args, **kwargs) - - # We don't want to turn getattr calls into proxies. So we just return the actual value. - def getattr(self, attr, attr_val, parameter_proxy_cache): - return attr_val - - def create_arg(self, a: Any): - if isinstance(a, torch.nn.Parameter): - for n, p in self.root.named_parameters(): - if a is p: - return self.create_node('get_attr', n, (), {}) - qualname: Optional[str] = None - - if not qualname: - i = 0 - while True: - qualname = f'_param_constant{i}' - if not hasattr(self.root, qualname): - break - i += 1 - setattr(self.root, qualname, a) - - return self.create_node('get_attr', qualname, (), {}) - elif isinstance(a, (SymInt, SymFloat)): - assert a.get_pyobj().constant is not None - return a.get_pyobj().constant - return super().create_arg(a) - - -def dispatch_trace( - root: Union[torch.nn.Module, Callable], - tracer: Tracer, - concrete_args: Optional[Tuple[Any, ...]] = None, -) -> GraphModule: - graph = tracer.trace(root, concrete_args) - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - return GraphModule(tracer.root, graph, name) - - -def wrap_key(f, tensors, tracer): - flat_tensors, tensors_spec = pytree.tree_flatten(tensors) - - @functools.wraps(f) - def wrapped(*proxies): - flat_proxies, proxies_spec = pytree.tree_flatten(proxies) - assert len(flat_proxies) == len(flat_tensors) - track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) - - out = f(*tensors) - return pytree.tree_map_only( - torch.Tensor, - lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy), - out - ) - - return wrapped - - -class ProxyTorchDispatchMode(TorchDispatchMode): - def __init__(self, tracer): - self.tracer = tracer - self.enable_tracing = True - self.sym_mode = ProxySymDispatchMode(tracer) - self.trace_state = {} - self._managers = [] - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - with self.sym_mode.enable(False): - return self.inner_torch_dispatch(func, types, args, kwargs) - - def __enter__(self): - # sym mode first, then us... - m = self.sym_mode.enable(True) - self._managers.append(m) - m.__enter__() - return super().__enter__() - - def __exit__(self, exc_type, exc_value, traceback): - m = self._managers.pop() - # ...exit us first, then sym mode - b = super().__exit__(exc_type, exc_value, traceback) - if not b: - return m.__exit__(exc_type, exc_value, traceback) - else: - return m.__exit__(None, None, None) - - def inner_torch_dispatch(self, func, types, args=(), kwargs=None): - if not self.enable_tracing: - return func(*args, **kwargs) - - if func in [prim.device.default]: - return func(*args, **kwargs) - - out = proxy_call(self, func, args, kwargs) - return out - - -SymInt = torch.SymIntNode -SymFloat = torch.SymFloatNode - - -class ProxySymDispatchMode(SymDispatchMode): - def __init__(self, tracer): - super().__init__() - self.tracer = tracer - # When false, we don't trace operations. If you do this, you MUST - # call track_tensor/track_tensor_tree on all results of the operation - # to ensure we can adeduately track the results - self.enable_tracing = True - - @contextmanager - def enable(self, b): - old = self.enable_tracing - self.enable_tracing = b - try: - yield - finally: - self.enable_tracing = old - - def _compute_proxy(self, func, args, out): - n_args = tuple( - get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant - if isinstance(a, (PySymInt, PySymFloat)) else a - for a in args - ) - - # func doesn't have a __torch_function__ that Proxy can interpose, so - # we gotta do it manually - n_out = self.tracer.create_node("call_function", func, n_args, {}) - p_out = fx.Proxy(n_out, self.tracer) - set_meta(p_out, out) - return p_out - - def __sym_dispatch__(self, func, types, args, kwargs): - if not self.enable_tracing: - return func(*args, **kwargs) - - # Peephole optimize multiply by one - if func == operator.mul: - if isinstance(args[1], PySymInt) and args[1].constant == 1: - return args[0] - elif isinstance(args[0], PySymInt) and args[0].constant == 1: - return args[1] - - # For speed, we assume there are no nested data structures - # (otherwise we could use tree_map) - # We also assume there are no keyword arguments. - assert not kwargs - out = func(*args, **kwargs) - assert isinstance(out, (PySymInt, PySymFloat)), f"{func}(*{args}, **{kwargs}) = {out}" - - # Delays tracing out the proxies on this op until we actually need it - p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out) - set_proxy_slot(out, self.tracer, p_out_thunk) - return out - - -# TODO: I'm not sure what the point of this class is; you can just -# make_fx through a regular Interpreter -class DecompositionInterpreter(pippy.fx.Interpreter): - def __init__(self, module: pippy.fx.GraphModule, new_graph: pippy.fx.Graph, decomposition_table=None, **kwargs): - super().__init__(module, **kwargs) - self.new_graph = new_graph - self.tracer = pippy.fx.proxy.GraphAppendingTracer(self.new_graph) - self.decomposition_table = decomposition_table - if self.decomposition_table is None: - self.decomposition_table = {} - self.mode = ProxyTorchDispatchMode(self.tracer) - - def placeholder(self, target, args, kwargs): - out = super().placeholder(target, args, kwargs) - proxy = pippy.fx.Proxy(self.new_graph.placeholder(target), self.tracer) - track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) - # TODO handle case where the first character of target is '*' - return out - - def get_attr(self, target, args, kwargs): - out = super().get_attr(target, args, kwargs) - proxy = pippy.fx.Proxy(self.new_graph.get_attr(target), self.tracer) - track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) - return out - - # call_function, call_method, call_module get traced automatically by the outer mode. - - def output(self, target, args, kwargs): - out = super().output(target, args, kwargs) - - def unwrap(e): - return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node) - self.new_graph.output(pytree.tree_map(unwrap, out)) - return out - - def run(self, *args, **kwargs): - # Should enter the mode at least once for being able to restore it later - # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 - with decompose(self.decomposition_table), self.mode: - return super().run(*args, **kwargs) - - -def wrapper_and_args_for_make_fx(func, args, kwargs): - # make_fx doesn't support kwargs, so we need to do this flattening - # and then unflatten the args before calling func - flat_args, spec = pytree.tree_flatten((args, kwargs)) - - def wrapped(flat_args): - fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) - return func(*fn_args, **fn_kwargs) - return wrapped, flat_args - -@contextmanager -def disable_autocast_cache(): - old_value = torch.is_autocast_cache_enabled() - torch.set_autocast_cache_enabled(False) - try: - yield - finally: - torch.set_autocast_cache_enabled(old_value) - - -def make_fx(f, decomposition_table=None, tracing_mode="real"): - assert tracing_mode in ["real", "fake", "symbolic"] - - if decomposition_table is None: - decomposition_table = {} - - @functools.wraps(f) - def wrapped(*args): - phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] - fx_tracer = PythonKeyTracer() - fake_tensor_mode: Any = nullcontext() - if tracing_mode == "real": - fake_tensor_mode = nullcontext() - elif tracing_mode == "fake": - fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) - elif tracing_mode == "symbolic": - fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) - else: - raise AssertionError(f"Unexpected tracing type: {tracing_mode}") - - python_dispatcher_mode: Any = nullcontext() - if tracing_mode == "symbolic": - python_dispatcher_mode = enable_python_dispatcher() - - proxy_mode = ProxyTorchDispatchMode(fx_tracer) - - def wrap_fake_concrete(x): - if isinstance(x, torch.Tensor): - return fake_tensor_mode.from_tensor(x) # type: ignore[attr-defined] - - return x - - shape_env = ShapeEnv() - sym_mode = proxy_mode.sym_mode - - # todo: Figure out a more informative name for symints - def wrap_fake_symbolic(x): - if isinstance(x, torch.Tensor): - return fake_tensor_mode.from_tensor(x, shape_env=shape_env) - return x - - wrap_fn_map = { - "real": lambda x: x, - "fake": wrap_fake_concrete, - "symbolic": wrap_fake_symbolic, - } - args = pytree.tree_map(wrap_fn_map[tracing_mode], args) - - if not hasattr(inspect.unwrap(f), '__code__') or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS: - # FX doesn't support varargs, so we gotta fake up a wrapper - # TODO: Would be nice to fix this at the source... - func = fake_signature(f, len(phs)) - else: - func = f - - # We disable the autocast cache as the autocast cache causes type conversions on parameters to - # check a cache, which introduces untracked tensors into the graph - with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \ - sym_mode, proxy_mode, disable_autocast_cache(): # type: ignore[attr-defined] - t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs)) - - # TODO: kind of a bad way to do it, should maybe figure out a better way - t.shape_env = shape_env # type: ignore[assignment] - return t - - return wrapped - - -def get_torch_dispatch_modes(): - return torch.utils._python_dispatch._get_current_dispatch_mode_stack() - - -@contextlib.contextmanager -def disable_proxy_modes_tracing(): - # TODO: This probably doesn't correctly also disable ProxySymDispatchMode - modes = get_torch_dispatch_modes() - proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)] - olds = [m.enable_tracing for m in proxy_tensor_modes] - for proxy_mode in proxy_tensor_modes: - proxy_mode.enable_tracing = False - try: - yield - finally: - for proxy_mode, old in zip(proxy_tensor_modes, olds): - proxy_mode.enable_tracing = old - - -def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"): - """A helper function used to get the GraphModule for the given func. - - It's expected to be used in the ProxyTensor tracing context. - It detaches the args and kwargs from the current tracer so that the trace of - the current graph module can be created without any side-effects. - """ - wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) - - with disable_proxy_modes_tracing(): - gm = make_fx(wrapped, tracing_mode=tracing_mode)(all_args) - return gm diff --git a/pippy/fx/experimental/refinement_types.py b/pippy/fx/experimental/refinement_types.py deleted file mode 100644 index 665c9d0d6..000000000 --- a/pippy/fx/experimental/refinement_types.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -class Equality: - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs - - def __str__(self): - return f'{self.lhs} = {self.rhs}' - - def __repr__(self): - return f'{self.lhs} = {self.rhs}' - - def __eq__(self, other): - if isinstance(other, Equality): - return self.lhs == other.lhs and self.rhs == other.rhs - else: - return False diff --git a/pippy/fx/experimental/rewriter.py b/pippy/fx/experimental/rewriter.py deleted file mode 100644 index d09eba545..000000000 --- a/pippy/fx/experimental/rewriter.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import ast -import inspect -import textwrap -import copy -import functools -from types import FunctionType -from typing import cast, Union, Callable, Dict, Optional, Any -from pippy.fx._symbolic_trace import Tracer -from pippy.fx.graph import Graph -from torch._sources import normalize_source_lines -import torch - -class AST_Rewriter(ast.NodeTransformer): - """ - Take a FunctionType object representing a `forward` method, then - perform an AST rewrite to swap out nodes that are not symbolically - traceable with a callsite to the FX alternative. - - To support swapping out an AST node, define a new `visit` method on - that node. For more details, see: - https://docs.python.org/3/library/ast.html#ast.NodeTransformer - """ - - def rewrite(self, fn: FunctionType): - - # Normalize the source lines - sourcelines, _ = inspect.getsourcelines(fn) - sourcelines = normalize_source_lines(sourcelines) - source = ''.join(sourcelines) - normalized_str = textwrap.dedent(source) - - # Rewrite the original AST - source_ast = ast.parse(normalized_str) - dest_ast = ast.fix_missing_locations(self.visit(source_ast)) - - # Pull out the compiled fucntion from the newly-created Module - code = compile(dest_ast, "", "exec") - globals_dict = copy.copy(fn.__globals__) - keys_before = set(globals_dict.keys()) - exec(code, globals_dict) - new_keys = list(set(globals_dict.keys()) - keys_before) - assert len(new_keys) == 1 - fn_compiled = globals_dict[new_keys[0]] - - # return the compiled function with the original globals - def change_func_globals(f, globals): - """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" - # __globals__ is a private member of the function class - # so we have to copy the function, f, all of its member, except f.__globals__ - g = FunctionType( - f.__code__, - globals, - name=f.__name__, - argdefs=f.__defaults__, - closure=f.__closure__, - ) - g = functools.update_wrapper(g, f) - g.__kwdefaults__ = copy.copy(f.__kwdefaults__) - return g - # Return the correct FunctionType object - return change_func_globals(fn_compiled, globals=fn.__globals__) - - def visit_Assert(self, node): - """ - Swap out the Assert node (Python's `assert`) with a callsite to the - symbolically-traceable torch._assert function - """ - # Create the Call node - n = ast.parse('torch._assert()', mode='eval') - assert isinstance(n, ast.Expression) - call_node = n.body - assert isinstance(call_node, ast.Call) - msg = node.msg if node.msg else ast.Constant(value="", kind=None) - call_node.args = [node.test, msg] - - # Ensure that the new node conforms to the Python AST grammar - expr_wrapper = ast.Expr(value=call_node) - - # Return the new Call node to signify that we want to use it as - # a replacement for the original _assert node - return ast.copy_location(expr_wrapper, node) - - def visit_AnnAssign(self, node): - """ - Swap out Python's AnnAssign with an Assign node where the annotation function is called. - Example: - Original: - y: Tensor_Type(1,2,3, Dyn) = f2(x) - Output: - y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) - """ - return ast.Assign(targets=[node.target], value=ast.Call( - func=ast.Name(id='annotate', ctx=ast.Load()), - args=[node.value, node.annotation], keywords=[])) - - -class RewritingTracer(Tracer): - def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: - return super().trace(_rewrite(root), concrete_args) - - -def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: - if isinstance(fn, torch.nn.Module): - # Rewrite this module's `forward` as well as the `forward`s of - # all of this module's recursive descendents. Return the new, - # rewritten module hierarchy. - def rewrite_module(m : torch.nn.Module): - class RewrittenModule(torch.nn.Module): - def __init__(self, orig): - super().__init__() - for k, v in orig.__dict__.items(): - if isinstance(v, torch.nn.Module): - self.__dict__[k] = copy.copy(rewrite_module(v)) - else: - self.__dict__[k] = copy.copy(v) - RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) - return RewrittenModule(m) - return rewrite_module(fn) - else: - # Rewrite this single free function - return AST_Rewriter().rewrite(cast(FunctionType, fn)) diff --git a/pippy/fx/experimental/schema_type_annotation.py b/pippy/fx/experimental/schema_type_annotation.py deleted file mode 100644 index 93102a6b5..000000000 --- a/pippy/fx/experimental/schema_type_annotation.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import pippy.fx -import inspect -from typing import Any, Dict, Optional, Tuple -from pippy.fx.node import Argument, Target -from torch._jit_internal import boolean_dispatched -from pippy.fx.operator_schemas import _torchscript_type_to_python_type - -from pippy.fx import Transformer - -class AnnotateTypesWithSchema(Transformer): - """ - Use Python function signatures to annotate types for `Nodes` within an FX graph. - This pulls out Python function signatures for: - - 1. Standard `torch.nn` Module calls - 2. `torch.nn.functional` calls - 3. Attribute fetches via `get_attr` - - Example usage: - - m = torchvision.models.resnet18() - - traced = pippy.fx.symbolic_trace(m) - - traced = AnnotateTypesWithSchema(traced).transform() - - """ - def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True, - annotate_modules : bool = True, annotate_get_attrs : bool = True): - super().__init__(module) - self.annotate_functionals = annotate_functionals - self.annotate_modules = annotate_modules - self.annotate_get_attrs = annotate_get_attrs - - def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): - python_ret_type = None - if self.annotate_functionals and target.__module__ == 'torch.nn.functional': - target_for_analysis = target - if target in boolean_dispatched: - # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have - # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` - # branches of the dispatch have exactly the same signature. If they do, use the `true` - # branch signature for analysis. Otherwise, leave this un-normalized - assert not isinstance(target, str) - dispatched = boolean_dispatched[target] - if_true, if_false = dispatched['if_true'], dispatched['if_false'] - # TODO: can we emit the union of these? What are the implications on TorchScript - # compilation? - if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation: - return super().call_function(target, args, kwargs) - target_for_analysis = if_true - - python_ret_type = self._extract_python_return_type(target_for_analysis) - - return_proxy = super().call_function(target, args, kwargs) - return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type - return return_proxy - - def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): - python_ret_type = None - assert isinstance(target, str) - submod = self.fetch_attr(target) - if self.annotate_modules and hasattr(submod.__class__, '__name__'): - classname = submod.__class__.__name__ - if getattr(torch.nn, classname, None) == submod.__class__: - python_ret_type = self._extract_python_return_type(submod.forward) - return_proxy = super().call_module(target, args, kwargs) - return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type - return return_proxy - - def get_attr(self, target : pippy.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): - attr_proxy = super().get_attr(target, args, kwargs) - - if self.annotate_get_attrs: - module_itr = self.module - assert isinstance(target, str) - atoms = target.split('.') - for i, atom in enumerate(atoms): - if not hasattr(module_itr, atom): - raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!') - module_itr = getattr(module_itr, atom) - - maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) - if maybe_inferred_ts_type.success(): - python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type()) - attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type - - return attr_proxy - - def _extract_python_return_type(self, target : Target) -> Optional[Any]: - """ - Given a Python call target, try to extract the Python return annotation - if it is available, otherwise return None - - Args: - - target (Callable): Python callable to get return annotation for - - Returns: - - Optional[Any]: Return annotation from the `target`, or None if it was - not available. - """ - assert callable(target) - try: - sig = inspect.signature(target) - except (ValueError, TypeError): - return None - - return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None diff --git a/pippy/fx/experimental/symbolic_shapes.py b/pippy/fx/experimental/symbolic_shapes.py deleted file mode 100644 index 5817194e5..000000000 --- a/pippy/fx/experimental/symbolic_shapes.py +++ /dev/null @@ -1,472 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import torch.utils._pytree as pytree -from typing import Set, Dict, List, Type, Optional, cast # pylint: disable=unused-import -import operator -import functools -from functools import lru_cache, partial -import traceback -import collections -import textwrap -from torch._subclasses.meta_utils import MetaConverter - -try: - import sympy # type: ignore[import] - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False - -aten = torch.ops.aten # type: ignore[has-type] - -__all__ = [ - "has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv", - "SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv" -] - -SYM_FUNCTION_MODE = None - -# We don't bother with the metaclass as all of the dispatching logic happens -# entirely from Python -# -# Didn't bother with ancestors for now, unlikely to have multiple modes for -# symints right now - - -# SymDispatchMode gets invoked whenever an operation is processed on -# a PySymInt. When this occurs, you get called at __sym_dispatch__ -# with the operation in question. This is symmetric to TorchDispatchMode -# but with some caveats: -# -# - In TorchDispatchMode, you get the same arguments as what a user -# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b), -# you get (a, b) as args to your call. In SymDispatchMode, if -# you call a + b (where a and b are SymInts), you will get -# (a.get_pyobj(), b.get_pyobj()) as your args (these are PySymInts) -# -# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor). -# So you have to manually call Tracer/create_node to write into -# the graph. See ProxySymDispatchMode for an example -# -class SymDispatchMode: - def __sym_dispatch__(self, func, types, args, kwargs): - raise NotImplementedError() - - def __enter__(self): - global SYM_FUNCTION_MODE - old = SYM_FUNCTION_MODE - if hasattr(self, "inner"): - raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version") - else: - self.inner = old - SYM_FUNCTION_MODE = self - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - global SYM_FUNCTION_MODE - SYM_FUNCTION_MODE = self.inner - -def has_symbolic_sizes_strides(elem): - return elem._has_symbolic_sizes_strides - -def create_contiguous(shape): - strides = [1] - for dim in reversed(shape[:-1]): - strides.append(dim * strides[-1]) - return list(reversed(strides)) - -def _handle_sym_dispatch(func, args, kwargs): - global SYM_FUNCTION_MODE - mode = SYM_FUNCTION_MODE - assert mode - SYM_FUNCTION_MODE = mode.inner - try: - # TODO: properly compute types - types: List[Type] = [] - return mode.__sym_dispatch__(func, types, args, kwargs) - finally: - SYM_FUNCTION_MODE = mode - -def sym_float(a): - if hasattr(a, '__sym_float__'): - return a.__sym_float__() - elif isinstance(a, torch._C.SymFloatNode): - return a - return float(a) - -# TODO: An incomplete list -# 1. Set variables to be equal when we do equality -# 2. Specialize on 0/1 when we do subtraction -class PySymInt(object): - """ - PySymInt objects are the primary "symbolic shape" objects that flow through - our program. They're what sit under FakeTensor, and contains our primary - implementation of symbolic shapes. - """ - def __init__(self, expr, shape_env, constant=None): - self.expr = expr - self.shape_env = shape_env - self.constant = constant - - def wrap(self, num): - return PySymInt(sympy.Integer(num), self.shape_env, constant=num) - - def __str__(self): - return f"{self.expr}" - - def __repr__(self): - return f"{self.expr}" - - # Today we error on calling int on a symbolic shape, as this is a very accessible footgun. - def __int__(self): - raise RuntimeError("Trying to extract a concrete int out of a symbolic int") - - # You can manually trigger a guard with this function - def guard_int(self, file, line): - # TODO: use the file/line for some useful diagnostic on why a - # guard occurred - return int(self.shape_env.evaluate_expr(self.expr)) - - def __sym_float__(self): - if SYM_FUNCTION_MODE: - return _handle_sym_dispatch(sym_float, (self,), {}) - # TODO: consider constant prop here - # TODO: wrapping the expr with sympy.Float doesn't seem to work, why - # not? - return PySymFloat(self.expr, self.shape_env) - - def __bool__(self): - return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))) - -class PySymFloat: - def __init__(self, expr, shape_env, constant=None): - self.expr = expr - self.shape_env = shape_env - self.constant = constant - - def wrap(self, num): - return PySymFloat(sympy.Float(num), self.shape_env, constant=num) - - def __str__(self): - return f"{self.expr}" - -if HAS_SYMPY: - class FloorDiv(sympy.Function): - """ - We maintain this so that: - 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. - 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) - """ - nargs = (2,) - - @classmethod - def eval(cls, base, divisor): - if base == 0: - return sympy.Integer(0) - if divisor == 1: - return base - if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return base // divisor - gcd = sympy.gcd(base, divisor) - if gcd != 1: - return FloorDiv( - sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) - ) - -# Methods that have a `__foo__` as well as `__rfoo__` -reflectable_magic_methods = { - 'add': lambda a, b: a + b, - 'sub': lambda a, b: a - b, - 'mul': lambda a, b: a * b, - 'mod': lambda a, b: a % b, - 'truediv': lambda a, b: a / b, - 'floordiv': lambda a, b: FloorDiv(a, b) -} - -magic_methods = { - **reflectable_magic_methods, - 'eq': lambda a, b: sympy.Eq(a, b), - 'gt': lambda a, b: sympy.Gt(a, b), - 'lt': lambda a, b: sympy.Lt(a, b), - 'le': lambda a, b: sympy.Le(a, b), - 'ge': lambda a, b: sympy.Ge(a, b), -} - -float_magic_methods = {"add", "sub", "mul", "truediv"} - -def _make_magic(method, func, py_type): - func = lru_cache(256)(func) - - def magic_impl(self, other): - if SYM_FUNCTION_MODE: - return _handle_sym_dispatch(getattr(operator, method), (self, other), {}) - if isinstance(other, py_type): - other = other.expr - # TODO: consider constant prop here - expr = self.shape_env.replace(self.expr) - other = self.shape_env.replace(other) - out = func(expr, other) - out = sympy.expand(out) - if method in ["truediv"]: - return PySymFloat(out, self.shape_env) - else: - # TODO: relational operators actually technically return a - # PySymBool, this is a type error - return py_type(out, self.shape_env) - - # this should be wrapped transparently into torch.SymIntNode - setattr(py_type, method, magic_impl) - setattr(py_type, f"__{method}__", magic_impl) - if method in reflectable_magic_methods: - setattr(py_type, f"__r{method}__", magic_impl) - -for method, func in magic_methods.items(): - _make_magic(method, func, PySymInt) - -for method, func in magic_methods.items(): - if method not in float_magic_methods: - continue - _make_magic(method, func, PySymFloat) - -del method -del func - -def _lru_cache(fn, maxsize=None): - """ - Wrapper around lru_cache that clears when new info about shapes has been - updated. - - Use lru_cache if the output is always the same, regardless of the - constraints we know now (i.e. evaluate_expr) - - Use _lru_cache otherwise. - """ - fn_cache = lru_cache(maxsize)(fn) - prior_key = None - - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - nonlocal prior_key - if prior_key != self._get_key(): - prior_key = self._get_key() - fn_cache.cache_clear() - return fn_cache(self, *args, **kwargs) - - wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] - return wrapper - - - -class ShapeEnv(object): - def __init__(self): - self.guards = [] - # Maps symbolic ints to their original concrete values - # Currently populated from tensors - self.var_to_val: Dict["sympy.Symbol", "sympy.Integer"] = {} - # Maps from sympy ints to expressions representing them - # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) - self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} # - # Set holds a % b expressions that evaluate to 0. - self.divisible: Set["sympy.Expr"] = set() - # Duck-shaping says that if two input tensors have the same size, - # they get assigned the same symbolic variable - self.val_to_symint: Dict[int, torch.SymIntNode] = {} - - def _get_key(self): - """ - Defines the current "state" of the guards we've accumulated in this ShapeEnv. - Determines when we need to invalidate our cache - """ - return (len(self.replacements), len(self.divisible)) - - # NB: This is only called for input symbolic sizes; intermediate symbolic - # sizes are allocated via a different mechanism - def create_symint(self, name, val): - assert val >= 0 - if not HAS_SYMPY: - raise RuntimeError("Need sympy installed to create symbolic shapes") - - # TODO: Put 0/1 specialization in guards - if val == 0 or val == 1: - return val - # This implements duck-shaping: input sizes that match are assigned - # the same symint - # TODO: Create a guard whenever this happens - # TODO: But how do I represent the guard in this case? - if val in self.val_to_symint: - return self.val_to_symint[val] - sympy_expr = sympy.Symbol(name, positive=True, integer=True) - py_sym_int = PySymInt(sympy_expr, self) - cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined] - self.var_to_val[sympy_expr] = sympy.Integer(val) - self.val_to_symint[val] = cpp_sym_int - return cpp_sym_int - - def evaluate_guards_for_args(self, *args): - new_env = ShapeEnv() - # NB: This must be kept in sync with create_aot_dispatcher_function - # and wrap_fake_symbolic - meta_converter = MetaConverter() - pytree.tree_map_only(torch.Tensor, partial(meta_converter, shape_env=new_env), args) - return all(guard.xreplace(new_env.var_to_val) == value for guard, value, _ in self.guards) - - def get_nontrivial_guards(self): - return [(self.simplify(guard), val) for guard, val, _ in self.guards if self._maybe_evaluate_static(guard) is None] - - def format_guards(self, verbose=False): - def format_val(guard, val): - if val is sympy.true: - return str(guard) - elif val is sympy.false: - return f"Not({guard})" - else: - return f"Eq({guard}, {val})" - - def format_tb(tb): - if not verbose: - return "" - return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}" - - return '\n'.join(f" - {format_val(guard, val)}{format_tb(tb)}" for guard, val, tb in self.guards) - - def get_shape_groups(self): - shape_groups = collections.defaultdict(list) - for k, v in self.replacements.items(): - shape_groups[v].append(k) - return shape_groups - - @_lru_cache - def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]": - """ - Tries to evaluate expr without introducing guards - """ - expr = self.simplify(expr) - # Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values) - symbols = list(expr.free_symbols) - new_shape_env = { - k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1 - for idx, k in enumerate(symbols) - } - new_expr = expr.xreplace(new_shape_env) - floor_div_replace = {} - for atom in new_expr.atoms(FloorDiv): - floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) - new_expr = sympy.expand(new_expr.xreplace(floor_div_replace)) - if len(list(new_expr.free_symbols)) == 0: - return new_expr - return None - - @_lru_cache - def replace(self, expr: "sympy.Expr") -> "sympy.Expr": - replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols} - return sympy.expand(expr.xreplace(replacements)) - - @_lru_cache - def _update_divisible(self): - new_divisible = set() - for k in self.divisible: - res = self.replace(k) - if len(res.free_symbols) > 0: - new_divisible.add(k) - - self.divisible = new_divisible - - @_lru_cache - def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": - expr = self.replace(expr) - if expr.has(FloorDiv): - self._update_divisible() - div_replacements = {} - for atom in expr.atoms(FloorDiv): - base, divisor = atom.args - if self.replace(base % divisor) in self.divisible: - div_replacements[atom] = base / divisor - expr = expr.xreplace(div_replacements) - expr = sympy.expand(expr) - return expr - - @lru_cache(256) - def size_hint(self, expr: "sympy.Expr"): - """ - Gets a size hint for a given expression from the underlying shapes we had. - Does not introduce a guard, so only use this when you can guarantee that - your code is still valid for arbitrary shapes (such as optimization decisions) - """ - result_expr = sympy.expand(expr).xreplace(self.var_to_val) - assert len(result_expr.free_symbols) == 0, "Size hint has variables we don't have underlying values for" - return result_expr - - @_lru_cache - def _find(self, a: "sympy.Symbol") -> "sympy.Expr": - """ - Implements a DSU-like algorithm to find the variable that represents a - Also handles transitive non-identity replacements. - - a: b + c - c: d - """ - if a not in self.replacements: - return a - res = self.replacements[a] - cur_replace = {s: self._find(s) for s in res.free_symbols} - self.replacements[a] = self.replacements[a].xreplace(cur_replace) - return self.replacements[a] - - @lru_cache(256) - def _maybe_guard_eq(self, expr: "sympy.Eq") -> None: - """ - Evaluates the result of an eq call. If true, uses information to - simplify shapes (i.e. a == b or a % 5 == 0) - """ - concrete_bool = bool(self.size_hint(expr)) - if not concrete_bool: - return - free = list(expr.free_symbols) - - assert len(free) > 0, "The expression should not be static by this point" - # In case of really gnarly expression, we don't blow up - if len(free) > 5: - return - free = sorted(free, key=lambda x: (self.size_hint(x), x.name), reverse=True) # type: ignore[attr-defined] - lhs = expr.lhs - rhs = expr.rhs - try: - solutions = sympy.solve(lhs - rhs, free[0], dict=True) - if len(solutions) != 1: - return - solution = solutions[0][free[0]] - if all(t.is_integer for t in sympy.preorder_traversal(solution)): - new_var = self._find(solution) - self.replacements[cast(sympy.Symbol, free[0])] = new_var - except NotImplementedError: - if expr.has(sympy.Mod): - mod_expr = tuple(expr.atoms(sympy.Mod))[0] - try: - solutions = sympy.solve(lhs - rhs, mod_expr, dict=True) - if len(solutions) == 1 and solutions[0][mod_expr] == 0: - self.divisible.add(mod_expr) - except NotImplementedError: - pass - return - - @lru_cache(256) - def evaluate_expr(self, expr: "sympy.Expr"): - """ - Given an expression, evaluates it, adding guards if necessary - """ - if len(expr.free_symbols) == 0: - return expr - expr = self.simplify(expr) - static_expr = self._maybe_evaluate_static(expr) - if static_expr is not None: - return static_expr - - if isinstance(expr, sympy.Eq): - self._maybe_guard_eq(expr) - concrete_val = self.size_hint(expr) - - # TODO: optimize this; avoid formatting traces until we need them - # NB: drop two frames; evaluate_expr and the Sym* function that - # actually called us - stack = ''.join(traceback.format_list(traceback.extract_stack()[:-2])) - self.guards.append((expr, concrete_val, stack)) - return concrete_val diff --git a/pippy/fx/experimental/unification/LICENSE.txt b/pippy/fx/experimental/unification/LICENSE.txt deleted file mode 100644 index 775eca52c..000000000 --- a/pippy/fx/experimental/unification/LICENSE.txt +++ /dev/null @@ -1,28 +0,0 @@ -Copyright (c) 2014 Matthew Rocklin - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - a. Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - b. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - c. Neither the name of Unification nor the names of its contributors - may be used to endorse or promote products derived from this software - without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -DAMAGE. diff --git a/pippy/fx/experimental/unification/__init__.py b/pippy/fx/experimental/unification/__init__.py deleted file mode 100644 index 5e1477089..000000000 --- a/pippy/fx/experimental/unification/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# type: ignore[attr-defined] -from .core import unify, reify # noqa: F403 -from .more import unifiable # noqa: F403 -from .variable import var, isvar, vars, variables, Var # noqa: F403 diff --git a/pippy/fx/experimental/unification/core.py b/pippy/fx/experimental/unification/core.py deleted file mode 100644 index c1eb2b3cf..000000000 --- a/pippy/fx/experimental/unification/core.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from collections.abc import Iterator # type: ignore[import] -from functools import partial - -from .unification_tools import assoc # type: ignore[import] -from .utils import transitive_get as walk -from .variable import isvar -from .dispatch import dispatch - -__all__ = ["reify", "unify"] - -################ -# Reificiation # -################ - -@dispatch(Iterator, dict) -def _reify(t, s): - return map(partial(reify, s=s), t) - # return (reify(arg, s) for arg in t) -_reify - -@dispatch(tuple, dict) # type: ignore[no-redef] -def _reify(t, s): - return tuple(reify(iter(t), s)) -_reify - -@dispatch(list, dict) # type: ignore[no-redef] -def _reify(t, s): - return list(reify(iter(t), s)) -_reify - -@dispatch(dict, dict) # type: ignore[no-redef] -def _reify(d, s): - return dict((k, reify(v, s)) for k, v in d.items()) -_reify - -@dispatch(object, dict) # type: ignore[no-redef] -def _reify(o, s): - return o # catch all, just return the object - -def reify(e, s): - """ Replace variables of expression with substitution - >>> # xdoctest: +SKIP - >>> x, y = var(), var() - >>> e = (1, x, (3, y)) - >>> s = {x: 2, y: 4} - >>> reify(e, s) - (1, 2, (3, 4)) - >>> e = {1: x, 3: (y, 5)} - >>> reify(e, s) - {1: 2, 3: (4, 5)} - """ - if isvar(e): - return reify(s[e], s) if e in s else e - return _reify(e, s) - -############### -# Unification # -############### - -seq = tuple, list, Iterator - -@dispatch(seq, seq, dict) -def _unify(u, v, s): - if len(u) != len(v): - return False - for uu, vv in zip(u, v): # avoiding recursion - s = unify(uu, vv, s) - if s is False: - return False - return s -# -# @dispatch((set, frozenset), (set, frozenset), dict) -# def _unify(u, v, s): -# i = u & v -# u = u - i -# v = v - i -# return _unify(sorted(u), sorted(v), s) -# -# -# @dispatch(dict, dict, dict) -# def _unify(u, v, s): -# if len(u) != len(v): -# return False -# for key, uval in iteritems(u): -# if key not in v: -# return False -# s = unify(uval, v[key], s) -# if s is False: -# return False -# return s -# -# -# @dispatch(object, object, dict) -# def _unify(u, v, s): -# return False # catch all - - -@dispatch(object, object, dict) -def unify(u, v, s): # no check at the moment - """ Find substitution so that u == v while satisfying s - >>> x = var('x') - >>> unify((1, x), (1, 2), {}) - {~x: 2} - """ - u = walk(u, s) - v = walk(v, s) - if u == v: - return s - if isvar(u): - return assoc(s, u, v) - if isvar(v): - return assoc(s, v, u) - return _unify(u, v, s) -unify - -@dispatch(object, object) # type: ignore[no-redef] -def unify(u, v): - return unify(u, v, {}) diff --git a/pippy/fx/experimental/unification/dispatch.py b/pippy/fx/experimental/unification/dispatch.py deleted file mode 100644 index fd9b37188..000000000 --- a/pippy/fx/experimental/unification/dispatch.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from functools import partial -from .multipledispatch import dispatch # type: ignore[import] - -namespace = {} # type: ignore[var-annotated] - -dispatch = partial(dispatch, namespace=namespace) diff --git a/pippy/fx/experimental/unification/match.py b/pippy/fx/experimental/unification/match.py deleted file mode 100644 index 09ca3ad5d..000000000 --- a/pippy/fx/experimental/unification/match.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .core import unify, reify # type: ignore[attr-defined] -from .variable import isvar -from .utils import _toposort, freeze -from .unification_tools import groupby, first # type: ignore[import] - - -class Dispatcher(object): - def __init__(self, name): - self.name = name - self.funcs = {} - self.ordering = [] - - def add(self, signature, func): - self.funcs[freeze(signature)] = func - self.ordering = ordering(self.funcs) - - def __call__(self, *args, **kwargs): - func, s = self.resolve(args) - return func(*args, **kwargs) - - def resolve(self, args): - n = len(args) - for signature in self.ordering: - if len(signature) != n: - continue - s = unify(freeze(args), signature) - if s is not False: - result = self.funcs[signature] - return result, s - raise NotImplementedError("No match found. \nKnown matches: " - + str(self.ordering) + "\nInput: " + str(args)) - - def register(self, *signature): - def _(func): - self.add(signature, func) - return self - return _ - -class VarDispatcher(Dispatcher): - """ A dispatcher that calls functions with variable names - >>> d = VarDispatcher('d') - >>> # xdoctest: +SKIP - >>> x = var('x') - >>> @d.register('inc', x) - ... def f(x): - ... return x + 1 - >>> @d.register('double', x) - ... def f(x): - ... return x * 2 - >>> d('inc', 10) - 11 - >>> d('double', 10) - 20 - """ - def __call__(self, *args, **kwargs): - func, s = self.resolve(args) - d = dict((k.token, v) for k, v in s.items()) - return func(**d) - - - - -global_namespace = {} # type: ignore[var-annotated] - - -def match(*signature, **kwargs): - namespace = kwargs.get('namespace', global_namespace) - dispatcher = kwargs.get('Dispatcher', Dispatcher) - - def _(func): - name = func.__name__ - - if name not in namespace: - namespace[name] = dispatcher(name) - d = namespace[name] - - d.add(signature, func) - - return d - return _ - - -def supercedes(a, b): - """ ``a`` is a more specific match than ``b`` """ - if isvar(b) and not isvar(a): - return True - s = unify(a, b) - if s is False: - return False - s = dict((k, v) for k, v in s.items() if not isvar(k) or not isvar(v)) - if reify(a, s) == a: - return True - if reify(b, s) == b: - return False - - -# Taken from multipledispatch -def edge(a, b, tie_breaker=hash): - """ A should be checked before B - Tie broken by tie_breaker, defaults to ``hash`` - """ - if supercedes(a, b): - if supercedes(b, a): - return tie_breaker(a) > tie_breaker(b) - else: - return True - return False - - -# Taken from multipledispatch -def ordering(signatures): - """ A sane ordering of signatures to check, first to last - Topoological sort of edges as given by ``edge`` and ``supercedes`` - """ - signatures = list(map(tuple, signatures)) - edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] - edges = groupby(first, edges) - for s in signatures: - if s not in edges: - edges[s] = [] - edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[attr-defined, assignment] - return _toposort(edges) diff --git a/pippy/fx/experimental/unification/more.py b/pippy/fx/experimental/unification/more.py deleted file mode 100644 index 81e72821f..000000000 --- a/pippy/fx/experimental/unification/more.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .core import unify, reify # type: ignore[attr-defined] -from .dispatch import dispatch - - -def unifiable(cls): - """ Register standard unify and reify operations on class - This uses the type and __dict__ or __slots__ attributes to define the - nature of the term - See Also: - >>> class A(object): - ... def __init__(self, a, b): - ... self.a = a - ... self.b = b - >>> # xdoctest: +SKIP - >>> unifiable(A) - - >>> x = var('x') - >>> a = A(1, 2) - >>> b = A(1, x) - >>> unify(a, b, {}) - {~x: 2} - """ - _unify.add((cls, cls, dict), unify_object) - _reify.add((cls, dict), reify_object) - - return cls - - -######### -# Reify # -######### - - -def reify_object(o, s): - """ Reify a Python object with a substitution - >>> class Foo(object): - ... def __init__(self, a, b): - ... self.a = a - ... self.b = b - ... def __str__(self): - ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) - >>> # xdoctest: +SKIP - >>> x = var('x') - >>> f = Foo(1, x) - >>> print(f) - Foo(1, ~x) - >>> print(reify_object(f, {x: 2})) - Foo(1, 2) - """ - if hasattr(o, '__slots__'): - return _reify_object_slots(o, s) - else: - return _reify_object_dict(o, s) - - -def _reify_object_dict(o, s): - obj = object.__new__(type(o)) - d = reify(o.__dict__, s) - if d == o.__dict__: - return o - obj.__dict__.update(d) - return obj - - -def _reify_object_slots(o, s): - attrs = [getattr(o, attr) for attr in o.__slots__] - new_attrs = reify(attrs, s) - if attrs == new_attrs: - return o - else: - newobj = object.__new__(type(o)) - for slot, attr in zip(o.__slots__, new_attrs): - setattr(newobj, slot, attr) - return newobj - - -@dispatch(slice, dict) -def _reify(o, s): - """ Reify a Python ``slice`` object """ - return slice(*reify((o.start, o.stop, o.step), s)) - - -######### -# Unify # -######### - - -def unify_object(u, v, s): - """ Unify two Python objects - Unifies their type and ``__dict__`` attributes - >>> class Foo(object): - ... def __init__(self, a, b): - ... self.a = a - ... self.b = b - ... def __str__(self): - ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) - >>> # xdoctest: +SKIP - >>> x = var('x') - >>> f = Foo(1, x) - >>> g = Foo(1, 2) - >>> unify_object(f, g, {}) - {~x: 2} - """ - if type(u) != type(v): - return False - if hasattr(u, '__slots__'): - return unify([getattr(u, slot) for slot in u.__slots__], - [getattr(v, slot) for slot in v.__slots__], - s) - else: - return unify(u.__dict__, v.__dict__, s) - -@dispatch(slice, slice, dict) -def _unify(u, v, s): - """ Unify a Python ``slice`` object """ - return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/pippy/fx/experimental/unification/multipledispatch/__init__.py b/pippy/fx/experimental/unification/multipledispatch/__init__.py deleted file mode 100644 index 26039e4ce..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .core import dispatch -from .dispatcher import (Dispatcher, halt_ordering, restart_ordering, - MDNotImplementedError) diff --git a/pippy/fx/experimental/unification/multipledispatch/conflict.py b/pippy/fx/experimental/unification/multipledispatch/conflict.py deleted file mode 100644 index 4ed1da4b0..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/conflict.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .utils import _toposort, groupby -from .variadic import isvariadic - -__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature", - "edge", "ordering"] - -class AmbiguityWarning(Warning): - pass - - -def supercedes(a, b): - """ A is consistent and strictly more specific than B """ - if len(a) < len(b): - # only case is if a is empty and b is variadic - return not a and len(b) == 1 and isvariadic(b[-1]) - elif len(a) == len(b): - return all(map(issubclass, a, b)) - else: - # len(a) > len(b) - p1 = 0 - p2 = 0 - while p1 < len(a) and p2 < len(b): - cur_a = a[p1] - cur_b = b[p2] - if not (isvariadic(cur_a) or isvariadic(cur_b)): - if not issubclass(cur_a, cur_b): - return False - p1 += 1 - p2 += 1 - elif isvariadic(cur_a): - assert p1 == len(a) - 1 - return p2 == len(b) - 1 and issubclass(cur_a, cur_b) - elif isvariadic(cur_b): - assert p2 == len(b) - 1 - if not issubclass(cur_a, cur_b): - return False - p1 += 1 - return p2 == len(b) - 1 and p1 == len(a) - - -def consistent(a, b): - """ It is possible for an argument list to satisfy both A and B """ - - # Need to check for empty args - if not a: - return not b or isvariadic(b[0]) - if not b: - return not a or isvariadic(a[0]) - - # Non-empty args check for mutual subclasses - if len(a) == len(b): - return all(issubclass(aa, bb) or issubclass(bb, aa) - for aa, bb in zip(a, b)) - else: - p1 = 0 - p2 = 0 - while p1 < len(a) and p2 < len(b): - cur_a = a[p1] - cur_b = b[p2] - if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): - return False - if not (isvariadic(cur_a) or isvariadic(cur_b)): - p1 += 1 - p2 += 1 - elif isvariadic(cur_a): - p2 += 1 - elif isvariadic(cur_b): - p1 += 1 - # We only need to check for variadic ends - # Variadic types are guaranteed to be the last element - return (isvariadic(cur_a) and p2 == len(b) or - isvariadic(cur_b) and p1 == len(a)) - - -def ambiguous(a, b): - """ A is consistent with B but neither is strictly more specific """ - return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) - - -def ambiguities(signatures): - """ All signature pairs such that A is ambiguous with B """ - signatures = list(map(tuple, signatures)) - return set((a, b) for a in signatures for b in signatures - if hash(a) < hash(b) - and ambiguous(a, b) - and not any(supercedes(c, a) and supercedes(c, b) - for c in signatures)) - - -def super_signature(signatures): - """ A signature that would break ambiguities """ - n = len(signatures[0]) - assert all(len(s) == n for s in signatures) - - return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] - for i in range(n)] - - -def edge(a, b, tie_breaker=hash): - """ A should be checked before B - Tie broken by tie_breaker, defaults to ``hash`` - """ - # A either supercedes B and B does not supercede A or if B does then call - # tie_breaker - return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)) - - -def ordering(signatures): - """ A sane ordering of signatures to check, first to last - Topoological sort of edges as given by ``edge`` and ``supercedes`` - """ - signatures = list(map(tuple, signatures)) - edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] - edges = groupby(lambda x: x[0], edges) - for s in signatures: - if s not in edges: - edges[s] = [] - edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[assignment, attr-defined] - return _toposort(edges) diff --git a/pippy/fx/experimental/unification/multipledispatch/core.py b/pippy/fx/experimental/unification/multipledispatch/core.py deleted file mode 100644 index ca79fcadb..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/core.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -import sys - -from .dispatcher import Dispatcher, MethodDispatcher - -global_namespace = {} # type: ignore[var-annotated] - -__all__ = ["dispatch", "ismethod"] - -def dispatch(*types, **kwargs): - """ Dispatch function on the types of the inputs - Supports dispatch on all non-keyword arguments. - Collects implementations based on the function name. Ignores namespaces. - If ambiguous type signatures occur a warning is raised when the function is - defined suggesting the additional method to break the ambiguity. - Examples - -------- - >>> @dispatch(int) - ... def f(x): - ... return x + 1 - >>> @dispatch(float) - ... def f(x): - ... return x - 1 - >>> f(3) - 4 - >>> f(3.0) - 2.0 - >>> # Specify an isolated namespace with the namespace keyword argument - >>> my_namespace = {} - >>> @dispatch(int, namespace=my_namespace) - ... def foo(x): - ... return x + 1 - >>> # Dispatch on instance methods within classes - >>> class MyClass(object): - ... @dispatch(list) - ... def __init__(self, data): - ... self.data = data - ... @dispatch(int) - ... def __init__(self, datum): - ... self.data = [datum] - >>> MyClass([1, 2, 3]).data - [1, 2, 3] - >>> MyClass(3).data - [3] - """ - namespace = kwargs.get('namespace', global_namespace) - - types = tuple(types) - - def _df(func): - name = func.__name__ - - if ismethod(func): - dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr] - name, # type: ignore[union-attr] - MethodDispatcher(name), - ) - else: - if name not in namespace: - namespace[name] = Dispatcher(name) - dispatcher = namespace[name] - - dispatcher.add(types, func) - return dispatcher - return _df - - -def ismethod(func): - """ Is func a method? - Note that this has to work as the method is defined but before the class is - defined. At this stage methods look like functions. - """ - if hasattr(inspect, "signature"): - signature = inspect.signature(func) - return signature.parameters.get('self', None) is not None - else: - if sys.version_info.major < 3: - spec = inspect.getargspec(func) - else: - spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] - return spec and spec.args and spec.args[0] == 'self' diff --git a/pippy/fx/experimental/unification/multipledispatch/dispatcher.py b/pippy/fx/experimental/unification/multipledispatch/dispatcher.py deleted file mode 100644 index 7427aebe5..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/dispatcher.py +++ /dev/null @@ -1,433 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from warnings import warn -import inspect -from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning -from .utils import expand_tuples -from .variadic import Variadic, isvariadic -import itertools as itl - -__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter", - "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"] - -class MDNotImplementedError(NotImplementedError): - """ A NotImplementedError for multiple dispatch """ - - -def ambiguity_warn(dispatcher, ambiguities): - """ Raise warning when ambiguity is detected - Parameters - ---------- - dispatcher : Dispatcher - The dispatcher on which the ambiguity was detected - ambiguities : set - Set of type signature pairs that are ambiguous within this dispatcher - See Also: - Dispatcher.add - warning_text - """ - warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) - - -def halt_ordering(): - """Deprecated interface to temporarily disable ordering. - """ - warn( - 'halt_ordering is deprecated, you can safely remove this call.', - DeprecationWarning, - ) - - -def restart_ordering(on_ambiguity=ambiguity_warn): - """Deprecated interface to temporarily resume ordering. - """ - warn( - 'restart_ordering is deprecated, if you would like to eagerly order' - 'the dispatchers, you should call the ``reorder()`` method on each' - ' dispatcher.', - DeprecationWarning, - ) - - -def variadic_signature_matches_iter(types, full_signature): - """Check if a set of input types matches a variadic signature. - Notes - ----- - The algorithm is as follows: - Initialize the current signature to the first in the sequence - For each type in `types`: - If the current signature is variadic - If the type matches the signature - yield True - Else - Try to get the next signature - If no signatures are left we can't possibly have a match - so yield False - Else - yield True if the type matches the current signature - Get the next signature - """ - sigiter = iter(full_signature) - sig = next(sigiter) - for typ in types: - matches = issubclass(typ, sig) - yield matches - if not isvariadic(sig): - # we're not matching a variadic argument, so move to the next - # element in the signature - sig = next(sigiter) - else: - try: - sig = next(sigiter) - except StopIteration: - assert isvariadic(sig) - yield True - else: - # We have signature items left over, so all of our arguments - # haven't matched - yield False - - -def variadic_signature_matches(types, full_signature): - # No arguments always matches a variadic signature - assert full_signature - return all(variadic_signature_matches_iter(types, full_signature)) - - -class Dispatcher(object): - """ Dispatch methods based on type signature - Use ``dispatch`` to add implementations - Examples - -------- - >>> # xdoctest: +SKIP("bad import name") - >>> from multipledispatch import dispatch - >>> @dispatch(int) - ... def f(x): - ... return x + 1 - >>> @dispatch(float) - ... def f(x): - ... return x - 1 - >>> f(3) - 4 - >>> f(3.0) - 2.0 - """ - __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' - - def __init__(self, name, doc=None): - self.name = self.__name__ = name - self.funcs = {} - self.doc = doc - - self._cache = {} - - def register(self, *types, **kwargs): - """ register dispatcher with new implementation - >>> f = Dispatcher('f') - >>> @f.register(int) - ... def inc(x): - ... return x + 1 - >>> @f.register(float) - ... def dec(x): - ... return x - 1 - >>> @f.register(list) - ... @f.register(tuple) - ... def reverse(x): - ... return x[::-1] - >>> f(1) - 2 - >>> f(1.0) - 0.0 - >>> f([1, 2, 3]) - [3, 2, 1] - """ - def _df(func): - self.add(types, func, **kwargs) # type: ignore[call-arg] - return func - return _df - - @classmethod - def get_func_params(cls, func): - if hasattr(inspect, "signature"): - sig = inspect.signature(func) - return sig.parameters.values() - - @classmethod - def get_func_annotations(cls, func): - """ get annotations of function positional parameters - """ - params = cls.get_func_params(func) - if params: - Parameter = inspect.Parameter - - params = (param for param in params - if param.kind in - (Parameter.POSITIONAL_ONLY, - Parameter.POSITIONAL_OR_KEYWORD)) - - annotations = tuple( - param.annotation - for param in params) - - if all(ann is not Parameter.empty for ann in annotations): - return annotations - - def add(self, signature, func): - """ Add new types/method pair to dispatcher - >>> D = Dispatcher('add') - >>> D.add((int, int), lambda x, y: x + y) - >>> D.add((float, float), lambda x, y: x + y) - >>> D(1, 2) - 3 - >>> D(1, 2.0) - Traceback (most recent call last): - ... - NotImplementedError: Could not find signature for add: - >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback - >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs - >>> # as inputs. See ``ambiguity_warn`` for an example. - """ - # Handle annotations - if not signature: - annotations = self.get_func_annotations(func) - if annotations: - signature = annotations - - # Handle union types - if any(isinstance(typ, tuple) for typ in signature): - for typs in expand_tuples(signature): - self.add(typs, func) - return - - new_signature = [] - - for index, typ in enumerate(signature, start=1): - if not isinstance(typ, (type, list)): - str_sig = ', '.join(c.__name__ if isinstance(c, type) - else str(c) for c in signature) - raise TypeError("Tried to dispatch on non-type: %s\n" - "In signature: <%s>\n" - "In function: %s" % - (typ, str_sig, self.name)) - - # handle variadic signatures - if isinstance(typ, list): - if index != len(signature): - raise TypeError( - 'Variadic signature must be the last element' - ) - - if len(typ) != 1: - raise TypeError( - 'Variadic signature must contain exactly one element. ' - 'To use a variadic union type place the desired types ' - 'inside of a tuple, e.g., [(int, str)]' - ) - new_signature.append(Variadic[typ[0]]) - else: - new_signature.append(typ) - - self.funcs[tuple(new_signature)] = func - self._cache.clear() - - try: - del self._ordering - except AttributeError: - pass - - @property - def ordering(self): - try: - return self._ordering - except AttributeError: - return self.reorder() - - def reorder(self, on_ambiguity=ambiguity_warn): - self._ordering = od = ordering(self.funcs) - amb = ambiguities(self.funcs) - if amb: - on_ambiguity(self, amb) - return od - - def __call__(self, *args, **kwargs): - types = tuple([type(arg) for arg in args]) - try: - func = self._cache[types] - except KeyError: - func = self.dispatch(*types) - if not func: - raise NotImplementedError( - 'Could not find signature for %s: <%s>' % - (self.name, str_signature(types))) - self._cache[types] = func - try: - return func(*args, **kwargs) - - except MDNotImplementedError: - funcs = self.dispatch_iter(*types) - next(funcs) # burn first - for func in funcs: - try: - return func(*args, **kwargs) - except MDNotImplementedError: - pass - - raise NotImplementedError( - "Matching functions for " - "%s: <%s> found, but none completed successfully" % ( - self.name, str_signature(types),),) - - def __str__(self): - return "" % self.name - __repr__ = __str__ - - def dispatch(self, *types): - """Deterimine appropriate implementation for this type signature - This method is internal. Users should call this object as a function. - Implementation resolution occurs within the ``__call__`` method. - >>> # xdoctest: +SKIP - >>> from multipledispatch import dispatch - >>> @dispatch(int) - ... def inc(x): - ... return x + 1 - >>> implementation = inc.dispatch(int) - >>> implementation(3) - 4 - >>> print(inc.dispatch(float)) - None - See Also: - ``multipledispatch.conflict`` - module to determine resolution order - """ - - if types in self.funcs: - return self.funcs[types] - - try: - return next(self.dispatch_iter(*types)) - except StopIteration: - return None - - def dispatch_iter(self, *types): - - n = len(types) - for signature in self.ordering: - if len(signature) == n and all(map(issubclass, types, signature)): - result = self.funcs[signature] - yield result - elif len(signature) and isvariadic(signature[-1]): - if variadic_signature_matches(types, signature): - result = self.funcs[signature] - yield result - - def resolve(self, types): - """ Deterimine appropriate implementation for this type signature - .. deprecated:: 0.4.4 - Use ``dispatch(*types)`` instead - """ - warn("resolve() is deprecated, use dispatch(*types)", - DeprecationWarning) - - return self.dispatch(*types) - - def __getstate__(self): - return {'name': self.name, - 'funcs': self.funcs} - - def __setstate__(self, d): - self.name = d['name'] - self.funcs = d['funcs'] - self._ordering = ordering(self.funcs) - self._cache = {} - - @property - def __doc__(self): - docs = ["Multiply dispatched method: %s" % self.name] - - if self.doc: - docs.append(self.doc) - - other = [] - for sig in self.ordering[::-1]: - func = self.funcs[sig] - if func.__doc__: - s = 'Inputs: <%s>\n' % str_signature(sig) - s += '-' * len(s) + '\n' - s += func.__doc__.strip() - docs.append(s) - else: - other.append(str_signature(sig)) - - if other: - docs.append('Other signatures:\n ' + '\n '.join(other)) - - return '\n\n'.join(docs) - - def _help(self, *args): - return self.dispatch(*map(type, args)).__doc__ - - def help(self, *args, **kwargs): - """ Print docstring for the function corresponding to inputs """ - print(self._help(*args)) - - def _source(self, *args): - func = self.dispatch(*map(type, args)) - if not func: - raise TypeError("No function found") - return source(func) - - def source(self, *args, **kwargs): - """ Print source code for the function corresponding to inputs """ - print(self._source(*args)) - - -def source(func): - s = 'File: %s\n\n' % inspect.getsourcefile(func) - s = s + inspect.getsource(func) - return s - - -class MethodDispatcher(Dispatcher): - """ Dispatch methods based on type signature - See Also: - Dispatcher - """ - __slots__ = ('obj', 'cls') - - @classmethod - def get_func_params(cls, func): - if hasattr(inspect, "signature"): - sig = inspect.signature(func) - return itl.islice(sig.parameters.values(), 1, None) - - def __get__(self, instance, owner): - self.obj = instance - self.cls = owner - return self - - def __call__(self, *args, **kwargs): - types = tuple([type(arg) for arg in args]) - func = self.dispatch(*types) - if not func: - raise NotImplementedError('Could not find signature for %s: <%s>' % - (self.name, str_signature(types))) - return func(self.obj, *args, **kwargs) - - -def str_signature(sig): - """ String representation of type signature - >>> str_signature((int, float)) - 'int, float' - """ - return ', '.join(cls.__name__ for cls in sig) - - -def warning_text(name, amb): - """ The text for ambiguity warnings """ - text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) - text += "The following signatures may result in ambiguous behavior:\n" - for pair in amb: - text += "\t" + \ - ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" - text += "\n\nConsider making the following additions:\n\n" - text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) - + ')\ndef %s(...)' % name for s in amb]) - return text diff --git a/pippy/fx/experimental/unification/multipledispatch/utils.py b/pippy/fx/experimental/unification/multipledispatch/utils.py deleted file mode 100644 index 3e427d2f4..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/utils.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from collections import OrderedDict - -__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] - -def raises(err, lamda): - try: - lamda() - return False - except err: - return True - - -def expand_tuples(L): - """ - >>> expand_tuples([1, (2, 3)]) - [(1, 2), (1, 3)] - >>> expand_tuples([1, 2]) - [(1, 2)] - """ - if not L: - return [()] - elif not isinstance(L[0], tuple): - rest = expand_tuples(L[1:]) - return [(L[0],) + t for t in rest] - else: - rest = expand_tuples(L[1:]) - return [(item,) + t for t in rest for item in L[0]] - - -# Taken from theano/theano/gof/sched.py -# Avoids licensing issues because this was written by Matthew Rocklin -def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) - inputs: - edges - a dict of the form {a: {b, c}} where b and c depend on a - outputs: - L - an ordered list of nodes that satisfy the dependencies of edges - >>> _toposort({1: (2, 3), 2: (3, )}) - [1, 2, 3] - >>> # Closely follows the wikipedia page [2] - >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", - >>> # Communications of the ACM - >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms - """ - incoming_edges = reverse_dict(edges) - incoming_edges = OrderedDict((k, set(val)) - for k, val in incoming_edges.items()) - S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) - L = [] - - while S: - n, _ = S.popitem() - L.append(n) - for m in edges.get(n, ()): - assert n in incoming_edges[m] - incoming_edges[m].remove(n) - if not incoming_edges[m]: - S[m] = None - if any(incoming_edges.get(v, None) for v in edges): - raise ValueError("Input has cycles") - return L - - -def reverse_dict(d): - """Reverses direction of dependence dict - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} - >>> reverse_dict(d) # doctest: +SKIP - {1: ('a',), 2: ('a', 'b'), 3: ('b',)} - :note: dict order are not deterministic. As we iterate on the - input dict, it make the output of this function depend on the - dict order. So this function output order should be considered - as undeterministic. - """ - result = OrderedDict() # type: ignore[var-annotated] - for key in d: - for val in d[key]: - result[val] = result.get(val, tuple()) + (key, ) - return result - - -# Taken from toolz -# Avoids licensing issues because this version was authored by Matthew Rocklin -def groupby(func, seq): - """ Group a collection by a key function - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] - >>> groupby(len, names) # doctest: +SKIP - {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} - >>> iseven = lambda x: x % 2 == 0 - >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP - {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} - See Also: - ``countby`` - """ - - d = OrderedDict() # type: ignore[var-annotated] - for item in seq: - key = func(item) - if key not in d: - d[key] = list() - d[key].append(item) - return d - - -def typename(type): - """Get the name of `type`. - Parameters - ---------- - type : Union[Type, Tuple[Type]] - Returns - ------- - str - The name of `type` or a tuple of the names of the types in `type`. - Examples - -------- - >>> typename(int) - 'int' - >>> typename((int, float)) - '(int, float)' - """ - try: - return type.__name__ - except AttributeError: - if len(type) == 1: - return typename(*type) - return '(%s)' % ', '.join(map(typename, type)) diff --git a/pippy/fx/experimental/unification/multipledispatch/variadic.py b/pippy/fx/experimental/unification/multipledispatch/variadic.py deleted file mode 100644 index 5802302ee..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/variadic.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import six - -from .utils import typename - -__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] - -class VariadicSignatureType(type): - # checking if subclass is a subclass of self - def __subclasscheck__(cls, subclass): - other_type = (subclass.variadic_type if isvariadic(subclass) - else (subclass,)) - return subclass is cls or all( - issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined] - ) - - def __eq__(cls, other): - """ - Return True if other has the same variadic type - Parameters - ---------- - other : object (type) - The object (type) to check - Returns - ------- - bool - Whether or not `other` is equal to `self` - """ - return (isvariadic(other) and - set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined] - - def __hash__(cls): - return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined] - - -def isvariadic(obj): - """Check whether the type `obj` is variadic. - Parameters - ---------- - obj : type - The type to check - Returns - ------- - bool - Whether or not `obj` is variadic - Examples - -------- - >>> isvariadic(int) - False - >>> isvariadic(Variadic[int]) - True - """ - return isinstance(obj, VariadicSignatureType) - - -class VariadicSignatureMeta(type): - """A metaclass that overrides ``__getitem__`` on the class. This is used to - generate a new type for Variadic signatures. See the Variadic class for - examples of how this behaves. - """ - def __getitem__(cls, variadic_type): - if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): - raise ValueError("Variadic types must be type or tuple of types" - " (Variadic[int] or Variadic[(int, float)]") - - if not isinstance(variadic_type, tuple): - variadic_type = variadic_type, - return VariadicSignatureType( - 'Variadic[%s]' % typename(variadic_type), - (), - dict(variadic_type=variadic_type, __slots__=()) - ) - - -class Variadic(six.with_metaclass(VariadicSignatureMeta)): - """A class whose getitem method can be used to generate a new type - representing a specific variadic signature. - Examples - -------- - >>> Variadic[int] # any number of int arguments - >>> # xdoctest: +SKIP - - >>> Variadic[(int, str)] # any number of one of int or str arguments - - >>> issubclass(int, Variadic[int]) - True - >>> issubclass(int, Variadic[(int, str)]) - True - >>> issubclass(str, Variadic[(int, str)]) - True - >>> issubclass(float, Variadic[(int, str)]) - False - """ diff --git a/pippy/fx/experimental/unification/unification_tools.py b/pippy/fx/experimental/unification/unification_tools.py deleted file mode 100644 index d2ddc1df3..000000000 --- a/pippy/fx/experimental/unification/unification_tools.py +++ /dev/null @@ -1,393 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import collections -import operator -from functools import reduce -from collections.abc import Mapping - -__all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap', - 'valfilter', 'keyfilter', 'itemfilter', - 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in') - -def _get_factory(f, kwargs): - factory = kwargs.pop('factory', dict) - if kwargs: - raise TypeError("{}() got an unexpected keyword argument " - "'{}'".format(f.__name__, kwargs.popitem()[0])) - return factory - - -def merge(*dicts, **kwargs): - """ Merge a collection of dictionaries - - >>> merge({1: 'one'}, {2: 'two'}) - {1: 'one', 2: 'two'} - - Later dictionaries have precedence - - >>> merge({1: 2, 3: 4}, {3: 3, 4: 4}) - {1: 2, 3: 3, 4: 4} - - See Also: - merge_with - """ - if len(dicts) == 1 and not isinstance(dicts[0], Mapping): - dicts = dicts[0] - factory = _get_factory(merge, kwargs) - - rv = factory() - for d in dicts: - rv.update(d) - return rv - - -def merge_with(func, *dicts, **kwargs): - """ Merge dictionaries and apply function to combined values - - A key may occur in more than one dict, and all values mapped from the key - will be passed to the function as a list, such as func([val1, val2, ...]). - - >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20}) - {1: 11, 2: 22} - - >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP - {1: 1, 2: 2, 3: 30} - - See Also: - merge - """ - if len(dicts) == 1 and not isinstance(dicts[0], Mapping): - dicts = dicts[0] - factory = _get_factory(merge_with, kwargs) - - result = factory() - for d in dicts: - for k, v in d.items(): - if k not in result: - result[k] = [v] - else: - result[k].append(v) - return valmap(func, result, factory) - - -def valmap(func, d, factory=dict): - """ Apply function to values of dictionary - - >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} - >>> valmap(sum, bills) # doctest: +SKIP - {'Alice': 65, 'Bob': 45} - - See Also: - keymap - itemmap - """ - rv = factory() - rv.update(zip(d.keys(), map(func, d.values()))) - return rv - - -def keymap(func, d, factory=dict): - """ Apply function to keys of dictionary - - >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} - >>> keymap(str.lower, bills) # doctest: +SKIP - {'alice': [20, 15, 30], 'bob': [10, 35]} - - See Also: - valmap - itemmap - """ - rv = factory() - rv.update(zip(map(func, d.keys()), d.values())) - return rv - - -def itemmap(func, d, factory=dict): - """ Apply function to items of dictionary - - >>> accountids = {"Alice": 10, "Bob": 20} - >>> itemmap(reversed, accountids) # doctest: +SKIP - {10: "Alice", 20: "Bob"} - - See Also: - keymap - valmap - """ - rv = factory() - rv.update(map(func, d.items())) - return rv - - -def valfilter(predicate, d, factory=dict): - """ Filter items in dictionary by value - - >>> iseven = lambda x: x % 2 == 0 - >>> d = {1: 2, 2: 3, 3: 4, 4: 5} - >>> valfilter(iseven, d) - {1: 2, 3: 4} - - See Also: - keyfilter - itemfilter - valmap - """ - rv = factory() - for k, v in d.items(): - if predicate(v): - rv[k] = v - return rv - - -def keyfilter(predicate, d, factory=dict): - """ Filter items in dictionary by key - - >>> iseven = lambda x: x % 2 == 0 - >>> d = {1: 2, 2: 3, 3: 4, 4: 5} - >>> keyfilter(iseven, d) - {2: 3, 4: 5} - - See Also: - valfilter - itemfilter - keymap - """ - rv = factory() - for k, v in d.items(): - if predicate(k): - rv[k] = v - return rv - - -def itemfilter(predicate, d, factory=dict): - """ Filter items in dictionary by item - - >>> def isvalid(item): - ... k, v = item - ... return k % 2 == 0 and v < 4 - - >>> d = {1: 2, 2: 3, 3: 4, 4: 5} - >>> itemfilter(isvalid, d) - {2: 3} - - See Also: - keyfilter - valfilter - itemmap - """ - rv = factory() - for item in d.items(): - if predicate(item): - k, v = item - rv[k] = v - return rv - - -def assoc(d, key, value, factory=dict): - """ Return a new dict with new key value pair - - New dict has d[key] set to value. Does not modify the initial dictionary. - - >>> assoc({'x': 1}, 'x', 2) - {'x': 2} - >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP - {'x': 1, 'y': 3} - """ - d2 = factory() - d2.update(d) - d2[key] = value - return d2 - - -def dissoc(d, *keys, **kwargs): - """ Return a new dict with the given key(s) removed. - - New dict has d[key] deleted for each supplied key. - Does not modify the initial dictionary. - - >>> dissoc({'x': 1, 'y': 2}, 'y') - {'x': 1} - >>> dissoc({'x': 1, 'y': 2}, 'y', 'x') - {} - >>> dissoc({'x': 1}, 'y') # Ignores missing keys - {'x': 1} - """ - factory = _get_factory(dissoc, kwargs) - d2 = factory() - - if len(keys) < len(d) * .6: - d2.update(d) - for key in keys: - if key in d2: - del d2[key] - else: - remaining = set(d) - remaining.difference_update(keys) - for k in remaining: - d2[k] = d[k] - return d2 - - -def assoc_in(d, keys, value, factory=dict): - """ Return a new dict with new, potentially nested, key value pair - - >>> purchase = {'name': 'Alice', - ... 'order': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP - {'credit card': '5555-1234-1234-1234', - 'name': 'Alice', - 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} - """ - return update_in(d, keys, lambda x: value, value, factory) - - -def update_in(d, keys, func, default=None, factory=dict): - """ Update value in a (potentially) nested dictionary - - inputs: - d - dictionary on which to operate - keys - list or tuple giving the location of the value to be changed in d - func - function to operate on that value - - If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the - original dictionary with v replaced by func(v), but does not mutate the - original dictionary. - - If k0 is not a key in d, update_in creates nested dictionaries to the depth - specified by the keys, with the innermost value set to func(default). - - >>> inc = lambda x: x + 1 - >>> update_in({'a': 0}, ['a'], inc) - {'a': 1} - - >>> transaction = {'name': 'Alice', - ... 'purchase': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP - {'credit card': '5555-1234-1234-1234', - 'name': 'Alice', - 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} - - >>> # updating a value when k0 is not in d - >>> update_in({}, [1, 2, 3], str, default="bar") - {1: {2: {3: 'bar'}}} - >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0) - {1: 'foo', 2: {3: {4: 1}}} - """ - ks = iter(keys) - k = next(ks) - - rv = inner = factory() - rv.update(d) - - for key in ks: - if k in d: - d = d[k] - dtemp = factory() - dtemp.update(d) - else: - d = dtemp = factory() - - inner[k] = inner = dtemp - k = key - - if k in d: - inner[k] = func(d[k]) - else: - inner[k] = func(default) - return rv - - -def get_in(keys, coll, default=None, no_default=False): - """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. - - If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless - ``no_default`` is specified, then it raises KeyError or IndexError. - - ``get_in`` is a generalization of ``operator.getitem`` for nested data - structures such as dictionaries and lists. - - >>> transaction = {'name': 'Alice', - ... 'purchase': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> get_in(['purchase', 'items', 0], transaction) - 'Apple' - >>> get_in(['name'], transaction) - 'Alice' - >>> get_in(['purchase', 'total'], transaction) - >>> get_in(['purchase', 'items', 'apple'], transaction) - >>> get_in(['purchase', 'items', 10], transaction) - >>> get_in(['purchase', 'total'], transaction, 0) - 0 - >>> get_in(['y'], {}, no_default=True) - Traceback (most recent call last): - ... - KeyError: 'y' - - See Also: - itertoolz.get - operator.getitem - """ - try: - return reduce(operator.getitem, keys, coll) - except (KeyError, IndexError, TypeError): - if no_default: - raise - return default - -def getter(index): - if isinstance(index, list): - if len(index) == 1: - index = index[0] - return lambda x: (x[index],) - elif index: - return operator.itemgetter(*index) - else: - return lambda x: () - else: - return operator.itemgetter(index) - -def groupby(key, seq): - """ Group a collection by a key function - - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] - >>> groupby(len, names) # doctest: +SKIP - {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} - - >>> iseven = lambda x: x % 2 == 0 - >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP - {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} - - Non-callable keys imply grouping on a member. - - >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'}, - ... {'name': 'Bob', 'gender': 'M'}, - ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP - {'F': [{'gender': 'F', 'name': 'Alice'}], - 'M': [{'gender': 'M', 'name': 'Bob'}, - {'gender': 'M', 'name': 'Charlie'}]} - - Not to be confused with ``itertools.groupby`` - - See Also: - countby - """ - if not callable(key): - key = getter(key) - d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] - for item in seq: - d[key(item)](item) - rv = {} - for k, v in d.items(): - rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] - return rv - -def first(seq): - """ The first element in a sequence - - >>> first('ABC') - 'A' - """ - return next(iter(seq)) diff --git a/pippy/fx/experimental/unification/utils.py b/pippy/fx/experimental/unification/utils.py deleted file mode 100644 index a54ad565d..000000000 --- a/pippy/fx/experimental/unification/utils.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] -def hashable(x): - try: - hash(x) - return True - except TypeError: - return False - - -def transitive_get(key, d): - """ Transitive dict.get - >>> d = {1: 2, 2: 3, 3: 4} - >>> d.get(1) - 2 - >>> transitive_get(1, d) - 4 - """ - while hashable(key) and key in d: - key = d[key] - return key - - -def raises(err, lamda): - try: - lamda() - return False - except err: - return True - - -# Taken from theano/theano/gof/sched.py -# Avoids licensing issues because this was written by Matthew Rocklin -def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) - inputs: - edges - a dict of the form {a: {b, c}} where b and c depend on a - outputs: - L - an ordered list of nodes that satisfy the dependencies of edges - >>> _toposort({1: (2, 3), 2: (3, )}) - >>> # xdoctest: +SKIP - [1, 2, 3] - Closely follows the wikipedia page [2] - [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", - Communications of the ACM - [2] http://en.wikipedia.org/wiki/Toposort#Algorithms - """ - incoming_edges = reverse_dict(edges) - incoming_edges = dict((k, set(val)) for k, val in incoming_edges.items()) - S = set((v for v in edges if v not in incoming_edges)) - L = [] - - while S: - n = S.pop() - L.append(n) - for m in edges.get(n, ()): - assert n in incoming_edges[m] - incoming_edges[m].remove(n) - if not incoming_edges[m]: - S.add(m) - if any(incoming_edges.get(v, None) for v in edges): - raise ValueError("Input has cycles") - return L - - -def reverse_dict(d): - """Reverses direction of dependence dict - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} - >>> reverse_dict(d) # doctest: +SKIP - {1: ('a',), 2: ('a', 'b'), 3: ('b',)} - :note: dict order are not deterministic. As we iterate on the - input dict, it make the output of this function depend on the - dict order. So this function output order should be considered - as undeterministic. - """ - result = {} # type: ignore[var-annotated] - for key in d: - for val in d[key]: - result[val] = result.get(val, tuple()) + (key, ) - return result - - -def xfail(func): - try: - func() - raise Exception("XFailed test passed") # pragma:nocover - except Exception: - pass - - -def freeze(d): - """ Freeze container to hashable form - >>> freeze(1) - 1 - >>> freeze([1, 2]) - (1, 2) - >>> freeze({1: 2}) # doctest: +SKIP - frozenset([(1, 2)]) - """ - if isinstance(d, dict): - return frozenset(map(freeze, d.items())) - if isinstance(d, set): - return frozenset(map(freeze, d)) - if isinstance(d, (tuple, list)): - return tuple(map(freeze, d)) - return d diff --git a/pippy/fx/experimental/unification/variable.py b/pippy/fx/experimental/unification/variable.py deleted file mode 100644 index e836d7653..000000000 --- a/pippy/fx/experimental/unification/variable.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from contextlib import contextmanager -from .utils import hashable -from .dispatch import dispatch - -_global_logic_variables = set() # type: ignore[var-annotated] -_glv = _global_logic_variables - - -class Var(object): - """ Logic Variable """ - - _id = 1 - - def __new__(cls, *token): - if len(token) == 0: - token = "_%s" % Var._id # type: ignore[assignment] - Var._id += 1 - elif len(token) == 1: - token = token[0] - - obj = object.__new__(cls) - obj.token = token # type: ignore[attr-defined] - return obj - - def __str__(self): - return "~" + str(self.token) # type: ignore[attr-defined] - __repr__ = __str__ - - def __eq__(self, other): - return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined] - - def __hash__(self): - return hash((type(self), self.token)) # type: ignore[attr-defined] - - -def var(): - return lambda *args: Var(*args) - -def vars(): - return lambda n: [var() for i in range(n)] - - -@dispatch(Var) -def isvar(v): - return True - -isvar - -@dispatch(object) # type: ignore[no-redef] -def isvar(o): - return not not _glv and hashable(o) and o in _glv - - -@contextmanager -def variables(*variables): - """ Context manager for logic variables - >>> from __future__ import with_statement - >>> with variables(1): - ... print(isvar(1)) - True - >>> print(isvar(1)) - False - >>> # xdoctest: +SKIP("undefined vars") - >>> # Normal approach - >>> from unification import unify - >>> x = var('x') - >>> unify(x, 1) - {~x: 1} - >>> # Context Manager approach - >>> with variables('x'): - ... print(unify('x', 1)) - {'x': 1} - """ - old_global_logic_variables = _global_logic_variables.copy() - _global_logic_variables.update(set(variables)) - try: - yield - finally: - _global_logic_variables.clear() - _global_logic_variables.update(old_global_logic_variables) diff --git a/pippy/fx/experimental/unify_refinements.py b/pippy/fx/experimental/unify_refinements.py deleted file mode 100644 index 07f9f4aca..000000000 --- a/pippy/fx/experimental/unify_refinements.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.experimental.graph_gradual_typechecker import Refine -from pippy.fx.tensor_type import TensorType -from pippy.fx.experimental.unification import Var, unify # type: ignore[attr-defined] - - -def infer_symbolic_types_single_pass(traced): - """ - Calls our symbolic inferencer once. - """ - r = Refine(traced) - r.refine() - mgu = unify_eq(r.constraints) - substitute_all_types(traced.graph, mgu) - -def infer_symbolic_types(traced): - """ - Calls our symbolic inferencer twice. - This is useful when one pass is not enough - to infer all the information such as the case - for braodcasting. - """ - r = Refine(traced) - r.refine() - mgu = unify_eq(r.constraints) - substitute_all_types(traced.graph, mgu) - - r = Refine(traced) - r.refine() - mgu = unify_eq(r.constraints) - substitute_all_types(traced.graph, mgu) - - r.symbolic_relations() - -def convert_eq(list_of_eq): - """ - Convert equality constraints in the right format - to be used by unification library. - """ - lhs = [] - rhs = [] - for eq in list_of_eq: - lhs.append(eq.lhs) - rhs.append(eq.rhs) - return tuple(lhs), tuple(rhs) - - -def unify_eq(list_of_eq): - """ - Apply unification to a set of - equality constraints - """ - lhs, rhs = convert_eq(list_of_eq) - return unify(lhs, rhs) - - -def substitute_solution_one_type(mapping, t): - """ - Apply the most general unifier to a type - """ - if isinstance(t, Var): - if t in mapping.keys(): - return mapping[t] - else: - return t - - elif isinstance(t, TensorType): - new_type = [] - for typ in t.__args__: - if typ in mapping.keys(): - new_type.append(mapping[typ]) - else: - new_type.append(typ) - return TensorType(tuple(new_type)) - - elif isinstance(t, list): - new_type = [] - for typ in t: - new_type.append(substitute_solution_one_type(mapping, typ)) - return new_type - - elif isinstance(t, tuple): - new_type = [] - for typ in t: - new_type.append(substitute_solution_one_type(mapping, typ)) - return tuple(new_type) - - else: - return t - - -def substitute_all_types(graph, mapping): - """ - Apply the most general unifier to all types in a graph - till reaching a fixed point. If the input and output graph - are the same, we converge. - """ - flag = True - while flag: - flag = False - for k in mapping: - old_mapping_val = mapping[k] - if mapping[k] in mapping.keys(): - new_key = mapping[k] - mapping[k] = mapping[new_key] - if old_mapping_val != mapping[k]: - flag = True - - for n in graph.nodes: - n.type = substitute_solution_one_type(mapping, n.type) - -def check_for_type_equality(g1, g2): - """ - A check equality to be used in fixed points. - We do not use graph equality but instead type - equality. - """ - for n, m in zip(g1.nodes, g2.nodes): - if n.type != m.type: - return False - return True diff --git a/pippy/fx/graph.py b/pippy/fx/graph.py deleted file mode 100644 index 2f30a750b..000000000 --- a/pippy/fx/graph.py +++ /dev/null @@ -1,1507 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import builtins -import contextlib -import copy -import inspect -import keyword -import math -import re -import warnings -from contextlib import contextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type - -import torch -import torch.utils._pytree as pytree - -import pippy -import pippy.fx -from . import _pytree as fx_pytree -from ._compatibility import compatibility -from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name - -__all__ = ["PythonCode", "CodeGen", "Graph"] - -if TYPE_CHECKING: - from .graph_module import GraphModule # noqa: F401 - from ._symbolic_trace import Tracer # noqa: F401 - - -# Mapping of builtins to their `typing` equivalent. -_origin_type_map = { - list: List, - dict: Dict, - set: Set, - frozenset: FrozenSet, - tuple: Tuple, -} - - -# Signature for functions thattransforms the body (`list[str]`) of the -# generated code -TransformCodeFunc = Callable[[List[str]], List[str]] - - -class _CustomBuiltin(NamedTuple): - """Additional objs that we add to every graph's globals. - - The repr() for some standard library objects is not valid Python code without - an import. For common objects of this sort, we bundle them in the globals of - every FX graph. - """ - # How to import this object from the standard library. - import_str: str - # The actual object, produced from that import string. - obj: Any - -_custom_builtins: Dict[str, _CustomBuiltin] = {} - - -def _register_custom_builtin(name: str, import_str: str, obj: Any): - _custom_builtins[name] = _CustomBuiltin(import_str, obj) - - -_register_custom_builtin('inf', 'from math import inf', math.inf) -_register_custom_builtin('nan', 'from math import nan', math.nan) -_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) -_register_custom_builtin('torch', 'import torch', torch) -_register_custom_builtin('pippy', 'import pippy', pippy) -_register_custom_builtin('device', 'from torch import device', torch.device) -_register_custom_builtin('fx_pytree', 'import pippy.fx._pytree as fx_pytree', fx_pytree) -_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) - - -def _is_magic(x: str) -> bool: - return x.startswith('__') and x.endswith('__') - - -def _snake_case(s: str) -> str: - """ - Transforms the given string ``s`` to a Python-style variable name - - Examples: - ``mod.snake_case`` -> ``mod.snake_case`` - ``mod.pascalCase``-> ``mod.pascal_case`` - ``mod.ALL_CAPS`` -> ``mod.all_caps`` - """ - chars = [] - prev_lower = False - for c in s: - if prev_lower and c.isupper(): - chars.append('_') - chars.append(c.lower()) - prev_lower = c.islower() - return ''.join(chars) - - -def _is_from_torch(obj: Any) -> bool: - module_name = getattr(obj, '__module__', None) - if module_name is not None: - if module_name.startswith('pippy.fx'): - return True - - base_module = module_name.partition('.')[0] - return base_module == 'torch' - - name = getattr(obj, '__name__', None) - # exclude torch because torch.torch.torch.torch works. idk mang - if name is not None and name != 'torch': - for guess in [torch, torch.nn.functional]: - if getattr(guess, name, None) is obj: - return True - - return False - - -class _Namespace: - """A context for associating names uniquely with objects. - - The following invariants are enforced: - - Each object gets a single name. - - Each name is unique within a given namespace. - - Names generated do not shadow builtins, unless the object is indeed that builtin. - """ - def __init__(self): - self._obj_to_name: Dict[Any, str] = {} - self._unassociated_names = set() - self._used_names: Dict[str, int] = {} - - self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') - self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") - - def create_name(self, candidate: str, obj: Optional[Any]) -> str: - """Create a unique name. - - Arguments: - candidate: used as the basis for the unique name, relevant to the user. - obj: If not None, an object that will be associated with the unique name. - """ - if obj is not None and obj in self._obj_to_name: - return self._obj_to_name[obj] - - # delete all characters that are illegal in a Python identifier - candidate = self._illegal_char_regex.sub('_', candidate) - - if candidate[0].isdigit(): - candidate = f'_{candidate}' - - match = self._name_suffix_regex.match(candidate) - if match is None: - base = candidate - num = None - else: - base, num_str = match.group(1, 2) - num = int(num_str) - - candidate = base if num is None else f'{base}_{num}' - num = num if num else 0 - - while candidate in self._used_names or self._is_illegal_name(candidate, obj): - num += 1 - candidate = f'{base}_{num}' - - self._used_names.setdefault(candidate, 0) - if obj is None: - self._unassociated_names.add(candidate) - else: - self._obj_to_name[obj] = candidate - return candidate - - def associate_name_with_obj(self, name: str, obj: Any): - """Associate a unique name with an object. - - Neither `name` nor `obj` should be associated already. - """ - assert obj not in self._obj_to_name - assert name in self._unassociated_names - self._obj_to_name[obj] = name - self._unassociated_names.remove(name) - - def _is_illegal_name(self, name: str, obj: Any) -> bool: - # 1. keywords are never allowed as names. - if name in keyword.kwlist: - return True - - # 2. Can't shadow a builtin name, unless you *are* that builtin. - if name in builtins.__dict__: - return obj is not builtins.__dict__[name] - - # 3. Can't shadow our custom builtins either - if name in _custom_builtins: - return obj is not _custom_builtins[name].obj - - return False - - -@compatibility(is_backward_compatible=True) -@dataclass -class PythonCode: - """ - Represents all the information necessary to exec or save a graph as Python code. - """ - # Python source code for the forward function definition. - src: str - # Values in global scope during exection of `src_def`. - globals: Dict[str, Any] - - -def _format_target(base: str, target: str) -> str: - elems = target.split('.') - r = base - for e in elems: - if not e.isidentifier(): - r = f'getattr({r}, "{e}")' - else: - r = f'{r}.{e}' - return r - -class _InsertPoint: - def __init__(self, graph, new_insert): - self.graph = graph - self.orig_insert, graph._insert = graph._insert, new_insert - - def __enter__(self): - pass - - def __exit__(self, type, value, tb): - self.graph._insert = self.orig_insert - -class _node_list: - def __init__(self, graph: 'Graph', direction: str = '_next'): - assert direction in ['_next', '_prev'] - self.graph = graph - self.direction = direction - - def __len__(self): - return self.graph._len - - def __iter__(self): - root, direction = self.graph._root, self.direction - cur = getattr(root, direction) - while cur is not root: - if not cur._erased: - yield cur - cur = getattr(cur, direction) - - def __reversed__(self): - return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') - -class _PyTreeInfo(NamedTuple): - """ - Contains extra info stored when we're using Pytrees - """ - orig_args: List[str] - in_spec: pytree.TreeSpec - out_spec: Optional[pytree.TreeSpec] - -@compatibility(is_backward_compatible=False) -class CodeGen(object): - def __init__(self): - self._body_transformer: Optional[TransformCodeFunc] = None - - def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: - """ - Given the free variables and a return annotation, generates the beginning of the FX function. - By default, `gen_fn_def(['a', 'b'], '') == 'def forward(a, b):'` - """ - # If the original function didn't have self as its first argument, we - # would have added it. - if len(free_vars) == 0 or free_vars[0] != 'self': - free_vars.insert(0, 'self') - return f"def forward({', '.join(free_vars)}){maybe_return_annotation}:" - - def generate_output(self, output_args: Argument) -> str: - """ - Given the output arguments, generates the return statement of the FX function. - Note: The returned statement should not be indented. - """ - return f'return {repr(output_args)}' - - def process_inputs(self, *args: Any) -> Any: - """ - Transforms the inputs so that the graph can take them as arguments, as - non-default codegen may result in the inputs to the function being - different from the inputs to the graph. - - If the graph was directly runnable, this invariant should hold true - `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` - """ - return args - - def process_outputs(self, outputs: Any) -> Any: - """ - Transforms the outputs of the graph to be identical to the codegen. - - See ``process_inputs`` for more details. - """ - return outputs - - def additional_globals(self) -> List[Tuple[str, Any]]: - """ - If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. - For example, return ['List', typing.List] if you need ``List`` in the global context. - """ - return [] - - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode: - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} - - # Wrap string in list to pass by reference - maybe_return_annotation : List[str] = [''] - - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. - - We call this for names that reference objects external to the - Graph, like functions or types. - - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj - return global_name - globals_[global_name] = obj - return global_name - - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) - - def type_repr(o : Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' - - typename = _type_repr(o) - - if hasattr(o, '__origin__'): - # This is a generic type, e.g. typing.List[torch.Tensor] - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) - - if hasattr(o, '__args__'): - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] - - if len(args) == 0: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python < 3.9 - return origin_typename - - return f'{origin_typename}[{",".join(args)}]' - else: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python 3.9+ - return origin_typename - - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): - # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): - qualified_name = _get_qualified_name(type(arg)) - global_name = add_global(qualified_name, type(arg)) - return f"{global_name}{repr(tuple(arg))}" - return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) - if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' - return args_s or kwargs_s - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use : Dict[Node, Node] = {} - user_to_last_uses : Dict[Node, List[Node]] = {} - - def register_last_uses(n : Node, user : Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - def delete_unused_values(user : Node): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == 'placeholder': - return - if user.op == 'output': - body.append('\n') - return - nodes_to_delete = user_to_last_uses.get(user, []) - if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') - else: - body.append('\n') - - prev_stacktrace = None - - def append_stacktrace_summary(node : Node): - """ - Append a summary of the stacktrace to the generated code. This is - useful for debugging. - """ - nonlocal prev_stacktrace - pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") - - if node.op not in {'placeholder', 'output'}: - if node.stack_trace: - if node.stack_trace != prev_stacktrace: - prev_stacktrace = node.stack_trace - - lines = node.stack_trace.strip().split('\n') - idx = 0 - context_lines = [] - while idx < len(lines): - line = lines[idx].strip() - if line.startswith('File '): - break - context_lines.append(line) - idx += 1 - - summary_lines = [] - if context_lines: - summary_lines.append(', '.join(context_lines)) - - if idx + 1 < len(lines): - matches = pattern.match(lines[idx].strip()) - if matches: - file = matches.group(1) - lineno = matches.group(2) - lineage = f'File: {file}:{lineno}' - summary_lines.append(lineage) - - code = f"code: {lines[idx + 1].strip()}" - summary_lines.append(code) - - summary_str = ', '.join(summary_lines) - body.append(f'\n# {summary_str}\n') - elif prev_stacktrace != "": - prev_stacktrace = "" - body.append('\n# No stacktrace found for following nodes \n') - - def emit_node(node : Node): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': - assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') - if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') - return - elif node.op == 'call_method': - assert isinstance(node.target, str) - body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') - return - elif node.op == 'call_function': - assert callable(node.target) - # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: - assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') - return - - # pretty print inplace operators; required for jit.script to work properly - # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') - return - - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') - return - body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): - wrapped_fns.setdefault(global_name) - return - elif node.op == 'call_module': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') - return - elif node.op == 'get_attr': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') - return - elif node.op == 'output': - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - body.append(self.generate_output(node.args[0])) - return - raise NotImplementedError(f'node: {node.op} {node.target}') - - for node in nodes: - # NOTE: emit_node does not emit a string with newline. It depends - # on delete_unused_values to append one - if verbose: - append_stacktrace_summary(node) - emit_node(node) - delete_unused_values(node) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append('pass\n') - - - - if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', pippy.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) - else: - wrap_stmts = '' - - if self._body_transformer: - body = self._body_transformer(body) - - for name, value in self.additional_globals(): - add_global(name, value) - - prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) - fn_code = f""" -{wrap_stmts} - -{prologue} -{code}""" - return PythonCode(fn_code, globals_) - - -# Ideally, we'd like to refactor all of the pytree logic into this codegen -# class. Unfortunately, there are 3 areas we currently need extra logic in FX. -# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. -# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. -# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. -# 3. We currently can't register the pytree imports with `add_global` - not sure why. -class _PyTreeCodeGen(CodeGen): - def __init__(self, pytree_info: _PyTreeInfo): - super().__init__() - self.pytree_info: _PyTreeInfo = pytree_info - - def process_inputs(self, *inputs: Any) -> Any: - flat_args, _ = pytree.tree_flatten(inputs) - return flat_args - - def process_outputs(self, out: Any) -> Any: - if self.pytree_info is None: - return out - if not isinstance(out, list): - out = [out] - assert(self.pytree_info.out_spec is not None) - return pytree.tree_unflatten(out, self.pytree_info.out_spec) - - def gen_fn_def(self, free_vars, maybe_return_annotation): - if self.pytree_info is None: - return super().gen_fn_def(free_vars, maybe_return_annotation) - function_args = self.pytree_info.orig_args - has_orig_self = (function_args[0] == 'self') - if has_orig_self: - free_vars.insert(0, 'self') - function_definition = super().gen_fn_def(function_args[:], maybe_return_annotation) - if len(free_vars) > 0: # pytree has placeholders in it - function_definition += f""" - {', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(function_args)}], self._in_spec)""" - return function_definition - - def generate_output(self, output_args): - if self.pytree_info: - return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' - else: - return super().generate_output(output_args) - -@compatibility(is_backward_compatible=True) -class Graph: - """ - ``Graph`` is the main data structure used in the FX Intermediate Representation. - It consists of a series of ``Node`` s, each representing callsites (or other - syntactic constructs). The list of ``Node`` s, taken together, constitute a - valid Python function. - - For example, the following code - - .. code-block:: python - - import torch - import pippy.fx - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - Will produce the following Graph:: - - print(gm.graph) - - .. code-block:: text - - graph(x): - %linear_weight : [#users=1] = self.linear.weight - %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) - %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) - %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) - %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) - %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) - return topk_1 - - For the semantics of operations represented in the ``Graph``, please see :class:`Node`. - """ - - @compatibility(is_backward_compatible=True) - def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, - tracer_extras: Optional[Dict[str, Any]] = None): - """ - Construct an empty Graph. - """ - self._root : Node = Node(self, '', 'root', '', (), {}) - self._used_names : Dict[str, int] = {} # base name -> number - self._insert = self._root.prepend - self._len = 0 - self._graph_namespace = _Namespace() - self._owners = 0 - self._owning_module = owning_module - self._tracer_cls = tracer_cls - self._tracer_extras = tracer_extras - self._codegen = CodeGen() - - @property - def owning_module(self): - """ - Return the module that owns this ``GraphModule``, if there is one, - ``None`` if there is no owning module or if there are multiple owning - modules. - """ - return self._owning_module - - @owning_module.setter - def owning_module(self, mod: Optional["GraphModule"]): - if mod: - self._owning_module = mod if not self._owners else None - self._owners += 1 - - @property - def nodes(self) -> _node_list: - """ - Get the list of Nodes that constitute this Graph. - - Note that this ``Node`` list representation is a doubly-linked list. Mutations - during iteration (e.g. delete a Node, add a Node) are safe. - - Returns: - - A doubly-linked list of Nodes. Note that ``reversed`` can be called on - this list to switch iteration order. - """ - return _node_list(self) - - @compatibility(is_backward_compatible=True) - def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': - """ - Copy all nodes from a given graph into ``self``. - - Args: - - g (Graph): The source graph from which to copy Nodes. - - val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping - from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed - in with values in it already to override copying of certain values. - - Returns: - - The value in ``self`` that is now equivalent to the output value in ``g``, - if ``g`` had an ``output`` node. ``None`` otherwise. - """ - for node in g.nodes: - if node in val_map: - continue - if node.op == 'output': - rv = map_arg(node.args[0], lambda n: val_map[n]) - return rv if not return_output_node else (rv, node) - val_map[node] = self.node_copy(node, lambda n : val_map[n]) - return None - - def __deepcopy__(self, memo=None) -> 'Graph': - """ - Explicitly implement __deepcopy__ to prevent excessive recursion depth - from the default implementation. This uses graph_copy to copy the nodes - in an iterative way, rather than recursive. It also populates the - memoization table to prevent unnecessary copies (e.g. references to - nodes or other parts of the Graph from a custom GraphModule implementation. - """ - memo = memo if memo else {} - g = Graph(tracer_cls=self._tracer_cls) - output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) - g._codegen = copy.deepcopy(self._codegen) - assert isinstance(output_vals, tuple) - output_val, old_output_val = output_vals - g.output(output_val, type_expr=getattr(old_output_val, 'type', None)) - return g - - @compatibility(is_backward_compatible=True) - def create_node(self, op: str, target: 'Target', - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - name: Optional[str] = None, - type_expr: Optional[Any] = None) -> Node: - """ - Create a ``Node`` and add it to the ``Graph`` at the current insert-point. - Note that the current insert-point can be set via :meth:`Graph.inserting_before` - and :meth:`Graph.inserting_after`. - - Args: - op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', - 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are - described in the ``Graph`` docstring. - - args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. - - kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node - - name (Optional[str]): an optional string name for the ``Node``. - This will influence the name of the value assigned to in the - Python generated code. - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - Returns: - - The newly-created and inserted node. - """ - assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') - args = () if args is None else args - kwargs = {} if kwargs is None else kwargs - assert isinstance(args, tuple), "args must be a tuple" - assert isinstance(kwargs, dict), "kwargs must be a dict" - - candidate = name if name is not None else self._target_to_str(target) - name = self._graph_namespace.create_name(candidate, None) - n = Node(self, name, op, target, args, kwargs, type_expr) - - self._graph_namespace.associate_name_with_obj(name, n) - - self._insert(n) - self._len += 1 - return n - - @compatibility(is_backward_compatible=False) - def process_inputs(self, *args): - """ - Processes args so that they can be passed to the FX graph. - """ - return self._codegen.process_inputs(*args) - - @compatibility(is_backward_compatible=False) - def process_outputs(self, out): - return self._codegen.process_outputs(out) - - - @compatibility(is_backward_compatible=True) - def erase_node(self, to_erase : Node) -> None: - """ - Erases a ``Node`` from the ``Graph``. Throws an exception if - there are still users of that node in the ``Graph``. - - Args: - - to_erase (Node): The ``Node`` to erase from the ``Graph``. - """ - if len(to_erase.users) > 0: - raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' - f'users in the graph: {to_erase.users}!') - - to_erase._remove_from_list() - to_erase._erased = True # iterators may retain handles to erased nodes - self._len -= 1 - - # Null out this Node's argument nodes so that the Nodes referred to - # can update their ``users`` accordingly - new_args = map_arg(to_erase.args, lambda n: None) - assert isinstance(new_args, tuple) - to_erase.args = new_args - new_kwargs = map_arg(to_erase.kwargs, lambda n: None) - assert isinstance(new_kwargs, dict) - to_erase.kwargs = new_kwargs - - @compatibility(is_backward_compatible=True) - def inserting_before(self, n: Optional[Node] = None): - """Set the point at which create_node and companion methods will insert into the graph. - When used within a 'with' statement, this will temporary set the insert point and - then restore it when the with statement exits:: - - with g.inserting_before(n): - ... # inserting before node n - ... # insert point restored to what it was previously - g.inserting_before(n) # set the insert point permanently - - Args: - - n (Optional[Node]): The node before which to insert. If None this will insert before - the beginning of the entire graph. - - Returns: - A resource manager that will restore the insert point on ``__exit__``. - """ - if n is None: - return self.inserting_after(self._root) - assert n.graph == self, "Node to insert before is not in graph." - return _InsertPoint(self, n.prepend) - - @compatibility(is_backward_compatible=True) - def inserting_after(self, n: Optional[Node] = None): - """Set the point at which create_node and companion methods will insert into the graph. - When used within a 'with' statement, this will temporary set the insert point and - then restore it when the with statement exits:: - - with g.inserting_after(n): - ... # inserting after node n - ... # insert point restored to what it was previously - g.inserting_after(n) # set the insert point permanently - - Args: - - n (Optional[Node]): The node before which to insert. If None this will insert after - the beginning of the entire graph. - - Returns: - A resource manager that will restore the insert point on ``__exit__``. - """ - if n is None: - return self.inserting_before(self._root) - assert n.graph == self, "Node to insert after is not in graph." - return _InsertPoint(self, n.append) - - @compatibility(is_backward_compatible=True) - def placeholder(self, name: str, type_expr: Optional[Any] = None, - default_value : Any = inspect.Signature.empty) -> Node: - """ - Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents - a function input. - - Args: - - name (str): A name for the input value. This corresponds to the name - of the positional argument to the function this ``Graph`` represents. - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. This is needed in some - cases for proper code generation (e.g. when the function is used - subsequently in TorchScript compilation). - - default_value (Any): The default value this function argument should take - on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` - should be passed as this argument to specify that the parameter does _not_ - have a default value. - - .. note:: - The same insertion point and type expression rules apply for this method - as ``Graph.create_node``. - """ - args = () if default_value is inspect.Signature.empty else (default_value,) - return self.create_node('placeholder', name, args=args, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: - """ - Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the - fetch of an attribute from the ``Module`` hierarchy. - - Args: - - qualified_name (str): the fully-qualified name of the attribute to be retrieved. - For example, if the traced Module has a submodule named ``foo``, which has a - submodule named ``bar``, which has an attribute named ``baz``, the qualified - name ``foo.bar.baz`` should be passed as ``qualified_name``. - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - - Returns: - - The newly-created and inserted ``get_attr`` node. - - .. note:: - The same insertion point and type expression rules apply for this method - as ``Graph.create_node``. - """ - def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: - module_path, _, name = qualified_name.rpartition(".") - - try: - submod: torch.nn.Module = mod.get_submodule(module_path) - except AttributeError: - warnings.warn(f"Failed to fetch module {module_path}!") - return False - - if not hasattr(submod, name): - return False - - res = getattr(submod, name) - - if (not isinstance(res, torch.nn.Module) - and not isinstance(res, torch.nn.Parameter) - and name not in submod._buffers): - return False - - return True - - if (self.owning_module and - not _get_attr_reference_exists(self.owning_module, qualified_name)): - warnings.warn("Attempted to insert a get_attr Node with no " - "underlying reference in the owning " - "GraphModule! Call " - "GraphModule.add_submodule to add the " - "necessary submodule, " - "GraphModule.add_parameter to add the " - "necessary Parameter, or " - "nn.Module.register_buffer to add the " - "necessary buffer", stacklevel=2) - return self.create_node('get_attr', qualified_name, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def call_module(self, - module_name: str, - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: - """ - Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node - represents a call to the forward() function of a ``Module`` in the ``Module`` - hierarchy. - - Args: - - module_name (str): The qualified name of the ``Module`` in the ``Module`` - hierarchy to be called. For example, if the traced ``Module`` has a - submodule named ``foo``, which has a submodule named ``bar``, the - qualified name ``foo.bar`` should be passed as ``module_name`` to - call that module. - - args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed - to the called method. Note that this should *not* include a ``self`` argument. - - kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed - to the called method - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - Returns: - - The newly-created and inserted ``call_module`` node. - - .. note:: - The same insertion point and type expression rules apply for this method - as :meth:`Graph.create_node`. - """ - if (self.owning_module and - self.owning_module.get_submodule(module_name) is None): - warnings.warn("Attempted to insert a call_module Node with " - "no underlying reference in the owning " - "GraphModule! Call " - "GraphModule.add_submodule to add the " - "necessary submodule") - return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def call_method(self, - method_name: str, - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: - """ - Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node - represents a call to a given method on the 0th element of ``args``. - - Args: - - method_name (str): The name of the method to apply to the self argument. - For example, if args[0] is a ``Node`` representing a ``Tensor``, - then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. - - args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed - to the called method. Note that this *should* include a ``self`` argument. - - kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed - to the called method - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - Returns: - - The newly created and inserted ``call_method`` node. - - .. note:: - The same insertion point and type expression rules apply for this method - as :meth:`Graph.create_node`. - """ - return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def call_function(self, - the_function: Callable[..., Any], - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: - """ - Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node - represents a call to a Python callable, specified by ``the_function``. - - Args: - - the_function (Callable[..., Any]): The function to be called. Can be any PyTorch - operator, Python function, or member of the ``builtins`` or ``operator`` - namespaces. - - args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed - to the called function. - - kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed - to the called function - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - Returns: - - The newly created and inserted ``call_function`` node. - - .. note:: - The same insertion point and type expression rules apply for this method - as :meth:`Graph.create_node`. - """ - return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: - """ - Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from - the graph of node to the graph of self. Example:: - - # Copying all the nodes in `g` into `new_graph` - g : pippy.fx.Graph = ... - new_graph = pippy.fx.graph() - value_remap = {} - for node in g.nodes: - value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) - - Args: - - node (Node): The node to copy into ``self``. - - arg_transform (Callable[[Node], Argument]): A function that transforms - ``Node`` arguments in node's ``args`` and ``kwargs`` into the - equivalent argument in ``self``. In the simplest case, this should - retrieve a value out of a table mapping Nodes in the original - graph to ``self``. - """ - args = map_arg(node.args, arg_transform) - kwargs = map_arg(node.kwargs, arg_transform) - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) - result_node.meta = copy.copy(node.meta) - return result_node - - @compatibility(is_backward_compatible=True) - def output(self, result: 'Argument', type_expr: Optional[Any] = None): - """ - Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents - a ``return`` statement in Python code. ``result`` is the value that should - be returned. - - Args: - - result (Argument): The value to be returned. - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - .. note:: - - The same insertion point and type expression rules apply for this method - as ``Graph.create_node``. - """ - return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) - - def _target_to_str(self, target : Target) -> str: - if callable(target): - op = target.__name__ - else: - assert isinstance(target, str) - op = target - if _is_magic(op): - op = op[2:-2] - op = _snake_case(op) - return op - - @compatibility(is_backward_compatible=True) - def python_code(self, root_module: str, *, verbose: bool = False) -> PythonCode: - """ - Turn this ``Graph`` into valid Python code. - - Args: - - root_module (str): The name of the root module on which to look-up - qualified name targets. This is usually 'self'. - - Returns: - - A PythonCode object, consisting of two fields: - src: the Python source code representing the object - globals: a dictionary of global names in `src` -> the objects that they reference. - """ - # NOTE: [Graph Namespaces] - # - # There are two types of symbols in generated Python source code: - # locals and globals. - # Locals are locally defined by the output of a node in the Graph. - # Globals are references to external objects, like functions or types. - # - # When generating Python code, we need to make sure to name things - # appropriately. In particular: - # - All names should be unique, to avoid weird shadowing bugs. - # - These names need to be consistent, e.g. a object should always be - # referenced by the same name. - # - # To do this, we create a new namespace just for this source. All names - # that get printed must come from this namespace. - # - # Why can't we re-use node.name? Because it was generated within the - # namespace `self._graph_namespace`. In order to provide uniqueness - # over both locals (node.name) *and* globals, we create a completely - # new namespace to put all identifiers in. - namespace = _Namespace() - - # Override Node's repr to generate a valid name within our namespace. - # Since repr() is designed to produce a valid Python expression, it - # makes sense to re-use it. This way, it's easy to print something like - # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is - # implemented cooperatively to allow this. - def node_repr(n: Node): - return namespace.create_name(n.name, n) - - @contextmanager - def override_node_repr(graph: Graph): - orig_repr_fns = {} - for node in graph.nodes: - orig_repr_fns[node] = node._repr_fn - node._repr_fn = node_repr - try: - yield None - finally: - # restore the original repr functions - for node in graph.nodes: - node._repr_fn = orig_repr_fns[node] - - with override_node_repr(self): - return self._python_code(root_module, namespace, verbose=verbose) - - def _python_code(self, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode: - return self._codegen._gen_python_code(self.nodes, root_module, namespace, verbose=verbose) - - - def __str__(self) -> str: - """ - Return a human-readable (not machine-readable) string representation - of this Graph - """ - placeholder_names : List[str] = [] - # This is a one-element array just so ``format_node`` can modify the closed - # over value - maybe_return_typename : List[str] = [''] - - node_strs = [node.format_node(placeholder_names) for node in self.nodes] - param_str = ', '.join(placeholder_names) - s = f'graph({param_str}){maybe_return_typename[0]}:' - for node_str in node_strs: - if node_str: - s += '\n ' + node_str - return s - - @compatibility(is_backward_compatible=True) - def print_tabular(self): - """ - Prints the intermediate representation of the graph in tabular - format. Note that this API requires the ``tabulate`` module to be - installed. - """ - try: - from tabulate import tabulate - except ImportError: - print("`print_tabular` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") - node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] - for n in self.nodes] - print(tabulate(node_specs, - headers=['opcode', 'name', 'target', 'args', 'kwargs'])) - - @compatibility(is_backward_compatible=True) - def lint(self): - """ - Runs various checks on this Graph to make sure it is well-formed. In - particular: - - Checks Nodes have correct ownership (owned by this graph) - - Checks Nodes appear in topological order - - If this Graph has an owning GraphModule, checks that targets - exist in that GraphModule - """ - - # Check topo order - def check_arg(arg : Node, n : Optional[Node] = None) -> None: - context_str = f' of Node \'{n}\' ' if n else ' ' - if arg.graph is not self: - raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' - f'but was used as an argument! If you are copying nodes from another graph, make ' - f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') - if arg not in seen_values: - raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' - f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') - - seen_names : Set[str] = set() - seen_values : Set[Node] = set() - for node in self.nodes: - if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: - raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') - if node.graph is not self: - raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') - map_arg(node.args, lambda arg: check_arg(arg, node)) - map_arg(node.kwargs, lambda arg: check_arg(arg, node)) - seen_values.add(node) - - if node.name in seen_names: - raise RuntimeError(f'Node redefined name {node.name}!') - seen_names.add(node.name) - - # Check targets are legit - if self.owning_module: - for node in self.nodes: - if node.op == 'call_function': - if not callable(node.target): - raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' - 'a Callable is expected') - else: - if not isinstance(node.target, str): - raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' - 'a str is expected') - if node.op in ['get_attr', 'call_module']: - target_atoms = node.target.split('.') - m_itr = self.owning_module - for i, atom in enumerate(target_atoms): - new_m_itr = getattr(m_itr, atom, None) - seen_qualname = '.'.join(target_atoms[:i]) - if new_m_itr is None: - raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' - f'{atom} of {seen_qualname}') - if (node.op == "call_module" - and not isinstance(new_m_itr, torch.nn.Module)): - raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module') - elif (node.op == "get_attr" - and not isinstance(new_m_itr, torch.nn.Module) - and not isinstance(new_m_itr, torch.nn.Parameter) - and atom not in m_itr._buffers): - warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module, nn.Parameter, or buffer, which is ' - 'what \'get_attr\' Nodes typically target') - else: - m_itr = new_m_itr - - @compatibility(is_backward_compatible=True) - def eliminate_dead_code(self): - """ - Remove all dead code from the graph, based on each node's number of - users, and whether the nodes have any side effects. The graph must be - topologically sorted before calling. - - Returns: - bool: Whether the graph was changed as a result of the pass. - - Example: - - Before dead code is eliminated, `a` from `a = x + 1` below has no users - and thus can be eliminated from the graph without having an effect. - - .. code-block:: python - - def forward(self, x): - a = x + 1 - return x + self.attr_1 - - After dead code is eliminated, `a = x + 1` has been removed, and the rest - of `forward` remains. - - .. code-block:: python - - def forward(self, x): - return x + self.attr_1 - - .. warning:: - - Dead code elimination has some heuristics to avoid removing - side-effectful nodes (see Node.is_impure) but in general coverage - is very bad, so you should assume that this method is not sound - to call unless you know that your FX graph consists entirely - of functional operations. - """ - # Lint the graph first to make sure its topologically sorted, otherwise - # DCE below will not behave as expected. - self.lint() - - # Reverse iterate so that when we remove a node, any nodes used as an - # input to that node have an updated user count that no longer reflects - # the removed node. - changed = False - for node in reversed(self.nodes): - if not node.is_impure() and len(node.users) == 0: - self.erase_node(node) - changed = True - - return changed - - @compatibility(is_backward_compatible=False) - def set_codegen(self, codegen: CodeGen): - self._codegen = codegen - - @compatibility(is_backward_compatible=False) - def on_generate_code( - self, - make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] - ): - """Register a transformer function when python code is generated - - Args: - make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): - a function that returns a code transformer to be registered. - This function is called by `on_generate_code` to obtain the - code transformer. - - This function is also given as its input the currently - registered code transformer (or None if nothing is registered), - in case it is not desirable to overwrite it. This is useful to - chain code transformers together. - - Returns: - a context manager that when used in a `with` statement, to automatically - restore the previously registered code transformer. - - Example: - - .. code-block:: python - - - gm: fx.GraphModule = ... - - # This is a code transformer we want to register. This code - # transformer prepends a pdb import and trace statement at the very - # beginning of the generated pippy.fx code to allow for manual - # debugging with the PDB library. - def insert_pdb(body): - return ["import pdb; pdb.set_trace()\\n", *body] - - # Registers `insert_pdb`, and overwrites the current registered - # code transformer (given by `_` to the lambda): - gm.graph.on_generate_code( - lambda _: insert_pdb - ) - - # Or alternatively, registers a code transformer which first - # runs `body` through existing registered transformer, then - # through `insert_pdb`: - gm.graph.on_generate_code( - lambda current_trans: ( - lambda body: insert_pdb( - current_trans(body) if current_trans - else body - ) - ) - ) - - gm.recompile() - gm(*inputs) # drops into pdb - - - This function can also be used as a context manager, with the benefit to - automatically restores the previously registered code transformer: - - .. code-block:: python - - # ... continue from previous example - - with gm.graph.on_generate_code(lambda _: insert_pdb): - # do more stuff with `gm`... - gm.recompile() - gm(*inputs) # drops into pdb - - # now previous code transformer is restored (but `gm`'s code with pdb - # remains - that means you can run `gm` with pdb here too, until you - # run next `recompile()`). - """ - on_gen_code_old = self._codegen._body_transformer - self._codegen._body_transformer = make_transformer(on_gen_code_old) - - @contextlib.contextmanager - def on_generate_code_context_manager(): - try: - yield - finally: - self._codegen._body_transformer = on_gen_code_old - - return on_generate_code_context_manager() - - -reflectable_magic_methods = { - 'add': '{} + {}', - 'sub': '{} - {}', - 'mul': '{} * {}', - 'floordiv': '{} // {}', - 'truediv': '{} / {}', - 'div': '{} / {}', - 'mod': '{} % {}', - 'pow': '{} ** {}', - 'lshift': '{} << {}', - 'rshift': '{} >> {}', - 'and_': '{} & {}', - 'or_': '{} | {}', - 'xor': '{} ^ {}', - 'getitem': '{}[{}]', - 'matmul': '{} @ {}', -} - -magic_methods = dict({ - 'eq': '{} == {}', - 'ne': '{} != {}', - 'lt': '{} < {}', - 'gt': '{} > {}', - 'le': '{} <= {}', - 'ge': '{} >= {}', - 'pos': '+{}', - 'neg': '-{}', - 'invert': '~{}'}, **reflectable_magic_methods) - -inplace_methods = { - 'iadd': '{} += {}', - 'iand': '{} &= {}', - 'ifloordiv': '{} //= {}', - 'ilshift': '{} <<= {}', - 'imod': '{} %= {}', - 'imul': '{} *= {}', - 'imatmul': '{} @= {}', - 'ior': '{} |= {}', - 'ipow': '{} **= {}', - 'irshift': '{} >>= {}', - 'isub': '{} -= {}', - 'itruediv': '{} /= {}', - 'ixor': '{} ^= {}', - 'setitem': '{}[{}] = {}', -} diff --git a/pippy/fx/graph_module.py b/pippy/fx/graph_module.py deleted file mode 100644 index f8f46d917..000000000 --- a/pippy/fx/graph_module.py +++ /dev/null @@ -1,759 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import copy -import itertools -import linecache -import os -import sys -import traceback -import warnings -from pathlib import Path -from typing import Type, Dict, List, Any, Union, Optional, Set # pylint: disable=unused-import - -import torch -import torch.nn as nn -import torch.overrides -from torch.nn.modules.module import _addindent -from torch.package import Importer, sys_importer -from torch.package import PackageImporter, PackageExporter - -from ._compatibility import compatibility -from .graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode - - -# Normal exec loses the source code, however we can work with -# the linecache module to recover it. -# Using _exec_with_source will add it to our local cache -# and then tools like TorchScript will be able to get source info. -class _EvalCacheLoader(object): - def __init__(self): - self.eval_cache = {} - self.next_id = 0 - - def cache(self, src: str, globals: Dict[str, Any]): - """Store the source in a private cache, and add a lazy entry in linecache - that allows the source to be retrieved by 'filename'. - - Args: - src (str): The module source to cache - globals (dict): The module globals - - Returns: - str: The cache key (and dummy filename) generated for src. - """ - - key = self._get_key() - self.eval_cache[key] = src - - # Don't mutate globals so that this loader is only used - # to populate linecache, and doesn't interact with other modules - # that might check `__loader__` - globals_copy = globals.copy() - globals_copy['__file__'] = key - globals_copy['__name__'] = key - globals_copy['__loader__'] = self - linecache.lazycache(key, globals_copy) - - return key - - # Part of the loader protocol (PEP 302) - # linecache will use this method when trying to find source code - def get_source(self, module_name) -> Optional[str]: - if module_name in self.eval_cache: - return self.eval_cache[module_name] - return None - - def _get_key(self): - key = f'.{self.next_id}' - self.next_id += 1 - return key - -_loader = _EvalCacheLoader() - - -def _exec_with_source(src: str, globals: Dict[str, Any]): - key = _loader.cache(src, globals) - exec(compile(src, key, 'exec'), globals) - - -def _forward_from_src(src: str, globals: Dict[str, Any]): - # avoid mutating the passed in dict - globals_copy = globals.copy() - _exec_with_source(src, globals_copy) - forward_fn = globals_copy['forward'] - del globals_copy['forward'] - return forward_fn - - -def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: - if name in _custom_builtins: - return _custom_builtins[name].import_str - if _is_from_torch(name): - return 'import torch' - module_name, attr_name = importer.get_name(obj) - return f'from {module_name} import {attr_name} as {name}' - - -def _format_import_block(globals: Dict[str, Any], importer: Importer): - import_strs: Set[str] = set() - for name, obj in globals.items(): - import_strs.add(_format_import_statement(name, obj, importer)) - return '\n'.join(import_strs) - - -@compatibility(is_backward_compatible=True) -def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: - # BC: attribute name was changed from `code` to `_code` to facilitate - # making `code` into a property and adding a docstring to it - fn_src = body.get('_code') or body['code'] - forward = _forward_from_src(import_block + fn_src, {}) - return _deserialize_graph_module(forward, body) - - -@compatibility(is_backward_compatible=True) -def reduce_package_graph_module( - importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str -) -> torch.nn.Module: - forward = importer.import_module(generated_module_name).forward - return _deserialize_graph_module(forward, body) - -@compatibility(is_backward_compatible=True) -def reduce_deploy_graph_module( - importer: PackageImporter, body: Dict[Any, Any], import_block: str -) -> torch.nn.Module: - ns = {} - ns["__builtins__"] = importer.patched_builtins - fn_src = body.get('_code') - assert fn_src is not None - forward = _forward_from_src(import_block + fn_src, ns) - return _deserialize_graph_module(forward, body) - - -def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module: - """ - Deserialize a GraphModule given the dictionary of the original module, - using the code to reconstruct the graph. We delete the actual graph before - saving the dictionary so that changes to the in-memory graph format do not - get serialized. - """ - # We create a dummy class here because symbolic_trace pulls the forward() - # function off of the class, rather than the instance - class CodeOnlyModule(torch.nn.Module): - def __init__(self, body): - super().__init__() - self.__dict__ = body - - # Try to retrieve the forward source in a backward-compatible way - CodeOnlyModule.forward = forward - - tracer_cls = body.get('_tracer_cls') - if tracer_cls is None: - from ._symbolic_trace import Tracer - tracer_cls = Tracer - - graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule') - - # This is a workaround for a mypy linter issue related to - # passing base class as an argument - https://github.com/python/mypy/issues/5865. - cls_tracer : Any = tracer_cls - - class KeepModules(cls_tracer): - # we shouldn't trace into any of the submodules, - # because they were not traced in the original GraphModule - def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: - return True - - com = CodeOnlyModule(body) - - tracer_extras = body.get('_tracer_extras', {}) - graph = KeepModules().trace(com, **tracer_extras) - - # Manually set Tracer class on the reconstructed Graph, to avoid - # referencing the private local subclass KeepModules. - graph._tracer_cls = tracer_cls - gm = GraphModule(com, graph, class_name=graphmodule_cls_name) - - # The GraphModule constructor only retains attributes referenced by the graph. - # In this case, our goal is return a GraphModule as close to identical as the one - # put into the package. If any additional attributes were present in body, - # we should keep them. - for k, v in body.items(): - if not hasattr(gm, k): - setattr(gm, k, v) - return gm - -# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' -# This installs empty Modules where none exist yet if they are subpaths of target -def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): - *prefix, field = target.split('.') - for item in prefix: - f = getattr(from_module, item) - t = getattr(to_module, item, None) - if f is t: - # we have already installed one of its parents - # (e.g. target = root.linear.weight, but we have already installed root.linear) - # once we install a parent, we no longer need to copy the children - # since all the needed properties will already be present - return - - if t is None: - t = torch.nn.Module() - setattr(to_module, item, t) - from_module, to_module = f, t - - orig = getattr(from_module, field) - # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. - # So, we register it as a named buffer in the target module. - if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): - to_module.register_buffer(field, orig) - else: - setattr(to_module, field, orig) - -# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module -# This installs empty Modules where none exist yet if they are subpaths of target -def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): - *prefix, field = target.split('.') - for item in prefix: - t = getattr(to_module, item, None) - - if t is None: - t = torch.nn.Module() - setattr(to_module, item, t) - to_module = t - - # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. - # So, we register it as a named buffer in the target module. - if isinstance(from_obj, torch.Tensor) and not isinstance(from_obj, torch.nn.Parameter): - to_module.register_buffer(field, from_obj) - else: - setattr(to_module, field, from_obj) - -class _WrappedCall: - def __init__(self, cls, cls_call): - self.cls = cls - self.cls_call = cls_call - - # Previously, if an error occurred when valid - # symbolically-traced code was run with an invalid input, the - # user would see the source of the error as coming from - # `File "`, where N is some number. We use - # this function to generate a more informative error message. We - # return the traceback itself, a message explaining that the - # error occurred in a traced Module's generated forward - # function, and five lines of context surrounding the faulty - # line - @staticmethod - def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: - # auxiliary variables (for readability) - err_lineno = frame_summary.lineno - assert err_lineno is not None - line = frame_summary.line - assert line is not None - err_line_len = len(line) - all_src_lines = linecache.getlines(frame_summary.filename) - - # constituent substrings of the error message - tb_repr = traceback.format_exc() - custom_msg = ("Call using an FX-traced Module, " - f"line {err_lineno} of the traced Module's " - "generated forward function:") - before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) - marker = "~" * err_line_len + "~~~ <--- HERE" - err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) - - # joined message - return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) - - def __call__(self, obj, *args, **kwargs): - try: - if self.cls_call is not None: - return self.cls_call(obj, *args, **kwargs) - else: - return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] - except Exception as e: - assert e.__traceback__ - topmost_framesummary: traceback.FrameSummary = \ - traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] - if "eval_with_key" in topmost_framesummary.filename: - print(_WrappedCall._generate_error_message(topmost_framesummary), - file=sys.stderr) - raise e.with_traceback(None) - else: - raise e - -@compatibility(is_backward_compatible=True) -class GraphModule(torch.nn.Module): - """ - GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a - ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated - from that ``graph``. - - .. warning:: - - When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically - regenerated. However, if you edit the contents of the ``graph`` without reassigning - the ``graph`` attribute itself, you must call ``recompile()`` to update the generated - code. - """ - def __new__(cls: 'Type[GraphModule]', *args, **kwargs): - # each instance of a graph module needs its own forward method - # so create a new singleton class for each instance. - # it is a subclass of the user-defined class, the only difference - # is an extra layer to install the forward method - - # address issue described at https://github.com/pytorch/pytorch/issues/63883 - # in other words, traverse class hierarchy to fix the redundant class definition problem - for t in cls.__mro__: - c = t.__qualname__.split('.')[-1] - if c != 'GraphModuleImpl': - cls = t - break - - class GraphModuleImpl(cls): # type: ignore[misc, valid-type] - pass - return super().__new__(GraphModuleImpl) - - @compatibility(is_backward_compatible=True) - def __init__(self, - root: Union[torch.nn.Module, Dict[str, Any]], - graph: Graph, - class_name: str = 'GraphModule'): - """ - Construct a GraphModule. - - Args: - - root (Union[torch.nn.Module, Dict[str, Any]): - ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. - In the case that ``root`` is a Module, any references to Module-based objects (via qualified - name) in the Graph's Nodes' ``target`` field will be copied over from the respective place - within ``root``'s Module hierarchy into the GraphModule's module hierarchy. - In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be - looked up directly in the dict's keys. The object mapped to by the Dict will be copied - over into the appropriate place within the GraphModule's module hierarchy. - - graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation - - class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all - error messages will report as originating from ``GraphModule``. It may be helpful to set this - to ``root``'s original name or a name that makes sense within the context of your transform. - """ - super().__init__() - self.__class__.__name__ = class_name - if isinstance(root, torch.nn.Module): - if hasattr(root, 'training'): - self.training = root.training - for node in graph.nodes: - if node.op in ['get_attr', 'call_module']: - assert isinstance(node.target, str) - _copy_attr(root, self, node.target) - elif isinstance(root, dict): - targets_to_copy = [] - for node in graph.nodes: - if node.op in ['get_attr', 'call_module']: - assert isinstance(node.target, str) - if node.target not in root: - raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target + - ' but that target was not provided in ``root``!') - targets_to_copy.append(node.target) - # Sort targets in ascending order of the # of atoms. - # This will ensure that less deeply nested attributes are assigned - # before more deeply nested attributes. For example, foo.bar - # will be assigned before foo.bar.baz. Otherwise, we might assign - # the user-provided ``foo.bar`` and wipe out the previously-assigned - # ``foo.bar.baz`` - targets_to_copy.sort(key=lambda t: t.count('.')) - for target_to_copy in targets_to_copy: - _assign_attr(root[target_to_copy], self, target_to_copy) - else: - raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!') - - self.graph = graph - - # Store the Tracer class responsible for creating a Graph separately as part of the - # GraphModule state, except when the Tracer is defined in a local namespace. - # Locally defined Tracers are not pickleable. This is needed because torch.package will - # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer - # to re-create the Graph during deserialization. - self._tracer_cls = None - if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: - self._tracer_cls = self.graph._tracer_cls - - self._tracer_extras = {} - if self.graph._tracer_extras: - self._tracer_extras = self.graph._tracer_extras - - # Dictionary to store metadata - self.meta : Dict[str, Any] = {} - - # TorchScript breaks trying to compile the graph setter because of the - # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 - # - # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway - __jit_unused_properties__ = ['graph'] - - @property - def graph(self) -> Graph: - """ - Return the ``Graph`` underlying this ``GraphModule`` - """ - return self._graph - - @graph.setter - def graph(self, g : Graph) -> None: - """ - Set the underlying ``Graph`` for this ``GraphModule``. This will internally - recompile the ``GraphModule`` so that the generated ``forward()`` function - corresponds to ``g`` - """ - assert isinstance(g, Graph), f'Expected a Graph instance, but got {type(g)}' - self._graph = g - g.owning_module = self - self.recompile() - - @compatibility(is_backward_compatible=False) - def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"): - """Dumps out module to ``folder`` with ``module_name`` so that it can be - imported with ``from import `` - - Args: - - folder (Union[str, os.PathLike]): The folder to write the code out to - - module_name (str): Top-level name to use for the ``Module`` while - writing out the code - """ - folder = Path(folder) - Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') - tab = " " * 4 - custom_builtins = '\n'.join([v.import_str for v in _custom_builtins.values()]) - model_str = f""" -import torch -{custom_builtins} - -from torch.nn import * -class {module_name}(torch.nn.Module): - def __init__(self): - super().__init__() -""" - - def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: - safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] - if type(module) in safe_reprs: - return f"{module.__repr__()}" - else: - return None - - blobified_modules = [] - for module_name, module in self.named_children(): - module_str = _gen_model_repr(module_name, module) - if module_str is None: - module_file = folder / f'{module_name}.pt' - torch.save(module, module_file) - blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') - module_str = f"torch.load(r'{module_file}') # {module_repr}" - model_str += f"{tab*2}self.{module_name} = {module_str}\n" - - for buffer_name, buffer in self._buffers.items(): - if buffer is None: - continue - model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" - - for param_name, param in self._parameters.items(): - if param is None: - continue - model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" - - model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" - model_str += f"{_addindent(self.code, 4)}\n" - - module_file = folder / 'module.py' - module_file.write_text(model_str) - - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') - - if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") - - @compatibility(is_backward_compatible=True) - def add_submodule(self, target: str, m: torch.nn.Module) -> bool: - """ - Adds the given submodule to ``self``. - - This installs empty Modules where none exist yet if they are - subpaths of ``target``. - - Args: - target: The fully-qualified string name of the new submodule - (See example in ``nn.Module.get_submodule`` for how to - specify a fully-qualified string.) - m: The submodule itself; the actual object we want to - install in the current Module - - Return: - bool: Whether or not the submodule could be inserted. For - this method to return True, each object in the chain - denoted by ``target`` must either a) not exist yet, - or b) reference an ``nn.Module`` (not a parameter or - other attribute) - """ - *prefix, field = target.split('.') - mod: torch.nn.Module = self - - for item in prefix: - - submod = getattr(mod, item, None) - - if submod is None: - submod = torch.nn.Module() - setattr(mod, item, submod) - - if not isinstance(submod, torch.nn.Module): - return False - - mod = submod - - mod.add_module(field, m) - return True - - @compatibility(is_backward_compatible=True) - def delete_submodule(self, target: str) -> bool: - """ - Deletes the given submodule from ``self``. - - The module will not be deleted if ``target`` is not a valid - target. - - Args: - target: The fully-qualified string name of the new submodule - (See example in ``nn.Module.get_submodule`` for how to - specify a fully-qualified string.) - - Returns: - bool: Whether or not the target string referenced a - submodule we want to delete. A return value of ``False`` - means that the ``target`` was not a valid reference to - a submodule. - """ - atoms = target.split(".") - path, target_submod = atoms[:-1], atoms[-1] - mod: torch.nn.Module = self - - # Get the parent module - for item in path: - - if not hasattr(mod, item): - return False - - mod = getattr(mod, item) - - if not isinstance(mod, torch.nn.Module): - return False - - if not hasattr(mod, target_submod): - return False - - if not isinstance(getattr(mod, target_submod), torch.nn.Module): - return False - - delattr(mod, target_submod) - return True - - @compatibility(is_backward_compatible=True) - def delete_all_unused_submodules(self) -> None: - """ - Deletes all unused submodules from ``self``. - - A Module is considered "used" if any one of the following is - true: - 1. It has children that are used - 2. Its forward is called directly via a ``call_module`` node - 3. It has a non-Module attribute that is used from a - ``get_attr`` node - - This method can be called to clean up an ``nn.Module`` without - manually calling ``delete_submodule`` on each unused submodule. - """ - used: List[str] = [] - - for node in self.graph.nodes: - - if node.op == "call_module" or node.op == "get_attr": - - # A list of strings representing the different parts - # of the path. For exmaple, `foo.bar.baz` gives us - # ["foo", "bar", "baz"] - fullpath = node.target.split(".") - - # If we're looking at multiple parts of a path, join - # join them with a dot. Otherwise, return that single - # element without doing anything to it. - def join_fn(x: str, y: str) -> str: - return '.'.join([x, y] if y else [x]) - - # Progressively collect all the names of intermediate - # modules. For example, if we have the target - # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and - # `foo.bar.baz` to the list. - for path in itertools.accumulate(fullpath, join_fn): - used.append(path) - - # For a `call_module` node, also register all recursive submodules - # as used - if node.op == "call_module": - try: - submod = self.get_submodule(node.target) - - for submod_name, _ in submod.named_modules(): - if submod_name != '': - used.append('.'.join([node.target, submod_name])) - except AttributeError: - # Node referenced nonexistent submodule, don't need to - # worry about GCing anything - pass - - to_delete = [name for name, _ in self.named_modules() - if name not in used] - - for name in to_delete: - self.delete_submodule(name) - - @property - def code(self) -> str: - """ - Return the Python code generated from the ``Graph`` underlying this - ``GraphModule``. - """ - if not hasattr(self, '_code'): - raise RuntimeError('Code has not been generated! Please report a bug to PyTorch') - return self._code - - @compatibility(is_backward_compatible=True) - def recompile(self) -> PythonCode: - """ - Recompile this GraphModule from its ``graph`` attribute. This should be - called after editing the contained ``graph``, otherwise the generated - code of this ``GraphModule`` will be out of date. - """ - if isinstance(self._graph._codegen, _PyTreeCodeGen): - self._in_spec = self._graph._codegen.pytree_info.in_spec - self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') - self._code = python_code.src - - cls = type(self) - cls.forward = _forward_from_src(self._code, python_code.globals) - - # Determine whether this class explicitly defines a __call__ implementation - # to wrap. If it does, save it in order to have wrapped_call invoke it. - # If it does not, wrapped_call can use a dynamic call to super() instead. - # In most cases, super().__call__ should be torch.nn.Module.__call__. - # We do not want to hold a reference to Module.__call__ here; doing so will - # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. - cls_call = cls.__call__ if "__call__" in vars(cls) else None - - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] - - def call_wrapped(self, *args, **kwargs): - return self._wrapped_call(self, *args, **kwargs) - - cls.__call__ = call_wrapped - - return python_code - - # Passing Tracer as argument allows subclasses extending fx.GraphModule - # define their own Tracer (extending fx.Tracer). - def __reduce_deploy__(self, importer: Importer): - dict_without_graph = self.__dict__.copy() - dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__ - del dict_without_graph['_graph'] - - python_code = self.recompile() - import_block = _format_import_block(python_code.globals, importer) - return (reduce_deploy_graph_module, (dict_without_graph, import_block)) - - def __reduce_package__(self, exporter: PackageExporter): - dict_without_graph = self.__dict__.copy() - dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__ - del dict_without_graph['_graph'] - - generated_module_name = f'fx-generated._{exporter.get_unique_id()}' - python_code = self.recompile() - import_block = _format_import_block(python_code.globals, exporter.importer) - module_code = import_block + self.code - exporter.save_source_string(generated_module_name, module_code) - return (reduce_package_graph_module, (dict_without_graph, generated_module_name)) - - def __reduce__(self): - """ - Serialization of GraphModule. We serialize only the generated code, not - the underlying ``Graph``. This is because ``Graph`` does not have on-disk - backward-compatibility guarantees, whereas Python source code does. - On the deserialization side, we symbolically trace through the generated - code to regenerate the underlying ``Graph`` - """ - dict_without_graph = self.__dict__.copy() - python_code = self.recompile() - import_block = _format_import_block(python_code.globals, sys_importer) - del dict_without_graph['_graph'] - return (reduce_graph_module, (dict_without_graph, import_block)) - - # because __reduce__ is defined for serialization, - # we need to define deepcopy otherwise it will call __reduce__ - # and cause symbolic tracing to occur every time we try to copy the object - def __deepcopy__(self, memo): - fake_mod = torch.nn.Module() - fake_mod.__dict__ = copy.deepcopy(self.__dict__) - return GraphModule(fake_mod, fake_mod.__dict__['_graph']) - - def __copy__(self): - return GraphModule(self, self.graph) - - @compatibility(is_backward_compatible=False) - def print_readable(self): - """ - Return the Python code generated for current GraphModule and its children GraphModules - """ - verbose_python_code = self._graph.python_code(root_module='self', verbose=True) - module_code = verbose_python_code.src - module_code = module_code.lstrip('\n') - module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code - module_code = _addindent(module_code, 4) - - submodule_code_list = [""] - for submodule in self.children(): - if isinstance(submodule, GraphModule): - submodule_code_list.append(submodule.__nested_code()) - submodule_code = "\n".join(submodule_code_list) - submodule_code = _addindent(submodule_code, 4) - - print(module_code + submodule_code) - - def __str__(self) -> str: - orig_str = super().__str__() - print_readable_reminder = "# To see more debug info, please use `graph_module.print_readable()`" - return '\n'.join([orig_str, self._code, print_readable_reminder]) - - def _replicate_for_data_parallel(self): - new_gm = self.__copy__() - new_gm._is_replica = True - return new_gm - -# workarounds for issues in __torch_function__ - -# WAR for __torch_function__ not handling tensor lists, -# fix is in https://github.com/pytorch/pytorch/pull/34725 -# orig_cat = torch.cat -# def patched_cat(*args, **kwargs): -# tensors = args[0] -# for t in tensors: -# if isinstance(t, Proxy): -# return t.__torch_function__(patched_cat, (), args, kwargs) -# return orig_cat(*args, **kwargs) -# patched_cat.__module__ = 'torch' -# patched_cat.__name__ = 'cat' -# torch.cat = patched_cat diff --git a/pippy/fx/immutable_collections.py b/pippy/fx/immutable_collections.py deleted file mode 100644 index de884c205..000000000 --- a/pippy/fx/immutable_collections.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Any, Dict, Tuple, List - -from ._compatibility import compatibility -from torch.utils._pytree import Context, _register_pytree_node - -__all__ = ["immutable_list", "immutable_dict"] - -_help_mutation = """\ -If you are attempting to modify the kwargs or args of a pippy.fx.Node object, -instead create a new copy of it and assign the copy to the node: - new_args = ... # copy and mutate args - node.args = new_args -""" - -def _no_mutation(self, *args, **kwargs): - raise NotImplementedError(f"'{type(self).__name__}' object does not support mutation. {_help_mutation}") - -def _create_immutable_container(base, mutable_functions): - container = type('immutable_' + base.__name__, (base,), {}) - for attr in mutable_functions: - setattr(container, attr, _no_mutation) - return container - -immutable_list = _create_immutable_container(list, - ['__delitem__', '__iadd__', '__imul__', '__setitem__', 'append', - 'clear', 'extend', 'insert', 'pop', 'remove']) -immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) - -compatibility(is_backward_compatible=True)(immutable_list) - -immutable_dict = _create_immutable_container(dict, ['__delitem__', '__setitem__', 'clear', 'pop', 'popitem', 'update']) -immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) -compatibility(is_backward_compatible=True)(immutable_dict) - - -# Register immutable collections for PyTree operations - -def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - -def _immutable_dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return immutable_dict({key: value for key, value in zip(context, values)}) - -def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: - return d, None - -def _immutable_list_unflatten(values: List[Any], context: Context) -> List[Any]: - return immutable_list(values) - - -_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) -_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) diff --git a/pippy/fx/interpreter.py b/pippy/fx/interpreter.py deleted file mode 100644 index 7a7cf7fc0..000000000 --- a/pippy/fx/interpreter.py +++ /dev/null @@ -1,481 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .graph_module import GraphModule -from .graph import Graph -from .node import Argument, Node, Target, map_arg, map_aggregate # pylint: disable=unused-import -from .proxy import Proxy -from ._symbolic_trace import Tracer -from ._compatibility import compatibility -import pippy.fx.traceback as fx_traceback -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union -import inspect -from contextlib import contextmanager - -__all__ = ['Interpreter', 'Transformer'] - -@compatibility(is_backward_compatible=True) -class Interpreter: - """ - An Interpreter executes an FX graph Node-by-Node. This pattern - can be useful for many things, including writing code - transformations as well as analysis passes. - - Methods in the Interpreter class can be overridden to customize - the behavior of execution. The map of overrideable methods - in terms of call hierarchy:: - - run() - +-- run_node - +-- placeholder() - +-- get_attr() - +-- call_function() - +-- call_method() - +-- call_module() - +-- output() - - Example: - - Suppose we want to swap all instances of ``torch.neg`` with - ``torch.sigmoid`` and vice versa (including their ``Tensor`` - method equivalents). We could subclass Interpreter like so:: - - class NegSigmSwapInterpreter(Interpreter): - def call_function(self, target : Target, - args : Tuple, kwargs : Dict) -> Any: - if target == torch.sigmoid: - return torch.neg(*args, **kwargs) - return super().call_function(n) - - def call_method(self, target : Target, - args : Tuple, kwargs : Dict) -> Any: - if target == 'neg': - call_self, *args_tail = args - return call_self.sigmoid(*args_tail, **kwargs) - return super().call_method(n) - - def fn(x): - return torch.sigmoid(x).neg() - - gm = pippy.fx.symbolic_trace(fn) - input = torch.randn(3, 4) - result = NegSigmSwapInterpreter(gm).run(input) - torch.testing.assert_allclose(result, torch.neg(input).sigmoid()) - - Args: - module (GraphModule): The module to be executed - garbage_collect_values (bool): Whether to delete values after their last - use within the Module's execution. This ensures optimal memory usage during - execution. This can be disabled to, for example, examine all of the intermediate - values in the execution by looking at the ``Interpreter.env`` attribute. - """ - @compatibility(is_backward_compatible=True) - def __init__(self, module : GraphModule, garbage_collect_values : bool = True): - assert isinstance(module, GraphModule) - self.module = module - self.submodules = dict(self.module.named_modules()) - self.env : Dict[Node, Any] = {} - - self.garbage_collect_values = garbage_collect_values - - if self.garbage_collect_values: - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use : Dict[Node, Node] = {} - self.user_to_last_uses : Dict[Node, List[Node]] = {} - - def register_last_uses(n : Node, user : Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - self.user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(self.module.graph.nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - @compatibility(is_backward_compatible=True) - def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: - """ - Run `module` via interpretation and return the result. - - Args: - *args: The arguments to the Module to run, in positional order - initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. - This is a dict mapping `Node` to any value. This can be used, for example, to - pre-populate results for certain `Nodes` so as to do only partial evaluation within - the interpreter. - enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and - process_outputs function first before using them. - - Returns: - Any: The value returned from executing the Module - """ - self.env = initial_env if initial_env else {} - - # Positional function args are consumed left-to-right by - # `placeholder` nodes. Use an iterator to keep track of - # position and extract those values. - if enable_io_processing: - args = self.module.graph.process_inputs(*args) - self.args_iter : Iterator[Any] = iter(args) - - for node in self.module.graph.nodes: - if node in self.env: - # Short circuit if we have this value. This could - # be used, for example, for partial evaluation - # where the caller has pre-populated `env` with - # values for a subset of the program. - continue - - try: - self.env[node] = self.run_node(node) - except Exception as e: - msg = f"While executing {node.format_node()}" - msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg) - msg += f"\nOriginal traceback:\n{node.stack_trace}" - e.args = (msg,) + e.args[1:] - if isinstance(e, KeyError): - raise RuntimeError(*e.args) - raise - - if self.garbage_collect_values: - for to_delete in self.user_to_last_uses.get(node, []): - del self.env[to_delete] - - if node.op == 'output': - output_val = self.env[node] - return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val - - @contextmanager - def _set_current_node(self, node): - with fx_traceback.append_stack_trace(node.stack_trace): - yield - - @compatibility(is_backward_compatible=True) - def run_node(self, n : Node) -> Any: - """ - Run a specific node ``n`` and return the result. - Calls into placeholder, get_attr, call_function, - call_method, call_module, or output depending - on ``node.op`` - - Args: - n (Node): The Node to execute - - Returns: - Any: The result of executing ``n`` - """ - with fx_traceback.append_stack_trace(n.stack_trace): - args, kwargs = self.fetch_args_kwargs_from_env(n) - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - return getattr(self, n.op)(n.target, args, kwargs) - - # Main Node running APIs - @compatibility(is_backward_compatible=True) - def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``placeholder`` node. Note that this is stateful: - ``Interpreter`` maintains an internal iterator over - arguments passed to ``run`` and this method returns - next() on that iterator. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Returns: - Any: The argument value that was retrieved. - """ - assert isinstance(target, str) - if target.startswith('*'): - # For a starred parameter e.g. `*args`, retrieve all - # remaining values from the args list. - return list(self.args_iter) - else: - try: - return next(self.args_iter) - except StopIteration as si: - if len(args) > 0: - return args[0] - else: - raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') - - @compatibility(is_backward_compatible=True) - def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``get_attr`` node. Will retrieve an attribute - value from the ``Module`` hierarchy of ``self.module``. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return: - Any: The value of the attribute that was retrieved - """ - assert isinstance(target, str) - return self.fetch_attr(target) - - @compatibility(is_backward_compatible=True) - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``call_function`` node and return the result. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return - Any: The value returned by the function invocation - """ - assert not isinstance(target, str) - - # Execute the function and return the result - return target(*args, **kwargs) - - @compatibility(is_backward_compatible=True) - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``call_method`` node and return the result. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return - Any: The value returned by the method invocation - """ - # args[0] is the `self` object for this method call - self_obj, *args_tail = args - - # Execute the method and return the result - assert isinstance(target, str) - return getattr(self_obj, target)(*args_tail, **kwargs) - - @compatibility(is_backward_compatible=True) - def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``call_module`` node and return the result. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return - Any: The value returned by the module invocation - """ - # Retrieve executed args and kwargs values from the environment - - # Execute the method and return the result - assert isinstance(target, str) - submod = self.fetch_attr(target) - - return submod(*args, **kwargs) - - @compatibility(is_backward_compatible=True) - def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute an ``output`` node. This really just retrieves - the value referenced by the ``output`` node and returns it. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return: - Any: The return value referenced by the output node - """ - return args[0] - - # Helper methods - @compatibility(is_backward_compatible=True) - def fetch_attr(self, target : str): - """ - Fetch an attribute from the ``Module`` hierarchy of ``self.module``. - - Args: - target (str): The fully-qualfiied name of the attribute to fetch - - Return: - Any: The value of the attribute. - """ - target_atoms = target.split('.') - attr_itr = self.module - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") - attr_itr = getattr(attr_itr, atom) - return attr_itr - - @compatibility(is_backward_compatible=True) - def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: - """ - Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` - from the current execution environment. - - Args: - n (Node): The node for which ``args`` and ``kwargs`` should be fetched. - - Return: - Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. - """ - args = self.map_nodes_to_values(n.args, n) - assert isinstance(args, tuple) - kwargs = self.map_nodes_to_values(n.kwargs, n) - assert isinstance(kwargs, dict) - return args, kwargs - - @compatibility(is_backward_compatible=True) - def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: - """ - Recursively descend through ``args`` and look up the concrete value - for each ``Node`` in the current execution environment. - - Args: - args (Argument): Data structure within which to look up concrete values - - n (Node): Node to which ``args`` belongs. This is only used for error reporting. - """ - def load_arg(n_arg : Node) -> Any: - if n_arg not in self.env: - raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' - f'to diagnose such issues') - return self.env[n_arg] - return map_arg(args, load_arg) - -@compatibility(is_backward_compatible=True) -class Transformer(Interpreter): - """ - ``Transformer`` is a special type of interpreter that produces a - new ``Module``. It exposes a ``transform()`` method that returns - the transformed ``Module``. ``Transformer`` does not require - arguments to run, as ``Interpreter`` does. ``Transformer`` works - entirely symbolically. - - Example: - - Suppose we want to swap all instances of ``torch.neg`` with - ``torch.sigmoid`` and vice versa (including their ``Tensor`` - method equivalents). We could subclass ``Transformer`` like so:: - - class NegSigmSwapXformer(Transformer): - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - if target == torch.sigmoid: - return torch.neg(*args, **kwargs) - return super().call_function(n) - - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - if target == 'neg': - call_self, *args_tail = args - return call_self.sigmoid(*args_tail, **kwargs) - return super().call_method(n) - - def fn(x): - return torch.sigmoid(x).neg() - - gm = pippy.fx.symbolic_trace(fn) - - transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() - input = torch.randn(3, 4) - torch.testing.assert_allclose(transformed(input), torch.neg(input).sigmoid()) - - Args: - module (GraphModule): The ``Module`` to be transformed. - """ - - @compatibility(is_backward_compatible=True) - def __init__(self, module): - super().__init__(module) - self.new_graph = Graph() - self.new_graph.set_codegen(module.graph._codegen) - - class TransformerTracer(Tracer): - def __init__(self, graph: Graph): - super().__init__() - self.graph = graph - - def is_leaf_module(self, _, __) -> bool: - return True - - self.tracer = TransformerTracer(self.new_graph) - self.tracer.root = module - - @compatibility(is_backward_compatible=True) - def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: - """ - Execute a ``placeholder`` node. In ``Transformer``, this is - overridden to insert a new ``placeholder`` into the output - graph. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - """ - assert isinstance(target, str) - default_value = next(iter(args)) if args else inspect.Signature.empty - return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) - - @compatibility(is_backward_compatible=True) - def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: - """ - Execute a ``get_attr`` node. In ``Transformer``, this is - overridden to insert a new ``get_attr`` node into the output - graph. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - """ - assert isinstance(target, str) - return Proxy(self.new_graph.get_attr(target), self.tracer) - - @compatibility(is_backward_compatible=True) - def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - # Override so that the leaf module policy from `self.tracer` is respected. - assert isinstance(target, str) - submod = self.fetch_attr(target) - return self.tracer.call_module(submod, submod.forward, args, kwargs) - - @compatibility(is_backward_compatible=True) - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - # Override so that functions that were wrapped are still wrapped. - return self.tracer.create_proxy('call_function', target, args, kwargs) - - @compatibility(is_backward_compatible=True) - def transform(self) -> GraphModule: - """ - Transform ``self.module`` and return the transformed - ``GraphModule``. - """ - with fx_traceback.override_stack_trace(): - result = super().run(enable_io_processing=False) - if result is not None: - def strip_proxy(a : Union[Argument, Proxy]) -> Any: - return a.node if isinstance(a, Proxy) else a - self.new_graph.output(map_aggregate(result, strip_proxy)) - return GraphModule(self.module, self.new_graph) diff --git a/pippy/fx/node.py b/pippy/fx/node.py deleted file mode 100644 index 5e4600319..000000000 --- a/pippy/fx/node.py +++ /dev/null @@ -1,627 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Nodes represent a definition of a value in our graph of operators. -from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set -from ._compatibility import compatibility -from .immutable_collections import immutable_dict, immutable_list -import torch -import builtins -import types -import warnings -from pippy.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair - -if TYPE_CHECKING: - from .graph import Graph - -__all__ = ['Node', 'map_arg', 'map_aggregate'] - -BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, - torch.Tensor, torch.device, torch.memory_format, torch.layout] -base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] - -Target = Union[Callable[..., Any], str] - -Argument = Optional[Union[ - Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types - List[Any], # actually Argument - Dict[str, Any], # actually Argument - slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing - 'Node', - BaseArgumentTypes -]] - -_side_effectful_functions: Set[Callable] = { - torch._assert, - torch.ops.profiler._record_function_enter, - torch.ops.profiler._record_function_enter_new, - torch.ops.profiler._record_function_exit} - -# this is fixed on master, WAR for 1.5 -def _find_module_of_method(orig_method: Callable[..., Any]) -> str: - name = orig_method.__name__ - module = orig_method.__module__ - if module is not None: - return module - for guess in [torch, torch.nn.functional]: - if getattr(guess, name, None) is orig_method: - return guess.__name__ - raise RuntimeError(f'cannot find module for {orig_method}') - -# Borrowed from CPython typing module -# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 -def _type_repr(obj): - """Return the repr() of an object, special-casing types (internal helper). - If obj is a type, we return a shorter version than the default - type.__repr__, based on the module and qualified name, which is - typically enough to uniquely identify a type. For everything - else, we fall back on repr(obj). - """ - if isinstance(obj, type): - if obj.__module__ == 'builtins': - return obj.__qualname__ - return f'{obj.__module__}.{obj.__qualname__}' - if obj is ...: - return('...') - if isinstance(obj, types.FunctionType): - return obj.__name__ - return repr(obj) - -def _get_qualified_name(func: Callable[..., Any]) -> str: - # things like getattr just appear in builtins - if getattr(builtins, func.__name__, None) is func: - return func.__name__ - name = func.__name__ - module = _find_module_of_method(func) - module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module - return f'{module}.{name}' - -def _format_arg(arg, max_list_len=float('inf')) -> str: - if hasattr(arg, '_custom_fx_repr_fn'): - return arg._custom_fx_repr_fn() - elif isinstance(arg, list): - items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) - maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' - return f'[{items}{maybe_len}]' - elif isinstance(arg, tuple): - items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) - maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' - maybe_comma = ',' if len(arg) == 1 else '' - return f'({items}{maybe_comma}{maybe_len})' - elif isinstance(arg, dict): - items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) - return f'{{{items_str}}}' - - if isinstance(arg, Node): - return '%' + str(arg) - else: - return str(arg) - -@compatibility(is_backward_compatible=True) -class Node: - """ - ``Node`` is the data structure that represents individual operations within - a ``Graph``. For the most part, Nodes represent callsites to various entities, - such as operators, methods, and Modules (some exceptions include nodes that - specify function inputs and outputs). Each ``Node`` has a function specified - by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: - - - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. - ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument - denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to - the function parameters (e.g. ``x``) in the graph printout. - - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the - fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. - ``args`` and ``kwargs`` are don't-care - - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign - to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, - following the Python calling convention - - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is - as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. - ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. - - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method - to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, - *including the self argument* - - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement - in the Graph printout. - """ - - @compatibility(is_backward_compatible=True) - def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', - args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], - return_type : Optional[Any] = None) -> None: - """ - Instantiate an instance of ``Node``. Note: most often, you want to use the - Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather - than instantiating a ``Node`` directly. - - Args: - graph (Graph): The ``Graph`` to which this ``Node`` should belong. - - name (str): The name to which the output of this ``Node`` should be assigned - - op (str): The opcode for this ``Node``. Can be one of 'placeholder', - 'call_method', 'call_module', 'call_function', 'get_attr', - 'output' - - target ('Target'): The target this op should call. See the broader - ``Node`` docstring for more details. - - args (Tuple['Argument']): The args to be passed to ``target`` - - kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` - - return_type (Optional[Any]): The python type expression representing the - type of the output of this node. This field can be used for - annotation of values in the generated code or for other types - of analyses. - """ - self.graph = graph - self.name = name # unique name of value being created - assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] - self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr - if op == 'call_function': - if not callable(target): - raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' - 'but a Callable is expected') - else: - if not isinstance(target, str): - raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' - 'but a str is expected') - self.target = target # for method/module/function, the name of the method/module/function/attr - # being invoked, e.g add, layer1, or torch.add - - # All `Node`-valued inputs. Key is the Node, value is don't-care. - # The public API for this is `all_input_nodes`, this private attribute - # should not be accessed directly. - self._input_nodes : Dict[Node, None] = {} - self.__update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore[arg-type] - - # All of the nodes that use the value produced by this Node - # Note one user may correspond to several uses, e.g. the node fo ``x + x`` - # would appear once here, but represents two uses. - # - # Is a dict to act as an "ordered set". Keys are significant, value dont-care - self.users : Dict['Node', None] = {} - # Type expression representing the output value of this node. - # This should contain the same class of Type objects that would appear - # as type annotations for function inputs/outputs. - # - # For placeholder nodes, this value will be used to type-annotate the - # generated function parameters. - # For the return node, this value will be used to type-annotate the - # generated function return type. (Note this is a special case. ``return`` - # does not produce a value, it's more of a notation. Thus, this value - # describes the type of args[0] in the ``return`` node. - self.type : Optional[Any] = return_type - self._prev = self - self._next = self - self._erased = False - - # If set, use this fn to print this node - self._repr_fn : Optional[Callable[[Node], str]] = None - - # Dictionary to store metadata passes need to do their - # transformations. This metadata is preserved across node copies - self.meta : Dict[str, Any] = {} - - @property - def next(self) -> 'Node': - """ - Returns the next ``Node`` in the linked list of Nodes. - - Returns: - - The next ``Node`` in the linked list of Nodes. - """ - return self._next - - @property - def prev(self) -> 'Node': - """ - Returns the previous ``Node`` in the linked list of Nodes. - - Returns: - - The previous ``Node`` in the linked list of Nodes. - """ - return self._prev - - @compatibility(is_backward_compatible=True) - def prepend(self, x: 'Node') -> None: - """ - Insert x before this node in the list of nodes in the graph. Example:: - - Before: p -> self - bx -> x -> ax - After: p -> x -> self - bx -> ax - - Args: - x (Node): The node to put before this node. Must be a member of the same graph. - """ - assert self.graph == x.graph, "Attempting to move a Node into a different Graph" - if self == x: - warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") - return - x._remove_from_list() - p = self._prev - p._next, x._prev = x, p - x._next, self._prev = self, x - - @compatibility(is_backward_compatible=True) - def append(self, x: 'Node') -> None: - """ - Insert ``x`` after this node in the list of nodes in the graph. - Equivalent to ``self.next.prepend(x)`` - - Args: - x (Node): The node to put after this node. Must be a member of the same graph. - """ - self._next.prepend(x) - - def _remove_from_list(self): - p, n = self._prev, self._next - p._next, n._prev = n, p - - @property - def args(self) -> Tuple[Argument, ...]: - """ - The tuple of arguments to this ``Node``. The interpretation of arguments - depends on the node's opcode. See the :class:`Node` docstring for more - information. - - Assignment to this property is allowed. All accounting of uses and users - is updated automatically on assignment. - """ - return self._args - - @args.setter - def args(self, a : Tuple[Argument, ...]): - """ - Set the tuple of arguments to this Node. The interpretation of arguments - depends on the node's opcode. See the ``fx.Graph`` docstring for more - information. - """ - # DO NOT CALL `__update_args_kwargs` directly. The correct way to - # set `args` is via direct assignment, i.e. `node.args = new_args` - self.__update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore[arg-type] - - @property - def kwargs(self) -> Dict[str, Argument]: - """ - The dict of keyword arguments to this ``Node``. The interpretation of arguments - depends on the node's opcode. See the :class:`Node` docstring for more - information. - - Assignment to this property is allowed. All accounting of uses and users - is updated automatically on assignment. - """ - return self._kwargs - - @kwargs.setter - def kwargs(self, k : Dict[str, Argument]): - """ - Set the dict of kwargs to this Node. The interpretation of arguments - depends on the node's opcode. See the ``fx.Graph`` docstring for more - information. - """ - # DO NOT CALL `__update_args_kwargs` directly. The correct way to - # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` - self.__update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore[arg-type] - - @property - def all_input_nodes(self) -> List['Node']: - """ - Return all Nodes that are inputs to this Node. This is equivalent to - iterating over ``args`` and ``kwargs`` and only collecting the values that - are Nodes. - - Returns: - - List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this - ``Node``, in that order. - """ - return list(self._input_nodes.keys()) - - @compatibility(is_backward_compatible=True) - def update_arg(self, idx : int, arg : Argument) -> None: - """ - Update an existing positional argument to contain the new value - ``arg``. After calling, ``self.args[idx] == arg``. - - Args: - - idx (int): The index into ``self.args`` of the element to update - arg (Argument): The new argument value to write into ``args`` - """ - args = list(self.args) - args[idx] = arg - self.args = tuple(args) - - @compatibility(is_backward_compatible=True) - def update_kwarg(self, key : str, arg : Argument) -> None: - """ - Update an existing keyword argument to contain the new value - ``arg``. After calling, ``self.kwargs[key] == arg``. - - Args: - - key (str): The key in ``self.kwargs`` of the element to update - arg (Argument): The new argument value to write into ``kwargs`` - """ - kwargs = dict(self.kwargs) - kwargs[key] = arg - self.kwargs = kwargs - - @property - def stack_trace(self) -> Optional[str]: - """ - Return the Python stack trace that was recorded during tracing, if any. - This property is usually populated by `Tracer.create_proxy`. To record - stack traces during tracing for debug purposes, set - `record_stack_traces = True` on the `Tracer` instance. - """ - return self.meta.get("stack_trace", None) - - @stack_trace.setter - def stack_trace(self, trace : Optional[str]): - self.meta["stack_trace"] = trace - - def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']): - """ - This API is internal. Do *not* call it directly. - """ - self._args = new_args - self._kwargs = new_kwargs - - for old_use in self._input_nodes.keys(): - old_use.users.pop(self) - - self._input_nodes = {} - map_arg(self._args, lambda n: self._input_nodes.setdefault(n)) - map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n)) - - for new_use in self._input_nodes.keys(): - new_use.users.setdefault(self) - - def __repr__(self) -> str: - if self._repr_fn: - return self._repr_fn(self) - return self.name - - def _pretty_print_target(self, target): - """ - Make target printouts more user-friendly. - 1) builtins will be printed as `builtins.xyz` - 2) operators will be printed as `operator.xyz` - 3) other callables will be printed with qualfied name, e.g. torch.add - """ - if isinstance(target, str): - return target - if hasattr(target, '__module__'): - if not hasattr(target, '__name__'): - # Just to be defensive, if we don't have `__name__`, get the - # qualname. Not sure if this happens for any members of `operator` - # or `builtins`. This fallback path is not as good, since e.g. - # things in `operator` have `_operator` as their __module__. - return _get_qualified_name(target) - if target.__module__ == 'builtins': - return f'builtins.{target.__name__}' - elif target.__module__ == '_operator': - return f'operator.{target.__name__}' - return _get_qualified_name(target) - - @compatibility(is_backward_compatible=True) - def format_node(self, - placeholder_names: Optional[List[str]] = None, - maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: - """ - Return a descriptive string representation of ``self``. - - This method can be used with no arguments as a debugging - utility. - - This function is also used internally in the ``__str__`` method - of ``Graph``. Together, the strings in ``placeholder_names`` - and ``maybe_return_typename`` make up the signature of the - autogenerated ``forward`` function in this Graph's surrounding - GraphModule. ``placeholder_names`` and ``maybe_return_typename`` - should not be used otherwise. - - Args: - placeholder_names: A list that will store formatted strings - representing the placeholders in the generated - ``forward`` function. Internal use only. - maybe_return_typename: A single-element list that will store - a formatted string representing the output of the - generated ``forward`` function. Internal use only. - - Returns: - str: If 1) we're using ``format_node`` as an internal helper - in the ``__str__`` method of ``Graph``, and 2) ``self`` - is a placeholder Node, return ``None``. Otherwise, - return a descriptive string representation of the - current Node. - """ - if self.op == 'placeholder': - assert isinstance(self.target, str) - arg_str = self.target - arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' - if placeholder_names: - placeholder_names.append(arg_str) - return None - maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' - default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' - return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' - elif self.op == 'get_attr': - maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' - return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = ' \ - f'{self.op}[target={self._pretty_print_target(self.target)}]' - elif self.op == 'output': - if self.type and maybe_return_typename: - maybe_return_typename[0] = f' -> {_type_repr(self.type)}' - return f'return {self.args[0]}' - else: - maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' - return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = ' \ - f'{self.op}[target={self._pretty_print_target(self.target)}](' \ - f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' - - @compatibility(is_backward_compatible=True) - def replace_all_uses_with(self, - replace_with : 'Node', - delete_user_cb: Callable[['Node'], bool] = lambda user: True - ) -> List['Node']: - """ - Replace all uses of ``self`` in the Graph with the Node ``replace_with``. - - Args: - - replace_with (Node): The node to replace all uses of ``self`` with. - delete_user_cb (Callable): Callback that is called to determine - whether a given user of the self node should be removed. - - Returns: - - The list of Nodes on which this change was made. - """ - to_process = list(self.users) - skipped = [] - for use_node in to_process: - if not delete_user_cb(use_node): - skipped.append(use_node) - continue - - def maybe_replace_node(n : Node) -> Node: - if n == self: - return replace_with - else: - return n - - new_args = map_arg(use_node.args, maybe_replace_node) - new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) - assert isinstance(new_args, tuple) - assert isinstance(new_kwargs, dict) - use_node.__update_args_kwargs(new_args, new_kwargs) - - assert len(self.users) - len(skipped) == 0 - return [n for n in to_process if n not in skipped] - - @compatibility(is_backward_compatible=False) - def is_impure(self): - """ - Returns whether this op is impure, i.e. if its op is a placeholder or - output, or if a call_function or call_module which is impure. - - Returns: - - bool: If the op is impure or not. - """ - if self.op in {"placeholder", "output"}: - return True - - # Check if an impure function. - if self.op == "call_function": - return self.target in _side_effectful_functions - - # Check if an impure module. - if self.op == "call_module": - assert ( - self.graph.owning_module is not None - ), "self.graph.owning_module not set for purity check" - target_mod = self.graph.owning_module.get_submodule(self.target) - assert ( - target_mod is not None - ), f"Did not find expected submodule target {self.target}" - return getattr(target_mod, "_is_impure", False) - - return False - - @compatibility(is_backward_compatible=False) - def normalized_arguments( - self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, - kwarg_types : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: - """ - Returns normalized arguments to Python targets. This means that - `args/kwargs` will be matched up to the module/functional's - signature and return exclusively kwargs in positional order - if `normalize_to_only_use_kwargs` is true. - Also populates default values. Does not support positional-only - parameters or varargs parameters. - - Supports module calls. - - May require `arg_types` and `kwarg_types` in order to disambiguate overloads. - - Args: - root (torch.nn.Module): Module upon which to resolve module targets. - arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args - kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs - normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. - - Returns: - - Returns NamedTuple ArgsKwargsPair, or `None` if not successful. - """ - if self.op == 'call_function': - assert callable(self.target) - return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] - elif self.op == 'call_module': - assert isinstance(self.target, str) - return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] - - return None - - @compatibility(is_backward_compatible=True) - def replace_input_with(self, old_input: 'Node', new_input: 'Node'): - """ - Loop through input nodes of ``self``, and replace all instances of - ``old_input`` with ``new_input``. - - Args: - - old_input (Node): The old input node to be replaced. - new_input (Node): The new input node to replace ``old_input``. - """ - def maybe_replace_node(n : Node) -> Node: - return new_input if n == old_input else n - - new_args = map_arg(self.args, maybe_replace_node) - new_kwargs = map_arg(self.kwargs, maybe_replace_node) - assert isinstance(new_args, tuple) - assert isinstance(new_kwargs, dict) - self.__update_args_kwargs(new_args, new_kwargs) - - -@compatibility(is_backward_compatible=True) -def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: - """ - Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. - """ - assert callable(fn), "pippy.fx.map_arg(a, fn): fn must be a callable" - return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) - - -@compatibility(is_backward_compatible=True) -def map_aggregate(a: Argument, fn: Callable[[Argument], Argument], - should_traverse_fn: Optional[Callable[[Argument], bool]] = None) -> Argument: - """ - Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. - Traverses list, tuple, slice, or dict if ``should_traverse_fn`` is either None or returns True for supplied argument - """ - if should_traverse_fn and not should_traverse_fn(a): - return fn(a) - - if isinstance(a, tuple): - t = tuple(map_aggregate(elem, fn, should_traverse_fn) for elem in a) - # Support NamedTuple (if it has `_fields`) by repacking into original type. - return t if not hasattr(a, '_fields') else type(a)(*t) - elif isinstance(a, list): - return immutable_list(map_aggregate(elem, fn, should_traverse_fn) for elem in a) - elif isinstance(a, dict): - return immutable_dict((k, map_aggregate(v, fn, should_traverse_fn)) for k, v in a.items()) - elif isinstance(a, slice): - return slice(map_aggregate(a.start, fn, should_traverse_fn), map_aggregate(a.stop, fn, should_traverse_fn), - map_aggregate(a.step, fn, should_traverse_fn)) - else: - return fn(a) diff --git a/pippy/fx/operator_schemas.py b/pippy/fx/operator_schemas.py deleted file mode 100644 index eccabf917..000000000 --- a/pippy/fx/operator_schemas.py +++ /dev/null @@ -1,409 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import inspect -import numbers -import types -import typing -import enum -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING -from torch._jit_internal import boolean_dispatched -from ._compatibility import compatibility -from torch._ops import OpOverloadPacket, OpOverload - -if TYPE_CHECKING: - from .node import Argument - -@compatibility(is_backward_compatible=False) -class ArgsKwargsPair(NamedTuple): - """ - Simple named tuple for wrapping args/kwargs pairs. - """ - args: Tuple[Any, ...] - kwargs: Dict[str, Any] - -_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} - -def _nonzero_schemas(): - signatures = [] - - def nonzero(self): - pass - signatures.append(inspect.signature(nonzero)) - - def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] - pass - signatures.append(inspect.signature(nonzero)) - - return signatures - -_manual_overrides[torch.nonzero] = _nonzero_schemas() - -class _FakeGlobalNamespace: - def __getattr__(self, name): - if name == 'torch': - return torch - raise RuntimeError('Expected a torch namespace lookup') - -_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, - 'number' : numbers.Number, 'Future' : torch.jit.Future, - 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, - '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), - 't': typing.TypeVar('t')} -for k in dir(typing): - _type_eval_globals[k] = getattr(typing, k) - -def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: - """ - Convert a TorchScript type to a Python type (including subtypes) via - eval'ing the annotation_str. _type_eval_globals sets up expressions - like "List" and "Future" to map to actual types (typing.List and jit.Future) - """ - return eval(ts_type.annotation_str, _type_eval_globals) - -def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: - parameters : List[inspect.Parameter] = [] - for arg in ts_schema.arguments: - arg_type = _torchscript_type_to_python_type(arg.type) - default = arg.default_value if arg.has_default_value() else inspect.Parameter.empty - # TODO: Figure out if this is safe. It seems like when generating the type signatures for - # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor - # argument name. Downstream, if someone converts that positional argument to a keyword - # argument, the name mismatch will break things, so here we're going to normalize the - # name to "input" - name = arg.name if arg.name != 'self' else 'input' - kind = inspect.Parameter.KEYWORD_ONLY if arg.kwarg_only else inspect.Parameter.POSITIONAL_OR_KEYWORD - parameters.append(inspect.Parameter(name=name, kind=kind, default=default, annotation=arg_type)) - return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] - if len(return_types) == 0: - return_type = None - elif len(return_types) == 1: - return_type = return_types[0] - else: - return_type = tuple(return_types) - - return inspect.Signature(parameters, return_annotation=return_type) - -@compatibility(is_backward_compatible=False) -def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): - signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) - - if signatures and schemas: - matched_schemas = [] - - # Iterate through all of the schema until we find one that matches - # If one matches, populate `new_args_and_kwargs` with the new args/kwargs - # values. If none matches, `new_args_and_kwargs` will be None - for candidate_signature, schema in zip(signatures, schemas): - try: - candidate_signature.bind(*args, **kwargs) - matched_schemas.append((candidate_signature, schema)) - except TypeError as e: - continue - - def throw_if_mutable(schema): - if schema.is_mutable: - raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' - f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' - f'are not supported') - - if len(matched_schemas) == 0: - # Did not match any schema. Cannot check for mutation - pass - elif len(matched_schemas) == 1: - # Matched exactly one schema, unambiguous - _, schema_to_check = matched_schemas[0] - throw_if_mutable(schema_to_check) - pass - else: - # Ambiguous schema match. Since mutability checking is best effort, - # do nothing. - pass - -@compatibility(is_backward_compatible=False) -def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): - """ - Given an operator on the `torch` namespace, return a list of `inspect.Signature` - objects corresponding to the overloads of that op.. May return `None` if a signature - could not be retrieved. - - Args: - op (Callable): An operator on the `torch` namespace to look up a signature for - - Returns: - Optional[List[inspect.Signature]]: A list of signatures for the overloads of this - operator, or None if the operator signatures could not be retrieved. If - return_schemas=True, returns a tuple containing the optional Python signatures - and the optional TorchScript Function signature - """ - if isinstance(op, OpOverload): - schemas = [op._schema] - elif isinstance(op, OpOverloadPacket): - schemas = [getattr(op, overload)._schema for overload in op.overloads()] - else: - override = _manual_overrides.get(op) - if override: - return (override, None) if return_schemas else None - - aten_fn = torch.jit._builtins._find_builtin(op) - - if aten_fn is None: - return (None, None) if return_schemas else None - schemas = torch._C._jit_get_schemas_for_operator(aten_fn) - - signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] - return (signatures, schemas) if return_schemas else signatures - -@compatibility(is_backward_compatible=False) -def create_type_hint(x): - try: - if isinstance(x, list) or isinstance(x, tuple): - # todo(chilli): Figure out the right way for mypy to handle this - if isinstance(x, list): - def ret_type(x): - return List[x] # type: ignore[valid-type] - else: - def ret_type(x): - return Tuple[x, ...] - if len(x) == 0: - return ret_type(Any) - base_type = x[0] - for t in x: - if issubclass(t, base_type): - continue - elif issubclass(base_type, t): - base_type = t - else: - return ret_type(Any) - return ret_type(base_type) - except Exception as e: - # We tried to create a type hint for list but failed. - warnings.warn(f"We were not able to successfully create type hint from the type {x}") - pass - return x - -@compatibility(is_backward_compatible=False) -def type_matches(signature_type : Any, argument_type : Any): - sig_origin_type = getattr(signature_type, '__origin__', signature_type) - - if signature_type is argument_type: - return True - - # Union types in signature. Given type needs to match one of the - # contained types in the Union - if sig_origin_type is typing.Union and signature_type != argument_type: - sig_contained = signature_type.__args__ - return any(type_matches(c, argument_type) for c in sig_contained) - - if signature_type is List[int] and argument_type is int: - # int can be promoted to List[int] - return True - - if getattr(signature_type, '__origin__', None) in {list, List}: - sig_el_type = signature_type.__args__[0] - if not inspect.isclass(sig_el_type): - warnings.warn( - f"Does not support nested parametric types, got {signature_type}. Please file a bug.") - return False - if getattr(argument_type, '__origin__', None) in {list, List}: - return issubclass(argument_type.__args__[0], sig_el_type) - - def is_homogeneous_tuple(t): - if not getattr(t, '__origin__', None) in {tuple, Tuple}: - return False - contained = t.__args__ - if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason - return True - return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) - - # Tuple[T] is accepted for List[T] parameters - return is_homogeneous_tuple(argument_type) - - # Dtype is an int in schemas - if signature_type is int and argument_type is torch.dtype: - return True - - if signature_type is numbers.Number and argument_type in {int, float}: - return True - if inspect.isclass(argument_type) and inspect.isclass(signature_type): - return issubclass(argument_type, signature_type) - - return False - -@compatibility(is_backward_compatible=False) -def normalize_function( - target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, - kwarg_types : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: - """ - Returns normalized arguments to PyTorch functions. This means that - `args/kwargs` will be matched up to the functional's - signature and return exclusively kwargs in positional order if - `normalize_to_only_use_kwargs` is True. - Also populates default values. Does not support positional-only - parameters or varargs parameters (*args, **kwargs). Does not support modules. - - May require `arg_types` and `kwarg_types` in order to disambiguate overloads. - - Args: - target (Callable): Function that we are normalizing - args (Tuple[Any]): Tuple of args to the function - kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function - arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args - kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs - normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. - - Returns: - - Returns normalized_args_and_kwargs, or `None` if not successful. - """ - if kwargs is None: - kwargs = {} - new_args_and_kwargs = None - if not isinstance(target, types.BuiltinFunctionType) and not ( - isinstance(target, OpOverloadPacket) or isinstance(target, OpOverload) - ): - target_for_analysis = target - if target in boolean_dispatched: - # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have - # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` - # branches of the dispatch have exactly the same signature. If they do, use the `true` - # branch signature for analysis. Otherwise, leave this un-normalized - assert not isinstance(target, str) - dispatched = boolean_dispatched[target] - if_true, if_false = dispatched['if_true'], dispatched['if_false'] - if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: - return None - target_for_analysis = if_true - - assert callable(target_for_analysis) - sig = inspect.signature(inspect.unwrap(target_for_analysis)) - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) - else: - assert callable(target) - torch_op_schemas = get_signature_for_torch_op(target) - matched_schemas = [] - if torch_op_schemas: - # Iterate through all of the schema until we find one that matches - # If one matches, populate `new_args_and_kwargs` with the new args/kwargs - # values. If none matches, `new_args_and_kwargs` will be None - for candidate_signature in torch_op_schemas: - try: - candidate_signature.bind(*args, **kwargs) - matched_schemas.append(candidate_signature) - except TypeError as e: - continue - - if len(matched_schemas) == 0: - # Did not match any schema. Cannot normalize - pass - elif len(matched_schemas) == 1: - # Matched exactly one schema, unambiguous - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, - normalize_to_only_use_kwargs) - else: - if arg_types is not None or kwarg_types is not None: - arg_types = arg_types if arg_types else cast(Tuple[Any], ()) - kwarg_types = kwarg_types if kwarg_types else {} - for candidate_signature in torch_op_schemas: - sig_matches = True - try: - bound_types = candidate_signature.bind(*arg_types, **kwarg_types) - for arg_name, arg_type in bound_types.arguments.items(): - param = candidate_signature.parameters[arg_name] - sig_matches = sig_matches and type_matches(param.annotation, arg_type) - except TypeError as e: - sig_matches = False - if sig_matches: - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, - normalize_to_only_use_kwargs) - break - else: - # Matched more than one schema. In this situation, the caller must provide the types of - # the arguments of the overload they expect. - schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) - raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' - f'the schema match was ambiguous! Please provide argument types to ' - f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') - - return new_args_and_kwargs - -@compatibility(is_backward_compatible=False) -def normalize_module( - root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: - """ - Returns normalized arguments to PyTorch modules. This means that - `args/kwargs` will be matched up to the functional's - signature and return exclusively kwargs in positional order if - `normalize_to_only_use_kwargs` is True. - Also populates default values. Does not support positional-only - parameters or varargs parameters (*args, **kwargs). - - Args: - root (nn.Module): root module upon which we query modules - target (Callable): Function that we are normalizing - args (Tuple[Any]): Tuple of args to the function - kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function - normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. - - Returns: - - Returns normalized_args_and_kwargs, or `None` if not successful. - """ - try: - submod = root.get_submodule(target) - except AttributeError: - raise RuntimeError(f"Tried to normalize node with target {target} but root did not " - f"have that target!") - if hasattr(submod.__class__, '__name__'): - classname = submod.__class__.__name__ - if getattr(torch.nn, classname, None) == submod.__class__: - sig = inspect.signature(inspect.unwrap(submod.forward)) - if kwargs is None: - kwargs = {} - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, - normalize_to_only_use_kwargs) - return new_args_and_kwargs - return None - -def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], - kwargs : Dict[str, Any], - normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: - """ - Given a call target, args, and kwargs, return the arguments normalized into - an ArgsKwargsPair, or None if the type signature is not supported by - this normalization. - - Args: - - target (inspect.Signature): Signature object for the target - args (Tuple): Arguments that appear at the callsite for `target` - kwargs (Dict): Keyword arguments that appear at the callsite for `target` - normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. - - Returns: - - Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if - this target is not supported. - """ - - # Don't currently support positional-only - # or varargs (*args, **kwargs) signatures - supported_parameter_types = { - inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} - if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): - return None - - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - - new_kwargs : Dict[str, Any] = {} - new_args : List[Any] = [] - for i, param in enumerate(sig.parameters): - if not normalize_to_only_use_kwargs and i < len(args): - new_args.append(bound_args.arguments[param]) - else: - new_kwargs[param] = bound_args.arguments[param] - - return ArgsKwargsPair(tuple(new_args), new_kwargs) diff --git a/pippy/fx/passes/README.md b/pippy/fx/passes/README.md deleted file mode 100644 index a29968487..000000000 --- a/pippy/fx/passes/README.md +++ /dev/null @@ -1,20 +0,0 @@ -## FX Pass Infrastructure -This folder contains the pass infarstructure and passes for transforming fx.Graph. - - -## Code Structure - -* [infra](infra) - Common infrastructure, such as PassManager, PassBase - * [partitioner.py](infra/partitioner.py) - backend agnostic FX graph partitioner -* [utils](utils) - Utility classes and functions - * [common.py](utils/common.py) - common utility functions - * [fuser_utis.py](utils/fuser_utils.py) - utility functions for fusing list of nodes into a single node -* [dialect](dialect) - dialect specific passes - * [common](dialect/common) - common passes that can be shared by all dialects - * [cse_pass.py](dialect/common/cse_pass.py) - a CSE pass - * [aten](dialect/aten) - aten dialect specific passes - * [prims](dialect/prims) - prim dialect specific passes -* [backends](backends) - Backend specific passes - * [nvfuser](backends/nvfuser) - passes for nvfuser - * [operator_support.py](backends/nvfuser/operator_support.py) - nvFuser supported ops -* [conversion](conversion) - Conversion passes between dialects diff --git a/pippy/fx/passes/__init__.py b/pippy/fx/passes/__init__.py deleted file mode 100644 index d20580680..000000000 --- a/pippy/fx/passes/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from . import graph_drawer -from . import graph_manipulation -from . import net_min_base -from . import operator_support -from . import param_fetch -from . import reinplace -from . import shape_prop -from . import split_module -from . import split_utils -from . import splitter_base -from . import tools_common diff --git a/pippy/fx/passes/backends/__init__.py b/pippy/fx/passes/backends/__init__.py deleted file mode 100644 index f2661b8c6..000000000 --- a/pippy/fx/passes/backends/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates diff --git a/pippy/fx/passes/backends/cudagraphs.py b/pippy/fx/passes/backends/cudagraphs.py deleted file mode 100644 index 3898a2e74..000000000 --- a/pippy/fx/passes/backends/cudagraphs.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -from pippy.fx.passes.infra.partitioner import CapabilityBasedPartitioner -from pippy.fx.passes.operator_support import OperatorSupport -from pippy.fx.passes.tools_common import CALLABLE_NODE_OPS -from pippy.fx.passes.fake_tensor_prop import FakeTensorProp -from torch.utils._pytree import tree_map - -import operator - -class CudaGraphsSupport(OperatorSupport): - # TODO: why is submodules passed here - def is_node_supported(self, submodules, node: pippy.fx.Node) -> bool: - if node.op not in CALLABLE_NODE_OPS: - return False - - if node.target in [torch.ops.aten.embedding_dense_backward.default]: - return False - - if node.target in [operator.getitem]: - return True - - found_not_cuda = False - - def meta_fk(meta): - return meta["val"] if "val" in meta else meta["fake_result"] - - def find_not_cuda(t): - nonlocal found_not_cuda - if isinstance(t, torch.Tensor) and t.device.type != 'cuda': - found_not_cuda = True - - for n in node.all_input_nodes: - tree_map(find_not_cuda, meta_fk(n.meta)) - - tree_map(find_not_cuda, meta_fk(node.meta)) - - # NB: factory function is accounted for because the result would be - # cpu or cuda - - return not found_not_cuda - -def partition_cudagraphs(gm, inputs): - """ - Partition an FX graph into sub-GraphModules that can be validly run under - CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations - must involve CUDA tensors only/ - """ - - FakeTensorProp(gm).propagate(*inputs) - supported_ops = CudaGraphsSupport() - # TODO: single node partition may be wrong due to the pessimization - # from copying in and out the data. Check in benchmarks, perhaps - partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True) - partitions = partitioner.propose_partitions() - fused_graph = partitioner.fuse_partitions(partitions) - return fused_graph diff --git a/pippy/fx/passes/backends/nvfuser.py b/pippy/fx/passes/backends/nvfuser.py deleted file mode 100644 index 689ab8432..000000000 --- a/pippy/fx/passes/backends/nvfuser.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict - -import torch -from torch.nn import Module -from torch._ops import OpOverload - -from pippy.fx import GraphModule -from pippy.fx.node import Node, _get_qualified_name -from pippy.fx.passes.operator_support import OperatorSupport -from pippy.fx.passes.tools_common import CALLABLE_NODE_OPS -from pippy.fx.passes.infra.partitioner import CapabilityBasedPartitioner -from torch._prims.executor import execute -from pippy.fx.experimental.proxy_tensor import DecompositionInterpreter -from torch._decomp import decomposition_table - -import typing as t - -import logging - -logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger(__name__) - -def aten_to_dtype(self, dtype: torch.dtype, **kwargs): - if len(kwargs) > 0 or not dtype: - raise RuntimeError("No support for other to.dtype() formats other than to.dtype(self, dtype)") - return torch._prims.convert_element_type(self, dtype) - -# decomposition_table currently contains both aten2aten and aten2prim decomposition -# this is a hack to separate them, as we only need aten2prim decomposition for nvfuser-supported aten graph lowering -aten2aten_decomp = {} -aten2prim_decomp = {} - -for op, decomp_fn in decomposition_table.items(): - if "torch._refs" in decomp_fn.__module__: - aten2prim_decomp[op] = decomp_fn - else: - aten2aten_decomp[op] = decomp_fn - -aten2aten_decomp_skips = { - "aten.native_layer_norm_backward.default", - "aten.embedding_dense_backward.default", # This is hurting nvfuser's perf - "aten.addmm.default" -} - -for op, decomp_fn in decomposition_table.items(): - if "torch._refs" in decomp_fn.__module__: - aten2prim_decomp[op] = decomp_fn - else: - if str(op) not in aten2aten_decomp_skips: - aten2aten_decomp[op] = decomp_fn - - -aten2prim_decomp[torch.ops.aten.to.dtype] = aten_to_dtype - - -class NvFuserOperatorSupport(OperatorSupport): - """ - Operator support for nvFuser backend. - - Currently, partitioning is based on FX ATen graph. The fused subgraph will latter be decomposed into prims. - To determine if an ATen ops is supported by nvFuser, we shall check the prim ops used in its ref decomposition. - Only if all the prim ops in the ref has a nvfuser_impl, we say this Aten op is suppported by nvFuser. - - Note: When adding a rule, please add it to the corresponding section and follow the - alphabetical order. - """ - - def __init__(self): - - # TODO: current list copied from torch/csrc/jit/codegen/cuda/parser.cpp is incorrect, - # as that file is solely for TorchScript and doesn't represent the actual status - # whether operation would be runnable by primTorch+nvFuser. - # We will iterate on this list to reflect the the reality. - support_dict = { - # =============================================================== - # call_function aten - # =============================================================== - # Following supported aten ops is copied from torch/csrc/jit/codegen/cuda/parser.cpp - # TODO: might need to update according to supported input types - "torch.ops.aten.add": None, - "torch.ops.aten.sub": None, - # "torch.ops.aten.rsub": None, # rsub decomp is supported at aten2aten level - "torch.ops.aten.div": None, - "torch.ops.aten.atan2": None, - "torch.ops.aten.mul": None, - "torch.ops.aten.max": None, - "torch.ops.aten.min": None, - "torch.ops.aten.pow": None, - "torch.ops.aten.remainder": None, - "torch.ops.aten.fmod": None, - "torch.ops.aten.bitwise_and": None, - "torch.ops.aten.__and__": None, - "torch.ops.aten.bitwise_or": None, - "torch.ops.aten.__or__": None, - "torch.ops.aten.bitwise_xor": None, - "torch.ops.aten.__xor__": None, - "torch.ops.aten.bitwise_left_shift": None, - "torch.ops.aten.__lshift__": None, - "torch.ops.aten.bitwise_right_shift": None, - "torch.ops.aten.__rshift__": None, - "torch.ops.aten.eq": None, - "torch.ops.aten.ne": None, - "torch.ops.aten.ge": None, - "torch.ops.aten.gt": None, - "torch.ops.aten.le": None, - "torch.ops.aten.lt": None, - "torch.ops.aten.abs": None, - "torch.ops.aten.bitwise_not": None, - "torch.ops.aten.ceil": None, - "torch.ops.aten.floor": None, - "torch.ops.aten.frac": None, - "torch.ops.aten.neg": None, - "torch.ops.aten.relu": None, - "torch.ops.aten.round": None, - "torch.ops.aten.silu": None, - "torch.ops.aten.trunc": None, - "torch.ops.aten.log": None, - "torch.ops.aten.log10": None, - "torch.ops.aten.log1p": None, - "torch.ops.aten.log2": None, - "torch.ops.aten.lgamma": None, - "torch.ops.aten.exp": None, - "torch.ops.aten.expm1": None, - "torch.ops.aten.erf": None, - "torch.ops.aten.erfc": None, - "torch.ops.aten.cos": None, - "torch.ops.aten.acos": None, - "torch.ops.aten.cosh": None, - "torch.ops.aten.sin": None, - "torch.ops.aten.asin": None, - "torch.ops.aten.sinh": None, - "torch.ops.aten.tan": None, - "torch.ops.aten.atan": None, - "torch.ops.aten.tanh": None, - "torch.ops.aten.atanh": None, - "torch.ops.aten.sqrt": None, - "torch.ops.aten.rsqrt": None, - "torch.ops.aten.reciprocal": None, - "torch.ops.aten.sigmoid": None, - "torch.ops.aten.isfinite": None, - "torch.ops.aten.isinf": None, - "torch.ops.aten.isnan": None, - "torch.ops.aten.isneginf": None, - "torch.ops.aten.isposinf": None, - "torch.ops.aten.isreal": None, - # "torch.ops.aten.rand_like": None, # causing Node empty_like_default does not support nvfuser - "torch.ops.aten.softplus": None, - "torch.ops.aten.threshold": None, - # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.new_zero op - # "torch.ops.aten.threshold_backward": None, - "torch.ops.aten.clamp": None, - # "torch.ops.aten.clone": None, - # Failing with where(): incompatible function arguments: \ - # [aten->prim decomp, aten2aten is using unsupported aten.div - # "torch.ops.aten.native_layer_norm_backward": None, - "torch.ops.aten.softmax.int": None, - "torch.ops.aten.log_softmax.int": None, - # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.amax - # "torch.ops.aten._softmax": None, - "torch.ops.aten._log_softmax_backward_data": None, - # "torch.ops.aten._softmax_backward_data": None, # Node _softmax_backward_data_default does not support nvfuser - # "torch.ops.aten.var.dim": None, # missing refs - "torch.ops.aten.std.dim": None, - "torch.ops.aten.sum": None, - # "torch.ops.aten.mean.dim": None, # missing refs - "torch.ops.aten._grad_sum_to_size": None, - "torch.ops.aten.sum_to_size": None, - "torch.ops.aten._autocast_to_reduced_precision": None, - "torch.ops.aten._autocast_to_full_precision": None, - # "torch.ops.aten.to.dtype": None, # causing segfault - # "torch.ops.aten.type_as": None, # missing refs - "torch.ops.aten.linear": None, - "torch.ops.aten.gelu": None, - # "torch.ops.aten.gelu_backward": None, # gelu_backward is handled at aten2aten decomp - # "torch.ops.aten.hardtanh": None, # has functional ref, using unsupported aten.clamp - "torch.ops.aten.leaky_relu": None, - "torch.ops.aten.square": None, - # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.conj_physical - "torch.ops.aten.tanh_backward": None, - # "torch.ops.aten.amax": None, # missing prim decomp - # "torch.ops.aten.amin": None, # missing prim decomp - # "torch.ops.aten.reshape": None, - # "torch.ops.aten.view": None, # missing prim decomp - "torch.ops.aten.flatten.using_ints": None, - - # =============================================================== - # call_function builtins and operator - # =============================================================== - "getattr": None, - "_operator.getitem": None, - } - - super().__init__(support_dict) - - def is_node_supported( - self, submodules: t.Mapping[str, Module], node: Node - ) -> bool: - - # nvFuser FX subgraph should be purely functional - if node.op not in CALLABLE_NODE_OPS: - return False - - # ops in supported_dict doesn't have overload name - # use overloadpacket's qualified_name for OpOverload - if isinstance(node.target, OpOverload): - target = _get_qualified_name(node.target.overloadpacket) - if target in self._support_dict: - return True - - return super().is_node_supported(submodules, node) - - -class NvFuserBackend: - def __init__(self): - self.supported_ops = NvFuserOperatorSupport() - - # TODO: this is a naive implementation of cache without proper guard - self.partitioner_cache: Dict[GraphModule, GraphModule] = {} - - # TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs - self.prim_decomp_cache: Dict[GraphModule, GraphModule] = {} - - def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs): - # `graph_module` is an Aten-Fx graph - # "lowering to prims" and "trace execution" are grouped into this function, as they are both input dependent - - if graph_module in self.prim_decomp_cache: - logging.debug("prim_decomp_cache hit!") - prim_module = self.prim_decomp_cache[graph_module] - else: - prim_graph = pippy.fx.Graph() - DecompositionInterpreter(graph_module, prim_graph, decomposition_table=aten2prim_decomp).run(*args, **kwargs) - prim_module = pippy.fx.GraphModule(graph_module, prim_graph) - self.prim_decomp_cache[graph_module] = prim_module - - logging.debug("Lower to prims graph: ", prim_module.code) - - # invokes trace executor for running the prim graph - return execute(prim_module, *args, executor="nvfuser") - - def compile(self, graph_module: GraphModule) -> GraphModule: - # entry function for nvFuser backend - logging.debug("Compiling graph_module: ", graph_module.code) - - # FX graph based partitioning based on nvfuser supported ops - if graph_module in self.partitioner_cache: - logging.debug("partitioner_cache hit!") - fused_graph_module = self.partitioner_cache[graph_module] - else: - partitioner = CapabilityBasedPartitioner( - graph_module, self.supported_ops, allows_single_node_partition=False) - fused_graph_module = partitioner.partition_and_fuse() - - self.partitioner_cache[graph_module] = fused_graph_module - - # Overriding fused_module's __call__() function with lower_to_prims_and_execute() - for node in fused_graph_module.graph.nodes: - # TODO: use a better way to identify fused submodule - if node.op == "call_module" and "fused_" in node.name: - fused_module = getattr(fused_graph_module, node.name) - fused_module._wrapped_call = self.lower_to_prims_and_execute - - return fused_graph_module - - def __call__(self, graph_module: GraphModule, _) -> GraphModule: - # wrap self.compile as __call__ function to fit the interface for AOTAutograd's fw_compiler - return self.compile(graph_module) diff --git a/pippy/fx/passes/dialect/__init__.py b/pippy/fx/passes/dialect/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pippy/fx/passes/dialect/common/__init__.py b/pippy/fx/passes/dialect/common/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pippy/fx/passes/dialect/common/cse_pass.py b/pippy/fx/passes/dialect/common/cse_pass.py deleted file mode 100644 index 365781794..000000000 --- a/pippy/fx/passes/dialect/common/cse_pass.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict, Tuple, Any - -import torch -from torch.utils._pytree import tree_flatten - -import pippy -from pippy.fx import GraphModule, Graph -from pippy.fx import Node -from pippy.fx.passes.infra.pass_base import PassBase, PassResult - -aten = torch.ops.aten - - -# stateful ops are banned from CSE -rand_ops = set([aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]) # noqa: E501 - -inplace_ops = set([aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_]) # noqa: E501 - - -@pippy.fx._compatibility.compatibility(is_backward_compatible=False) -def get_CSE_banned_ops(): - return rand_ops.union(inplace_ops) - - -@pippy.fx._compatibility.compatibility(is_backward_compatible=False) -class CSEPass(PassBase): - - def __init__(self, banned_ops=None): - """ - This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. - - For functional dialects, user would only need to specify the random ops in ban list. - - Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. - If your dialect contains stateful operators, please customized the banned_ops. - - """ - if banned_ops is None: - banned_ops = set() - self.banned_ops = banned_ops - super().__init__() - - def call(self, graph_module: GraphModule) -> PassResult: - """ - Return a new copy of pippy.fx.GraphModule with CSE applied to the input graph - - Example usage: - - from pippy.fx.experimental.proxy_tensor import make_fx - def f(a): - b = a * a - c = a * a - return b+c - - p = CSEPass() - traced_graph = make_fx(f)(torch.tensor(1)) - print(traced_graph) - result = p(traced_graph) - print(result.graph_module) - """ - def get_aten_target(node): - if hasattr(node.target, 'overloadpacket'): - return node.target.overloadpacket - return node.target - - modified = False - new_graph = Graph() - env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph - hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph - token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token - for n in graph_module.graph.nodes: - # The placeholder, output, and get_attr nodes are copied to the new grpah without change - # do not CSE away random operations - if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops: - new_node = new_graph.node_copy(n, lambda x: env[x]) - env[n] = new_node - else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' - # substitute args and kwargs memebrs to their mapping in env if exists - # specs can be used to reconstruct nested list/dictionaries - def substitute(arg_list): - arg_list, spec = tree_flatten(arg_list) - for i in range(len(arg_list)): - v = arg_list[i] - if isinstance(v, Node) and v in env: - arg_list[i] = env[v] - return tuple(arg_list), spec - args, args_spec = substitute(n.args) - kwargs, kwargs_spec = substitute(n.kwargs) - - # each token corresponds to a unique node - # nodes with the same token can be substituted - token = {"target": n.target, "args": args, "args_spec": args_spec, - "kwargs": kwargs, "kwargs_spec": kwargs_spec} - - # hash substituted args to a number, do not hash specs because specs are not hashable - hash_arg = hash((args, kwargs)) - hash_val = (n.target, hash_arg) - - # check if a node has a substitute and can be eliminated - hash_val_in_hash_env = hash_val in hash_env - if hash_val_in_hash_env and token_map[hash_val] == token: - modified = True # substition happens and the graph is modified - env[n] = hash_env[hash_val] - continue - - new_node = new_graph.node_copy(n, lambda x: env[x]) - env[n] = new_node - if not hash_val_in_hash_env: - hash_env[hash_val] = new_node - token_map[hash_val] = token - - csed_gm = GraphModule(graph_module, new_graph) - return PassResult(csed_gm, modified) diff --git a/pippy/fx/passes/fake_tensor_prop.py b/pippy/fx/passes/fake_tensor_prop.py deleted file mode 100644 index bf30bd3f6..000000000 --- a/pippy/fx/passes/fake_tensor_prop.py +++ /dev/null @@ -1,30 +0,0 @@ -import pippy.fx -from pippy.fx import Node -from pippy.fx._compatibility import compatibility -from torch._subclasses.fake_tensor import FakeTensorMode - -__all__ = ['FakeTensorProp'] - -@compatibility(is_backward_compatible=False) -class FakeTensorProp(pippy.fx.Interpreter): - """ - Execute an FX graph Node-by-Node and record a fake tensor representing - the metadata for the node. Unlike ShapeProp, (1) this propagation - is cheap--it does the propagation with meta tensors which do not actually - store data, and (2) the fake tensors have much more fine grained information, - e.g., they have accurate alias information that can be consulted by looking - at the storages. - - Args: - module (GraphModule): The module to be executed - """ - - def run_node(self, n: Node): - result = super().run_node(n) - n.meta['val'] = result - return result - - def propagate(self, *args): - with FakeTensorMode.push() as mode: - fake_args = [mode.from_tensor(a) for a in args] - return super().run(*fake_args) diff --git a/pippy/fx/passes/graph_drawer.py b/pippy/fx/passes/graph_drawer.py deleted file mode 100644 index cddd6d99f..000000000 --- a/pippy/fx/passes/graph_drawer.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from __future__ import absolute_import, division, print_function, unicode_literals - -import hashlib -import torch -import pippy.fx -from typing import Dict, Any, TYPE_CHECKING -from pippy.fx.node import _get_qualified_name, _format_arg -from pippy.fx.passes.shape_prop import TensorMetadata -from pippy.fx._compatibility import compatibility -from itertools import chain - -__all__ = ['FxGraphDrawer'] -try: - import pydot - HAS_PYDOT = True -except ImportError: - HAS_PYDOT = False - -_COLOR_MAP = { - "placeholder": '"AliceBlue"', - "call_module": "LemonChiffon1", - "get_param": "Yellow2", - "get_attr": "LightGrey", - "output": "PowderBlue", -} - -_HASH_COLOR_MAP = [ - "CadetBlue1", - "Coral", - "DarkOliveGreen1", - "DarkSeaGreen1", - "GhostWhite", - "Khaki1", - "LavenderBlush1", - "LightSkyBlue", - "MistyRose1", - "MistyRose2", - "PaleTurquoise2", - "PeachPuff1", - "Salmon", - "Thistle1", - "Thistle3", - "Wheat1", -] - -_WEIGHT_TEMPLATE = { - "shape": "record", - "fillcolor": "Salmon", - "style": '"filled,rounded"', - "fontcolor": "#000000", -} - -if HAS_PYDOT: - @compatibility(is_backward_compatible=False) - class FxGraphDrawer: - """ - Visualize a pippy.fx.Graph with graphviz - Basic usage: - g = FxGraphDrawer(symbolic_traced, "resnet18") - with open("a.svg", "w") as f: - f.write(g.get_dot_graph().create_svg()) - """ - - def __init__( - self, - graph_module: pippy.fx.GraphModule, - name: str, - ignore_getattr: bool = False, - ignore_parameters_and_buffers: bool = False, - skip_node_names_in_args: bool = True, - ): - self._name = name - self._dot_graphs = { - name: self._to_dot( - graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args - ) - } - - for node in graph_module.graph.nodes: - if node.op != "call_module": - continue - - leaf_node = self._get_leaf_node(graph_module, node) - - if not isinstance(leaf_node, pippy.fx.GraphModule): - continue - - self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( - leaf_node, - f"{name}_{node.target}", - ignore_getattr, - ignore_parameters_and_buffers, - skip_node_names_in_args, - ) - - def get_dot_graph(self, submod_name=None) -> pydot.Dot: - if submod_name is None: - return self.get_main_dot_graph() - else: - return self.get_submod_dot_graph(submod_name) - - def get_main_dot_graph(self) -> pydot.Dot: - return self._dot_graphs[self._name] - - def get_submod_dot_graph(self, submod_name) -> pydot.Dot: - return self._dot_graphs[f"{self._name}_{submod_name}"] - - def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: - return self._dot_graphs - - def _get_node_style(self, node: pippy.fx.Node) -> Dict[str, str]: - template = { - "shape": "record", - "fillcolor": "#CAFFE3", - "style": '"filled,rounded"', - "fontcolor": "#000000", - } - if node.op in _COLOR_MAP: - template["fillcolor"] = _COLOR_MAP[node.op] - else: - # Use a random color for each node; based on its name so it's stable. - target_name = node._pretty_print_target(node.target) - target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) - template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] - return template - - def _get_leaf_node( - self, module: torch.nn.Module, node: pippy.fx.Node - ) -> torch.nn.Module: - py_obj = module - assert isinstance(node.target, str) - atoms = node.target.split(".") - for atom in atoms: - if not hasattr(py_obj, atom): - raise RuntimeError( - str(py_obj) + " does not have attribute " + atom + "!" - ) - py_obj = getattr(py_obj, atom) - return py_obj - - def _typename(self, target: Any) -> str: - if isinstance(target, torch.nn.Module): - ret = torch.typename(target) - elif isinstance(target, str): - ret = target - else: - ret = _get_qualified_name(target) - - # Escape "{" and "}" to prevent dot files like: - # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc - # which triggers `Error: bad label format (...)` from dot - return ret.replace("{", r"\{").replace("}", r"\}") - - def _get_node_label( - self, - module: pippy.fx.GraphModule, - node: pippy.fx.Node, - skip_node_names_in_args: bool, - ) -> str: - def _get_str_for_args_kwargs(arg): - if isinstance(arg, tuple): - prefix, suffix = r"|args=(\l", r",\n)\l" - arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] - elif isinstance(arg, dict): - prefix, suffix = r"|kwargs={\l", r",\n}\l" - arg_strs_list = [ - f"{k}: {_format_arg(v, max_list_len=8)}" - for k, v in arg.items() - ] - else: # Fall back to nothing in unexpected case. - return "" - - # Strip out node names if requested. - if skip_node_names_in_args: - arg_strs_list = [a for a in arg_strs_list if "%" not in a] - if len(arg_strs_list) == 0: - return "" - arg_strs = prefix + r",\n".join(arg_strs_list) + suffix - return arg_strs.replace("{", r"\{").replace("}", r"\}") - - - label = "{" + f"name=%{node.name}|op_code={node.op}\n" - - if node.op == "call_module": - leaf_module = self._get_leaf_node(module, node) - label += r"\n" + self._typename(leaf_module) + r"\n|" - extra = "" - if hasattr(leaf_module, "__constants__"): - extra = r"\n".join( - [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] - ) - label += extra + r"\n" - else: - label += f"|target={self._typename(node.target)}" + r"\n" - if len(node.args) > 0: - label += _get_str_for_args_kwargs(node.args) - if len(node.kwargs) > 0: - label += _get_str_for_args_kwargs(node.kwargs) - label += f"|num_users={len(node.users)}" + r"\n" - - tensor_meta = node.meta.get('tensor_meta') - label += self._tensor_meta_to_label(tensor_meta) - - return label + "}" - - def _tensor_meta_to_label(self, tm) -> str: - if tm is None: - return "" - elif isinstance(tm, TensorMetadata): - return self._stringify_tensor_meta(tm) - elif isinstance(tm, list): - result = "" - for item in tm: - result += self._tensor_meta_to_label(item) - return result - elif isinstance(tm, dict): - result = "" - for k, v in tm.items(): - result += self._tensor_meta_to_label(v) - return result - elif isinstance(tm, tuple): - result = "" - for item in tm: - result += self._tensor_meta_to_label(item) - return result - else: - raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") - - def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: - result = "" - if not hasattr(tm, "dtype"): - print("tm", tm) - result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" - result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" - result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" - result += "|" + "stride" + "=" + str(tm.stride) + r"\n" - if tm.is_quantized: - assert tm.qparams is not None - assert "qscheme" in tm.qparams - qscheme = tm.qparams["qscheme"] - if qscheme in { - torch.per_tensor_affine, - torch.per_tensor_symmetric, - }: - result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" - result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" - elif qscheme in { - torch.per_channel_affine, - torch.per_channel_symmetric, - torch.per_channel_affine_float_qparams, - }: - result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" - result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" - result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" - else: - raise RuntimeError(f"Unsupported qscheme: {qscheme}") - result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" - return result - - def _get_tensor_label(self, t: torch.Tensor) -> str: - return str(t.dtype) + str(list(t.shape)) + r"\n" - - def _to_dot( - self, - graph_module: pippy.fx.GraphModule, - name: str, - ignore_getattr: bool, - ignore_parameters_and_buffers: bool, - skip_node_names_in_args: bool, - ) -> pydot.Dot: - """ - Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. - If ignore_parameters_and_buffers is True, the parameters and buffers - created with the module will not be added as nodes and edges. - """ - dot_graph = pydot.Dot(name, rankdir="TB") - - for node in graph_module.graph.nodes: - if ignore_getattr and node.op == "get_attr": - continue - - style = self._get_node_style(node) - dot_node = pydot.Node( - node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args), **style - ) - dot_graph.add_node(dot_node) - - def get_module_params_or_buffers(): - for pname, ptensor in chain( - leaf_module.named_parameters(), leaf_module.named_buffers() - ): - pname1 = node.name + "." + pname - label1 = ( - pname1 + "|op_code=get_" + "parameter" - if isinstance(ptensor, torch.nn.Parameter) - else "buffer" + r"\l" - ) - dot_w_node = pydot.Node( - pname1, - label="{" + label1 + self._get_tensor_label(ptensor) + "}", - **_WEIGHT_TEMPLATE, - ) - dot_graph.add_node(dot_w_node) - dot_graph.add_edge(pydot.Edge(pname1, node.name)) - - if node.op == "call_module": - leaf_module = self._get_leaf_node(graph_module, node) - - if not ignore_parameters_and_buffers and not isinstance(leaf_module, pippy.fx.GraphModule): - get_module_params_or_buffers() - - for node in graph_module.graph.nodes: - if ignore_getattr and node.op == "get_attr": - continue - - for user in node.users: - dot_graph.add_edge(pydot.Edge(node.name, user.name)) - - return dot_graph - -else: - if not TYPE_CHECKING: - @compatibility(is_backward_compatible=False) - class FxGraphDrawer: - def __init__(self, graph_module: pippy.fx.GraphModule, name: str, ignore_getattr: bool = False): - raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' - 'pydot through your favorite Python package manager.') diff --git a/pippy/fx/passes/graph_manipulation.py b/pippy/fx/passes/graph_manipulation.py deleted file mode 100644 index c4c6716e6..000000000 --- a/pippy/fx/passes/graph_manipulation.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Any, Dict, List, NamedTuple, Optional - -import torch -from pippy.fx._compatibility import compatibility -from pippy.fx.graph import Graph -from pippy.fx.graph_module import GraphModule -from pippy.fx.node import ( - map_arg, - Node, - Target, -) -from pippy.fx.passes.shape_prop import ShapeProp - -__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta', - 'get_size_of_node'] - -@compatibility(is_backward_compatible=False) -def replace_target_nodes_with( - fx_module: GraphModule, - old_op: str, - old_target: Target, - new_op: str, - new_target: Target, -): - """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, - and updates them to match the new op code and target""" - new_graph = Graph() - val_map: Dict[Node, Node] = {} - for node in fx_module.graph.nodes: - if node.op == old_op and node.target == old_target: - args = map_arg(node.args, lambda n: val_map[n]) - kwargs = map_arg(node.kwargs, lambda n: val_map[n]) - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - val_map[node] = new_graph.create_node( - new_op, new_target, args, kwargs, node.name - ) - else: - val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) - fx_module.graph = new_graph - - -@compatibility(is_backward_compatible=False) -class size_bytes(NamedTuple): - output_size: int - total_size: int - - -@compatibility(is_backward_compatible=False) -def get_size_of_all_nodes( - fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None -) -> None: - """Given a fx graph module, update each node with its total size (weights + bias + output) - and its output_size(output). For a non-module node, the total size is the output size. - return total size""" - if args is not None: - # Mark shape and dtype for each node (node.shape and node.dtype) - ShapeProp(fx_module).propagate(*args) - # Calculate the total size of the whole fx graph - total_size_of_graph = 0.0 - for node in fx_module.graph.nodes: - if node.op == "output": - break - node.size_bytes = get_size_of_node(fx_module, node) - return - - -@compatibility(is_backward_compatible=False) -def get_tensor_meta(node: Node) -> Any: - tensor_meta = node.meta.get("tensor_meta") - - if not tensor_meta: - raise RuntimeError( - f"Node {node} has no tensor metadata associated with it! " - f"Check that shape propagation has run." - ) - - return tensor_meta - - -@compatibility(is_backward_compatible=False) -def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: - """Given a node with node.dtype and node.shape, return its total size and its output size. - total_size = weights + bias + output_size - """ - # Total num of elements - total_num_of_elems = 0 - # For a module, conside all parameters - if node.op == "call_module": - submodule_dict = dict(fx_module.named_modules()) - submodule = submodule_dict[node.target] - parameters = submodule.named_parameters() - # Parameters are named tuples - for name, p in parameters: - total_num_of_elems += p.numel() - # Don't forget the output size - # node.shape is the shape of this node's output - tensor_meta = get_tensor_meta(node) - output_elem = tensor_meta.shape.numel() - total_num_of_elems += output_elem - # Assume for now if it's quantized then it's qint8 or quint8 - if tensor_meta.is_quantized: - size_per_elem_bytes = torch._empty_affine_quantized( - [], dtype=tensor_meta.dtype - ).element_size() - else: - size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size() - total_size = size_per_elem_bytes * total_num_of_elems - output_size = size_per_elem_bytes * output_elem - return size_bytes(output_size, total_size) diff --git a/pippy/fx/passes/infra/__init__.py b/pippy/fx/passes/infra/__init__.py deleted file mode 100644 index c53c3b3f7..000000000 --- a/pippy/fx/passes/infra/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from . import pass_manager diff --git a/pippy/fx/passes/infra/partitioner.py b/pippy/fx/passes/infra/partitioner.py deleted file mode 100644 index b5a39fb2b..000000000 --- a/pippy/fx/passes/infra/partitioner.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict, List, Set, Iterable, Optional - -from pippy.fx.passes.utils.fuser_utils import fuse_by_partitions -from pippy.fx.passes.tools_common import NodeList - -from pippy.fx.graph_module import GraphModule -from pippy.fx.node import Node, _get_qualified_name -from pippy.fx.passes.operator_support import OperatorSupportBase - -from collections import defaultdict -import logging -import itertools - -logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger(__name__) - -class Partition: - def __init__(self, id: int = None, nodes: Iterable[Node] = None): - self.id = id - self.nodes: Set[Node] = set(nodes) if nodes is not None else set() - - def __repr__(self) -> str: - return str(self.nodes) - - def add_node(self, node: Node): - self.nodes.add(node) - - def remove_node(self, node: Node): - self.nodes.remove(node) - - def size(self): - return len(self.nodes) - -class CapabilityBasedPartitioner: - - def __init__(self, - graph_module: GraphModule, - operator_support: OperatorSupportBase, - allows_single_node_partition: bool = False - ) -> None: - self.graph_module = graph_module - self.operator_support = operator_support - self.allows_single_node_partition = allows_single_node_partition - - # map of node to it's upstream dependency nodes - # if A is found in dependency_map[B], then B depends on A (or a is an upstream depedency of b) - self.dependency_map = self.__build_dependency_map() - - def __build_dependency_map(self) -> Dict[Node, Set[Node]]: - dependency_map = defaultdict(set) - - # assumptions: nodes in graph are sorted in topological order - for node in self.graph_module.graph.nodes: - for input_node in node.all_input_nodes: - # add input_node and input_node's upstream dependency - dependency_map[node].add(input_node) - dependency_map[node].update(dependency_map[input_node]) - - return dependency_map - - def __node_depends_on(self, a: Node, b: Node) -> int: - # Returns - # 1 if b depends on a (,or equivalently a is an upstream depedency of b) - # -1 if a depends on b (,or equivalently b is an upstream depedency of a) - # 0 if a and b doesn't have dependency between each other - - if a in self.dependency_map[b]: - return 1 - elif b in self.dependency_map[a]: - return -1 - else: - return 0 - - def __partition_depends_on(self, partition_a: Partition, partition_b: Partition) -> int: - # Returns - # 1 if b depends on a (,or equivalently a is an upstream depedency of b) - # -1 if a depends on b (,or equivalently b is an upstream depedency of a) - # 0 if a and b doesn't have dependency between each other - - # TODO: build a cache here to speedup the query - - for node_a in partition_a.nodes: - for node_b in partition_b.nodes: - dependency = self.__node_depends_on(node_a, node_b) - if dependency != 0: - return dependency - return 0 - - def __get_supported_nodes(self) -> NodeList: - logging.debug("Collecting supported nodes...") - supported_nodes = [] - for node in self.graph_module.graph.nodes: - if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node): - supported_nodes.append(node) - return supported_nodes - - def propose_partitions(self) -> List[Partition]: - candidates: NodeList = self.__get_supported_nodes() - - # assumptions: nodes in candidate list is sorted in topological order - assignment: Dict[Node, int] = {} # maping from node to partition_id - partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition - new_partition_id = itertools.count() - - def assign(node: Node, id: Optional[int] = None): - # If id is None, remove the node from original assigment - - # node has been assigned before, clean up and re-assign - if node in assignment: - original_id = assignment[node] - del assignment[node] - partitions_by_id[original_id].remove_node(node) - if partitions_by_id[original_id].size() == 0: - del partitions_by_id[original_id] - - if id is not None: - assignment[node] = id - if id not in partitions_by_id: - partitions_by_id[id] = Partition(id=id, nodes=[node]) - else: - partitions_by_id[id].add_node(node) - - logging.debug("Proposing partitions...") - - # visit candidates in reversed topological order - for node in reversed(candidates): - # use Dict as an ordered set to ensure deterministic partitioning result, don't care value - user_partitions: Dict[Partition, None] = {} - for user_node in node.users: - if user_node in assignment: - id = assignment[user_node] - user_partitions[partitions_by_id[id]] = None - else: - user_partitions[Partition(nodes=[user_node])] = None - - # Filter out all the partitions that has dependency on other users - # TODO: find a better way to do this, rather than pair-wise comparision - user_partitions_list = list(user_partitions.keys()) - for i in range(len(user_partitions_list)): - for j in range(i + 1, len(user_partitions_list)): - pi = user_partitions_list[i] - pj = user_partitions_list[j] - dependency = self.__partition_depends_on(pi, pj) - if dependency == 1 and pj in user_partitions: - del user_partitions[pj] - elif dependency == -1 and pi in user_partitions: - del user_partitions[pi] - - # We use the following rules for partition assignment: - # 1. If none of the candidates has been assigned to a partition, create a new partition - # 2. If there is one partition candidate, assign to the partition - # 3. If there are more than one partition candidates, assign current node to the first partition and - # merge the other partitions with first partition, since user_partitions doesn't have depedency between - # each other. - - assigned_candidate_partition_ids = [partition.id for partition in user_partitions if partition.id is not None] - - if len(assigned_candidate_partition_ids) == 0: - # create a new partition - assign(node, next(new_partition_id)) - elif len(assigned_candidate_partition_ids) == 1: - id = assigned_candidate_partition_ids[0] - assign(node, id) - else: - # users are assigned to more than one partition, since user_partitions doesn't have - # dependency on each other, they can be fused into a single partition - id = assigned_candidate_partition_ids[0] - assign(node, id) - - reassignment: Dict[Node, int] = {} - for other_id in assigned_candidate_partition_ids[1:]: - for other_node in partitions_by_id[other_id].nodes: - reassignment[other_node] = id - for other_node in reassignment: - assign(other_node, id) - - # post processing to re-assign "getitem" nodes into upstream partition - logger.debug("Reassigning getitem nodes to its producer node's partition...") - nodes_reassignment: Dict[Node, int] = {} - for node in self.graph_module.graph.nodes: - is_tuple_output = True - for user in node.users: - if user.op != "call_function" or \ - _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] - is_tuple_output = False - break - - # node has tuple outputs, re-assign all following getitem node into node's partition - if is_tuple_output: - id = assignment.get(node, None) # type: ignore[arg-type] - for user in node.users: - if assignment.get(user, None) != id: # type: ignore[arg-type] - nodes_reassignment[user] = id - for node, id in nodes_reassignment.items(): - assign(node, id) - - # filter out single node partitions - if not self.allows_single_node_partition: - logger.debug("Filtering out single node partitions...") - non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} - partitions_to_remove: List[int] = [] - for id, partition in partitions_by_id.items(): - compute_node_count = 0 - for node in partition.nodes: - if node.op == "call_function" and \ - _get_qualified_name(node.target) not in non_compute_ops: # type: ignore[arg-type] - compute_node_count += 1 - if compute_node_count <= 1: - partitions_to_remove.append(id) - for id in partitions_to_remove: - del partitions_by_id[id] - - logging.debug("Partitions proposed:") - for id, partition in partitions_by_id.items(): - logging.debug(f"partition #{id}", [node.name for node in partition.nodes]) - - return list(partitions_by_id.values()) - - def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: - logging.debug("Fusing partitions...") - # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] - return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions]) - - def partition_and_fuse(self) -> GraphModule: - partitions = self.propose_partitions() - fused_gm = self.fuse_partitions(partitions) - return fused_gm diff --git a/pippy/fx/passes/infra/pass_base.py b/pippy/fx/passes/infra/pass_base.py deleted file mode 100644 index 711a1a1ca..000000000 --- a/pippy/fx/passes/infra/pass_base.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import abc -from collections import namedtuple -from typing import Optional - -from pippy.fx.graph_module import GraphModule -from pippy.fx._compatibility import compatibility - - -__all__ = ['PassResult', 'PassBase'] - -@compatibility(is_backward_compatible=False) -class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): - """ - Result of a pass: - graph_module: The modified graph module - modified: A flag for if the pass has modified the graph module - """ - def __new__(cls, graph_module, modified): - return super().__new__(cls, graph_module, modified) - -@compatibility(is_backward_compatible=False) -class PassBase(abc.ABC): - """ - Base interface for implementing passes. - - It is required to implement the `call` function so that we can directly - pass instances of the Pass directly to the PassManager and call them as a - function. - - We can directly pass an instance of a class implementing this interface into - the PassManager's `passes` attribute. - """ - - def __init__(self) -> None: - pass - - def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: - """ - Runs the precondition check, the pass itself, and the postcondition check. - """ - - self.requires(graph_module) - res = self.call(graph_module) - self.ensures(graph_module) - return res - - @abc.abstractmethod - def call(self, graph_module: GraphModule) -> Optional[PassResult]: - """ - The pass that is run through the given graph module. To implement a - pass, it is required to implement this function. - - Args: - graph_module: The graph module we will run a pass on - """ - pass - - def requires(self, graph_module: GraphModule) -> None: - """ - This function will be called before the pass is run and will check that - the given graph module contains the preconditions needed to run the - pass. It is not required to implement this function. - - Args: - graph_module: The graph module we will run checks on - """ - pass - - def ensures(self, graph_module: GraphModule) -> None: - """ - This function will be called after the pass is run and will check that - the given graph module contains the postconditions needed to run the - pass. It is not required to implement this function. - - Args: - graph_module: The graph module we will run checks on - """ - pass diff --git a/pippy/fx/passes/infra/pass_manager.py b/pippy/fx/passes/infra/pass_manager.py deleted file mode 100644 index 27fc6618a..000000000 --- a/pippy/fx/passes/infra/pass_manager.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -from queue import Queue -from functools import wraps -from typing import Callable, Dict, List - -import torch.nn as nn -from pippy.fx.graph_module import GraphModule -from pippy.fx._compatibility import compatibility -from pippy.fx.passes.infra.pass_base import PassResult - -__all__ = ['inplace_wrapper', 'pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager'] - -@compatibility(is_backward_compatible=False) -def inplace_wrapper(fn: Callable) -> Callable: - """ - Convenience wrapper for passes which modify an object inplace. This - wrapper makes them return a PassResult containing the modified object and - True for the "modified" flag. - - Args: - fn (Callable[Module, Any]) - - Returns: - wrapped_fn (Callable[Module, PassResult]) - """ - if fn is None: - return None - - @wraps(fn) - def wrapped_fn(gm): - return fn(gm) or PassResult(gm, True) - - if wrapped_fn.__name__ == 'wrapped_fn': - wrapped_fn.__name__ = str(fn) - return wrapped_fn - -@compatibility(is_backward_compatible=False) -def pass_result_wrapper(fn: Callable) -> Callable: - """ - Wrapper for passes which currently do not return a PassResult. - This wrapper makes them return a PassResult containing the modified object - and True for the "modified" flag. - - Args: - fn (Callable[Module, Any]) - - Returns: - wrapped_fn (Callable[Module, PassResult]) - """ - if fn is None: - return None - - @wraps(fn) - def wrapped_fn(gm): - gm = fn(gm) - return PassResult(gm, True) - - return wrapped_fn - -def _validate_pass_schedule_constraint( - constraint: Callable[[Callable, Callable], bool], passes: List[Callable] -) -> None: - for i, a in enumerate(passes): - for j, b in enumerate(passes[i + 1 :]): - if constraint(a, b): - continue - raise RuntimeError( - f"pass schedule constraint violated. Expected {a} before {b}" - f" but found {a} at index {i} and {b} at index{j} in pass" - f" list." - ) - -def _topological_sort_passes( - passes: List[Callable], constraints: List[Callable] -) -> List[Callable]: - """ - Args - passes: Passes that we are ordering - constraints: Constraints applied on these passes - - Returns - A sorted list of callables and a boolean of if a circular dependency - existed - """ - if len(constraints) == 0: - return passes - - # Contruct a graph mapping nodes to a list of their users - graph: Dict[Callable, List[Callable]] = {p : [] for p in passes} - indegree_map: Dict[Callable, int] = {p : 0 for p in passes} - candidates: Queue = Queue() - for a in passes: - for b in passes: - if a == b: - continue - - for constraint in constraints: - if not constraint(a, b): - graph[b].append(a) - indegree_map[a] += 1 - - if indegree_map[a] == 0: - candidates.put(a) - - visited: Dict[Callable, bool] = {p : False for p in passes} - sorted_passes: List[Callable] = [] - - while not candidates.empty(): - p = candidates.get() - sorted_passes.append(p) - visited[p] = True - - for n in graph[p]: - if not visited[n]: - indegree_map[n] -= 1 - if indegree_map[n] == 0: - candidates.put(n) - - # Check if there are unvisited nodes (aka cycles in the graph) - cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) - if len(cycle_passes) != 0: - error = f"Circular dependency detected within the following passes: {cycle_passes}" - raise RuntimeError(error) - - return sorted_passes - -@compatibility(is_backward_compatible=False) -def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: - """ - Defines a partial order ('depends on' function) where `this` must occur - before `that`. - - For example, the following pass list and constraint list would be invalid. - ``` - passes = [pass_b, pass_a] - - constraints = [ - this_before_that_pass_constraint(pass_a, pass_b) - ] - ``` - - Args: - this (Callable): pass which should occur first - that (Callable): pass which should occur later - - Returns: - depends_on (Callable[[Object, Object], bool] - """ - - def depends_on(a: Callable, b: Callable): - if a == that and b == this: - return False - return True - - return depends_on - - -@compatibility(is_backward_compatible=False) -class PassManager: - """ - Construct a PassManager. - - Collects passes and constraints. This defines the pass schedule, manages - pass constraints and pass execution. - - Args: - passes (Optional[List[Callable]]): List of passes. A pass is a - callable which modifies an object and returns a PassResult - constraint (Optional[List[Callable]]): List of constraints. A - constraint is a callable which takes two passes (A, B) and returns - True if A depends on B and False otherwise. See implementation of - `this_before_that_pass_constraint` for example. - steps (int): Max number of times we run the passes (default = 1). - run_checks_after_each_pass (bool): Whether to run checks and linting - after each pass - suppress_check_failures (bool): Whether to raise errors when running - checks - """ - - passes: List[Callable[[nn.Module], PassResult]] = [] - constraints: List[Callable[[Callable, Callable], bool]] = [] - _validated: bool = False - steps: int = 1 - - def __init__( - self, - passes=None, - constraints=None, - steps=None, - run_checks_after_each_pass: bool = False, - suppress_check_failures: bool = False, - debug: bool = False, - ): - if passes: - self.passes = passes - if constraints: - self.constraints = constraints - if steps: - self.steps = steps - - self.run_checks_after_each_pass = run_checks_after_each_pass - self.suppress_check_failures = suppress_check_failures - self.debug = debug - - def add_pass(self, _pass: Callable): - """ - Adds a pass into the current list of passes. - """ - self.passes.append(_pass) - self._validated = False - - def add_constraint(self, constraint: Callable): - """ - Adds a constraint into the current list of constraints. - """ - self.constraints.append(constraint) - self._validated = False - - def validate_constraints(self): - """ - Validates that current pass schedule defined by `self.passes` is valid - according to all constraints in `self.constraints` - """ - if self._validated: - return - for constraint in self.constraints: - _validate_pass_schedule_constraint(constraint, self.passes) - self._validated = True - - def solve_constraints(self): - """ - Finds a valid traversal order based on the given constraints and orders - the passes based on this order. - - If a circular dependency exists between the constraints and steps = 1, - then we will raise an error because if steps != 1 this means that we - will re-run the passes, allowing for circular dependencies. - """ - self.passes = _topological_sort_passes(self.passes, self.constraints) - self._validated = True - - def add_checks(self, check: Callable) -> None: - """ - Adds a function which takes runs various checks on a given graph module. - This function is run before and after each pass if the - `run_checks_after_each_pass` flag is enabled. - """ - sig = inspect.signature(check) - - if len(list(sig.parameters.values())) != 1: - raise TypeError("PassManager check function should only take in one variable, a module") - - setattr(self, "check", check) # noqa: B010 - - def check(self, module: nn.Module) -> None: - pass - - def __call__(self, module: nn.Module) -> PassResult: - """ - Runs a list of passes in the order based on `self.passes` on the given - graph module. Each time a pass is run, checks and linting will be run on - the graph module if `run_checks_after_each_pass` is set. - - If the module is a graph module, we will run the list of passes until - the graph stops changing, or until `steps` number of times. - """ - # Order the passes based on the constraints - if not self._validated: - self.solve_constraints() - - # Check graph invariants - self.check(module) - - # Run the set of passes `steps` number of times or until the graph stops - # changing - overall_modified = False - for _ in range(self.steps): - modified = False - - # Run the set of passes on the graph module - for i, fn in enumerate(self.passes): - if self.debug: - print(f"Running pass \'{fn.__name__}\'") - - try: - res = fn(module) - except Exception as e: - prev_pass_names = [p.__name__ for p in self.passes[:i]] - msg = f"An error occurred when running the \'{fn.__name__}\' pass after the following passes: {prev_pass_names}" - raise type(e)(msg) from e - - module = res.graph_module - modified = modified or res.modified - - if isinstance(module, GraphModule): - module.recompile() - - # Check graph invariants - if self.run_checks_after_each_pass: - self.check(module) - - # If the graph no longer changes, then we can stop running these passes - overall_modified = overall_modified or modified - if not modified: - break - - return PassResult(module, overall_modified) diff --git a/pippy/fx/passes/net_min_base.py b/pippy/fx/passes/net_min_base.py deleted file mode 100644 index da18d980e..000000000 --- a/pippy/fx/passes/net_min_base.py +++ /dev/null @@ -1,619 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import logging -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -import pippy.fx -from pippy.fx._compatibility import compatibility -from pippy.fx.node import map_arg - -from .shape_prop import ShapeProp -from .split_utils import split_by_tags -from .tools_common import ( - CALLABLE_NODE_OPS, - FxNetAccFusionsFinder, - Names, - NodeList, - NodeSet, - TensorOrTensors, - Tensors, -) - -__all__ = [ - "FxNetMinimizerBadModuleError", - "FxNetMinimizerRunFuncError", - "FxNetMinimizerResultMismatchError", -] - -_LOGGER = logging.getLogger(__name__) - - -@compatibility(is_backward_compatible=False) -class FxNetMinimizerBadModuleError(Exception): - """ - Raised if failed to split out a minimize module - """ - - pass - - -@compatibility(is_backward_compatible=False) -class FxNetMinimizerRunFuncError(Exception): - """ - Raised if error occurs during run_a or run_b functions - """ - - pass - - -@compatibility(is_backward_compatible=False) -class FxNetMinimizerResultMismatchError(Exception): - """ - Raised if comparing function thinks the results are mismatching. - """ - - pass - - -@dataclass -class _MinimizerSettingBase: - """ - Args: - `accumulate_error`: Instead of using a's input for both converted module to verify - , use the previous outputs of each converted module as input to accumulate the - errors. - - `traverse_method`: "sequential" or "binary" or "accumulate" - Determine the way of traverse the nodes in FX module. - - `find_all`: Minimizer will go through the entire model and return all problematic nodes. - - `return_intermediate`: If true, when using `run_nodes()` function to run the - model, intermediate results of all the ops will be returned as output. - """ - - accumulate_error: bool = False - traverse_method: str = "sequential" - find_all: bool = False - return_intermediate: bool = False - - def __str__(self): - settings_str = "FX Minimizer Settings:\n" - - for k, v in vars(self).items(): - settings_str += f"\t{k}: {v}\n" - - return settings_str - - -class _MinimizerBase: - """ - This class is used to automatically find problematic nodes in a model. It takes a FX - graphmodule and generate some submodules while traverse the graph. Then two functions - `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn` - will be used to compare the results. - - Currently we provides two ways to traverse the graph and generate submodules. - 1. Sequential traversal: this will traverse the graph node by node and generate - one submodule with one sigle node. - 2. Binary searching: this will do a binary search style traversal on the graph. - - For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. - """ - - def __init__( - self, - module: pippy.fx.GraphModule, - sample_input: Tensors, - compare_fn: Callable[ - [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool] - ], - settings: _MinimizerSettingBase, - ): - assert isinstance(module, pippy.fx.GraphModule) - - self.module = module - self.sample_input = sample_input - self.compare_fn = compare_fn - self.settings = settings - - # Stores outputs of run_a function - self.a_outputs: Dict[str, Any] = {} - - # Stores outputs of run_b function - self.b_outputs: Dict[str, Any] = {} - - # Stores the results of compare_fn - self.results: Dict[Any, Any] = {} - - # Stores the report for the runs - self.reports: List[List[str]] = [] - - # Current iteration - self.iteration: int = 0 - - callable_nodes = { - node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS - } - ShapeProp(self.module).propagate(*self.sample_input) - self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)() - - # Check if number of input in sample_input matches the number of placeholders - placeholders = [ - node.name for node in self.module.graph.nodes if node.op == "placeholder" - ] - assert len(placeholders) == len(self.sample_input) - - # Store sample_input - for i, name in enumerate(placeholders): - self.a_outputs[name] = sample_input[i] - self.b_outputs[name] = sample_input[i] - - def run_a(self, mod: pippy.fx.GraphModule, inputs: Tensors) -> TensorOrTensors: - """ - Run `mod` with `inputs` and generate output. The output will be compared with - output of run_b(). - """ - raise RuntimeError("run_a() is not implemented.") - - def run_b(self, mod: pippy.fx.GraphModule, inputs: Tensors) -> TensorOrTensors: - """ - Run `mod` with `inputs` and generate output. The output will be compared with - output of run_a(). - """ - raise RuntimeError("run_b() is not implemented.") - - def _store_outputs( - self, - a_result: TensorOrTensors, - b_result: TensorOrTensors, - submodule: pippy.fx.GraphModule, - ): - """ - Store the outputs of self.run_a() and self.run_b() into self.a_outputs and - self.b_outputs, so that we can use them when execute preceding nodes that - use those outputs as inputs. - - Args: - a_result: Output of self.run_a(). Could be a tensor or tensors. - b_result: Output of self.run_b(). Could be a tensor or tensors. - submodule: The module that generates a_result and b_result. - """ - output_node = next( - node for node in submodule.graph.nodes if node.op == "output" - ) - - # Only one output - if isinstance(output_node.args[0], pippy.fx.Node): - self.a_outputs[output_node.args[0].name] = a_result - self.b_outputs[output_node.args[0].name] = b_result - # Multiple outputs - else: - for i, arg in enumerate(output_node.args[0]): - self.a_outputs[arg.name] = a_result[i] - self.b_outputs[arg.name] = b_result[i] - - def _get_submod_inputs( - self, main_module: pippy.fx.GraphModule, submod_path: str - ) -> Tuple[Tensors, Tensors]: - """ - Try get submodule inputs from stored outputs. If not found then use - torch_glow.get_submod_inputs to get the inputs. - - If accumulate_error is False, use a_input for run_a() and run_b() - otherwise use a_input for run_a and b_input for run_b. - - Args: - main_module: Top-levlel fx module. - submod_path: Path to the submodule we want to run and compare results. - - Returns: - a_input: List of tensor(s) that will be used by run_a() as submodule inputs. - b_input: List of tensor(s) that will be used by run_b() as submodule inputs. - """ - a_input = [] - b_input = [] - submodule = getattr(main_module, submod_path) - placeholders = [ - node.name for node in submodule.graph.nodes if node.op == "placeholder" - ] - - # If all placeholder can be found in stored outputs, use stored - # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs` - # to get the inputs. - if set(placeholders) <= self.a_outputs.keys(): - for name in placeholders: - a_input.append(self.a_outputs[name]) - b_input.append(self.b_outputs[name]) - else: - if self.settings.accumulate_error: - print(f"Can't find previous stored outputs named {placeholders}!") - - def get_inputs(self: torch.nn.Module, inputs: Any): - nonlocal a_input - a_input = inputs - - # Use forward hook to get the inputs to the submodule - handle = submodule.register_forward_pre_hook(get_inputs) - main_module(*self.sample_input) - handle.remove() - - b_input = a_input - - if not self.settings.accumulate_error: - return a_input, a_input - - return a_input, b_input - - def _tag_nodes(self, selected_nodes: NodeSet): - """ - Tag selected nodes with tag "minimize". Nodes with the same tags will - be split to the same submodule afterwards. - - Args: - selected_nodes: Nodes that we want to minimize. We will tag those nodes - with "minimize", all preceding nodes with "main_0" and all following - nodes with "main_1". - """ - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - if node in selected_nodes: - node.tag = "minimize" - elif any( - n.tag in {"minimize", "main_1"} - for n in node.all_input_nodes - if n.op in CALLABLE_NODE_OPS - ): - node.tag = "main_1" - else: - node.tag = "main_0" - - def _build_submodule(self, nodes: NodeSet) -> Tuple[pippy.fx.GraphModule, str]: - """ - Split self.module so that one submodule consists of `nodes` and only `nodes`. - - Args: - nodes: Nodes that we want to include in the minimize submodule. - - Returns: - split_module (pippy.fx.GraphModule): the module after split. - submodule_name (str): the name of the submodule that consists of `nodes`. - """ - # Color provided nodes - self._tag_nodes(nodes) - - # Split module based on coloring - split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"]) - - # Find submodule containing colored nodes - submodule_name: str = "" - for child_name, _ in split_module.named_children(): - # Skip submodules we're not interested in at the moment - if "minimize" not in child_name: - continue - - if submodule_name == "": - submodule_name = child_name - else: - raise FxNetMinimizerBadModuleError( - f"Expected only one minimize submodule with nodes {nodes}" - ) - - if submodule_name == "": - raise FxNetMinimizerBadModuleError( - f"Minimize submodule was not found with nodes {nodes}" - ) - - return split_module, submodule_name - - def _run_and_compare( - self, split_module: pippy.fx.GraphModule, submod_name: str, output_names: Names - ): - """ - Run the submodule in `split_module` that has name `submod_name` - using `self.run_a` and `self.run_b` and compare their results. - - Args: - split_module: Main module that contains the minimize submodule. - submod_name: Name of the minimize submodule. - output_names: Names of the node we want to output. If None, we - will use the original output. - """ - submodule = getattr(split_module, submod_name) - a_input, b_input = self._get_submod_inputs(split_module, submod_name) - - if len(self.reports) == 0: - self.reports.append([]) - self.iteration = 1 - - report = self.reports[self.iteration - 1] - report.append("Run and compare ...") - - if output_names: - output_nodes: NodeList = [] - for node in submodule.graph.nodes: - if node.op == "output": - submodule.graph.erase_node(node) - - if node.name in output_names: - output_nodes.append(node) - - submodule.graph.output( - output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes) - ) - submodule.graph.lint() - submodule.recompile() - - # Use name of args in output node as key to store comparison result - for node in submodule.graph.nodes: - if node.op == "output": - result_key = map_arg(node.args, lambda x: x.name) - - a_result = self.run_a(submodule, a_input) - b_result = self.run_b(submodule, b_input) - self._store_outputs(a_result, b_result, submodule) - - # Compare results - names: Names = output_names - if output_names is None: - names = [str(v) for v in result_key] - - numeric_result, bool_result = self.compare_fn(a_result, b_result, names) - - self.results[result_key] = numeric_result - report.append(f"Numerical accuracy = {numeric_result}") - if not bool_result: - report.append(f"Result mismatch for {result_key}") - raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") - - def _binary_search_impl( - self, all_nodes: NodeList, start_idx: int, end_idx: int - ) -> NodeSet: - """ - Recursive binary search implementation. - """ - nodes: NodeList = all_nodes[start_idx:end_idx] - - report: List[str] = [] - self.reports.append(report) - self.iteration += 1 - report.append(f"Binary search iteration {self.iteration}.") - report.append( - f"From node index {start_idx} to {end_idx-1}. " - f"Size of the interested node list is {len(nodes)}" - ) - - cur_nodes: NodeSet = set(nodes) - - for node in nodes: - if node in self.fusions: - cur_nodes.update(self.fusions[node]) - - try: - split_module, submod_name = self._build_submodule(cur_nodes) - self._run_and_compare(split_module, submod_name, []) - except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): - - if len(nodes) == 1: - report.append( - f"This is the last node in the sub-module. " - f"Search in the current branch is successful with culprit = {cur_nodes}." - ) - self.print_report(report) - return cur_nodes - - report.append( - "Proceed to split and lower the halves of the current " - "sub-module individually." - ) - self.print_report(report) - - mid = len(nodes) // 2 - culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid) - - if len(culprits) != 0 and not self.settings.find_all: - return culprits - - culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx) - - if len(culprits) == 0: - report.append( - f"Further split and lowering found no errors. " - f"Unable to minimize the submodule with list of nodes: {nodes}" - ) - self.print_report(report) - - return culprits - else: - report.append("No discrepancy found.") - self.print_report(report) - return set() - - def _binary_traverse(self, nodes: NodeList) -> NodeSet: - """ - Binary search on `nodes` for culprit. - """ - return self._binary_search_impl(nodes, 0, len(nodes)) - - def _sequential_traverse(self, nodes: NodeList) -> NodeSet: - """ - Traverse `nodes` one by one and determine if any of them is a culprit. - """ - culprits: NodeSet = set() - - for node in nodes: - report: List[str] = [] - self.reports.append(report) - self.iteration += 1 - report.append(f"Sequential traverse iteration {self.iteration}.") - report.append(f"Visit node: {node.name}") - - _LOGGER.info(f"Visit node: {node.name}") - cur_nodes: NodeSet = {node} - - if node in self.fusions: - cur_nodes = self.fusions[node] - - try: - split_module, submod_name = self._build_submodule(cur_nodes) - self._run_and_compare(split_module, submod_name, [node.name]) - self.print_report(report) - except (FxNetMinimizerResultMismatchError): - culprits.add(node) - report.append(f"Found culprit from numeric error: {node}") - self.print_report(report) - if not self.settings.find_all: - return culprits - except (FxNetMinimizerRunFuncError): - culprits.update(cur_nodes) - report.append(f"Found culprit from run error: {node}") - self.print_report(report) - if not self.settings.find_all: - return culprits - - return culprits - - def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: - culprits: NodeSet = set() - nodes_to_run: NodeSet = set() - - # find_all is not supported for accumulate traversal because all the - # ops run on NNPI. So we return after the first op that raises error. - if self.settings.find_all: - print("'Find All' mode is not supported in accumulate traversal.") - return culprits - - for node in nodes: - report: List[str] = [] - self.reports.append(report) - self.iteration += 1 - report.append(f"Accumulate traverse iteration {self.iteration}.") - - nodes_to_run.add(node) - - node_name = node.name - if node_name is not None and isinstance(node_name, tuple): - node_name = node_name[0] - assert node_name is not None and isinstance( - node_name, str - ), f"minimize: node_name: {node_name}" - - report.append(f"Add node: {node_name}") - - try: - split_module, submod_name = self._build_submodule(nodes_to_run) - self._run_and_compare(split_module, submod_name, [node_name]) - self.print_report(report) - except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): - culprits.add(node) - report.append(f"Found culprit {node}") - self.print_report(report) - return culprits - - return culprits - - def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: - """ - Collect nodes in the model that between nodes with name of `start` and `end`. - These two nodes are also included. - """ - nodes: NodeList = [] - add_node = start is None - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - if node.name == start: - add_node = True - - if add_node: - nodes.append(node) - - if node.name == end: - break - - return nodes - - def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None): - """ - Run part of the model from `start` node to `end` node. If `start` is None - then we start from the beginning of the model. If `end` is None then we - stop at the end of the model. - - Args: - start: The name of the node which is the first node of the submodule - we want to run. If set to None, then we'll start with the first - node of the model. - end: The name of the node which is the last node of the submodule we - want to run. If set to None, we'll end with the last node of the - model. - """ - nodes = self._collect_nodes(start, end) - cur_nodes = set(nodes) - - for node in nodes: - if node in self.fusions: - cur_nodes.update(self.fusions[node]) - - output_names = [] - if self.settings.return_intermediate: - output_names = [node.name for node in nodes] - - try: - split_module, submod_name = self._build_submodule(cur_nodes) - self._run_and_compare(split_module, submod_name, output_names) - except ( - FxNetMinimizerRunFuncError, - FxNetMinimizerResultMismatchError, - ) as e: - print(e) - - def print_report(self, report: List[str]): - for i in range(len(report)): - if i > 0: - print(" . " + report[i]) - else: - print(report[i]) - - def print_reports(self): - for report in self.reports: - self.print_report(report) - - def minimize( - self, start: Optional[str] = None, end: Optional[str] = None - ) -> NodeSet: - """ - Minimizing the model from node with name `start` to node with name `end` base - on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or - FxNetMinimizerResultMismatchError errors. - - Args: - start: The name of the node where we want to start minimizing. If set - to None, then we'll start with the first node of the model. - end: The name of the node where we want to terminate minimizing. If - set to None, we'll end with the last node of the model. - - Returns: - nodes: A list of nodes that causes FxNetMinimizerRunFuncError or - FxNetMinimizerResultMismatchError errors during minimizing. - """ - - print(self.settings) - print(self.module.graph) - - nodes = self._collect_nodes(start, end) - - if self.settings.traverse_method == "sequential": - return self._sequential_traverse(nodes) - - if self.settings.traverse_method == "binary": - return self._binary_traverse(nodes) - - if self.settings.traverse_method == "accumulate": - return self._accumulate_traverse(nodes) - - raise RuntimeError(f"Unknow traverse method {self.settings.traverse_method}!") diff --git a/pippy/fx/passes/operator_support.py b/pippy/fx/passes/operator_support.py deleted file mode 100644 index 62aa708a7..000000000 --- a/pippy/fx/passes/operator_support.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import abc -import typing as t - -import torch -import pippy.fx -from pippy.fx._compatibility import compatibility -from .shape_prop import TensorMetadata -from .tools_common import get_node_target, CALLABLE_NODE_OPS - - -__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports'] - -# fx.Node.target typename, as returned by `get_node_target()` -TargetTypeName = str - -# Arguments' dtypes for a given node, see `OperatorSupport` -SupportedArgumentDTypes = t.Optional[ - t.Tuple[ - t.Sequence[t.Sequence[torch.dtype]], - t.Dict[str, t.Sequence[torch.dtype]], - ] -] - -SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] - - -@compatibility(is_backward_compatible=False) -class OperatorSupportBase(abc.ABC): - """Interface for determining if a fx.Node is supported by a backend""" - @abc.abstractmethod - def is_node_supported( - self, submodules: t.Mapping[str, torch.nn.Module], node: pippy.fx.Node - ) -> bool: - raise NotImplementedError() - - -@compatibility(is_backward_compatible=False) -class OperatorSupport(OperatorSupportBase): - """ - `_support_dict` maps node.target typename to supported inputs dtypes. - - node.target typename is retrieved using helper function `get_node_target()` - - If supported inputs dtypes is None, it means any dtype is supported, else - we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}). - - The first tuple ([dtypes], ...) indicates what dtypes are supported for - inputs in node.args and the second dict {"name": [dtypes], ...} indicates - what dtypes are supported for inputs in node.kwargs. - - For inputs in args, if we don't want to check it, we can put None there, - e.g. (None, [torch.float]) indicates that we don't care about the type of - the first input in args. And for inputs in kwargs, if not listed, will not - be checked. - """ - - _support_dict: SupportDict - - def __init__( - self, - support_dict: t.Optional[SupportDict] = None - ): - self._support_dict = support_dict or {} - - def is_node_supported( - self, submodules: t.Mapping[str, torch.nn.Module], node: pippy.fx.Node - ) -> bool: - """ - Args: - `sumodules`: mapping from module name to the module. This can be - retrieved by calling model.named_modules(). - - `node`: a Fx node that we want to determine whether it's supported. - - Returns: - `is_supported`: whether the arg `node` is supported. - """ - if node.op not in CALLABLE_NODE_OPS: - return True - - target = get_node_target(submodules, node) - - # Target not found in _support_dict meaning that we don't support this op at all - if target not in self._support_dict: - return False - - # The rule for target is None meaning that we accept any dtype - if self._support_dict[target] is None: - return True - - args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc] - - # Check args dtypes - for i, dtypes in enumerate(args_dtypes): - if len(node.args) <= i: - break - - # None indicates we don't care about the dtype of args[i] - if dtypes is None: - continue - - # If arg is not a node then we don't check it - if not isinstance(node.args[i], pippy.fx.Node): - continue - - arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type] - if arg_dtype not in dtypes: - return False - - # Check kwargs dtypes - for k, dtypes in kwargs_dtypes.items(): - if k not in node.kwargs: - continue - - # If arg is not a node then we don't check it - if not isinstance(node.kwargs[k], pippy.fx.Node): - continue - - kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type] - if kwarg_dtype not in dtypes: - return False - - return True - - -# ====================================================================== -# Functional interfaces and utils for defining basic operator support logic -# and composing them into more complex ones -# ====================================================================== - -IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], pippy.fx.Node], bool] - - -@compatibility(is_backward_compatible=False) -def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase: - """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance - - `IsNodeSupported` has the same call signature as - `OperatorSupportBase.is_node_supported` - """ - class FunctionalOperatorSupport(OperatorSupportBase): - def is_node_supported( - self, submodules: t.Mapping[str, torch.nn.Module], node: pippy.fx.Node - ) -> bool: - return is_node_supported(submodules, node) - return FunctionalOperatorSupport() - - -@compatibility(is_backward_compatible=False) -def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: - """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` - instance by evaluating each input `OperatorSupportBase` instance, and returns False if - any of it reports False. - """ - def _chain(submods, node) -> bool: - return all( - x.is_node_supported(submods, node) - for x in op_support - ) - return create_op_support(_chain) - - -@compatibility(is_backward_compatible=False) -class OpSupports: - """A set of atomic `OperatorSupportBase` instances that can be combined together - to form more complex operator support logic. - """ - @classmethod - def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: - """Report a node as non-supported, if any of its arguments is of dtype""" - - def _decline_if_input_dtype( - submodules: t.Mapping[str, torch.nn.Module], - node: pippy.fx.Node, - ) -> bool: - for arg in node.all_input_nodes: - # escape dtype check for get_attr node - if arg.op == "get_attr": - continue - arg_dtype = _get_arg_dtype(arg) - if arg_dtype == dtype: - return False - return True - return create_op_support(_decline_if_input_dtype) - - @classmethod - def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase: - """ - If a node has a name that is in the disallow set, reported it as non-supported. - """ - def _decline_if_node_in_names( - submodules: t.Mapping[str, torch.nn.Module], - node: pippy.fx.Node, - ) -> bool: - if node.name in disallow_set: - return False - else: - return True - return create_op_support(_decline_if_node_in_names) - - -def _get_arg_dtype(arg: pippy.fx.Node) -> t.Any: - assert isinstance(arg, pippy.fx.Node) - tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] - dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"] - return dtype diff --git a/pippy/fx/passes/param_fetch.py b/pippy/fx/passes/param_fetch.py deleted file mode 100644 index 411134f7b..000000000 --- a/pippy/fx/passes/param_fetch.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.graph_module import GraphModule -from typing import Any, Callable, Dict, List, Tuple, Type -import torch -import torch.nn as nn - -from pippy.fx._compatibility import compatibility - -__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] - -# Matching method matches the attribute name of current version to the attribute name of `target_version` -@compatibility(is_backward_compatible=False) -def default_matching(name: str, target_version: int) -> str: - """Default matching method - """ - return name - -# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. -# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. -# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. -module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { - torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), - torch.nn.modules.conv.Conv2d: ( - 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching - ), - torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), - torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), - torch.nn.modules.pooling.MaxPool2d: ( - 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching - ), - torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), -} - -@compatibility(is_backward_compatible=False) -def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: - """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` - after checking module's version is compatible with the `module_fetch_book`. - """ - attrs_for_lowering: Dict[str, Any] = {} - attrs_for_lowering["name"] = torch.typename(mod) - - if type(mod) in module_fetch_book: - version, param_to_fetch, matching_method = module_fetch_book[type(mod)] - if version < mod._version: - raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " - "please upgrade the module_fetch_book, open an issue and @842974287 " - "or report a bug to AIACC team directly.") - for attr in param_to_fetch: - attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) - else: - raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " - "please add it to the module_fetch_book, open an issue and @842974287 " - "or report a bug to AIACC team directly.") - return attrs_for_lowering - -@compatibility(is_backward_compatible=False) -def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: - """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. - """ - submodules = dict(fx_module.named_modules()) - - for node in fx_module.graph.nodes: - if node.op == "call_module": - if isinstance(submodules[node.target], GraphModule): - lift_lowering_attrs_to_nodes(submodules[node.target]) - else: - node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) diff --git a/pippy/fx/passes/pass_manager.py b/pippy/fx/passes/pass_manager.py deleted file mode 100644 index 3bdde31b5..000000000 --- a/pippy/fx/passes/pass_manager.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from functools import wraps -from inspect import unwrap -from typing import Callable, List -import logging - -logger = logging.getLogger(__name__) - - -# for callables which modify object inplace and return something other than -# the object on which they act -def inplace_wrapper(fn: Callable) -> Callable: - """ - Convenience wrapper for passes which modify an object inplace. This - wrapper makes them return the modified object instead. - - Args: - fn (Callable[Object, Any]) - - Returns: - wrapped_fn (Callable[Object, Object]) - """ - - @wraps(fn) - def wrapped_fn(gm): - val = fn(gm) - return gm - - return wrapped_fn - -def log_hook(fn: Callable, level=logging.INFO) -> Callable: - """ - Logs callable output. - - This is useful for logging output of passes. Note inplace_wrapper replaces - the pass output with the modified object. If we want to log the original - output, apply this wrapper before inplace_wrapper. - - - ``` - def my_pass(d: Dict) -> bool: - changed = False - if 'foo' in d: - d['foo'] = 'bar' - changed = True - return changed - - pm = PassManager( - passes=[ - inplace_wrapper(log_hook(my_pass)) - ] - ) - ``` - - Args: - fn (Callable[Type1, Type2]) - level: logging level (e.g. logging.INFO) - - Returns: - wrapped_fn (Callable[Type1, Type2]) - """ - @wraps(fn) - def wrapped_fn(gm): - val = fn(gm) - logger.log(level, f"Ran pass {fn}\t Return value: {val}",) - return val - - return wrapped_fn - - - -def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None): - """ - Convenience wrapper for passes which need to be applied multiple times. - - Exactly one of `n_iter`or `predicate` must be specified. - - Args: - base_pass (Callable[Object, Object]): pass to be applied in loop - n_iter (int, optional): number of times to loop pass - predicate (Callable[Object, bool], optional): - - """ - assert (n_iter is not None) ^ ( - predicate is not None - ), "Exactly one of `n_iter`or `predicate` must be specified." - - @wraps(base_pass) - def new_pass(source): - output = source - if n_iter is not None and n_iter > 0: - for _ in range(n_iter): - output = base_pass(output) - elif predicate is not None: - while predicate(output): - output = base_pass(output) - else: - raise RuntimeError( - f"loop_pass must be given positive int n_iter (given " - f"{n_iter}) xor predicate (given {predicate})" - ) - return output - - return new_pass - - -# Pass Schedule Constraints: -# -# Implemented as 'depends on' operators. A constraint is satisfied iff a list -# has a valid partial ordering according to this comparison operator. -def _validate_pass_schedule_constraint( - constraint: Callable[[Callable, Callable], bool], passes: List[Callable] -): - for i, a in enumerate(passes): - for j, b in enumerate(passes[i + 1 :]): - if constraint(a, b): - continue - raise RuntimeError( - f"pass schedule constraint violated. Expected {a} before {b}" - f" but found {a} at index {i} and {b} at index{j} in pass" - f" list." - ) - - -def this_before_that_pass_constraint(this: Callable, that: Callable): - """ - Defines a partial order ('depends on' function) where `this` must occur - before `that`. - """ - - def depends_on(a: Callable, b: Callable): - if a == that and b == this: - return False - return True - - return depends_on - - -def these_before_those_pass_constraint(these: Callable, those: Callable): - """ - Defines a partial order ('depends on' function) where `these` must occur - before `those`. Where the inputs are 'unwrapped' before comparison. - - For example, the following pass list and constraint list would be invalid. - ``` - passes = [ - loop_pass(pass_b, 3), - loop_pass(pass_a, 5), - ] - - constraints = [ - these_before_those_pass_constraint(pass_a, pass_b) - ] - ``` - - Args: - these (Callable): pass which should occur first - those (Callable): pass which should occur later - - Returns: - depends_on (Callable[[Object, Object], bool] - """ - - def depends_on(a: Callable, b: Callable): - if unwrap(a) == those and unwrap(b) == these: - return False - return True - - return depends_on - - -class PassManager: - """ - Construct a PassManager. - - Collects passes and constraints. This defines the pass schedule, manages - pass constraints and pass execution. - - Args: - passes (Optional[List[Callable]]): list of passes. A pass is a - callable which modifies an object and returns modified object - constraint (Optional[List[Callable]]): list of constraints. A - constraint is a callable which takes two passes (A, B) and returns - True if A depends on B and False otherwise. See implementation of - `this_before_that_pass_constraint` for example. - """ - - passes: List[Callable] = [] - constraints: List[Callable] = [] - _validated: bool = False - - def __init__( - self, - passes=None, - constraints=None, - ): - if passes: - self.passes = passes - if constraints: - self.constraints = constraints - - @classmethod - def build_from_passlist(cls, passes): - pm = PassManager(passes) - # TODO(alexbeloi): add constraint management/validation - return pm - - def add_pass(self, _pass: Callable): - self.passes.append(_pass) - self._validated = False - - def add_constraint(self, constraint): - self.constraints.append(constraint) - self._validated = False - - def remove_pass(self, _passes: List[Callable]): - if _passes is None: - return - passes_left = [] - for ps in self.passes: - if ps.__name__ not in _passes: - passes_left.append(ps) - self.passes = passes_left - self._validated = False - - def validate(self): - """ - Validates that current pass schedule defined by `self.passes` is valid - according to all constraints in `self.constraints` - """ - if self._validated: - return - for constraint in self.constraints: - _validate_pass_schedule_constraint(constraint, self.passes) - self._validated = True - - def __call__(self, source): - self.validate() - out = source - for _pass in self.passes: - out = _pass(out) - return out diff --git a/pippy/fx/passes/reinplace.py b/pippy/fx/passes/reinplace.py deleted file mode 100644 index 94419bbc9..000000000 --- a/pippy/fx/passes/reinplace.py +++ /dev/null @@ -1,663 +0,0 @@ -import torch -import pippy -from pippy.fx import Node -from pippy.fx._compatibility import compatibility -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor -from torch.utils._pytree import tree_map, tree_flatten, tree_map_only -from torch.multiprocessing.reductions import StorageWeakRef - -import _operator -from enum import Enum -import itertools -from typing import Set, Dict -from collections import defaultdict - -__all__ = ['reinplace'] - -class _ViewType(Enum): - NonView = 0 - SingleOutputView = 1 - MultiOutputView = 2 - -def _is_view_op(tgt): - if tgt is not None and isinstance(tgt, torch._ops.OpOverload): - schema = tgt._schema - if len(schema.arguments) > 0: - first_arg = schema.arguments[0] - # check if op is a view - return first_arg.alias_info is not None and not first_arg.alias_info.is_write - -def _get_view_type(tgt) -> _ViewType: - if tgt is not None and isinstance(tgt, torch._ops.OpOverload): - schema = tgt._schema - if len(schema.arguments) > 0: - first_arg = schema.arguments[0] - # check if op is a view - if first_arg.alias_info is not None and not first_arg.alias_info.is_write: - # check if op is a multi-output view - if '*' in first_arg.alias_info.after_set: - return _ViewType.MultiOutputView - else: - return _ViewType.SingleOutputView - return _ViewType.NonView - - -# Stores a bunch of metadata related to functionalization each node. -# Relevant metadata: -# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) -# The fake tensor output from running the current node -# n.meta['view_of']: Node -# If the current node n is a view of some base tensor, the 'view_of' field tells us which -# view node was used to generate the current node (a view tensor). -# This information actually makes `fake_result` redundant, but we can use `fake_result` -# to sanity check that our aliasing information is correct. -@compatibility(is_backward_compatible=False) -class _FunctionalizationMetadataProp(pippy.fx.Interpreter): - - def run_node(self, node: Node): - self.node_counter += 1 - result = super().run_node(node) - node.meta['fake_result'] = result - node.meta['node_idx'] = self.node_counter - - # (1) Update metadata with the list of nodes that are used by this node - # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. - # We don't want to treat it as "being used as an input". - node_args = node.args - if node.target is torch.ops.aten.copy_.default: - node_args = node_args[1:] - - # (2) Update metadata to track aliasing information about view tensor nodes. - if node.op == 'call_function': - view_type = _get_view_type(node.target) - if view_type == _ViewType.SingleOutputView: - assert isinstance(node.args[0], Node) - node.meta['view_of'] = node.args[0] - elif view_type == _ViewType.MultiOutputView: - self.multi_output_view_nodes[node] = node.args[0] - - # Check if we returned a multi-output view, - # and we're now grabbing the individual views from the output. - # - # For multi-output views, we want to map each output view to the base, - # but this mapping involves two separate nodes in FX IR. - # e.g. "a, b = x_1.split(...)" becomes: - # %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) - # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) - # %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) - # And we'd like to set: - # getitem1.meta['view_of'] = x_1 - elif node.target is _operator.getitem: - list_arg = node.args[0] - maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) - if maybe_base_of_view is not None: - # Note: we could also track indexing info here for multi-output views. - # I don't think this metadata is strictly needed for de-functionalization. - assert isinstance(maybe_base_of_view, Node) - node.meta['view_of'] = maybe_base_of_view - - if 'view_of' in node.meta: - # We're linking the current node with its first argument as views. - # Assert here that this is actually the case, and their storages are the same. - assert isinstance(node.meta['fake_result'], FakeTensor) - assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) - view_storage = StorageWeakRef(node.meta['fake_result'].storage()) - base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage()) - assert view_storage == base_storage - return result - - - - def propagate(self, *args): - self.multi_output_view_nodes = {} - self.node_counter = -1 - - with FakeTensorMode(allow_meta=True) as mode: - fake_args = [mode.from_tensor(a) for a in args] - return super().run(*fake_args) - -def _schemas_match(functional_schema, inplace_schema): - names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name - arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( - a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) - # for the inplace op, its first argument should be mutable - assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write - # and its remaining arguments shouldn't be. - assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) - return names_match and arg_types_match - -# TODO: this should be beefed up to be able to properly re-inplace with: -# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) -# - out= ops (e.g. angle -> angle.out) -# TODO: we should also figure this info out using torchgen. -def _maybe_get_inplace_op(op): - # __module__ seems broken; it returns torch._ops.aten which doesn't exist - if not isinstance(op, torch._ops.OpOverload): - return None - # Some view ops have inplace variants (as_strided_, etc), - # but we do NOT want the reinplacing pass to directly add these into the program. - # (they'll require extra special handling, aren't aren't really useful for perf anyway) - if _is_view_op(op): - return None - op_namespace = op.__module__.split(".")[-1] - op_base_name = op.overloadpacket.__name__ - maybe_namespace_module = getattr(torch.ops, op_namespace) - maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) - if maybe_inplace_op is None: - return None - - inplace_overloads = [ - getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() - ] - inplace_overloads_with_matching_schemas = [ - f - for f in inplace_overloads - if _schemas_match(op._schema, f._schema) - ] - # Just becuase foo() and foo_() are both existing operators, - # They aren't guaranteed to have compatible schemas. - # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant, - # Even though several overloads of pow_ exist. - if len(inplace_overloads_with_matching_schemas) == 0: - return None - assert len(inplace_overloads_with_matching_schemas) == 1 - inplace_op = inplace_overloads_with_matching_schemas[0] - return inplace_op - -_VIEW_INVERSE_MAP = { - torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, - torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, - torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, - torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, -} - -# This function, given a set of set of (aliased) tensor nodes, -# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index -# in the node ordering. -def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): - def _add_if_tensor(x, set_): - if isinstance(x, FakeTensor): - set_.add(StorageWeakRef(x.storage())) - - nodes_used_after = set() - for t in tensor_aliases: - # get all nodes that use the current alias - usage_nodes = t.users - for n in usage_nodes: - # We only care about usages after the current node - if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index: - continue - # We also don't care about intermediate view ops. - # They only matter if their output is then used elsewhere - # (either in an out-of-place op, or as an output to the function). - if n in tensor_aliases: - if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: - continue - nodes_used_after.add(n) - return nodes_used_after - -# Given an op that we're trying to re-inplace, "b = foo(a)", -# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" -# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: -# If there are any aliases in the alias_set(a) that satisfy: -# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" -# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata -# as "alias" -def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: - def matching_view_metadata(a, b): - return a.size() == b.size() and \ - a.stride() == b.stride() and \ - a.storage_offset() == b.storage_offset() - - view_inverse_nodes = set() - # Go through them in node order, so we can see chains of view_scatter ops. - for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): - if n.target not in _VIEW_INVERSE_MAP: - continue - base = n.args[0] - mutated_view = n.args[1] - assert isinstance(base, Node) - assert isinstance(base.meta['fake_result'], FakeTensor) - assert isinstance(mutated_view, Node) - assert isinstance(mutated_view.meta['fake_result'], FakeTensor) - # Check that this view_inverse op actually corresponds to taking doing the inverse - # of one of our existing self_alias nodes. - original_view = _VIEW_INVERSE_MAP[n.target] - for self_alias in self_aliases: - # We're looking for some alias of the self arg, "alias", - # that was created from some op `alias = foo(base, args...)` - # such that the current _scatter op "inverts" that foo call. - # We can check that by running the original op again, and checking that the strides match. - if 'view_of' not in self_alias.meta: - continue - self_alias_base = self_alias.meta['view_of'] - try: - # The we're trying to re-use the args from the view_scatter call inside of the corresponding - # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse - # of the current alias we're looking at. - view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) - expected_metadata = self_alias.meta['fake_result'] - # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. - if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ - matching_view_metadata(view_replay_metadata, expected_metadata): - view_inverse_nodes.add(n) - except Exception: - continue - - return view_inverse_nodes - - -@compatibility(is_backward_compatible=True) -def reinplace(gm, *sample_args): - """ - Given an fx.GraphModule, modifies it to perform "reinplacing", - mutating the nodes of the graph. - We look for out-of-place op call sites like `b = a.add(...)`, - and convert them to be inplace (`b = a.add_(...)`), - as long as the input to the current operator ("a") isn't re-used - anywhere later in the graph. - - This pass currently expects to operate on a **functional, ATen** graph. - This can be obtained by running `make_fx(functionalize(f))`. - - Sample inputs are needed to determine aliasing relationships of the inputs. - In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the - inputs to the program. - - Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: - - (1) Perform some initial checks on the metadata of "a" and "args..." - that can disqualify them from being reinplaced. - - (1a) Check that the self argument we're attempting to reinplace - has acceptable dtype/size metadata to reinplace with. - - For example, if we have: - a = torch.ones(1) - b = torch.ones(10) - out = torch.add(a, b) - We can't turn that into - a.add_(b) - Because that would require resizing "a". - - Similarly, we can't convert torch.ge(a, b) into a.ge_(b), - beause that would require changing a's dtype (from e.g. float32 to bool). - Note that in this specific example, we could technically do better.. - - If we see the pattern: - a_1 = a.ge(b) - a_2 = aten._to_copy(a_1, a.dtype) - Then we this should be valid to completely re-inplace - (this is exactly what functionalization will emit when it sees a.ge_(b)). - - This optimization is only really important for user programs - that directly use inplace comparison ops though. - - We also cannot re-inplace on tensors that have overlapping memory, - e.g. torch.ones(1).expand(4, 4).add_(1) - - (1b) Check if "a" is an alias of any of the program inputs. - - If it is, skip and move to the next node. - Inplace'ing an op that would cause it to mutate a program is not sound, - because that would be a side effect visible to the user. - - NOTE: there's a future optimization that we should make: - if "a" is a (alias of a) program input, but later in the program - there is a node that looks like "a.copy_(...)", - Then re-inplacing is ok to do - we are temporarily re-using a's buffer, - which will later be overwritten by the copy_() call. - - This will be an important optimization to have for programs that mutate - their inputs. It currently isn't implemented though. - - (1c) Check if "a" and "args..." alias - - For example, re-inplacing to create code like the below - isn't guaranteed to be sound: - - aten.mul_(a, a) - - (2) Check that "a" and all of its outstanding aliases are not used anywhere - later in the graph. If this is the case, then it's safe to re-inplace - to "b = foo_(a)". - - There are a few caveats to this, explained in more detail below: - (a) If "a" is used later as an argument to a view op, that is okay. - It's only a problem if "a" (or that view) is later passed - into a normal operator, or if it is returned as the program output. - (b) If "a" is a repeat argument in `foo()`, then don't reinplace. - Most ATen kernels don't make any guarantees that this is sound, - e.g. if you do aten.mul_(a, a). - So we'll just ban re-inplacing in this case. - It's only a problem if "a" (or that view) is later passed - (c) If "a" is used as an input into a view "inverse" / "scatter" - operator, it is potentially fine to re-inplace - (and remove that scatter operator from the graph). - See below for a more detailed example. - - NOTE: there is an optimization in this step that is crucial - to fully recovering performance from functionalization. - - Given this program: - def f(x): - a = torch.ops.aten.add(x, x) - b = torch.ops.aten.diagonal(a) - torch.ops.aten.fill_(b, 0) - return d - - Functionalization will emit the following: - def f(x): - a = torch.ops.aten.add(x, x) - b = torch.ops.aten.diagonal(a, 0, 1) - b_updated = torch.ops.aten.fill(b, 0) - a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) - return a_updated - - Ordinarily, we would not be able to reinplace the fill, - because "b" aliases with "a" which is used by the diagonal_scatter call. - - "re-inplacing" is on the hook for figuring out that it is ok to - completely, the expensive diagonal_scatter call, if we re-inplace the add(). - - So, for every `alias in alias_set(a)`, instead of checking - that "alias" is not used anywhere later in the graph, - we check that - EITHER: - (a) alias is not used anywhere later in the graph - OR: - (b) alias is used exactly once later on in the graph, - in the following op: - - out = foo_scatter(alias, x, args...) - - where the following must hold: - (i) "foo_scatter" is the "inverse" operator for foo. - This only applies to "foo" ops that are view operators, - which view into a subset of the original tensor's memory. - In practice, there are ~4 operators where this applies: - diagonal -> diagonal_scatter - slice -> slice_scatter - select -> select_scatter - as_strided -> as_strided_scatter - (ii) "args..." are the same between the foo() and foo_scatter() calls. - - (3) Perform the actual re-inplacing on foo! - - (3b) is the common case, but special care is needed for {view}_scatter (3a) - - (3a) {view}_scatter ops. - - Consider this program: - a = torch.zeros(2, 2) - b = torch.ones(2) - a[0] = b - - Post functionalization, that will look like: - a = torch.zeros(2) - b = torch.ones(1) - a_updated = torch.select_scatter(a, b, 0, 0) - - In this case though, there is no "functional" op to re-inplace! - Instead, we'd like to directly remove toe select_scatter call. - We already know from (3) that this is valid, - because "a" has no later usages in the graph. - - We perform the re-inplacing on the {view}_scatter op like so - Before: - a_updated = torch.select_scatter(a, b, args...) - After: - a_slice = a.select(a, args...) - a_slice.copy_(b) - - (3b) Otherwise, replace the functional op with its inplace variant. - Before: - b = foo(a, args...) - After: - a.foo_(args...) - - (4) Finally, after converting either: - Before: - b = foo(a) - After: - foo_(a) - or - Before: - b = {slice}_scatter(a, mutated_slice, args...) - After: - slice = {slice}(a, args...) - slice.copy_(mutated_slice) - - We now need to find all later nodes that use "b" as an argument - and update them to take in "a" instead. - - Note that for the majority of inplace ops, this isn't actually necessary - (because most inplace ops return "self" as their output). - This isn't generally true for all mutable ops though, which is why - we need to actually replace all of the arguments. - - We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], - That maps a given tensor storage to the set of all nodes that take in that storage - as an input. - Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused - together. - - (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" - during step (3) get manually deleted from the graph. - Their outputs are no longer used, so technically standard DCE would be able - to do this, but we can no longer run FX's DCE pass now that we have mutable - ops in the graph. - """ - _FunctionalizationMetadataProp(gm).propagate(*sample_args) - - # Useful debug printing - # def _print(x): - # if isinstance(x, FakeTensor): - # print(f'fake_result: {StorageWeakRef(x.storage()).cdata}') - - # for n in gm.graph.nodes: - # print(n.format_node()) - # if hasattr(n, 'meta'): - # print(f'node_idx: {n.meta["node_idx"]}') - # if 'fake_result' in n.meta: - # tree_map(_print, n.meta['fake_result']) - # if 'view_of' in n.meta: - # print(f'view_of: {str(n.meta["view_of"])}') - # print() - - # We need to know which nodes correspond to inputs (or their aliases) - # so we know not to re-inplace them. - # NOTE: later, we'll need to add an optimization for fully recovering performance - # on programs that mutate inputs. - input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder') - - - # We also need to know for a given node, what are all of its aliasing nodes. - storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) - for n in gm.graph.nodes: - if 'fake_result' in n.meta: - # Tree-mapping because some ops can return lists of tensors. - def _add_to_map(x): - if isinstance(x, FakeTensor): - storage_to_nodes[StorageWeakRef(x.storage())].add(n) - tree_map(_add_to_map, n.meta['fake_result']) - - # inplace-ify functional ops, subject to the constraints written below. - all_later_view_inverse_nodes_to_delete = set() - for idx, node in enumerate(gm.graph.nodes): - if node.op == 'call_function': - - # Today, the re-inplace pass on directly acts on: - # - functional ops with an inplace variant - # - {view}_scatter ops that can be potentially removed from the graph. - # Both of these ops take in tensor first args, so filtering on this condition - # makes the later code simpler. - # We should revisit this at some point though, particularly when we also want - # the reinplacer to be able to handle out= and mutable operators - # and tensorlist first args (like `_foreach_` ops). - if not isinstance(node.target, torch._ops.OpOverload): - continue - if len(node.target._schema.arguments) < 1: - continue - if type(node.target._schema.arguments[0].type) != torch.TensorType: - continue - - # Step 1a: Check that the self argument we're attempting to reinplace - # has the same size/stride as the output. - # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor) - # As it would require resizing scalar_tensor. - # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), - # this is probably an optimization to revisit later). - self_arg = node.args[0] - self_flattened, _ = tree_flatten(self_arg.meta['fake_result']) - node_flattened, _ = tree_flatten(node.meta['fake_result']) - self_has_wrong_metadata = False - if len(self_flattened) == len(node_flattened): - for self_meta, node_meta in zip(self_flattened, node_flattened): - if self_meta.numel() != node_meta.numel(): - self_has_wrong_metadata = True - if self_meta.dtype != node_meta.dtype: - self_has_wrong_metadata = True - # We also cannot re-inplace on tensors that have internal memory overlap. - # e.g. torch.ones(1).expand(4, 4).add_(1) - if torch._debug_has_internal_overlap(self_meta) == 1: - self_has_wrong_metadata = True - # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace, - # Since users should never really be calling the functional "torch.ops.aten.resize" - # op directly in their programs. - if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: - continue - - # Step 1b: ensure that the op we're trying to re-inplace isn't a program input - self_arg_name = self_arg.name - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) - if self_arg_storage in input_storages: - # TODO: later, add the optimization for handling `copy_()` calls in the graph. - continue - if len([x for x in node.args if x is self_arg]) > 1: - # Step 1c: - # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, - # so we prevent re-inplacing in this case. - continue - - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) - self_aliases = storage_to_nodes[self_arg_storage] - - # First, we find all later usages of any of the aliases of self_arg. - later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) - # Then, we check if any of those later usages are actually view_scatter ops - # that are safe to fully remove. - later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) - - # Step 2: Check to see if the input to the op is re-used later in the graph. - # If not (same goes for its aliases), then this op is safe to re-in place. - # This is a slightly roundabout way to check that there are no later usages of the current self argument. - # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) - can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 - if not can_reinplace: - continue - - # Step 3a: Special handling for when we see *_scatter operators. - # When we see an operator like `b = torch.slice_scatter(a, ...)`, - # instead of trying to "inplace" it into a.slice_scatter_(..._), - # we would prefer to remove it from the graph entirely, - # and instead copy_() the slice directly into the larger tensor. - # See the description of the algorithm for a full example. - if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete: - view_op = _VIEW_INVERSE_MAP[node.target] - # Before: - # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) - # After: - # slice = torch.ops.aten.slice.default(base, args...) - # slice.copy_(mutated_slice) - with gm.graph.inserting_before(node): - mutated_slice_node = node.args[1] - remaining_slice_args = node.args[2:] - slice_node = gm.graph.create_node( - 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs) - copy_node = gm.graph.create_node( - 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {}) - # Add the slice_scatter node to our "nodes to delete" list. - all_later_view_inverse_nodes_to_delete.add(node) - - - else: - # Step 3b: Check to see if this operator has an inplace variant. - maybe_inplace_op = _maybe_get_inplace_op(node.target) - if maybe_inplace_op is None: - continue - # And if so, replace it with its inplace variant. - node.target = maybe_inplace_op - - # At this point, 'storage_to_nodes' will be stale. - # Now that we're inplacing `b = foo(a)`, we need to effectively - # union together the dict values for b and a's storage. - # Hmm... morally I think we also want to keep the `fake_result` metadata - # up to date here, but I'm not sure how easy it is to do. - # Maybe it's fine to wait until the end of the pass to update it. - curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage()) - storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) - storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) - - # Need to remember the view_scatter view nodes we found so we can remove them alter. - all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages) - - # Step 4: - # Now that we've replaced b = a.foo() with a.foo_(), - # We need to replace any later usages of "b" with "a" - for old in itertools.chain([node], later_view_inverse_node_usages): - new = old.args[0] - nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] - for node_to_update in nodes_to_update: - new_args = [] - args = node_to_update.args - - def replace_arg(a): - if a == old: - return new - return a - - # First, replace usages of "b" with "a" - node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args) - node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs) - - # Second, update our storage_to_nodes data structure. - old_flattened_res, _ = tree_flatten(old.meta['fake_result']) - node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result']) - - old_res_storage = set(StorageWeakRef(x.storage()) for x in old_flattened_res if isinstance(x, FakeTensor)) - node_res_storage = set(StorageWeakRef(x.storage()) for x in node_flattened_res if isinstance(x, FakeTensor)) - - # This will happen if we're updating a view op, e.g. - # e.g. replacing - # x = view(old) - # x = view(new) - # When that happens, we need to make sure to keep our - # storage mapping up to date. - # - # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, - # or multiple tensors that all share the same storage. - # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. - if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: - new_flattened_res, _ = tree_flatten(new.meta['fake_result']) - new_res_storage = set(StorageWeakRef(x.storage()) for x in new_flattened_res if isinstance(x, FakeTensor)) - assert len(new_res_storage) == 1 - (old_ref,) = old_res_storage - (new_ref,) = new_res_storage - (node_ref,) = node_res_storage - # Technically, "old_ref" and all its aliases will remain - # in our mapping. - # That should be fine though, since we deleted "old" - # from the graph at this point. - storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) - storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) - - # Step 4: delete any _scatter nodes that we de-functionalized - # Need to take care not to delete any of these nodes until after *all* modifications - # to the graph are finished. - for to_delete in all_later_view_inverse_nodes_to_delete: - gm.graph.erase_node(to_delete) - - - gm.recompile() - return gm diff --git a/pippy/fx/passes/shape_prop.py b/pippy/fx/passes/shape_prop.py deleted file mode 100644 index 9745136a2..000000000 --- a/pippy/fx/passes/shape_prop.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import pippy.fx -import traceback - -from pippy.fx.node import Node, map_aggregate -from typing import Any, Tuple, NamedTuple, Optional, Dict -from pippy.fx._compatibility import compatibility - -__all__ = ['TensorMetadata', 'ShapeProp'] - -@compatibility(is_backward_compatible=True) -class TensorMetadata(NamedTuple): - # TensorMetadata is a structure containing pertinent information - # about a tensor within a PyTorch program. - - # General Tensor metadata - shape : torch.Size - dtype : torch.dtype - requires_grad : bool - stride : Tuple[int] - memory_format : Optional[torch.memory_format] - - # Quantization metadata - is_quantized : bool - qparams: Dict[str, Any] - -def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: - """ - Extract a TensorMetadata NamedTuple describing `result`. - """ - shape = result.shape - dtype = result.dtype - requires_grad = result.requires_grad - stride = result.stride() - - memory_formats = { - torch.contiguous_format, - torch.channels_last, - torch.channels_last_3d, - } - - memory_format = None - - for query_format in memory_formats: - if result.is_contiguous(memory_format=query_format): - memory_format = query_format - break - - is_quantized = result.is_quantized - qparams: Dict[str, Any] = {} - if is_quantized: - qscheme = result.qscheme() - qparams["qscheme"] = qscheme - if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: - qparams["scale"] = result.q_scale() # type: ignore[assignment] - qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] - elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: - # In this branch, scale and zero_point are expected to be tensors, - # we store the values as immutable_list in TensorMetadata for - # easier serialization downstream - qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] - qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] - qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] - - return TensorMetadata( - shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) - -@compatibility(is_backward_compatible=True) -class ShapeProp(pippy.fx.Interpreter): - """ - Execute an FX graph Node-by-Node and - record the shape and type of the result - into the corresponding node. - - Example: - In this example, we record the shape - and data type of a module given - an example input ``torch.randn(50, D_in)``. - We print the name, shape and dtype of each node. - - class TwoLayerNet(torch.nn.Module): - def __init__(self, D_in, H, D_out): - super(TwoLayerNet, self).__init__() - self.linear1 = torch.nn.Linear(D_in, H) - self.linear2 = torch.nn.Linear(H, D_out) - def forward(self, x): - h_relu = self.linear1(x).clamp(min=0) - y_pred = self.linear2(h_relu) - return y_pred - N, D_in, H, D_out = 64, 1000, 100, 10 - x = torch.randn(N, D_in) - y = torch.randn(N, D_out) - model = TwoLayerNet(D_in, H, D_out) - gm = pippy.fx.symbolic_trace(model) - sample_input = torch.randn(50, D_in) - ShapeProp(gm).propagate(sample_input) - - for node in gm.graph.nodes: - print(node.name, node.meta['tensor_meta'].dtype, - node.meta['tensor_meta'].shape) - - The output of this code is: - - x torch.float32 torch.Size([50, 1000]) - linear1 torch.float32 torch.Size([50, 100]) - clamp_1 torch.float32 torch.Size([50, 100]) - linear2 torch.float32 torch.Size([50, 10]) - output torch.float32 torch.Size([50, 10]) - - Args: - module (GraphModule): The module to be executed - - """ - def run_node(self, n : Node) -> Any: - try: - result = super().run_node(n) - except Exception: - traceback.print_exc() - raise RuntimeError( - f"ShapeProp error for: node={n.format_node()} with " - f"meta={n.meta}" - ) - - found_tensor = False - - def extract_tensor_meta(obj): - if isinstance(obj, torch.Tensor): - nonlocal found_tensor - found_tensor = True - return _extract_tensor_metadata(obj) - else: - return obj - - meta = map_aggregate(result, extract_tensor_meta) - if found_tensor: - n.meta['tensor_meta'] = meta - - n.meta['type'] = type(result) - return result - - def propagate(self, *args): - """ - Run `module` via interpretation and return the result and - record the shape and type of each node. - - Args: - *args (Tensor): the sample input. - - Returns: - Any: The value returned from executing the Module - """ - return super().run(*args) diff --git a/pippy/fx/passes/split_module.py b/pippy/fx/passes/split_module.py deleted file mode 100644 index 2ccc28108..000000000 --- a/pippy/fx/passes/split_module.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -from typing import Any, Callable, Dict, List, Optional - -import torch - -import pippy -import pippy.fx -from pippy.fx._compatibility import compatibility -from pippy.fx.graph_module import GraphModule - -__all__ = ["Partition", "split_module"] - - -@compatibility(is_backward_compatible=True) -class Partition: - def __init__(self, name: str): - self.name: str = name - self.submod_name = f"submod_{name}" - self.node_names: List[str] = [] - self.inputs: Dict[str, None] = {} - self.outputs: Dict[str, None] = {} - self.partitions_dependent_on: Dict[str, None] = {} - self.partition_dependents: Dict[str, None] = {} - self.graph: pippy.fx.graph.Graph = pippy.fx.graph.Graph() - self.environment: Dict[pippy.fx.node.Node, pippy.fx.node.Node] = {} - self.targets: Dict[str, Any] = {} - - def __repr__(self) -> str: - return ( - f"name: {self.name},\n" - f" nodes: {self.node_names},\n" - f" inputs: {self.inputs},\n" - f" outputs: {self.outputs},\n" - f" partitions depenent on: {self.partitions_dependent_on},\n" - f" parition dependents: {self.partition_dependents}" - ) - - -# Creates subgraphs out of main graph -@compatibility(is_backward_compatible=True) -def split_module( - m: GraphModule, - root_m: torch.nn.Module, - split_callback: Callable[[pippy.fx.node.Node], int], - qualname_map: Optional[Dict[str, str]] = None, - keep_original_order: Optional[bool] = False, -): - """ - Creates subgraphs out of main graph - - Args: - m (GraphModule): Graph module to split - root_m (torch.nn.Module): root nn module. Not currently used. Included - because the root nn module is usually transformed via - pippy.fx._symbolic_trace.symbolic_trace (see example below) - split_callback (Callable[[pippy.fx.node.Node], int]): Callable function - that maps a given Node instance to a numeric partition identifier. - split_module will use this function as the policy for which operations - appear in which partitions in the output Module. - qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a - mapping from new target names in the module after split to old target - names in the original module. - keep_original_order: Optional[bool]: keep the original order of the GraphModule - or use the Topological order of the new constructed GraphModule - - - Returns: - GraphModule: the module after split. - - Example: - - This is a sample setup: - - import torch - from pippy.fx.symbolic_trace import symbolic_trace - from pippy.fx.graph_module import GraphModule - from pippy.fx.node import Node - from pippy.fx.passes.split_module import split_module - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x, y): - z = self.linear(x + self.param).clamp(min=0.0, max=1.0) - w = self.linear(y).clamp(min=0.0, max=1.0) - return z + w - - # symbolically trace model - my_module = MyModule() - my_module_traced = symbolic_trace(my_module) - - # random mod partitioning - partition_counter = 0 - NPARTITIONS = 3 - - def mod_partition(node: Node): - global partition_counter - partition = partition_counter % NPARTITIONS - partition_counter = (partition_counter + 1) % NPARTITIONS - return partition - - # split module in module with submodules - module_with_submodules = split_module( - my_module_traced, my_module, mod_partition - ) - - Output looks like this. Original graph is broken into partitions - - > print(module_with_submodules) - GraphModule( - (submod_0): GraphModule( - (linear): Linear(in_features=4, out_features=5, bias=True) - ) - (submod_1): GraphModule( - (linear): Linear(in_features=4, out_features=5, bias=True) - ) - (submod_2): GraphModule() - ) - - def forward(self, x, y): - param = self.param - submod_0 = self.submod_0(x, param, y); x = param = y = None - getitem = submod_0[0] - getitem_1 = submod_0[1]; submod_0 = None - submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None - getitem_2 = submod_1[0] - getitem_3 = submod_1[1]; submod_1 = None - submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None - return submod_2 - - Output of split module is the same as output of input traced module. - This is an example within a test setting: - - > orig_out = my_module_traced(x, y) - > submodules_out = module_with_submodules(x, y) - > self.assertEqual(orig_out, submodules_out) - True - """ - partitions: Dict[str, Partition] = {} - orig_nodes: Dict[str, pippy.fx.node.Node] = {} - - def record_cross_partition_use( - def_node: pippy.fx.node.Node, use_node: Optional[pippy.fx.node.Node] - ): # noqa: B950 - def_partition_name = getattr(def_node, "_fx_partition", None) - use_partition_name = getattr(use_node, "_fx_partition", None) - if def_partition_name != use_partition_name: - if def_partition_name is not None: - def_partition = partitions[def_partition_name] - def_partition.outputs.setdefault(def_node.name) - if use_partition_name is not None: - def_partition.partition_dependents.setdefault(use_partition_name) - - if use_partition_name is not None: - use_partition = partitions[use_partition_name] - use_partition.inputs.setdefault(def_node.name) - if def_partition_name is not None: - use_partition.partitions_dependent_on.setdefault(def_partition_name) - - # split nodes into parititons - for node in m.graph.nodes: - orig_nodes[node.name] = node - - # TODO currently placeholders/parameters aren't put into random partitions, - # rather they're added to the graphs where they are used down below - if node.op in ["placeholder", "get_attr"]: - continue - if node.op == "output": - pippy.fx.graph.map_arg( - node.args[0], lambda n: record_cross_partition_use(n, None) - ) - continue - partition_name = str(split_callback(node)) - - # add node to partitions - partition = partitions.get(partition_name) - if partition is None: - partitions[partition_name] = partition = Partition(partition_name) - - partition.node_names.append(node.name) - node._fx_partition = partition_name - - pippy.fx.graph.map_arg( - node.args, lambda def_node: record_cross_partition_use(def_node, node) - ) - pippy.fx.graph.map_arg( - node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) - ) # noqa: B950 - - original_partition_order = list(partitions.keys()) - # find partitions with no dependencies - root_partitions: List[str] = [] - for partition_name, partition in partitions.items(): - if not len(partition.partitions_dependent_on): - root_partitions.append(partition_name) - - # check partitions for circular dependencies and create topological partition ordering - sorted_partitions: List[str] = [] - while root_partitions: - root_partition = root_partitions.pop() - sorted_partitions.append(root_partition) - for dependent in partitions[root_partition].partition_dependents: - partitions[dependent].partitions_dependent_on.pop(root_partition) - if not partitions[dependent].partitions_dependent_on: - root_partitions.append(dependent) - if len(sorted_partitions) != len(partitions): - raise RuntimeError("cycle exists between partitions!") - - # add placeholders to parititons - for partition_name in sorted_partitions: - partition = partitions[partition_name] - for input in partition.inputs: - placeholder = partition.graph.placeholder(input) - placeholder.meta = orig_nodes[input].meta.copy() - partition.environment[orig_nodes[input]] = placeholder - - # Transform nodes and collect targets for partition's submodule - for node in m.graph.nodes: - if hasattr(node, "_fx_partition"): - partition = partitions[node._fx_partition] - - # swap out old graph nodes in kw/args with references to new nodes in this submodule - environment = partition.environment - gathered_args = pippy.fx.graph.map_arg(node.args, lambda n: environment[n]) - gathered_kwargs = pippy.fx.graph.map_arg( - node.kwargs, lambda n: environment[n] - ) - - if node.op not in ["call_module", "get_attr"]: - target = node.target - else: - target_atoms = node.target.split(".") - target_attr = m - for atom in target_atoms: - if not hasattr(target_attr, atom): - raise RuntimeError(f"Operator target {node.target} not found!") - target_attr = getattr(target_attr, atom) - # target = target_atoms[-1] - target = "_".join(target_atoms) - partition.targets[target] = target_attr - # Fill in the passed-in mapping from new qualname to old qualname - if qualname_map is not None: - # When creating the split module later, the submodules will have - # path prefix matching the corresponding partition's submod_name - qualname = f"{partition.submod_name}.{target}" - qualname_map[qualname] = node.target - - assert isinstance(gathered_args, tuple) - assert isinstance(gathered_kwargs, dict) - new_node = partition.graph.create_node( - op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs - ) - new_node.meta = node.meta.copy() - partition.environment[node] = new_node - - # Set up values to construct base module - base_mod_env: Dict[str, pippy.fx.node.Node] = {} - base_mod_graph: pippy.fx.graph.Graph = pippy.fx.graph.Graph() - base_mod_attrs: Dict[str, pippy.fx.graph_module.GraphModule] = {} - for node in m.graph.nodes: - if node.op == "placeholder": - default_value = ( - node.args[0] if len(node.args) > 0 else inspect.Signature.empty - ) - base_mod_env[node.name] = base_mod_graph.placeholder( - node.target, type_expr=node.type, default_value=default_value - ) - base_mod_env[node.name].meta = node.meta.copy() - elif node.op == "get_attr": - base_mod_env[node.name] = base_mod_graph.get_attr(node.target) - base_mod_env[node.name].meta = node.meta.copy() - attr_val = m - for atom in node.target.split("."): - if not hasattr(attr_val, atom): - raise RuntimeError(f"Node target {node.target} not found!") - attr_val = getattr(attr_val, atom) - base_mod_attrs[node.target] = attr_val - - # Do some things iterating over the partitions in topological order again: - # 1) Finish off submodule Graphs by setting corresponding outputs - # 2) Construct GraphModules for each submodule - # 3) Construct the base graph by emitting calls to those submodules in - # topological order - - construct_order_partitions = ( - sorted_partitions if not keep_original_order else original_partition_order - ) - - for partition_name in construct_order_partitions: - partition = partitions[partition_name] - - # Set correct output values - output_vals = tuple( - partition.environment[orig_nodes[name]] for name in partition.outputs - ) - output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] - partition.graph.output(output_vals) - - # Construct GraphModule for this partition - base_mod_attrs[partition.submod_name] = pippy.fx.graph_module.GraphModule( - partition.targets, partition.graph - ) # noqa: B950 - - # Emit call in base graph to this submodule - output_val = base_mod_graph.call_module( - partition.submod_name, - tuple(base_mod_env[name] for name in partition.inputs), - ) - if len(partition.outputs) > 1: - # Unpack multiple return values from submodule - output_val_proxy = pippy.fx.proxy.Proxy(output_val) - for i, output_name in enumerate(partition.outputs): - base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] - else: - base_mod_env[list(partition.outputs)[0]] = output_val - - for node in m.graph.nodes: - if node.op == "output": - base_mod_graph.output( - pippy.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) - ) # noqa: B950 - - return pippy.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/pippy/fx/passes/split_utils.py b/pippy/fx/passes/split_utils.py deleted file mode 100644 index f6f8b90be..000000000 --- a/pippy/fx/passes/split_utils.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from dataclasses import dataclass, field -from typing import List, Optional, Dict - -import pippy.fx -from pippy.fx.graph import map_arg -from .tools_common import NodeList -from pippy.fx._compatibility import compatibility -from pippy.fx.passes.utils import lift_subgraph_as_module, HolderModule - -__all__ = ['getattr_recursive', 'setattr_recursive', 'Component', 'split_by_tags'] - -@compatibility(is_backward_compatible=False) -def getattr_recursive(obj, name): - for layer in name.split("."): - if hasattr(obj, layer): - obj = getattr(obj, layer) - else: - return None - return obj - - -@compatibility(is_backward_compatible=False) -def setattr_recursive(obj, attr, value): - if "." not in attr: - setattr(obj, attr, value) - else: - layer = attr.split(".") - setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) - - -@compatibility(is_backward_compatible=False) -@dataclass -class Component: - """ - A component serves as a container for a subgraph we want to create afterwards. - """ - - graph: pippy.fx.Graph - order: int - name: str - - # Stores the placeholder nodes in `graph`. - input_placeholders: List = field(default_factory=list) - - # Store the nodes in original graph that are placeholder in `graph`. - orig_inputs: List = field(default_factory=list) - - # Store the nodes in original graph that are outputs in `graph`. - orig_outputs: List = field(default_factory=list) - - # Mapping from get_attr node in original graph to get_attr node in `graph`. - getattr_maps: Dict[pippy.fx.Node, pippy.fx.Node] = field(default_factory=dict) - constructor_args: List[str] = field(default_factory=list) - gm: Optional[pippy.fx.GraphModule] = None - - -@compatibility(is_backward_compatible=False) -def split_by_tags(gm: pippy.fx.GraphModule, tags: List[str]) -> pippy.fx.GraphModule: - """ - Splits a GraphModule using tags on its graph nodes. We honor the order of - tags. For example, we have tags = ["a", "b", "c"], the function will create - the initial submodules in the order of "a_0", "b_1", "c_2". - - To set a tag: - gm.graph.nodes[idx].tag = "mytag" - - This will result in all nodes with the same tag being extracted and placed in their - own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder - and output nodes are created when needed while get_attr nodes get copied to submodules - where they are used. - - Given the following module def: - - class SimpleModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(...) - self.linear2 = torch.nn.Linear(...) - self.linear3 = torch.nn.Linear(...) - - def forward(self, in1, in2): - r1 = self.linear1(in1) - r2 = self.linear2(in2) - r3 = torch.cat([r1, r2]) - return self.linear3(r3) - - Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: - - ro_0: - def forward(self, in1): - self = self.root - linear1 = self.linear1(in1) - return linear1 - - main_1: - def forward(self, in2, linear1): - self = self.root - linear2 = self.linear2(in2) - cat_1 = torch.cat([linear1, linear2]) - linear3 = self.linear3(cat_1) - return linear3 - - main_0: - def forward(self, in1, in2): - self = self.root - ro_0 = self.ro_0(in1) - main_1 = self.main_1(in2, ro_0) - return main_1 - """ - - def flatten(x: pippy.fx.node.Argument) -> NodeList: - """ - Stores nodes in x to a list and returns the list. - """ - r: NodeList = [] - map_arg(x, r.append) - return r - - # Mapping from node in original module to node in created submodule. - node_remapping: Dict[pippy.fx.Node, pippy.fx.Node] = {} - - # Mapping from node in original module or created submodules to - # corresponding component. - node_to_component: Dict[pippy.fx.Node, Component] = {} - - # Mapping from tag to the corresponding component. - tag_to_component: Dict[str, Component] = {} - - # Stores all components. - all_components: List[Component] = [] - - # Stores nodes that will be used in main graph. - used_in_main: Dict[pippy.fx.Node, None] = {} - - # Main graph after split. - main_g = pippy.fx.Graph() - - # Mapping from node in original module to node in main graph after split. - main_remapping: Dict[pippy.fx.Node, pippy.fx.Node] = {} - - # Output node of original module. - output_node: Optional[pippy.fx.Node] = None - - # Create a component for each tag, we don't expect to create other components afterwards. - for tag in tags: - comp = Component(pippy.fx.Graph(), len(all_components), f"{tag}") - all_components.append(comp) - tag_to_component[tag] = comp - - # Traverse the nodes in original graph and take care of them. - for node in gm.graph.nodes: - if node.op == "output": - if output_node is not None: - raise RuntimeError("Multiple output nodes in graph!") - output_node = node - continue - - # Placeholders in the original graph get copied to main graph. - if node.op == "placeholder": - main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) - continue - - # Get_attr nodes are ignored because we are not tagging them. - # Instead, we copy them directly to the submodules use them afterwards. - if node.op == "get_attr": - continue - - # Now we process callable nodes which are nodes with op of call_module, - # call_function or call_method. Every callable nodes should be tagged. - assert hasattr(node, "tag") - - upstream_components = [ - node_to_component[x] - for x in flatten(node.args) + flatten(node.kwargs) - if x.op not in {"placeholder", "get_attr"} - ] - - comp = tag_to_component[node.tag] - node_to_component[node] = comp - - # Max order of upperstream components. - mx = max((c.order for c in upstream_components), default=0) - - # Expect the componet for `node` has higher order then its upstream components. - assert comp.order >= mx - - # Map a input of `node` to nodes in the component's graph. - def remap_func(x): - # If input is a get_attr node, copy it to current component's graph. - # Returns the get_attr node in current component's graph. - if x.op == "get_attr": - if x not in comp.getattr_maps: - comp.getattr_maps[x] = comp.graph.get_attr( - x.target, type_expr=x.type - ) - return comp.getattr_maps[x] - - # If input is not a placeholder, it should have been put into a component - # already. If it's the current component then we return the corresponding - # node in the component. - if x.op != "placeholder" and node_to_component[x] == comp: - return node_remapping[x] - - # If input is a placeholder or it's in other components, we want to make it - # as a placeholder in current component's graph. - if x not in comp.orig_inputs: - comp.orig_inputs.append(x) - comp.input_placeholders.append( - comp.graph.placeholder(x.name, type_expr=x.type) - ) - used_in_main[x] = None - - return comp.input_placeholders[ - next(i for i, y in enumerate(comp.orig_inputs) if x is y) - ] - - n = comp.graph.node_copy(node, remap_func) - n.tag = node.tag # type: ignore[attr-defined] - node_remapping[node] = n - node_to_component[n] = comp - - if output_node is None: - raise RuntimeError("Graph had no output node!") - - for x in flatten(output_node.args[0]): - if x.op == "get_attr": - # We don't need components mapping for nodes of type "get_attr" - # that are consumed by the output. Only need to make sure we create - # corresponding counterparts in the resulting graph. - main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) - else: - # All component results consumed by the output node should be - # marked as "used in main". - used_in_main[x] = None - - # If a node is used in main graph then we mark it as an output in the component - # it belongs to. - for n in used_in_main: - if n.op != "placeholder": - node_to_component[n].orig_outputs.append(n) - - # Now we create a graphmodule for each component. - for comp in all_components: - outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) - - # Take care of the args of FX output node. If there's a single - # output then the output node args is like (output_single), else - # if there're multiple outputs then the output node args is like - # ((output_0, output_1, ...)). - comp.graph.output(outs[0] if len(outs) == 1 else outs) - - comp.gm = lift_subgraph_as_module(gm, comp.graph) - - # Create a call_module node in main graph. - main_node = main_g.call_module( - comp.name, - args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), - kwargs=None, - ) - - if len(outs) == 1: - main_remapping[comp.orig_outputs[0]] = main_node - else: - for i, o in enumerate(comp.orig_outputs): - # Use Proxy to record getitem access. - main_remapping[o] = pippy.fx.Proxy(main_node)[i].node # type: ignore[index] - - main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) - main_root = HolderModule({comp.name: comp.gm for comp in all_components}) - - # If the output nodes consumes get_attr directly in the original graph, - # then we need to make sure get_attr is copied to the new graph. - for x in flatten(output_node.args[0]): - if x.op == "get_attr": - setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] - - return pippy.fx.GraphModule(main_root, main_g) diff --git a/pippy/fx/passes/splitter_base.py b/pippy/fx/passes/splitter_base.py deleted file mode 100644 index 99ed92bd0..000000000 --- a/pippy/fx/passes/splitter_base.py +++ /dev/null @@ -1,854 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import logging -import warnings -from collections import defaultdict -from dataclasses import dataclass -from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple - -import torch - -import pippy -import pippy.fx -from pippy.fx._compatibility import compatibility -from pippy.fx.node import map_arg -from pippy.fx.passes.graph_manipulation import get_size_of_node -from .graph_drawer import FxGraphDrawer -from .operator_support import ( - get_node_target, - OperatorSupportBase, -) -from .shape_prop import ShapeProp -from .split_utils import split_by_tags -from .tools_common import ( - FxNetAccFusionsFinder, - CALLABLE_NODE_OPS, - Tensors, - NodeList, - NodeSet, - is_node_output_tensor, -) - -__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] -_LOGGER = logging.getLogger(__name__) - - -class _SplitterSettingBase: - def __init__(self): - parser = argparse.ArgumentParser() - parser.add_argument( - "--min_acc_module_size", - default=1, - type=int, - help="Minimum size limit of an accelerator subgraph.", - ) - parser.add_argument( - "--skip_fusion", - default=False, - action="store_true", - help="If true then no fusion groups. Fusion group is used to " - "enforce no non-tensor data flow between submodules. If we don't " - "have this constrain, setting this to false is recommended as it " - "can reduce overhead.", - ) - parser.add_argument( - "--allow_non_tensor", - default=False, - action="store_true", - help="For some backends non-tensor data flow between cpu and them " - "are not allowed. Therefore, if a node supported by accelerator but " - "it has non-tensor inputs or outputs to a cpu node we would want to " - "consider it as a cpu node during splitting. However, for some backends " - "we might not care about non-tensor data flow and we can set this option " - "to true to disable the functionality that prevent non-tensor data flow.", - ) - args, unknown = parser.parse_known_args() - - self.min_acc_module_size: int = args.min_acc_module_size - self.skip_fusion: bool = args.skip_fusion - self.allow_non_tensor: bool = args.allow_non_tensor - - -@compatibility(is_backward_compatible=False) -class FxNetAccNodesFinder: - """ - Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor - input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. - - I.e. if we have a chain: - - ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 - - where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. - - This behavior can be turned off by passing allow_non_tensor=True. - """ - - def __init__( - self, - module: pippy.fx.GraphModule, - operator_support: OperatorSupportBase, - allow_non_tensor: bool, - ): - self.module = module - self.operator_support = operator_support - self.allow_non_tensor = allow_non_tensor - - def reduce_acc_nodes_non_tensor_input_helper( - self, cpu_worklist: NodeList - ): - """ - Transitively excludes nodes from ACC supported set. - For every node in the worklist: - - removes its downstream ACC nodes from ACC supported set, - - if any downstream ACC node produces non-tensor output, - then it gets added into the worklist. - """ - while cpu_worklist: - node = cpu_worklist.pop(0) - - for user in node.users: - if user in self.acc_nodes: - self.acc_nodes.remove(user) - if not is_node_output_tensor(user): - cpu_worklist.append(user) - - def reduce_acc_nodes_non_tensor_input(self): - """ - Excludes nodes from ACC supported set that have direct - upstream CPU nodes that produce non-tensor outputs. - """ - non_tensor_cpu_nodes: NodeList = [] - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - if node in self.acc_nodes: - continue - if is_node_output_tensor(node): - continue - non_tensor_cpu_nodes.append(node) - - self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) - - def reduce_acc_nodes_non_tensor_output(self): - """ - Excludes nodes from ACC supported set that produce non-tensor - outputs and have downstream CPU nodes. - """ - while True: - new_cpu_nodes: NodeList = [] - - for acc_node in self.acc_nodes: - if is_node_output_tensor(acc_node): - continue - for user in acc_node.users: - if user not in self.acc_nodes: - new_cpu_nodes.append(acc_node) - break - - if not new_cpu_nodes: - break - - for new_cpu_node in new_cpu_nodes: - self.acc_nodes.remove(new_cpu_node) - - self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) - - def __call__(self) -> NodeSet: - submodules = dict(self.module.named_modules()) - self.acc_nodes = { - n - for n in self.module.graph.nodes - if n.op in CALLABLE_NODE_OPS - and self.operator_support.is_node_supported(submodules, n) - } - - if not self.allow_non_tensor: - self.reduce_acc_nodes_non_tensor_input() - self.reduce_acc_nodes_non_tensor_output() - - return self.acc_nodes - -@compatibility(is_backward_compatible=False) -class FxNetSplitterInternalError(Exception): - pass - -@compatibility(is_backward_compatible=False) -@dataclass -class Subgraph: - is_acc: bool - nodes: NodeList - - -@compatibility(is_backward_compatible=False) -class SplitResult(NamedTuple): - """ - Stores the results of the splitter. - - Attributes: - split_module: root module after splitting. - submodule_inputs: a dict that maps submodule name to its inputs. - non_acc_submodule_prefix: the prefix for non acc submodules. For - acc submodule the prefix is alwasy "_run_on_acc_". - """ - - split_module: pippy.fx.GraphModule - submodule_inputs: Dict[str, Any] - non_acc_submodule_prefix: str - - -@compatibility(is_backward_compatible=False) -def generate_inputs_for_submodules( - model: torch.nn.Module, - inputs: Sequence[Any], - target_submodules: Iterable[str] -) -> Dict[str, Any]: - """ - Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this - function doesn't work. - - Args: - model: root model. - inputs: inputs to the root model. - target_submodules: submodules that we want to generate inputs for. - - Returns: - A dict that maps from submodule name to its inputs. - """ - - handles = [] - results = {} - submodule_to_names = dict((mod, name) for name, mod in model.named_modules()) - - def pre_forward(module, module_inputs): - results[submodule_to_names[module]] = module_inputs - try: - for name, mod in model.named_modules(): - if name in target_submodules: - handles.append(mod.register_forward_pre_hook(pre_forward)) - model(*inputs) - except Exception as e: - warnings.warn(f"Failed to generate submodule inputs because of the following error:\n{e}") - finally: - for h in handles: - h.remove() - return results - - -class _SplitterBase: - """ - Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. - Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. - Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. - - Given the following graph: - ==> b ==> - // \\ - a d - \\ // - ==> c ==> - - class SimpleModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.cos(a) - d = b + c - return d - - and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, - we will get the following split result: - - main: - def forward(self, a): - run_on_acc_0_0 = self._run_on_acc_0_0(a) - getitem = run_on_acc_0_0[0] - getitem_1 = run_on_acc_0_0[1] - run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) - return run_on_cpu_1_1 - - _run_on_acc_0_0: - def forward(self, a): - sin_1 = torch.sin(a) - cos_1 = torch.cos(a) - return (sin_1, cos_1) - - _run_on_cpu_1_1: - def forward(self, sin_1, cos_1): - add_1 = sin_1 + cos_1 - return add_1 - """ - - # PCIe bandwidth for the backend, default to 100 GB/s - PCIe_BW = 100 * 2 ** 30 - - def __init__( - self, - module: pippy.fx.GraphModule, - sample_input: Sequence[Any], - operator_support: OperatorSupportBase, - settings: _SplitterSettingBase, - non_acc_submodule_name: str = "_run_on_cpu_", - ): - """ - Preprocesses graph before splitting: - - finds nodes supported by ACC, - - finds fusion groups for ACC nodes having non-tensor IO, - - builds a graph of direct dependencies, - - builds a map of fused nodes to their fusions. - As a result we get self.acc_nodes, self.deps and self.fusions. - """ - assert isinstance(module, pippy.fx.GraphModule) - - self.module = module - ShapeProp(self.module).propagate(*sample_input) - - self.settings = settings - self.operator_support = operator_support - self.sample_input = sample_input - self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() - - if self.settings.skip_fusion: - self.fusions = {} - else: - self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() - - # Modify deps to add more deps for fused nodes - self.deps = self.find_deps() - self.update_deps_for_fusions() - - self.non_acc_submodule_name = non_acc_submodule_name - self._node_submodule_map: Dict[str, str] = {} - - # =============================================================== - # Helpers for ctor and initial state - # =============================================================== - - def get_node_submodule_map(self) -> Dict[str, str]: - """ Returns a map from node name to submodule name, e.g. - node: main_module_impl_impl_over_arch_unary_multiple_embedding - _pooling_embedding_pooling_sparse_entity_equivalence_key - _proxy_embedding_bag - maps to submodule name of: _run_on_acc_1 - """ - return self._node_submodule_map - - def find_deps(self) -> Dict[pippy.fx.Node, NodeSet]: - """ - Builds a graph of node dependencies. Leaf nodes don't have any - dependencies and the "output" node doesn't have nodes depending on it. - - Resulting graph has only direct dependencies, i.e. there are no - transitive dependencies. - """ - deps: Dict[pippy.fx.Node, NodeSet] = defaultdict(set) - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - for user in node.users: - if user.op != "output": - deps[user].add(node) - return deps - - def update_deps_for_fusions(self): - """ - Updates graph of dependencies so that: - - nodes from the same fusion depend on the same set of outer nodes, - - outer nodes depending on a fusion depend on all nodes in that fusion. - """ - for node in self.fusions: - fusion = self.fusions[node] - for fused_neighbor in fusion: - self.deps[node].update(self.deps[fused_neighbor] - fusion) - - for user in fused_neighbor.users: - if user not in fusion: - self.deps[user].add(node) - - # =============================================================== - # Helpers for preview - # =============================================================== - - def _lower_model_to_backend( - self, mod: pippy.fx.GraphModule, inputs: Tensors - ) -> torch.nn.Module: - """ - Lower the model to a backend. - """ - - return mod - - def _find_culprit( - self, mod: pippy.fx.GraphModule, inputs: Tensors - ) -> str: - """ - When an error occurs during lowering or running the lowered mod, we use this - function to find culprits in the `mod` that causes the error. - """ - - return "Unable to find a culprit because _find_culprit() function is not implemented." - - def _draw_graph_based_on_node_support( - self, mod: pippy.fx.GraphModule, supported_nodes: NodeList - ): - color_map = { - "default": "AliceBlue", - "supported": "chartreuse1", - "unsupported": "crimson", - } - - class CustomDrawer(FxGraphDrawer): - def _get_node_style(self, node): - template = super()._get_node_style(node) - if node in supported_nodes: - template["fillcolor"] = color_map["supported"] - elif node.op in CALLABLE_NODE_OPS: - template["fillcolor"] = color_map["unsupported"] - else: - template["fillcolor"] = color_map["default"] - - return template - - drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) - dot_graph = drawer.get_main_dot_graph() - dot_graph.write_raw("node_support.dot") - - def node_support_preview(self, dump_graph: bool = False): - submodules = dict(self.module.named_modules()) - - supported_nodes: NodeList = [] - supported_node_types = defaultdict(set) - unsupported_node_types = defaultdict(set) - - def get_dtype(arg): - tensor_meta = arg.meta.get("tensor_meta") - return getattr(tensor_meta, "dtype", None) - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - target = get_node_target(submodules, node) - - # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. - arg_dtypes = [ - get_dtype(arg) if isinstance(arg, pippy.fx.Node) else None - for arg in node.args - ] - - # Find last non-None element. If all elements are None, return max_len. - last_index = len(arg_dtypes) - next( - ( - i - for i, dtype in enumerate(reversed(arg_dtypes)) - if dtype is not None - ), - len(arg_dtypes), - ) - - # Strip None elements at the end. - arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) - kwarg_dtypes_tuple = tuple( - (k, get_dtype(arg)) - for k, arg in node.kwargs.items() - if isinstance(arg, pippy.fx.Node) - ) - - if self.operator_support.is_node_supported(submodules, node): - supported_nodes.append(node) - supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) - else: - unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) - - if dump_graph: - self._draw_graph_based_on_node_support(self.module, supported_nodes) - - reports = "\nSupported node types in the model:\n" - for t, dtypes in supported_node_types.items(): - for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: - reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" - - reports += "\nUnsupported node types in the model:\n" - for t, dtypes in unsupported_node_types.items(): - for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: - reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" - - print(reports) - - # Return reports for testing purpose - return reports - - def split_preview(self, dump_graph: bool = False): - reports = "" - subgraphs = self.put_nodes_into_subgraphs() - acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) - cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num - reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" - reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" - - subgraphs = self.remove_small_acc_subgraphs(subgraphs) - acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) - cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num - reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" - reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" - - for i, subgraph in enumerate(subgraphs): - reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: " - reports += f"{len(subgraph.nodes)} node(s)\n" - - self.tag(subgraphs) - split_mod = self.split(remove_tag=True) - split_mod.eval() - - if dump_graph: - drawer = FxGraphDrawer( - split_mod, "preview", ignore_getattr=True - ) - dot_graphs = drawer.get_all_dot_graphs() - for name, dot_graph in dot_graphs.items(): - dot_graph.write_raw(f"{name}.dot") - - max_qps: float = self.PCIe_BW - bottleneck_module = "" - - for node in split_mod.graph.nodes: - if node.op == "call_module" and "acc" in node.target: - reports += f"\nProcessing acc submodule {node.target}\n" - - submod = getattr(split_mod, node.target) - - def get_submod_inputs(main_mod, submod, example_inputs): - sub_inputs = None - - def get_inputs(self, inputs): - nonlocal sub_inputs - sub_inputs = inputs - - handle = submod.register_forward_pre_hook(get_inputs) - main_mod(*example_inputs) - handle.remove() - return sub_inputs - - submod_inputs = get_submod_inputs( - split_mod, submod, self.sample_input - ) - ShapeProp(submod).propagate(*submod_inputs) - - total_input_bytes = 0 - total_output_bytes = 0 - - reports += "Checking inputs...\n" - for n in submod.graph.nodes: - if n.op == "placeholder": - if not is_node_output_tensor(n): - reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" - else: - total_input_bytes += get_size_of_node(submod, n)[0] - if n.op == "output": - output_node = n - - reports += "Checking outputs...\n" - - def get_bytes(node: pippy.fx.Node): - nonlocal total_output_bytes - nonlocal reports - if not is_node_output_tensor(node): - reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" - else: - total_output_bytes += get_size_of_node(submod, node)[0] - - map_arg(output_node.args, get_bytes) - qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) - reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," - reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" - - if qps < max_qps: - max_qps = qps - bottleneck_module = node.target - - try: - lowered_submod = self._lower_model_to_backend(submod, submod_inputs) - except RuntimeError: - reports += "Run into an error during lowering!\n" - reports += self._find_culprit(submod, submod_inputs) - continue - - try: - lowered_submod(*submod_inputs) - except RuntimeError: - reports += "Run into an error during inference!\n" - reports += self._find_culprit(submod, submod_inputs) - else: - reports += "Lowering and running succeed!\n" - - reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," - reports += f" bottleneck is submodule {bottleneck_module}." - print(reports) - - # return the reports for testing purposes - return reports - - # =============================================================== - # Helpers for extend_acc_subgraph() method - # =============================================================== - - def find_reverse_deps( - self, tag_id: Optional[int] = None - ) -> Dict[pippy.fx.Node, NodeSet]: - """ - Builds reversed topological node dependencies, if tag_id is specified, - we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. - """ - result: Dict[pippy.fx.Node, NodeSet] = defaultdict(set) - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - for user in node.users: - if user.op not in CALLABLE_NODE_OPS: - continue - - if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): - result[node].add(user) - - return result - - def update_reverse_deps_for_fusions( - self, deps: Dict[pippy.fx.Node, NodeSet] - ): - processed_node = set() - - for node, fusion in self.fusions.items(): - if node in processed_node: - continue - - new_dep = set() - - # Create a new dependency set which include all the - # dependencies of the nodes in the fusion group - for n in fusion: - new_dep.update(deps[n]) - - # Exclude nodes in the fusion - new_dep.difference_update(fusion) - - # Update dependency - for n in fusion: - deps[n] = new_dep - - for arg in n.all_input_nodes: - if arg not in fusion: - deps[arg].update(fusion) - - processed_node.add(n) - - def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: - """ - Finds parent nodes of the `tag` subgraph. - - Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph - and is not a placeholder, we consider it as the parent node of the subgraph. - """ - parent_nodes = set() - - for node in self.module.graph.nodes: - if node.op in CALLABLE_NODE_OPS and node.tag == tag: - for arg in node.all_input_nodes: - if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: - parent_nodes.add(arg) - - return parent_nodes - - def extend_acc_subgraph(self, tag: str): - """ - Extend the acc subgraph with `tag` going the reversed topological direction. - """ - # Dict that maps node to its users and ignore users that - # are in the subgraph that has greater tag - deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) - self.update_reverse_deps_for_fusions(deps) - - # Parent nodes of the subgraph - parent_nodes = self.find_parent_nodes_of_subgraph(tag) - - visited_nodes: NodeSet = set() - - while parent_nodes: - node = None - - # Find a acc node that depends on visited nodes only - for n in parent_nodes: - if deps[n] <= visited_nodes and n in self.acc_nodes: - node = n - break - - if node is None: - break - - # Put the node into `tag` subgraph - node.tag = tag # type: ignore[attr-defined] - parent_nodes.remove(node) - visited_nodes.add(node) - - # If node is in a fusion group, add all fusion buddies to parent nodes - if node in self.fusions: - for fusion_node in self.fusions[node]: - if fusion_node not in visited_nodes: - parent_nodes.add(fusion_node) - - # Add inputs of the node to parent nodes - for arg in node.all_input_nodes: - if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: - parent_nodes.add(arg) - - # =============================================================== - # Helpers for split() method - # =============================================================== - - def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: - """ - Finds nodes that consume module inputs or get_attr nodes. - """ - starter_cpu_nodes: NodeSet = set() - starter_acc_nodes: NodeSet = set() - for node in self.module.graph.nodes: - if node.op not in {"placeholder", "get_attr"}: - continue - for user in node.users: - if user in self.acc_nodes: - starter_acc_nodes.add(user) - else: - starter_cpu_nodes.add(user) - return starter_cpu_nodes, starter_acc_nodes - - def put_nodes_into_subgraphs(self) -> List[Subgraph]: - # We start graph traversal from leaf nodes - current_cpu_nodes, current_acc_nodes = self.starter_nodes() - visited_nodes: NodeSet = set() - - # Determine which subgraph to start from based on which subgraph has - # 0-dep node - acc_subgraph: bool = not any([len(self.deps[n]) == 0 for n in current_cpu_nodes]) - - current_subgraph_nodes: NodeList = [] - - # Result accumulator - subgraphs: List[Subgraph] = [] - while current_cpu_nodes or current_acc_nodes: - # Find the first node that should belong to the current subgraph and has all dependencies resolved - current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes - node = next( - (n for n in current_nodes if self.deps[n] <= visited_nodes), - None, - ) - - # If nothing was found, then it's time to flip the mode and start a new subgraph - if node is None: - if not current_subgraph_nodes: - raise FxNetSplitterInternalError("Subgraph can't be empty") - - subgraphs.append( - Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) - ) - acc_subgraph = not acc_subgraph - current_subgraph_nodes = [] - continue - - current_nodes.remove(node) - visited_nodes.add(node) - current_subgraph_nodes.append(node) - - # Add fusion buddies - if node in self.fusions: - if node in self.acc_nodes: - current_acc_nodes.update(self.fusions[node] - visited_nodes) - else: - current_cpu_nodes.update(self.fusions[node] - visited_nodes) - - # Put depending nodes into the queue - for user in node.users: - if user.op not in CALLABLE_NODE_OPS: - continue - - # Add downstream nodes - if user in self.acc_nodes: - current_acc_nodes.add(user) - else: - current_cpu_nodes.add(user) - - # Check if the last subgraph was not created - if current_subgraph_nodes: - subgraphs.append( - Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) - ) - - if not subgraphs: - raise FxNetSplitterInternalError("Couldn't create subgraphs") - - return subgraphs - - def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: - """ - This pass finds ACC submodules with less than specified size and merges - them with adjacent CPU submodules. - """ - result: List[Subgraph] = [] - for subgraph in subgraphs: - if subgraph.is_acc: - if len(subgraph.nodes) >= self.settings.min_acc_module_size: - result.append(subgraph) - else: - print( - "Eliminating acc subgraph because it's smaller than the threshold: " - f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" - ) - if result: - result[-1].nodes.extend(subgraph.nodes) - else: - subgraph.is_acc = False - result.append(subgraph) - else: - if result and not result[-1].is_acc: - result[-1].nodes.extend(subgraph.nodes) - else: - result.append(subgraph) - return result - - def tag(self, subgraphs: List[Subgraph]): - self.tags: List[str] = [] - for subgraph in subgraphs: - tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}" - self.tags.append(tag) - for node in subgraph.nodes: - if hasattr(node, "tag"): - raise FxNetSplitterInternalError(f"Node {node} was already tagged") - - node.tag = tag # type: ignore[attr-defined] - self._node_submodule_map[node.name] = tag - - def split(self, remove_tag: bool = False) -> pippy.fx.GraphModule: - split_module = split_by_tags(self.module, self.tags) - if remove_tag: - for node in self.module.graph.nodes: - if hasattr(node, "tag"): - del node.tag - return split_module - - def __call__(self) -> pippy.fx.GraphModule: - subgraphs = self.put_nodes_into_subgraphs() - subgraphs = self.remove_small_acc_subgraphs(subgraphs) - acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) - non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count - print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") - self.tag(subgraphs) - return self.split() - - def generate_split_results(self) -> SplitResult: - split_module = self() - submodule_names = [] - for name, mod in split_module.named_children(): - submodule_names.append(name) - submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) - return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) diff --git a/pippy/fx/passes/tests/__init__.py b/pippy/fx/passes/tests/__init__.py deleted file mode 100644 index f2661b8c6..000000000 --- a/pippy/fx/passes/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates diff --git a/pippy/fx/passes/tests/test_pass_manager.py b/pippy/fx/passes/tests/test_pass_manager.py deleted file mode 100644 index 34b325355..000000000 --- a/pippy/fx/passes/tests/test_pass_manager.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import unittest - -from ..pass_manager import ( - inplace_wrapper, - PassManager, - these_before_those_pass_constraint, - this_before_that_pass_constraint, -) - - -class TestPassManager(unittest.TestCase): - def test_pass_manager_builder(self) -> None: - passes = [lambda x: 2 * x for _ in range(10)] - pm = PassManager(passes) - pm.validate() - - def test_this_before_that_pass_constraint(self) -> None: - passes = [lambda x: 2 * x for _ in range(10)] - pm = PassManager(passes) - - # add unfulfillable constraint - pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) - - self.assertRaises(RuntimeError, pm.validate) - - def test_these_before_those_pass_constraint(self) -> None: - passes = [lambda x: 2 * x for _ in range(10)] - constraint = these_before_those_pass_constraint(passes[-1], passes[0]) - pm = PassManager( - [inplace_wrapper(p) for p in passes] - ) - - # add unfulfillable constraint - pm.add_constraint(constraint) - - self.assertRaises(RuntimeError, pm.validate) diff --git a/pippy/fx/passes/tools_common.py b/pippy/fx/passes/tools_common.py deleted file mode 100644 index 50a242c88..000000000 --- a/pippy/fx/passes/tools_common.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import List, Tuple, Union, Dict, Any, Set, Mapping -import collections -from dataclasses import dataclass - -import torch -import pippy.fx -from pippy.fx.node import _get_qualified_name -from pippy.fx._compatibility import compatibility - -__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] - -Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] -TensorOrTensors = Union[torch.Tensor, Tensors] -NodeList = List[pippy.fx.Node] -NodeSet = Set[pippy.fx.Node] -Names = List[str] -CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} - - -@compatibility(is_backward_compatible=False) -def get_acc_ops_name(k): - if isinstance(k, str): - return k - elif k.__module__ and "acc_ops" in k.__module__: - return f"acc_ops.{k.__name__}" - else: - module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module - return f"{module if module else ''}.{k.__name__}" - - -@compatibility(is_backward_compatible=False) -def get_node_target(submodules: Mapping[str, torch.nn.Module], node: pippy.fx.Node) -> str: - """ - Given a `node` returns its target typename. - - For "call_method" node, return node.target which is the name of that method being called. - This could potential lead to conflict but should be okay because normally it's on a tensor. - - For "call_function" node, return typename of node.target. - - For "call_module" node, return typename of the module that node.target point to. - - If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by - "torch". e.g. _VariableFunctionsClass.relu would become torch.relu. - """ - - assert node.op in CALLABLE_NODE_OPS, ( - "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" - ) - - if node.op == "call_module": - assert isinstance(node.target, str) - submod = submodules[node.target] - submod_type = getattr(submod, "_base_class_origin", type(submod)) - return get_acc_ops_name(submod_type) - elif node.op == "call_function": - target: Any = node.target - return ( - f"acc_ops.{target.__name__}" - if target.__module__ is not None and "acc_ops" in target.__module__ - else _get_qualified_name(target) - ) - else: - assert isinstance(node.target, str) - return node.target - -@compatibility(is_backward_compatible=False) -def is_node_output_tensor(node: pippy.fx.Node) -> bool: - """Checks if the node output produces a Tensor or not. - - NOTE: This requires to run `ShapeProp` on the containing fx graph before - calling this function. This is because it works by checking the `type` - metadata on the node. This metadata is produced by the `ShapeProp`. - """ - type_ = node.meta.get("type", None) - return type_ is not None and issubclass(type_, torch.Tensor) - -@compatibility(is_backward_compatible=False) -class FxNetAccFusionsFinder: - """ - Finds groups of connected ACC nodes that pass non-tensor data between each other. - Such groups are called fusion groups. - """ - - def __init__(self, module: pippy.fx.GraphModule, acc_nodes: NodeSet): - self.module = module - self.nodes = list(module.graph.nodes) - self.acc_nodes = acc_nodes - - @dataclass - class FusionGroup: - # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. - top_node_idx: int - - # Nodes in this fusion group. - nodes: NodeSet - - # Inputs to this fusion group. - inputs: NodeSet - - # Nodes that in the fusion group that haven't been processed yet. - nodes_need_process: NodeSet - - def add_node(self, node): - """ - Add a node to fusion group. - """ - if node in self.nodes: - return - - self.nodes_need_process.add(node) - self.nodes.add(node) - self.inputs.discard(node) - self.inputs.update( - { - n - for n in node.all_input_nodes - if n.op in CALLABLE_NODE_OPS and n not in self.nodes - } - ) - - def recursive_add_node( - self, - fusion_group: "FxNetAccFusionsFinder.FusionGroup", - inputs: Union[NodeSet, NodeList], - ): - """ - Start from inputs and going reverse topological order. If any upstream node - is in the fusion group, add all the nodes in this path to fusion group. - """ - for arg in inputs: - # Skip placeholder and get_attr because they won't be in the fusion group. - if arg.op not in CALLABLE_NODE_OPS: - continue - - # If the node has smaller idx, it's already an upstream node of the fusion - # group. We don't need to check it anymore. - if self.nodes.index(arg) < fusion_group.top_node_idx: - continue - - # If the node is in the fusion group, return True. - if arg in fusion_group.nodes: - return True - - # Check the upstream nodes of the node, if any of them is in the fusion group - # we'll add this node to fusion group and return True. - if self.recursive_add_node(fusion_group, arg.all_input_nodes): - fusion_group.add_node(arg) - return True - - return False - - def __call__(self) -> Dict[pippy.fx.Node, NodeSet]: - result: Dict[pippy.fx.Node, NodeSet] = {} - acc_nodes = list(self.acc_nodes) - - for node in acc_nodes: - if node in result: - continue - if node.op not in CALLABLE_NODE_OPS: - continue - if "tensor_meta" in node.meta: - continue - if node not in self.acc_nodes: - continue - - fusion_group: "FxNetAccFusionsFinder.FusionGroup" = self.FusionGroup( - top_node_idx=self.nodes.index(node), - nodes={node}, - inputs=set(node.all_input_nodes), - nodes_need_process={node}, - ) - while fusion_group.nodes_need_process: - node = fusion_group.nodes_need_process.pop() - self.recursive_add_node(fusion_group, fusion_group.inputs) - - # Optionally add downstream nodes - if "tensor_meta" not in node.meta: - for user in node.users: - if user.op not in CALLABLE_NODE_OPS: - continue - if user in fusion_group.nodes: - continue - - fusion_group.add_node(user) - self.recursive_add_node(fusion_group, fusion_group.inputs) - - # Add some upstream nodes - for arg in node.all_input_nodes: - if arg.op not in CALLABLE_NODE_OPS: - continue - if "tensor_meta" in arg.meta: - continue - if arg in fusion_group.nodes: - continue - - fusion_group.add_node(arg) - fusion_group.top_node_idx = min( - fusion_group.top_node_idx, self.nodes.index(arg) - ) - self.recursive_add_node(fusion_group, fusion_group.inputs) - - if not (set(fusion_group.nodes) <= self.acc_nodes): - self.acc_nodes -= fusion_group.nodes - else: - for n in fusion_group.nodes: - result[n] = fusion_group.nodes - - return result - - -@compatibility(is_backward_compatible=False) -def legalize_graph(gm: pippy.fx.GraphModule) -> pippy.fx.GraphModule: - """ - Replace the graph of the given GraphModule with one that contains the same nodes as the - original, but in topologically sorted order. - - This is used by the merge_matmul transformation below, which disturbs the topologically sorted - order of its input GraphModule, so that this order is restored before further transformation. - - Arguments: - gm: The graph module to topologically sort. It is modified in-place. - - Returns: - The graph module in-place sorted - """ - indeg = {node: 0 for node in gm.graph.nodes} - new_graph = pippy.fx.Graph() - # Track how many unfulfilled dependencies each node has - for node in gm.graph.nodes: - for user in node.users: - indeg[user] += 1 - queue: collections.deque = collections.deque() - # Add all nodes with no dependencies to the queue - for node in gm.graph.nodes: - if indeg[node] == 0: - queue.append(node) - env: Dict[pippy.fx.Node, pippy.fx.Node] = {} - # Pop nodes from the queue, and add nodes that have had all their - # dependencies fulfilled - while len(queue) > 0: - cur = queue.popleft() - env[cur] = new_graph.node_copy(cur, lambda x: env[x]) - for user in cur.users: - indeg[user] -= 1 - if indeg[user] == 0: - queue.append(user) - # If the new graph's size is not as large as the old one, then there must be - # a cycle (i.e. some node's dependencies were not satisfied.) - if len(new_graph.nodes) < len(gm.graph.nodes): - raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}") - gm.graph = new_graph - return gm diff --git a/pippy/fx/passes/utils/__init__.py b/pippy/fx/passes/utils/__init__.py deleted file mode 100644 index e4af20899..000000000 --- a/pippy/fx/passes/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .common import lift_subgraph_as_module, HolderModule diff --git a/pippy/fx/passes/utils/common.py b/pippy/fx/passes/utils/common.py deleted file mode 100644 index 972848347..000000000 --- a/pippy/fx/passes/utils/common.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from torch.nn import Module - -from pippy.fx.graph_module import GraphModule -from pippy.fx.graph import Graph -from pippy.fx.passes.utils.matcher_utils import SubgraphMatcher -from pippy.fx._compatibility import compatibility - - -__all__ = ['HolderModule', 'lift_subgraph_as_module', 'compare_graphs'] - -@compatibility(is_backward_compatible=False) -class HolderModule(Module): - """ - HolderModule is used to copy all the attributes from original module to submodules - that uses the attributes - """ - - def __init__(self, d): - super().__init__() - for k, v in d.items(): - self.add_module(k, v) - - -@compatibility(is_backward_compatible=False) -def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str = 'GraphModule') -> GraphModule: - """ - Create a GraphModule for subgraph, which copies the necessory attributes from the original parent graph_module. - - Args: - gm (GraphModule): parent graph module - - subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph - - class_name (str): name for the submodule - - """ - - # Loop through all module calls (call_module) and param fetches (get_attr) - # in this component, creating HolderModules as necessary to match the path. - # e.g. if in the original module there's a get_attr node fetches "conv.weight". - # We create a HolderModule as root -> add a HolderModule named "conv" -> - # make "weight" a attribute of "conv" HolderModule and point to conv.weight in - # the original module. - submodule = HolderModule({}) - for n in subgraph.nodes: - if n.op not in ("call_module", "get_attr"): - continue - - target = n.target - assert isinstance(target, str) - target_name_parts = target.split(".") - curr = submodule - orig_gm = gm - - for name in target_name_parts[:-1]: - if not hasattr(curr, name): - curr.add_module(name, HolderModule({})) - - curr = getattr(curr, name) - orig_gm = getattr(orig_gm, name) - - leaf_node_name = target_name_parts[-1] - leaf_node = getattr(orig_gm, leaf_node_name) - - # Relies on custom __setattr__ magic. - setattr(curr, leaf_node_name, leaf_node) - - return GraphModule(submodule, subgraph, class_name) - - -@compatibility(is_backward_compatible=False) -def compare_graphs(left: Graph, right: Graph) -> bool: - """ - Return True if two graphs are identical, i.e they - - have the same number of outputs in the same order - - have the same number of inputs in the same order - - have the same set of nodes, and identical connectivity - """ - - matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) - matches = matcher.match(right) - - return len(matches) > 0 diff --git a/pippy/fx/passes/utils/fuser_utils.py b/pippy/fx/passes/utils/fuser_utils.py deleted file mode 100644 index 270739078..000000000 --- a/pippy/fx/passes/utils/fuser_utils.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import copy -from queue import SimpleQueue -from typing import List, Dict, Tuple - -import pippy.fx -from pippy.fx.graph_module import GraphModule -from pippy.fx.graph import Graph -from pippy.fx.node import Node -from pippy.fx.passes.tools_common import NodeList, NodeSet, legalize_graph -from pippy.fx.passes.utils import lift_subgraph_as_module - -def topo_sort(nodes: NodeList) -> NodeList: - # sort nodes according to the topological order - indegree_map = {node : 0 for node in nodes} - candidates: SimpleQueue = SimpleQueue() - - for node in nodes: - for n in node.all_input_nodes: - if n in indegree_map: - indegree_map[node] += 1 - if indegree_map[node] == 0: - candidates.put(node) - - sorted_nodes: NodeList = list() - while not candidates.empty(): - node = candidates.get() - sorted_nodes.append(node) - - for n in node.users: - if n in indegree_map: - indegree_map[n] -= 1 - if indegree_map[n] == 0: - candidates.put(n) - - assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes" - - return sorted_nodes - - -def validate_partition(partition: NodeList) -> bool: - # verify the partition does't form a dependency cycle in the original graph - # returns True for valid partition, False for invalid - - partition_set = set(partition) - - outputs: NodeList = list() - for node in partition_set: - for user_node in node.users: - if user_node not in partition_set: - # external user node, need to expose as an output - outputs.append(user_node) - - # perform DFS on the parition outputs - # if it reaches a node within the partition, then it found a cycle - visited: NodeSet = set() - - def dfs_find_cycle(node): - if node in partition_set: - return True # found cycle, return - - visited.add(node) - for user_node in node.users: - if user_node not in visited: - if dfs_find_cycle(user_node): - return True - return False - - for output_node in outputs: - if dfs_find_cycle(output_node): - return False - - return True - - -def fuse_as_graphmodule(gm: GraphModule, - nodes: NodeList, - module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: - - """ - Fuse nodes in graph_module into a GraphModule. - - Args: - gm (GraphModule): target graph_module - - nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted - - module_name: class name for the fused GraphModule - - Returns: - fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm` - - original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm` - - original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm` - - """ - - # assumption: nodes are already sorted in topo order - - for node in nodes: - assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" - assert not node._erased, f"{node} has been removed from owning graph" - assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}" - - # validates partition doesn't introduce dependency circles in the graph - assert validate_partition(nodes), "Invalid partition, found dependency cycles" - - subgraph = Graph() - - node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph - node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph - - # handles inputs throught graph.node_copy's arg_transform functions - def remap_inputs(x): - if x.op == "get_attr": - # TODO: do we really need copy the get_attr node into the graph? - # do something here - pass - - if x in nodes: - # x is inside subgraph, return the copied node - # the node should have been copied aleady, as we are copying graph in the topological order - return node_map[x] - - if x not in node_to_placeholder: - # x is not in subgraph, create a new placeholder for subgraph - placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) - # copy all meta fields, even if some fields might be irrelvant for the placeholder node - placeholder_node.meta = copy.copy(x.meta) - node_to_placeholder[x] = placeholder_node - - return node_to_placeholder[x] - - # copy nodes in topological order - for node in nodes: - new_node = subgraph.node_copy(node, remap_inputs) - node_map[node] = new_node - - # handles outputs - output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs - - for node in nodes: - for user_node in node.users: - if user_node not in nodes: - # external user node, need to expose as an output - output_mapping[node] = node_map[node] - - # outs contain nodes in the new subgraph - outs = tuple(output_mapping.values()) - - # Take care of the args of FX output node. If there's a single - # output then the output node args is like (output_single), else - # if there're multiple outputs then the output node args is like - # ((output_0, output_1, ...)). - subgraph.output(outs[0] if len(outs) == 1 else outs) - - # lint to ensure correctness - subgraph.lint() - - fused_gm: GraphModule = lift_subgraph_as_module(gm, subgraph, class_name=module_name) - - # sub_gm's input nodes in the original module - original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) - - # sub_gm's outputs node in the original module - original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys()) - - return fused_gm, original_inputs, original_outputs - - -def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]): - # add sub_gm into gm - submodule_name = sub_gm.__class__.__name__ - gm.add_submodule(submodule_name, sub_gm) - - # Create a call_module node in main graph. - module_node = gm.graph.call_module( - submodule_name, - args=orig_inputs, - kwargs=None) - - if len(orig_outputs) == 1: - # main_remapping[comp.orig_outputs[0]] = module_node - orig_outputs[0].replace_all_uses_with(module_node) - else: - for i, orig_output in enumerate(orig_outputs): - # Use Proxy to record getitem access. - proxy_out = pippy.fx.Proxy(module_node)[i].node # type: ignore[index] - orig_output.replace_all_uses_with(proxy_out) - return gm - -def erase_nodes(gm: GraphModule, nodes: NodeList): - - # erase original nodes in inversed topological order - for node in reversed(nodes): - gm.graph.erase_node(node) - - -def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule: - for partition_id, nodes in enumerate(partitions): - sorted_nodes = topo_sort(nodes) - - submodule_name = "fused_" + str(partition_id) - sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) - - insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) - - erase_nodes(gm, sorted_nodes) - - # topological sort original gm with newly created sub_gm - legalize_graph(gm) - - return gm diff --git a/pippy/fx/passes/utils/matcher_utils.py b/pippy/fx/passes/utils/matcher_utils.py deleted file mode 100644 index 27bb9240a..000000000 --- a/pippy/fx/passes/utils/matcher_utils.py +++ /dev/null @@ -1,309 +0,0 @@ -from dataclasses import dataclass, field -from collections import defaultdict -import copy -from pippy.fx.graph import Graph -from pippy.fx.node import Node -from pippy.fx._compatibility import compatibility -import torch.utils._pytree as pytree -from typing import Dict, List, Set, Any -import os -import logging - -__all__ = ['SubgraphMatcher', 'InternalMatch'] - -format_str = "%(levelname)s > %(message)s" -LOGLEVEL = os.environ.get('LOGLEVEL', 'WARNING').upper() -logging.basicConfig(level=LOGLEVEL, format=format_str) -logger = logging.getLogger(__name__) - -@compatibility(is_backward_compatible=False) -@dataclass -class InternalMatch(): - # Nodes from which the match was found - anchors: List[Node] - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] = field(default_factory=dict) - - # nodes in target graph that are matched placeholder in pattern - placeholder_nodes: List[Node] = field(default_factory=list) - - # nodes in matched subgraph returned by output - returning_nodes: List[Node] = field(default_factory=list) - - def __copy__(self): - return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(), - placeholder_nodes=self.placeholder_nodes.copy(), - returning_nodes=self.returning_nodes.copy()) - -@compatibility(is_backward_compatible=False) -class SubgraphMatcher: - def __init__(self, pattern: Graph, - match_output: bool = False, - match_placeholder: bool = False, - remove_overlapping_matches: bool = True) -> None: - """ - Args: - pattern: the targeted matching pattern, represented in fx.Graph. - match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern. - If False, output node is ignored during match. - match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of - the targeted pattern. If False, placeholder nodes will be used a wildcard. - remove_overlapping_matches: If True, in the case of overlapping matches, only the first match - will be returned. - """ - - self.pattern = pattern - self.match_output = match_output - self.match_placeholder = match_placeholder - self.remove_overlapping_matches = remove_overlapping_matches - - if len(pattern.nodes) == 0: - raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern") - - for node in pattern.nodes: - if node.op != "output": - assert len(node.users) > 0, \ - "SubgraphMatcher cannot be initialized with an pattern with dead code" - - # TODO: assert pattern is a connected graph - - self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"] - output_node = next(iter(reversed(pattern.nodes))) - # nodes returned by outputs - self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes - - self.pattern_anchors: List[Node] = [] - if match_output: - self.pattern_anchors = [output_node] - else: - # If a node has output_node as the ONLY user, then this node is a graph sink, - # and should be matched against as an anchor - self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1] - - def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: - # if exact match for placeholder is not required, then use placeholder as a wildcard - if not self.match_placeholder and pn.op == "placeholder": - return True - - if pn.op == gn.op: - if pn.op == "placeholder" or pn.op == "output": - return True - return pn.target == gn.target - return False - - def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool: - # `lookup` represents all the nodes in `original_graph` - # that are part of `pattern` - lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items()} - for gn, pn in lookup.items(): - # Placeholders can be used by other nodes in the graphs - if pn.op == "placeholder": - continue - - # nodes returned by output are allowed to be used in other areas of the graph - if pn in self.pattern_returning_nodes: - continue - - for user in gn.users: - # If this node has users that were not in `lookup`, then it must leak out of the - # pattern subgraph - if user not in lookup: - return False - return True - - def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]: - non_overlapping_matches: List[InternalMatch] = list() - nodes_matched: Set[Node] = set() - - for match in matches: - found_overlap = False - for pn, gn in match.nodes_map.items(): - if pn.op not in {"placeholder", "output"} and gn in nodes_matched: - found_overlap = True - break - - if not found_overlap: - non_overlapping_matches.append(match) - for pn, gn in match.nodes_map.items(): - if pn.op not in {"placeholder", "output"}: - nodes_matched.add(gn) - return non_overlapping_matches - - def _match_args(self, pn: Any, gn: Any, match: InternalMatch) -> bool: - assert not(isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node" - - if isinstance(pn, Node) and not isinstance(gn, Node): - if pn.op == "placeholder": - # Check if we've already matched these nodes in the current - # traversal - if pn in match.nodes_map: - return match.nodes_map[pn] == gn - - match.nodes_map[pn] = gn - return True - else: - return False - elif not isinstance(pn, Node) and isinstance(gn, Node): - return False - else: - return type(gn) == type(pn) and gn == pn - - def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: - logger.info(f" matching {pn} to {gn}") - - assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}") - - # Check if we've already matched these nodes in the current - # traversal - if pn in match.nodes_map: - return match.nodes_map[pn] == gn - - # TODO: use a more efficienty way to check if gn is matched before: two-way dict - if gn in match.nodes_map.values(): - return False - - if not self._nodes_are_equal(pn, gn): - return False - - # Optimistically mark `pn` as a match for `gn`, and save a local copy of match - saved_match = copy.copy(match) - match.nodes_map[pn] = gn - - if pn.op == "placeholder": - return True - - # Recursively traverse upwards to check if `pn` is a true - # match for `gn` - match_found = True - - pn_flatten_args, _ = pytree.tree_flatten(pn.args) - gn_flatten_args, _ = pytree.tree_flatten(gn.args) - - if pn.kwargs.keys() == gn.kwargs.keys(): - for key in pn.kwargs.keys(): - pn_flatten_args.append(pn.kwargs[key]) - gn_flatten_args.append(gn.kwargs[key]) - else: - match_found = False - - if match_found and len(pn_flatten_args) == len(gn_flatten_args): - for pn_, gn_ in zip(pn_flatten_args, gn_flatten_args): - if isinstance(gn_, Node) and isinstance(pn_, Node): - matched = self._match_nodes(pn_, gn_, match) - else: - matched = self._match_args(pn_, gn_, match) - - if not matched: - match_found = False - break - else: - match_found = False - - if not match_found: - # revert to saved_match before matching with current node - match = copy.copy(saved_match) - return False - - return True - - def match(self, graph: Graph) -> List[InternalMatch]: - """ - Returns: - The matched subgraphs. - Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder - and nodes returned by output) can only be consumed by nodes within the matched subgraph. - - Subgraph pattern matcher is implemented with the backtracking style in the following steps: - - 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes - are the "sinks" (nodes with no user other than the output node) of the pattern graph. - One pattern graph could have multiple anchors if it has multiple return values. - - 2. In the target graph, we identify the potential candidate nodes that can be matched - with each anchor. These anchor-candidate pairs are the starting points for - pairwise per-node matching. - - 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both - pattern and target graphs. For every pattern nodes along traversal path, we compare it - against the target nodes. In case any comparison failed, the match for this anchor-candidate - pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes` - for more details. - - 4. In the case of multiple anchors, every anchor will need to find a match using step 3. - In addition, the matches found between anchors need to have a common intersection node - in order for the match to be valid. This is implemented with backtracking. See `backtracking` - for more details. - - Notice: graph traversal must be done in the reverser order because a tensor can have multiple - consumers, but can only have a single producer. Only with reverser order, we can we jointly - traverse the pattern and target graph in a deterministic path. - - Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However, - in practice, it's unlikely to blow up. - - """ - from pippy.fx.passes.utils.fuser_utils import validate_partition - - # find candidate nodes to match with pattern anchors - match_candidates: Dict[Node, List[Node]] = defaultdict(list) - for pattern_anchor in self.pattern_anchors: - for node in graph.nodes: - if self._nodes_are_equal(pattern_anchor, node): - match_candidates[pattern_anchor].append(node) - match_candidates_list = list(match_candidates.items()) - matches: List[InternalMatch] = [] - - def backtracking(anchor_index, match): - if anchor_index == len(match_candidates_list): - match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes] - match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes] - matches.append(match) - - logger.info(f"Found a match: {match}\n") - return - - pattern_anchor, candidate_nodes = match_candidates_list[anchor_index] - saved_match = copy.copy(match) - - for node in candidate_nodes: - logger.info(f"Trying to match anchor {pattern_anchor} to {node}") - - match_found = self._match_nodes(pattern_anchor, node, match) - if match_found: - # match next anchor - backtracking(anchor_index + 1, match) - else: - logger.info(f"Failed to match anchor {pattern_anchor} to {node}\n") - - # revert to saved_match before matching with current anchor - match = copy.copy(saved_match) - - match = InternalMatch(anchors=self.pattern_anchors) - backtracking(0, match) - - # filter out the matches where the subgraph is not fully_contained - before = len(matches) - matches = [match for match in matches if self._is_contained(match.nodes_map)] - after = len(matches) - if before != after: - logger.info(f"Filtered out {before - after} matches because they are not fully contained") - - # filter out the matches that that forms a cycle if the subgraph is fused - valid_matches = [] - for match in matches: - matched_compute_nodes = \ - [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}] - if validate_partition(matched_compute_nodes): - valid_matches.append(match) - if len(valid_matches) != len(matches): - logger.info(f"Filtered out {len(matches) - len(valid_matches)} matches because \ - matched subgraph would form a cycle if fused") - - if self.remove_overlapping_matches: - before = len(valid_matches) - matches = self._remove_overlapping_matches(valid_matches) - after = len(matches) - if before != after: - logger.info(f"Filtered out {before - after} matches because matched subgraphs are overlapping") - - return matches diff --git a/pippy/fx/proxy.py b/pippy/fx/proxy.py deleted file mode 100644 index 73ec01089..000000000 --- a/pippy/fx/proxy.py +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import dis -import inspect -import operator -import traceback -from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable - -import torch - -import pippy.fx.traceback as fx_traceback -from ._compatibility import compatibility -from .graph import magic_methods, reflectable_magic_methods, Graph -from .node import Target, Node, Argument, base_types, map_aggregate -from .operator_schemas import check_for_mutable_operation - -__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', 'Proxy', 'Attribute', 'ParameterProxy'] - -@compatibility(is_backward_compatible=True) -class TracerBase: - graph: Graph - record_stack_traces : bool = False - # Feature flag for mutable schema checking - # Enableby default in 1.12 - check_mutable_operations : bool = False - # Feature flag for assert tracing - trace_asserts : bool = False - # Feature flag for proxying accesses to buffer values - proxy_buffer_attributes : bool = False - - # Name of the function to be traced. It will only be used when - # ``root`` is an instance of ``nn.Module`` - traced_func_name: str = "forward" - - @compatibility(is_backward_compatible=True) - def create_node(self, kind : str, target : Target, - args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: - """ - Inserts a graph node given target, args, kwargs, and name. - - This method can be overridden to do extra checking, validation, or - modification of values used in node creation. For example, one might - want to disallow in-place operations from being recorded. - """ - if kind == 'call_function' and self.check_mutable_operations: - check_for_mutable_operation(target, args, kwargs) - - return self.graph.create_node(kind, target, args, kwargs, name, type_expr) - - @compatibility(is_backward_compatible=True) - def proxy(self, node: Node) -> 'Proxy': - return Proxy(node, self) - - @compatibility(is_backward_compatible=True) - def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], - name: Optional[str] = None, type_expr : Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - ''' - Create a Node from the given arguments, then return the Node - wrapped in a Proxy object. - - If kind = 'placeholder', then we're creating a Node that - represents the parameter of a function. If we need to encode - a default parameter, we use the ``args`` tuple. ``args`` is - otherwise empty for ``placeholder`` Nodes. - ''' - - args_ = self.create_arg(args) - kwargs_ = self.create_arg(kwargs) - assert isinstance(args_, tuple) - assert isinstance(kwargs_, dict) - - node = self.create_node(kind, target, args_, kwargs_, name, type_expr) - - if not proxy_factory_fn: - proxy = self.proxy(node) - else: - proxy = proxy_factory_fn(node) - - # Optionally set stack trace on the created Node for debugging purposes - if fx_traceback.is_stack_trace_overridden(): - stacks = fx_traceback.format_stack() - proxy.node.stack_trace = '\n'.join(reversed(stacks)) - elif self.record_stack_traces: - user_frame = self._find_user_frame() - if user_frame: - walk_stack_gen = traceback.walk_stack(user_frame) - summary = traceback.StackSummary.extract(walk_stack_gen) # type: ignore[arg-type] - tb_lines = summary.format() - proxy.node.stack_trace = ''.join(tb_lines) - - return proxy - - def _find_user_frame(self): - """ - Find the Python stack frame executing the user code during - symbolic tracing. - """ - # We have to do a little dance here. Basically, walk up the callstack and - # record the first frame not in the pytorch source. This is the frame executing - # the user code during tracing. - frame = inspect.currentframe() - - pt_files = ['torch/fx/proxy.py', - 'torch/fx/_symbolic_trace.py', - 'torch/fx/experimental/proxy_tensor.py', - 'torch/_ops.py', - 'torch/_tensor.py', - 'torch/utils/_python_dispatch.py', - 'torch/_prims_common/wrappers.py', - 'torch/_refs/__init__.py', - 'torch/_refs/nn/functional/__init__.py' - ] - while frame: - frame = frame.f_back - if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): - break - - if not frame: - return None - - return frame - - @compatibility(is_backward_compatible=True) - def create_arg(self, a: Any) -> Argument: - """ - A method that lowers the objects seen as arguments during symbolic evaluation - into Argument types that can be stored in IR. - - Can be override to support more trace-specific types. - """ - if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): - return a.__fx_create_arg__(self) - # aggregates - elif isinstance(a, tuple) and hasattr(a, '_fields'): - # NamedTuple constructors don't seem to like getting a generator - # expression as an argument to their constructor, so build this - # intermediate tuple and unpack it into the NamedTuple constructor - args = tuple(self.create_arg(elem) for elem in a) - return type(a)(*args) # type: ignore[arg-type] - elif isinstance(a, (tuple, list)): - return type(a)(self.create_arg(elem) for elem in a) - elif isinstance(a, dict): - r = {} - for k, v in a.items(): - # Check for invalid dict keys. We do not want a Proxy to appear - # anywhere within the key. Since keys can be collection types, - # we iterate through the key with map_aggregate - k = self.create_arg(k) - - def no_node(arg): - if isinstance(arg, Node): - raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " - "Node. Got key: {k}") - map_aggregate(k, no_node) - - r[k] = self.create_arg(v) - return r - elif isinstance(a, slice): - return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) - - if isinstance(a, Proxy): - # base case: we unwrap the Proxy object - return a.node - elif isinstance(a, base_types) or a is None or a is ...: - return a - raise NotImplementedError(f"argument of type: {type(a)}") - - @compatibility(is_backward_compatible=True) - def to_bool(self, obj: 'Proxy') -> bool: - """Called when a proxy object is being converted to a boolean, such as - when used in control flow. Normally we don't know what to do because - we don't know the value of the proxy, but a custom tracer can attach more - information to the graph node using create_node and can choose to return a value. - """ - raise TraceError('symbolically traced variables cannot be used as inputs to control flow') - - @compatibility(is_backward_compatible=True) - def iter(self, obj: 'Proxy') -> Iterator: - """Called when a proxy object is being iterated over, such as - when used in control flow. Normally we don't know what to do because - we don't know the value of the proxy, but a custom tracer can attach more - information to the graph node using create_node and can choose to return an iterator. - """ - raise TraceError('Proxy object cannot be iterated. This can be ' - 'attempted when the Proxy is used in a loop or' - ' as a *args or **kwargs function argument. ' - 'See the pippy.fx docs on pytorch.org for a ' - 'more detailed explanation of what types of ' - 'control flow can be traced, and check out the' - ' Proxy docstring for help troubleshooting ' - 'Proxy iteration errors') - - @compatibility(is_backward_compatible=True) - def keys(self, obj: 'Proxy') -> Any: - """Called when a proxy object is has the keys() method called. - This is what happens when ** is called on a proxy. This should return an - iterator it ** is suppose to work in your custom tracer. - """ - return Attribute(obj, 'keys')() - - -# used in Proxy object when just appending to the graph while not tracing. -@compatibility(is_backward_compatible=True) -class GraphAppendingTracer(TracerBase): - def __init__(self, graph: Graph): - super().__init__() - self.graph = graph - -@compatibility(is_backward_compatible=False) -def assert_fn(x): - assert x - -@compatibility(is_backward_compatible=True) -class TraceError(ValueError): - pass - -@compatibility(is_backward_compatible=True) -class Proxy: - """ - ``Proxy`` objects are ``Node`` wrappers that flow through the - program during symbolic tracing and record all the operations - (``torch`` function calls, method calls, operators) that they touch - into the growing FX Graph. - - If you're doing graph transforms, you can wrap your own ``Proxy`` - method around a raw ``Node`` so that you can use the overloaded - operators to add additional things to a ``Graph``. - - ``Proxy`` objects cannot be iterated. In other words, the symbolic - tracer will throw an error if a ``Proxy`` is used in a loop or as - an ``*args``/``**kwargs`` function argument. - - There are two main ways around this: - 1. Factor out the untraceable logic into a top-level function and - use ``fx.wrap`` on it. - 2. If the control flow is static (i.e. the loop trip count is - based on some hyperparameter), the code can be kept in its original - position and refactored into something like:: - - for i in range(self.some_hyperparameter): - indexed_item = proxied_value[i] - - For a more detailed description into the Proxy internals, check out - the "Proxy" section in `torch/fx/OVERVIEW.md` - """ - - @compatibility(is_backward_compatible=True) - def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): - if tracer is None: - # This allows you to create a Proxy object around a raw Node - tracer = GraphAppendingTracer(node.graph) - self.tracer = tracer - self.node = node - - def __repr__(self) -> str: - return f'Proxy({self.node.name})' - - def __getattr__(self, k) -> 'Attribute': - # note: not added to the graph yet, if this is a method call - # we peephole optimize to the method invocation - return Attribute(self, k) - - def __call__(self, *args, **kwargs) -> 'Proxy': - return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) - - def __iter__(self) -> Iterable['Proxy']: - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - inst = list(dis.get_instructions(calling_frame.f_code))[calling_frame.f_lasti // 2] - if inst.opname == 'UNPACK_SEQUENCE': - return (self[i] for i in range(inst.argval)) # type: ignore[index] - - return self.tracer.iter(self) - - def __bool__(self) -> bool: - if self.tracer.trace_asserts: - # check if this boolean is used in an assertion, bytecode pattern for assertions - # is pretty stable for Python 3.7--3.9 - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - insts = list(dis.get_instructions(calling_frame.f_code)) - cur = calling_frame.f_lasti // 2 - inst = insts[cur] - - if inst.opname == 'POP_JUMP_IF_TRUE': - first = insts[cur + 1] - assert inst.arg is not None - last = insts[inst.arg // 2 - 1] - starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' - or first.opname == 'LOAD_ASSERTION_ERROR') - if starts_with_assert and last.opname == 'RAISE_VARARGS': - self.tracer.create_proxy('call_function', assert_fn, (self,), {}) - return True - - return self.tracer.to_bool(self) - - @compatibility(is_backward_compatible=True) - def keys(self): - return self.tracer.keys(self) - - def __len__(self): - raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " - "this call to be recorded, please call pippy.fx.wrap('len') at " - "module scope") - - @classmethod - def __torch_function__(cls, orig_method, types, args=None, kwargs=None): - args = args if args else () - kwargs = kwargs if kwargs else {} - - tracers : Dict[Any, None] = {} - - def find_tracer(a): - if isinstance(a, cls): - tracers[a.tracer] = None - map_aggregate(args, find_tracer) - map_aggregate(kwargs, find_tracer) - - if len(tracers) > 1: - raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' - f'trying to trace operations {orig_method}') - tracer = next(iter(tracers.keys())) - - if isinstance(orig_method, torch._C.ScriptMethod): - args = (orig_method.owner,) + args - return tracer.create_proxy('call_method', orig_method.name, args, kwargs) - if torch.overrides.is_tensor_method_or_property(orig_method): - return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) - else: - return tracer.create_proxy('call_function', orig_method, args, kwargs, - name=tracer.graph._target_to_str(orig_method.__name__)) - - -@compatibility(is_backward_compatible=True) -class Attribute(Proxy): - @compatibility(is_backward_compatible=True) - def __init__(self, root: Proxy, attr: str): - self.root = root - self.attr = attr - self.tracer = root.tracer - self._node: Optional[Node] = None - - @property - def node(self): - # the node for attributes is added lazily, since most will just be method calls - # which do not rely on the getitem call - if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node - return self._node - - def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) - - -@compatibility(is_backward_compatible=False) -class ParameterProxy(Proxy): - """ - A special proxy which lets "shape", "size", "dim", and a few other - attribute accesses pass through to the underlying module parameter object, - so that conditional tests on these attributes will not throw exception during tracing - """ - def __init__(self, tracer: TracerBase, node: Node, name, param): - super().__init__(node, tracer) - assert(isinstance(param, torch.nn.Parameter)) - self.param = param - self.name = name - - def __repr__(self) -> str: - return f'ParameterProxy({self.name})' - - @property - def shape(self): - return self.param.shape - - def size(self): - return self.param.size() - - def dim(self): - return self.param.dim() - - @property - def ndim(self): - return self.param.ndim - - def numel(self): - return self.param.numel() - - def nelement(self): - return self.param.nelement() - - -for method in magic_methods: - def _scope(method): - def impl(*args, **kwargs): - tracer = args[0].tracer - target = getattr(operator, method) - return tracer.create_proxy('call_function', target, args, kwargs) - impl.__name__ = method - as_magic = f'__{method.strip("_")}__' - setattr(Proxy, as_magic, impl) - _scope(method) - -def _define_reflectable(orig_method_name): - method_name = f'__r{orig_method_name.strip("_")}__' - - def impl(self, rhs): - target = getattr(operator, orig_method_name) - return self.tracer.create_proxy('call_function', target, (rhs, self), {}) - impl.__name__ = method_name - impl.__qualname__ = method_name - setattr(Proxy, method_name, impl) - -for orig_method_name in reflectable_magic_methods: - _define_reflectable(orig_method_name) diff --git a/pippy/fx/subgraph_rewriter.py b/pippy/fx/subgraph_rewriter.py deleted file mode 100644 index 42620e05c..000000000 --- a/pippy/fx/subgraph_rewriter.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .graph_module import GraphModule -from .graph import Graph -from .node import Node -from ._symbolic_trace import symbolic_trace -from ._compatibility import compatibility - -import copy -from typing import Callable, Dict, List, NamedTuple, Optional, Set -import torch - -__all__ = ['Match', 'replace_pattern'] - -@compatibility(is_backward_compatible=True) -class Match(NamedTuple): - # Node from which the match was found - anchor: Node - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] - - -def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None: - gm.delete_all_unused_submodules() - - if isinstance(replacement, GraphModule): - replacement.graph.lint() - - def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Module]: - try: - mod_match = mod.get_submodule(target) - return mod_match - except AttributeError: - return None - - for node in gm.graph.nodes: - if node.op == "call_module" or node.op == "get_attr": - - gm_submod = try_get_submodule(gm, node.target) - - replacement_submod = try_get_submodule(replacement, node.target) - - # CASE 1: This target already exists as a submodule in our - # result GraphModule. Whether or not it exists in - # `replacement`, the existing submodule takes precedence. - if gm_submod is not None: - continue - - # CASE 2: The target exists as a submodule in `replacement` - # only, so we need to copy it over. - elif replacement_submod is not None: - new_submod = copy.deepcopy(getattr(replacement, node.target)) - gm.add_submodule(node.target, new_submod) - - # CASE 3: The target doesn't exist as a submodule in `gm` - # or `replacement` - else: - raise RuntimeError("Attempted to create a \"", node.op, - "\" node during subgraph rewriting " - f"with target {node.target}, but " - "the referenced submodule does not " - "exist in either the original " - "GraphModule `gm` or the replacement" - " GraphModule `replacement`") - - gm.graph.lint() - -@compatibility(is_backward_compatible=True) -def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]: - """ - Matches all possible non-overlapping sets of operators and their - data dependencies (``pattern``) in the Graph of a GraphModule - (``gm``), then replaces each of these matched subgraphs with another - subgraph (``replacement``). - - Args: - ``gm``: The GraphModule that wraps the Graph to operate on - ``pattern``: The subgraph to match in ``gm`` for replacement - ``replacement``: The subgraph to replace ``pattern`` with - - Returns: - List[Match]: A list of ``Match`` objects representing the places - in the original graph that ``pattern`` was matched to. The list - is empty if there are no matches. ``Match`` is defined as: - - .. code-block:: python - - class Match(NamedTuple): - # Node from which the match was found - anchor: Node - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] - - Examples: - - .. code-block:: python - - import torch - from pippy.fx import symbolic_trace, subgraph_rewriter - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, w1, w2): - m1 = torch.cat([w1, w2]).sum() - m2 = torch.cat([w1, w2]).sum() - return x + torch.max(m1) + torch.max(m2) - - def pattern(w1, w2): - return torch.cat([w1, w2]).sum() - - def replacement(w1, w2): - return torch.stack([w1, w2]) - - traced_module = symbolic_trace(M()) - - subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) - - The above code will first match ``pattern`` in the ``forward`` - method of ``traced_module``. Pattern-matching is done based on - use-def relationships, not node names. For example, if you had - ``p = torch.cat([a, b])`` in ``pattern``, you could match - ``m = torch.cat([a, b])`` in the original ``forward`` function, - despite the variable names being different (``p`` vs ``m``). - - The ``return`` statement in ``pattern`` is matched based on its - value only; it may or may not match to the ``return`` statement in - the larger graph. In other words, the pattern doesn't have to extend - to the end of the larger graph. - - When the pattern is matched, it will be removed from the larger - function and replaced by ``replacement``. If there are multiple - matches for ``pattern`` in the larger function, each non-overlapping - match will be replaced. In the case of a match overlap, the first - found match in the set of overlapping matches will be replaced. - ("First" here being defined as the first in a topological ordering - of the Nodes' use-def relationships. In most cases, the first Node - is the parameter that appears directly after ``self``, while the - last Node is whatever the function returns.) - - One important thing to note is that the parameters of the - ``pattern`` Callable must be used in the Callable itself, - and the parameters of the ``replacement`` Callable must match - the pattern. The first rule is why, in the above code block, the - ``forward`` function has parameters ``x, w1, w2``, but the - ``pattern`` function only has parameters ``w1, w2``. ``pattern`` - doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. - As an example of the second rule, consider replacing - - .. code-block:: python - - def pattern(x, y): - return torch.neg(x) + torch.relu(y) - - with - - .. code-block:: python - - def replacement(x, y): - return torch.relu(x) - - In this case, ``replacement`` needs the same number of parameters - as ``pattern`` (both ``x`` and ``y``), even though the parameter - ``y`` isn't used in ``replacement``. - - After calling ``subgraph_rewriter.replace_pattern``, the generated - Python code looks like this: - - .. code-block:: python - - def forward(self, x, w1, w2): - stack_1 = torch.stack([w1, w2]) - sum_1 = stack_1.sum() - stack_2 = torch.stack([w1, w2]) - sum_2 = stack_2.sum() - max_1 = torch.max(sum_1) - add_1 = x + max_1 - max_2 = torch.max(sum_2) - add_2 = add_1 + max_2 - return add_2 - """ - from pippy.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch - - # Get the graphs for `gm`, `pattern`, `replacement` - original_graph: Graph = gm.graph - pattern_graph: Graph = symbolic_trace(pattern).graph - replacement_graph: Graph = symbolic_trace(replacement).graph - - matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, - remove_overlapping_matches=True) - _matches: List[InternalMatch] = matcher.match(original_graph) - - replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] - - # As we progressively replace nodes, we'll need to keep track of how the match results should change - match_changed_node: Dict[Node, Node] = {} - - for match in _matches: - - # Build connecting between replacement graph's input and original graph input producer node - - # Initialize `val_map` with mappings from placeholder nodes in - # `replacement` to their corresponding node in `original_graph` - assert len(match.placeholder_nodes) == len(replacement_placeholders) - val_map: Dict[Node, Node] = {} - for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): - val_map[rn] = match_changed_node.get(gn, gn) - - # Copy the replacement graph over - user_nodes: Set[Node] = set() - for n in match.returning_nodes: - for user in n.users: - user_nodes.add(user) - assert user_nodes, "The returning_nodes should have at least one user node" - - if len(user_nodes) == 1: - first_user_node = list(user_nodes)[0] - else: - # If there are multiple user nodes, we need to find the first user node - # in the current execution order of the `original_graph` - for n in original_graph.nodes: - if n in user_nodes: - first_user_node = n - break - - with original_graph.inserting_before(first_user_node): - copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) - - if isinstance(copied_returning_nodes, Node): - copied_returning_nodes = (copied_returning_nodes, ) - - # Hook the output Node of the replacement subgraph in to the - # original Graph at the correct location - assert len(match.returning_nodes) == len(copied_returning_nodes) - for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): - gn.replace_all_uses_with(copied_node) - match_changed_node[gn] = copied_node - # Remove the original nodes - for node in reversed(pattern_graph.nodes): - if node.op != "placeholder" and node.op != "output": - gn = match.nodes_map[node] - gm.graph.erase_node(gn) - - # Update the passed-in GraphModule to reflect the new state of - # `original_graph` - gm.recompile() - - # If `replacement` was an nn.Module, we'll need to make sure that - # all the submodules have been copied over correctly - if isinstance(replacement, torch.nn.Module): - _replace_submodules(gm, replacement) - - # Convert _matches: InternalMatch to Match to comply with backward compatibility of this function - matches: List[Match] = [Match(anchor=match.anchors[0], nodes_map=match.nodes_map) for match in _matches] - return matches diff --git a/pippy/fx/tensor_type.py b/pippy/fx/tensor_type.py deleted file mode 100644 index a85292ea3..000000000 --- a/pippy/fx/tensor_type.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.experimental.unification import Var # type: ignore[attr-defined] - -from ._compatibility import compatibility - - -@compatibility(is_backward_compatible=False) -class TensorType: - """ - TensorType defines a type for tensors, which consists of a list of dimensions. - Example: - class M(torch.nn.Module): - def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): - return torch.add(x, y) - """ - - def __init__(self, dim): - self.__origin__ = TensorType - self.__args__ = dim - - def __repr__(self): - return f'TensorType[{self.__args__}]' - - def __eq__(self, other): - if isinstance(other, self.__class__): - return list(self.__args__) == list(other.__args__) - else: - return False - - @staticmethod - def __class_getitem__(*args): - if len(args) == 1 and isinstance(args[0], tuple): - args = args[0] - return TensorType(tuple(args)) - - -class _DynType: - """ - _DynType defines a type which stands for the absence of type information. - """ - def __init__(self): - self.__name__ = '_DynType' - - def __eq__(self, other): - return isinstance(other, self.__class__) - - def __str__(self): - return "Dyn" - - def __repr__(self): - return "Dyn" - - -Dyn = _DynType() - -@compatibility(is_backward_compatible=False) -def is_consistent(t1, t2): - """ - A binary relation denoted by ~ that determines if t1 is consistent with t2. - The relation is reflexive, semmetric but not transitive. - returns True if t1 and t2 are consistent and False otherwise. - Example: - Dyn ~ TensorType((1,2,3)) - int ~ Dyn - int ~ int - TensorType((1,Dyn,3)) ~ TensorType((1,2,3)) - """ - - if t1 == t2: - return True - - if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): - return True - - if isinstance(t1, TensorType) and isinstance(t2, TensorType): - return len(t1.__args__) == len(t2.__args__) and \ - all([is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)]) - else: - return False - - -@compatibility(is_backward_compatible=False) -def is_more_precise(t1, t2): - """ - A binary relation denoted by <= that determines if t1 is more precise than t2. - The relation is reflexive and transitive. - returns True if t1 is more precise than t2 and False otherwise. - Example: - Dyn >= TensorType((1,2,3)) - int >= Dyn - int >= int - TensorType((1,Dyn,3)) <= TensorType((1,2,3)) - """ - if t1 == t2: - return True - - if isinstance(t2, _DynType): - return True - - if isinstance(t1, TensorType) and isinstance(t2, TensorType): - return len(t1.__args__) == len(t2.__args__) and \ - all([is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)]) - - else: - return False diff --git a/pippy/fx/traceback.py b/pippy/fx/traceback.py deleted file mode 100644 index a07b36b99..000000000 --- a/pippy/fx/traceback.py +++ /dev/null @@ -1,62 +0,0 @@ -import traceback -from contextlib import contextmanager -from typing import Optional, List -from ._compatibility import compatibility - -__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden'] - - -current_stack: List[str] = [] -is_overridden = False - - -@compatibility(is_backward_compatible=False) -@contextmanager -def override_stack_trace(): - global is_overridden - - saved_is_overridden = is_overridden - try: - is_overridden = True - yield - finally: - is_overridden = saved_is_overridden - - -@compatibility(is_backward_compatible=False) -def set_stack_trace(stack : List[str]): - global current_stack - - if is_overridden and stack: - current_stack = stack - -@compatibility(is_backward_compatible=False) -@contextmanager -def append_stack_trace(stack : Optional[str]): - """ - The content of stack here is an entire stacktraces as a string - """ - global current_stack - - if is_overridden and stack: - try: - current_stack.append(stack) - yield - finally: - current_stack.pop() - else: - yield - - -@compatibility(is_backward_compatible=False) -def format_stack() -> List[str]: - if is_overridden: - return current_stack.copy() - else: - # fallback to traceback.format_stack() - return traceback.format_stack() - - -@compatibility(is_backward_compatible=False) -def is_stack_trace_overridden() -> bool: - return is_overridden diff --git a/pippy/microbatch.py b/pippy/microbatch.py index eb2cace9a..a84c81bbf 100644 --- a/pippy/microbatch.py +++ b/pippy/microbatch.py @@ -1,13 +1,18 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging -import warnings -from typing import Any import torch - from torch.utils._pytree import tree_flatten, tree_unflatten -from pippy.IR import TrivialLossWrapper + +logger = logging.getLogger(__name__) + +""" +_debug_mask_minibatches specifies to send masked versions of the mini-batch +through instead of micro-batch slices--this can be used for more stable +numerical testing (see [A Note About Correctness Testing]) +""" +_debug_mask_minibatches = False class CustomReducer: @@ -48,7 +53,6 @@ def shard_dict_of_args( args_dict, args_chunk_spec, num_chunks, - _debug_mask_minibatches: bool = False, ): # Stage 1+2: flatten and shard/replicate @@ -95,7 +99,7 @@ def shard_dict_of_args( if first_tensor: # We can only adjust number of chunks when we hit this # issue at the first tensor encountered - warnings.warn( + logger.warning( f"Tensor size on chunking dimension is {v_split_dim_size}, " f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}." ) @@ -173,7 +177,6 @@ def split_args_kwargs_into_chunks( chunks, args_chunk_spec=None, kwargs_chunk_spec=None, - _debug_mask_minibatches: bool = False, ): # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` @@ -221,12 +224,13 @@ def split_args_kwargs_into_chunks( dict(enumerate(args)), dict(enumerate(args_chunk_spec)), chunks, - _debug_mask_minibatches, ) real_num_chunks = len(args_split_dict) kwargs_split = shard_dict_of_args( - kwargs, kwargs_chunk_spec, real_num_chunks, _debug_mask_minibatches + kwargs, + kwargs_chunk_spec, + real_num_chunks, ) if len(kwargs_split) < real_num_chunks: @@ -238,7 +242,6 @@ def split_args_kwargs_into_chunks( dict(enumerate(args)), dict(enumerate(args_chunk_spec)), real_num_chunks, - _debug_mask_minibatches, ) if len(args_split_dict) != len(kwargs_split): @@ -254,7 +257,7 @@ def split_args_kwargs_into_chunks( return args_split, kwargs_split -def merge_chunks(chunks, chunk_spec, _debug_mask_minibatches: bool = False): +def merge_chunks(chunks, chunk_spec): # Given a list of chunks and a chunk specification, merge the chunks # into a single value according to that chunk spec. This is essentially # the inverse of `split_args_kwargs_into_chunks`, so the steps are @@ -374,6 +377,8 @@ def merge_chunks(chunks, chunk_spec, _debug_mask_minibatches: bool = False): return tree_unflatten(args_flattened, flatten_spec) +# TODO: determine if we still need this helper +""" def gen_output_chunk_spec(loss_spec, loss_reducer): output_chunk_spec: Any = None if loss_spec is None: @@ -390,8 +395,9 @@ def gen_output_chunk_spec(loss_spec, loss_reducer): else: raise ValueError(f"Cannot generate output chunk spec for {loss_spec}") - logging.info( + logger.info( f"Generated output_chunk_spec for loss_spec {loss_spec}: " f"{output_chunk_spec}" ) return output_chunk_spec +""" diff --git a/pippy/utils.py b/pippy/utils.py index 7e3d4d7d4..69a9b7ae1 100644 --- a/pippy/utils.py +++ b/pippy/utils.py @@ -1,277 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -import logging -import os -import socket -from typing import List - -import torch.distributed as dist - - -# Pinning process to a separate GPU if not yet done by launch script -# Notes: -# 1. Previously this env was added to work around an issue that each RPC process creates an extra CUDA context on device -# 0. This issue may have been caused by RPC not automatically pinning spawned worker threads to same CUDA device as the -# main thread. So pinning each RPC process to one device would avoid the issue. -# 2. This pinning must be done before `import torch` at which point CUDA context may have been created. Thus, if user -# code has `import torch` before importing PiPPy, this may not work. -# (Update): the issue in #1 seems to be gone as of March 2023. Hence, we are setting the default value of -# `PIPPY_PIN_DEVICE` to 0 now. -if os.getenv("PIPPY_PIN_DEVICE", "0") == "1": - cuda_devices_str = os.getenv("CUDA_VISIBLE_DEVICES") - if ( - cuda_devices_str is None # not set - or len(cuda_devices_str.split(",")) > 1 - ): # or set to all devices - # If launchers like Torchrun sets `LOCAL_RANK`, we would use this information - local_rank_str = os.getenv("LOCAL_RANK") - if local_rank_str is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = local_rank_str - print( - f"Pinning local process {local_rank_str} to gpu {os.getenv('CUDA_VISIBLE_DEVICES')}" - ) - - import torch -import torch.distributed.rpc as rpc -import torch.multiprocessing as mp - -import pippy.fx - - -def get_rank() -> int: - worker_info = rpc.get_worker_info() - logging.debug(worker_info) - return worker_info.id - - -def get_device() -> torch.device: - worker_info = rpc.get_worker_info() - agent = rpc._get_current_rpc_agent() - dev_map = agent._get_device_map(worker_info) - logging.debug(dev_map) - num_devs = len(dev_map) - - if num_devs == 0: - logging.debug("Empty device mapping, assuming device type to be cpu") - device = torch.device("cpu") - elif num_devs != 1: - raise AssertionError( - f"Expecting at most one device for RPC worker {worker_info}, " - f"but got device map of length {num_devs}: {dev_map}" - ) - else: - src_dev = next(iter(dev_map)) - dst_dev = dev_map[src_dev] - if src_dev != dst_dev: - raise AssertionError( - f"Expecting at most one device for RPC worker {worker_info}, " - f"but got {dev_map}" - ) - device = src_dev - - logging.info(f"Found device {device} for rank {worker_info.id}") - return device - - -def get_pp_rank(rank: int, ranks: List[int]) -> int: - for index, r in enumerate(ranks): - if rank == r: - return index - raise ValueError(f"Rank {rank} not in ranks {ranks}") - - -def has_efa() -> bool: - try: - import subprocess - - return ( - subprocess.run( - ["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ).returncode - == 0 - ) - except FileNotFoundError: - return False - except PermissionError: - return False - - -def tp_transports(): - return ["shm", "uv"] if has_efa() else None - - -global _pp_group_barrier -# Defined later in `run_worker` (triggered via `run_pippy`) - - -# A barrier util for pipeline dimension -def pp_group_barrier(): - _pp_group_barrier() # type: ignore[name-defined] - - -def run_pippy(run_func, args, *extra_args): - if not hasattr(args, "world_size"): - assert hasattr(args, "pp_group_size") - args.dp_group_size = ( - args.dp_group_size if hasattr(args, "dp_group_size") else 1 - ) - else: - if not hasattr(args, "dp_group_size"): - args.pp_group_size = ( - args.pp_group_size - if hasattr(args, "pp_group_size") - else args.world_size - ) - assert args.world_size % args.pp_group_size == 0 - args.dp_group_size = args.world_size // args.pp_group_size - elif not hasattr(args, "pp_group_size"): - args.dp_group_size = ( - args.dp_group_size if hasattr(args, "dp_group_size") else 1 - ) - assert args.world_size % args.dp_group_size == 0 - args.pp_group_size = args.world_size // args.dp_group_size - else: - pass - # TODO: doesn't work for PiPPyTrainingArguments - # assert args.world_size == args.dp_group_size * args.pp_group_size - - actual_world_size = args.dp_group_size * args.pp_group_size - print( - f"[PiPPy] World size: {actual_world_size}, " - f"DP group size: {args.dp_group_size}, " - f"PP group size: {args.pp_group_size}" - ) - - if args.rank == -1: - mp.spawn( - run_worker, - args=(run_func, args, *extra_args), - nprocs=actual_world_size, - join=True, - ) - elif args.rank < actual_world_size: - run_worker(args.rank, run_func, args, *extra_args) - else: - print("I'm unused, exiting") - - -def run_worker(rank, run_func, args, *extra_args): - args.rank = rank - - os.environ["MASTER_ADDR"] = args.master_addr - os.environ["MASTER_PORT"] = args.master_port - - actual_world_size = args.dp_group_size * args.pp_group_size - - # TODO: Move to training args, blocked by: cannot pickle 'TensorPipeRpcBackendOptions' object - # Exclude IB for metadata transport due to lack of EFA support on AWS - if hasattr(args, "num_worker_threads"): - num_worker_threads = args.num_worker_threads - else: - num_worker_threads = 512 - - if hasattr(args, "rpc_timeout"): - rpc_timeout = args.rpc_timeout - else: - rpc_timeout = 1800 - - options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads=num_worker_threads, - rpc_timeout=rpc_timeout, - _transports=tp_transports(), - ) - if args.cuda: - n_devs = torch.cuda.device_count() - if n_devs > 0: - dev_id = rank % n_devs - for i in range(actual_world_size): - options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) - # Does not seem effective for RPC device pinning. TODO - # options.set_devices([f'cuda:{dev_id}']) - else: - args.cuda = 0 - print("Warning: no CUDA device found. Running on CPU instead.") - - args.device = f"cuda:{dev_id}" if args.cuda else "cpu" - print( - f"rank = {rank} host/pid/device = " - f"{socket.gethostname()}/{os.getpid()}/{args.device}" - ) - - # Init DDP process group - backend = "nccl" if args.cuda else "gloo" - torch.distributed.init_process_group( - backend=backend, rank=rank, world_size=actual_world_size - ) - - rpc.init_rpc( - f"worker{rank}", - rank=rank, - world_size=actual_world_size, - rpc_backend_options=options, - ) - - global dp_pg_per_pp_rank - dp_ranks_per_pp_rank = ( - torch.arange(actual_world_size) - .reshape(args.pp_group_size, args.dp_group_size) - .tolist() - ) - dp_pg_per_pp_rank = [ # type: ignore[name-defined] - torch.distributed.new_group(ranks) for ranks in dp_ranks_per_pp_rank - ] - - pp_ranks_per_dp_group = [ - [i * args.dp_group_size + rank for i in range(args.pp_group_size)] - for rank in range(args.dp_group_size) - ] - - my_pp_ranks = pp_ranks_per_dp_group[rank % args.dp_group_size] - - args.driver_group = torch.distributed.new_group( - list(range(args.dp_group_size)) - ) - - global exclude_master - exclude_master = ( # type: ignore[name-defined] - args.exclude_master if hasattr(args, "exclude_master") else 0 - ) - gspmd = ( # type: ignore[name-defined] - args.gspmd if hasattr(args, "gspmd") else 0 - ) - - # A barrier util for pipeline dimension - global _pp_group_barrier - - # ProcessGroupGloo cannot create group with strided ranks, e.g. [0, 2, 4, 6, ...] - # Skipping the `pp_group` and `pp_group_barrier` creation here - # TODO: unskip - if torch.distributed.get_backend() == "gloo" and args.dp_group_size > 1: - - def _pp_group_barrier(): - logging.warning( - f"pp_group_barrier() does not support ProcessGroupGloo with strided ranks {my_pp_ranks}. This will be a no-op." - ) - - else: - pp_group = torch.distributed.new_group(my_pp_ranks) - - def _pp_group_barrier(): - logging.debug( - f"Running pipeline group barrier on ranks {my_pp_ranks}" - ) - torch.distributed.barrier(pp_group) - - if rank >= 0 and rank // args.dp_group_size == 0: - args.driver_index = rank - args.local_driver_index = os.getenv("LOCAL_RANK", rank) - run_func(my_pp_ranks, args, *extra_args) - elif gspmd == 1: - run_func(my_pp_ranks, args, *extra_args) - - rpc.shutdown() +import torch.distributed as dist +from torch import fx def flatten_args_detach(args): @@ -287,11 +17,14 @@ def extract_tensor_args(a): flat_detached_args.append(a) return a + """ def dont_traverse_size(a): return type(a) != torch.Size + """ - new_args = pippy.fx.node.map_aggregate( - args, extract_tensor_args, dont_traverse_size + new_args = fx.node.map_aggregate( + args, + extract_tensor_args, # dont_traverse_size ) return new_args, flat_detached_args @@ -305,10 +38,15 @@ def extract_tensor_args(a): flat_args.append(a) return a + """ def dont_traverse_size(a): return type(a) != torch.Size + """ - pippy.fx.node.map_aggregate(args, extract_tensor_args, dont_traverse_size) + fx.node.map_aggregate( + args, + extract_tensor_args, # dont_traverse_size + ) return flat_args diff --git a/requirements.txt b/requirements.txt index ac954cf5d..4d73a1d52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch >= 1.13.0 +torch >= 2.2.0.dev packaging >= 21.3 diff --git a/setup.py b/setup.py index f5229da0c..b50de9597 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def write_version_file(): requirements = [ # If the torch version has a ".dev" suffix, it would represent a nightly version of PyTorch. # It can be installed as a binary or from source. - "torch>=1.13.0", + "torch>=2.2.0.dev", ] extras: Dict = {} diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect deleted file mode 100644 index 1b732fd1f..000000000 --- a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect +++ /dev/null @@ -1,19 +0,0 @@ -pippy.fx._symbolic_trace.ProxyableClassMeta [] -pippy.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'getattr', 'is_leaf_module', 'path_of_module', 'trace'] -pippy.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'on_generate_code', 'output', 'owning_module', 'placeholder', 'print_tabular', 'process_inputs', 'process_outputs', 'python_code', 'set_codegen'] -pippy.fx.graph.PythonCode [] -pippy.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'print_readable', 'recompile', 'to_folder'] -pippy.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update'] -pippy.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove'] -pippy.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node'] -pippy.fx.interpreter.Transformer ['call_function', 'call_module', 'get_attr', 'placeholder', 'transform'] -pippy.fx.node.Node ['all_input_nodes', 'append', 'args', 'format_node', 'is_impure', 'kwargs', 'next', 'normalized_arguments', 'prepend', 'prev', 'replace_all_uses_with', 'replace_input_with', 'stack_trace', 'update_arg', 'update_kwarg'] -pippy.fx.passes.shape_prop.ShapeProp ['propagate', 'run_node'] -pippy.fx.passes.shape_prop.TensorMetadata ['dtype', 'is_quantized', 'memory_format', 'qparams', 'requires_grad', 'shape', 'stride'] -pippy.fx.passes.split_module.Partition [] -pippy.fx.proxy.Attribute ['node'] -pippy.fx.proxy.GraphAppendingTracer [] -pippy.fx.proxy.Proxy ['keys'] -pippy.fx.proxy.TraceError [] -pippy.fx.proxy.TracerBase ['check_mutable_operations', 'create_arg', 'create_node', 'create_proxy', 'iter', 'keys', 'proxy', 'proxy_buffer_attributes', 'record_stack_traces', 'to_bool', 'trace_asserts', 'traced_func_name'] -pippy.fx.subgraph_rewriter.Match ['anchor', 'nodes_map'] \ No newline at end of file diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect deleted file mode 100644 index 25e1e641c..000000000 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ /dev/null @@ -1,74 +0,0 @@ -pippy.fx._symbolic_trace.Tracer.__init__(self, autowrap_modules: Tuple[Callable] = (,), autowrap_functions: Tuple[Callable, ...] = (,), param_shapes_constant: bool = False) -> None -pippy.fx._symbolic_trace.Tracer.call_module(self, m: torch.nn.modules.module.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx._symbolic_trace.Tracer.create_arg(self, a: Any) -> 'Argument' -pippy.fx._symbolic_trace.Tracer.is_leaf_module(self, m: torch.nn.modules.module.Module, module_qualified_name: str) -> bool -pippy.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.module.Module) -> str -pippy.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> pippy.fx.graph.Graph -pippy.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> pippy.fx.graph_module.GraphModule -pippy.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable]) -pippy.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None, tracer_extras: Optional[Dict[str, Any]] = None) -pippy.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.create_node(self, op: str, target: 'Target', args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.eliminate_dead_code(self) -pippy.fx.graph.Graph.erase_node(self, to_erase: pippy.fx.node.Node) -> None -pippy.fx.graph.Graph.get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.graph_copy(self, g: 'Graph', val_map: Dict[pippy.fx.node.Node, pippy.fx.node.Node], return_output_node = False) -> 'Optional[Argument]' -pippy.fx.graph.Graph.inserting_after(self, n: Optional[pippy.fx.node.Node] = None) -pippy.fx.graph.Graph.inserting_before(self, n: Optional[pippy.fx.node.Node] = None) -pippy.fx.graph.Graph.lint(self) -pippy.fx.graph.Graph.node_copy(self, node: pippy.fx.node.Node, arg_transform: Callable[[pippy.fx.node.Node], Argument] = >) -> pippy.fx.node.Node -pippy.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) -pippy.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> pippy.fx.node.Node -pippy.fx.graph.Graph.print_tabular(self) -pippy.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False) -> pippy.fx.graph.PythonCode -pippy.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: pippy.fx.graph.Graph, class_name: str = 'GraphModule') -pippy.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool -pippy.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None -pippy.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool -pippy.fx.graph_module.GraphModule.recompile(self) -> pippy.fx.graph.PythonCode -pippy.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module -pippy.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module -pippy.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module -pippy.fx.interpreter.Interpreter.__init__(self, module: pippy.fx.graph_module.GraphModule, garbage_collect_values: bool = True) -pippy.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.call_module(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.fetch_args_kwargs_from_env(self, n: pippy.fx.node.Node) -> Tuple[Tuple, Dict] -pippy.fx.interpreter.Interpreter.fetch_attr(self, target: str) -pippy.fx.interpreter.Interpreter.get_attr(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.map_nodes_to_values(self, args: pippy.fx.node.Argument, n: pippy.fx.node.Node) -> pippy.fx.node.Argument -pippy.fx.interpreter.Interpreter.output(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.placeholder(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.run(self, *args, initial_env: Optional[Dict[pippy.fx.node.Node, Any]] = None, enable_io_processing: bool = True) -> Any -pippy.fx.interpreter.Interpreter.run_node(self, n: pippy.fx.node.Node) -> Any -pippy.fx.interpreter.Transformer.__init__(self, module) -pippy.fx.interpreter.Transformer.call_function(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Transformer.call_module(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Transformer.get_attr(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> pippy.fx.proxy.Proxy -pippy.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> pippy.fx.proxy.Proxy -pippy.fx.interpreter.Transformer.transform(self) -> pippy.fx.graph_module.GraphModule -pippy.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None -pippy.fx.node.Node.append(self, x: 'Node') -> None -pippy.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None) -> Optional[str] -pippy.fx.node.Node.prepend(self, x: 'Node') -> None -pippy.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = >) -> List[Node] -pippy.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -pippy.fx.node.Node.update_arg(self, idx: int, arg: pippy.fx.node.Argument) -> None -pippy.fx.node.Node.update_kwarg(self, key: str, arg: pippy.fx.node.Argument) -> None -pippy.fx.node.map_aggregate(a: pippy.fx.node.Argument, fn: Callable[[pippy.fx.node.Argument], pippy.fx.node.Argument], should_traverse_fn: Optional[Callable[[pippy.fx.node.Argument], bool]] = None) -> pippy.fx.node.Argument -pippy.fx.node.map_arg(a: pippy.fx.node.Argument, fn: Callable[[pippy.fx.node.Node], pippy.fx.node.Argument]) -> pippy.fx.node.Argument -pippy.fx.passes.reinplace.reinplace(gm, *sample_args) -pippy.fx.passes.split_module.split_module(m: pippy.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[pippy.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False) -pippy.fx.proxy.Attribute.__init__(self, root: pippy.fx.proxy.Proxy, attr: str) -pippy.fx.proxy.Proxy.__init__(self, node: pippy.fx.node.Node, tracer: 'Optional[TracerBase]' = None) -pippy.fx.proxy.Proxy.keys(self) -pippy.fx.proxy.TracerBase.create_arg(self, a: Any) -> pippy.fx.node.Argument -pippy.fx.proxy.TracerBase.create_node(self, kind: str, target: pippy.fx.node.Target, args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, pippy.fx.node.Argument], name: Optional[str] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.proxy.TracerBase.create_proxy(self, kind: str, target: pippy.fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Callable[[pippy.fx.node.Node], Proxy] = None) -pippy.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator -pippy.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any -pippy.fx.proxy.TracerBase.proxy(self, node: pippy.fx.node.Node) -> 'Proxy' -pippy.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool -pippy.fx.subgraph_rewriter.replace_pattern(gm: pippy.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[pippy.fx.subgraph_rewriter.Match] diff --git a/test/fx/named_tup.py b/test/fx/named_tup.py deleted file mode 100644 index 2d4f63113..000000000 --- a/test/fx/named_tup.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import NamedTuple - -import torch - -class MyNamedTup(NamedTuple): - i : torch.Tensor - f : torch.Tensor diff --git a/test/fx/quantization.py b/test/fx/quantization.py deleted file mode 100644 index 75589ddbc..000000000 --- a/test/fx/quantization.py +++ /dev/null @@ -1,325 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -r''' -**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not -rely on it for anything!** -''' -from pippy.fx import Graph, GraphModule -from pippy.fx.graph import map_arg -from pippy.fx.proxy import Proxy -import sys -import torch -from torch.nn.utils import fuse_conv_bn_weights -import operator - -# can be a -# module type, a builtin function, or a string to match target - -def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps): - min_val = min(0.0, min_val) - max_val = max(0.0, max_val) - if max_val == min_val: - return 1.0, 0 - else: - scale = (max_val - min_val) / float(qmax - qmin) - scale = max(scale, eps) - zero_point = qmin - round(min_val / scale) - zero_point = max(qmin, zero_point) - zero_point = min(qmax, zero_point) - zero_point = int(zero_point) - return scale, zero_point - -class MinMaxObserver: - def __init__(self, quantizer, node): - self.min, self.max = float('inf'), float('-inf') - self.all_tensors = True - - def observe(self, node, env): - v = env[node.name] - if not isinstance(v, torch.Tensor): - self.all_tensors = False - return - self.max = max(self.max, float(v.max())) - self.min = min(self.min, float(v.min())) - - def scale_zeropoint(self): - return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255) - -class NoObserver: - def __init__(self, quantizer, node): - pass - - def observe(self, node, env): - pass - -DEFAULT_QUANTIZATION_PATTERNS = {} -def register_pattern(pattern): - def insert(fn): - DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn - return fn - return insert - - -@register_pattern(operator.add) -class Add(MinMaxObserver): - def quantize(self, quantizer, node, load_arg): - if not self.all_tensors: - return NotImplemented - scale, zeropoint = self.scale_zeropoint() - return quantizer.quantized_graph.create_node( - 'call_function', torch.ops.quantized.add, load_arg(node.args), {'scale': scale, 'zero_point': zeropoint}) - - -class Relu(NoObserver): - def quantize(self, quantizer, node, load_arg): - return torch.relu(load_arg(node.args[0])) # torch.relu works directly on quantized tensors? - -# these ops have quantized equivalents that do not need any extra information -@register_pattern(torch.nn.ReLU) -@register_pattern(torch.nn.AvgPool2d) -@register_pattern(torch.nn.MaxPool2d) -@register_pattern(torch.nn.AdaptiveAvgPool2d) -class CopyNode(NoObserver): - def quantize(self, quantizer, node, load_arg): - return quantizer.quantized_graph.node_copy(node, load_arg) - -class IdentityModule(torch.nn.Module): - def forward(self, x): - return x - -# handle conv, maybe followed by bn, maybe followed by relu -@register_pattern(torch.nn.modules.conv.Conv2d) -@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d)) -@register_pattern((torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)) -@register_pattern((torch.nn.ReLU, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d))) -class ConvNormRelu(MinMaxObserver): - def __init__(self, quantizer, node): - super().__init__(quantizer, node) - self.relu_node, self.bn_node = None, None - if isinstance(quantizer.modules[node.target], torch.nn.ReLU): - self.relu_node = node - node = node.args[0] - if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d): - self.bn_node = node - self.bn = quantizer.modules[self.bn_node.target] - node = node.args[0] - assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d) - self.conv_node = node - self.conv = quantizer.modules[self.conv_node.target] - - def quantize(self, quantizer, node, load_arg): - mod = self.conv - weight, bias = mod.weight, mod.bias - - if self.bn_node is not None: - weight, bias = fuse_conv_bn_weights( - weight, bias, self.bn.running_mean, self.bn.running_var, - self.bn.eps, self.bn.weight, self.bn.bias) - - min_val, max_val = float(weight.min()), float(weight.max()) - - act_scale, act_zp = self.scale_zeropoint() - - weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val) - qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8) - - ctor = torch.ao.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.ao.nn.quantized.Conv2d - - qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size, - mod.stride, mod.padding, mod.dilation, mod.groups, - mod.bias is not None, mod.padding_mode) - - qconv.set_weight_bias(qweight, bias) - qconv.scale = float(act_scale) - qconv.zero_point = int(act_zp) - parent_name, name = _parent_name(self.conv_node.target) - setattr(quantizer.modules[parent_name], name, qconv) - if self.bn_node is not None: - parent_bn, bn_name = _parent_name(self.bn_node.target) - # we can't just delete this because submodules's forwards (which are not longer use) - # try to call it, so replace with something that does nothing. - setattr(quantizer.modules[parent_name], bn_name, IdentityModule()) - - return quantizer.quantized_graph.create_node('call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]),), {}) - - -# turn foo.bar -> ['foo', 'bar'] -def _parent_name(target): - r = target.rsplit('.', 1) - if len(r) == 1: - return '', r[0] - else: - return r[0], r[1] - - - -class DefaultQuant(MinMaxObserver): - def quantize(self, input): - assert self.all_tensors - scale, zeropoint = self.scale_zeropoint() - return torch.quantize_per_tensor(Proxy(input), scale, zeropoint, torch.quint8).node - -def matches(modules, node, pattern, max_uses=sys.maxsize): - if isinstance(pattern, tuple): - self_match, *arg_matches = pattern - else: - self_match = pattern - arg_matches = None - - if len(node.users) > max_uses: - return False - - if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): - if node.op != 'call_module': - return False - if not isinstance(modules[node.target], self_match): - return False - elif callable(self_match): - if node.op != 'call_function' or node.target is not self_match: - return False - elif node.target != self_match: - return False - - if not arg_matches: - return True - - if len(arg_matches) != len(node.args): - return False - - return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) - - -class Quantizer: - def __init__(self, mod, patterns=DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant): - self.root = mod - self.graph = mod.graph - self.quant_ctor = quant_ctor - - # cached information for observe - self.state_dict = self.root.state_dict() - self.modules = dict(self.root.named_modules()) - - # match the patterns that will get quantized - self.matches = self._find_matches(patterns) - # find _inputs_ to matched nodes that are not quantized, these - # have to be quantized, which requires measuring stats, - # initialize an quant_ctor object for each - self.quants = self._find_quants(quant_ctor) - - - - def observe(self, args): - # most of this function is just an interpreter for the graph - # it would be possible to put this in some abstraction, but - # it is pretty nice to just be able to see exactly what is happening here - # and hack on it. - # maybe we should just provide an example interpreter that people copy/paste - # then edit. - args_iter = iter(args) - env = {} - - def load_arg(a): - return map_arg(a, lambda node: env[node.name]) - - output_node : Optional[Node] = None - for node in self.graph.nodes: - if node.op == 'placeholder': - result = next(args_iter) - elif node.op == 'get_attr': - result = self.state_dict[node.target] - elif node.op == 'call_function': - result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) - elif node.op == 'call_method': - self_obj, *args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = getattr(self_obj, node.target)(*args, **kwargs) - elif node.op == 'call_module': - result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) - elif node.op == 'output': - return load_arg(node.args[0]) - - env[node.name] = result - root_node, obj = self.matches.get(node.name, (None, None)) - if root_node is node: - obj.observe(node, env) - if node.name in self.quants: - self.quants[node.name].observe(node, env) - - raise RuntimeError('Graph had no output node!') - - def quantize(self): - self.quantized_graph = Graph() - - env = {} - quant_env = {} - - def load_arg(n, quantized): - if not quantized: - if n.name not in env and n.name in quant_env: - env[n.name] = Proxy(quant_env[n.name]).dequantize().node - return env[n.name] - else: - if n.name not in quant_env and n.name in env: - quant_env[n.name] = self.quants[n.name].quantize(env[n.name]) - return quant_env[n.name] - - def copy_recursive(node): - def load_or_emit(n): - if n.name in env or e.name in quant_env: - return load_arg(n, quantized=False) - else: - return copy_recusive(n) - r = env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) - return r - - for node in self.graph.nodes: - root_node, obj = self.matches.get(node.name, (None, None)) - if root_node is None: - # not quantized just copy it - env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) - - elif root_node is node: - r = obj.quantize(self, node, lambda a: map_arg(a, lambda n: load_arg(n, quantized=True))) - if r is NotImplemented: - # quantizer choose to to quantize the node take the entire match, and just copy it over - env[node.name] = copy_recursive(node) - else: - quant_env[node.name] = r - - return GraphModule(self.root, self.quantized_graph) - - def _find_matches(self, patterns): - modules = dict(self.root.named_modules()) - match_map = {} # node name -> (root_node, match_value?) - - def apply_match(pattern, node, match): - if isinstance(pattern, tuple): - s, *args = pattern - apply_match(s, node, match) - for subpattern, arg in zip(args, node.args): - apply_match(subpattern, arg, match) - else: - match_map[node.name] = match - - for node in reversed(self.graph.nodes): - if node.name not in match_map: - for pattern, value in patterns.items(): - if matches(modules, node, pattern): - apply_match(pattern, node, (node, value(self, node))) - - return match_map - - def _find_quants(self, quant_ctor): - quants = {} - - def visit_arg(n): - # note: we have to measure quantization information - # even for nodes where we might not use it because it is already - # quantized. This is because each match has the option to - # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) - if n.name not in quants: - quants[n.name] = quant_ctor(self, n) - for node in self.graph.nodes: - if node.name in self.matches: - map_arg(node.args, visit_arg) - map_arg(node.kwargs, visit_arg) - return quants diff --git a/test/fx/test_common_passes.py b/test/fx/test_common_passes.py deleted file mode 100644 index d16020e69..000000000 --- a/test/fx/test_common_passes.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["oncall: fx"] - -import torch - -from torch.testing._internal.common_utils import ( - TestCase, parametrize, instantiate_parametrized_tests, run_tests) -from pippy.fx.experimental.proxy_tensor import make_fx -from pippy.fx.passes.dialect.common.cse_pass import CSEPass -from pippy.fx.graph_module import GraphModule - -import itertools - -def FactoryFunctionCall(x, device): - y = torch.full(x.shape, 3, device=device) - z = torch.add(y, x) - return z - - -def TorchTensorCall(x): - y = torch.tensor(3) - return x + y - - -def TakeList(x): - z = torch.cat([x, x]) - return z - - -def ReturnList(x): - a = torch.arange(10).reshape(5, 2) - z = torch.split(a, [1, 4]) - return z - - -def Mutation(x): - y = x + 2 - y.add_(1) - return x + y - - -def MutationInput(x): - x.add_(1) - y = x + 2 - return x + y - - -def MutationFactory(x, device): - y = torch.full(x.shape, 3, device=device) - y.add_(1) - return x + y - - -def MutationTorchTensorCall(x): - y = torch.tensor(3) - y.add_(1) - return x + y - - -def MutationMetadata(x): - x.resize_(2) - return x - - -Passes = [CSEPass] -Test_Cases = [TakeList, - ReturnList, - Mutation, - MutationInput, - MutationMetadata, - MutationTorchTensorCall] -Factory_Test_Cases = [FactoryFunctionCall, MutationFactory] -Devices = ["cpu"] -if torch.cuda.is_available(): - Devices.append("cuda") - -@instantiate_parametrized_tests -class TestCommonPass(TestCase): - - @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices)) - def test_correctness(self, common_pass, f, device): - inp = torch.randn(10, device=device) - - traced_m = make_fx(f)(inp) - P = common_pass() - - res = P(traced_m) - modified_m = res.graph_module - assert isinstance(modified_m, GraphModule) - - inp_copy = inp.clone() - expected = f(inp) - result = modified_m(inp_copy) - - self.assertEqual(result, expected) - - - @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices)) - def test_correctness_factory(self, common_pass, f, device): - inp = torch.randn(10, device=device) - traced_m = make_fx(f)(inp, device) - P = common_pass() - - res = P(traced_m) - modified_m = res.graph_module - assert isinstance(modified_m, GraphModule) - - inp_copy = inp.clone() - expected = f(inp, device) - result = modified_m(inp_copy, device) - - self.assertEqual(result, expected) - - -if __name__ == '__main__': - run_tests() diff --git a/test/fx/test_cse_pass.py b/test/fx/test_cse_pass.py deleted file mode 100644 index 21a83e66d..000000000 --- a/test/fx/test_cse_pass.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["oncall: fx"] - -import torch - -from torch.testing._internal.common_utils import ( - TestCase, run_tests) -from pippy.fx.experimental.proxy_tensor import make_fx -from pippy.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops -from pippy.fx import symbolic_trace - -import random - - -banned_ops = get_CSE_banned_ops() -P_default = CSEPass(banned_ops=banned_ops) - -def check(self, f, t, delta, check_val=True, graph_input=False, P=None): - """ - check if the CSE modified graph of ``f`` - 1) has delta less nodes, and - 2) do not reduce the number of nodes further on a second pass, and - 3) modified returned is true only if the number of nodes decreases. - - Args: - f: function to be checked - t: tensor to be passed to f - delta: an integer >= -1. - If delta = -1, it only checks if the new graph has less or equal number of nodes - check_val: if True, check if the output of f is correct - graph_input: True is f is type GraphModule - P: the pass to use. If None, use P_default - """ - if graph_input: - fx_g = f - else: - fx_g = make_fx(f)(t) - - if P is None: - P = P_default - - res = P(fx_g) - new_g = res.graph_module - new_graph = new_g.graph - modified = res.modified - - # the number of nodes decrease/ or stay the same - old_num_nodes = len(fx_g.graph.nodes) - new_num_nodes = len(new_graph.nodes) - - assert (new_num_nodes < old_num_nodes) == modified, "modified should be True if the number of nodes decrease" - - if delta == -1: - self.assertTrue(old_num_nodes >= new_num_nodes, ( - f"number of nodes increased {old_num_nodes}, {new_num_nodes}")) - else: - self.assertTrue(old_num_nodes == new_num_nodes + delta, ( - f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}")) - - # a second pass should not reduce more nodes - res = P(new_g) - pass_2_graph = res.graph_module.graph - pass_2_num_nodes = len(pass_2_graph.nodes) - self.assertTrue(pass_2_num_nodes == new_num_nodes, ( - f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}")) - - # check correctness - if check_val: - true_result = fx_g(t) - our_result = new_g(t) - if true_result is None: # both return None - self.assertTrue(our_result is None, f"true result is None, CSE result is {our_result}") - else: # results returned are the same - self.assertTrue(torch.all(true_result == our_result), ( - f"results are different {true_result}, {our_result}")) # check results are the same - -class TestCSEPass(TestCase): - - def test_nochange(self): - def f(x): - a = x + 1 - b = x + a - a = x - d = x + a - return b + d - t = torch.randn(2, 2) - check(self, f, t, 0) - - def test_empty(self): - def f(x): - pass - t = torch.randn(2, 2) - check(self, f, t, 0) - - - def test_immutable_list_type(self): - def f(x): - a = x.sum(dim=1) - b = x.sum(dim=1) - c = x.sum() - d = x.sum() - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 2) - - def test_immutable_list_multiple_entries(self): - def f(x): - a = x.sum(dim=[0, 1]) - b = x.sum(dim=[0, 1]) - c = x.sum(dim=1) - d = x.sum(dim=1) - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 2) - - def test_simple(self): - def f(x): - a = x.cos() - b = x.cos() - c = a + a - d = b + b - return c + d - t = torch.randn(2, 2) - check(self, f, t, 2) - - def test_simple_2(self): - def f(x): - a = x.cos().sin() - b = x.cos().sin() - c = a + a - d = b + b - return c + d - t = torch.randn(1) - check(self, f, t, 3) - - def test_two_args_default(self): - def f(x): - a = x.sum(dim=1) - b = x.sum(dim=1, keepdim=False) - c = x.sum(dim=1, keepdim=False) - d = x.sum(dim=1) - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 3) - - def test_two_args(self): - def f(x): - a = x.sum(dim=1) - b = x.sum(dim=1, keepdim=True) - c = x.sum(dim=1, keepdim=True) - d = x.sum(dim=1) - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 2) - - def test_simple_multiple_same_ops(self): - def f(x): - a = x.sum() - b = x.sum() - c = x.sum() - d = x.sum() - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 3) - - def test_nested_immutable_list_type(self): - def f(x): - a = torch.cat((x, x)) - b = torch.cat((x, x)) - return a + b - t = torch.randn(2, 2) - check(self, f, t, 1) - - def test_kwarg(self): - def f(x): - a = torch.ones_like(x) - b = torch.ones_like(x) - return a + b - t = torch.randn(2, 2) - check(self, f, t, 1) - - """ - Generate function with random ops and check if the result is the same - """ - def test_random(self): - def f(x): - vals = [x] - ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu] - for _ in range(100): - new_val = random.choice(ops)(random.choice(vals)) - vals.append(new_val) - return vals[-1] - - fx_g = symbolic_trace(f) - fx_g.graph.eliminate_dead_code() - fx_g.recompile() - t = torch.randn(2, 2) - - for _ in range(30): - check(self, fx_g, t, -1, graph_input=True) - - """ - Test that banned list ban ops as expected. - """ - def test_banned_list(self): - def f(x): - a = x + 1 - b = x + 1 - return a + b - - t = torch.randn(2, 2) - P_ban_add = P = CSEPass(banned_ops=[torch.ops.aten.add]) - check(self, f, t, 0, P=P_ban_add) # check that add is banned - check(self, f, t, 1) # check that add is not banned by default - - def test_rand_like(self): - def f(x): - a = torch.rand_like(x) - b = torch.rand_like(x) - return a + b - t = torch.randn(2, 2) - check(self, f, t, 0, check_val=False) - - def test_rand_n(self): - def f(x): - a = torch.randn(4) - b = torch.randn(4) - return a + b - t = torch.randn(2, 2) - check(self, f, t, 0, check_val=False) - - -if __name__ == '__main__': - run_tests() diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py deleted file mode 100644 index bb29df4c9..000000000 --- a/test/fx/test_dce_pass.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -from typing import Set, Type -import torch -import pippy.fx - -from torch.testing._internal.common_utils import TestCase - - -class TestDCE(TestCase): - def _has_nodes_without_users(self, m: pippy.fx.GraphModule): - for node in m.graph.nodes: - if node.is_impure(): - continue - if len(node.users) == 0: - return True - return False - - def _get_num_placeholders(self, m: pippy.fx.GraphModule) -> int: - count = 0 - for node in m.graph.nodes: - if node.op == "placeholder": - count += 1 - return count - - def _run_dce_and_test( - self, - m: torch.nn.Module, - expect_dce_changes: bool, - modules_to_be_leafs: Set[Type] = None, - ): - class TestTracer(pippy.fx.Tracer): - def is_leaf_module(self, m, qualname): - if modules_to_be_leafs and type(m) in modules_to_be_leafs: - return True - return super().trace(m, qualname) - - traced: pippy.fx.GraphModule = pippy.fx.GraphModule(m, TestTracer().trace(m)) - print(str(traced.graph)) - - # Verify there are nodes without users (if expected). - has_nodes_without_users = self._has_nodes_without_users(traced) - if expect_dce_changes: - self.assertTrue(has_nodes_without_users) - else: - self.assertFalse(has_nodes_without_users) - - # Get the original number of placeholders to verify it doesn't change - # during DCE. - orig_num_phs = self._get_num_placeholders(traced) - changed = traced.graph.eliminate_dead_code() - - self.assertTrue(changed if expect_dce_changes else not changed) - - # Verify there are no nodes without users after DCE is run. - self.assertFalse(self._has_nodes_without_users(traced)) - new_num_phs = self._get_num_placeholders(traced) - self.assertEqual(orig_num_phs, new_num_phs) - - traced.recompile() - # Make sure we run and get the same results before/after DCE. - inputs = [torch.tensor([1.5])] * new_num_phs - self.assertTrue(torch.equal(m(*inputs), traced(*inputs))) - - def test_simple(self): - """ - Tests that a single node in the graph is DCE'd correctly. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) - - def forward(self, x): - a = x + 1 - return x + self.attr_1 - - self._run_dce_and_test(TestModule(), expect_dce_changes=True) - - def test_dead_chain(self): - """ - Tests that a chain of two nodes in the graph are DCE'd correctly. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) - - def forward(self, x): - a = x + 1 - b = a * 7 - return x + self.attr_1 - - self._run_dce_and_test(TestModule(), expect_dce_changes=True) - - def test_dead_getattr(self): - """ - Tests that a getatrr in the graph is DCE'd correctly. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) - - def forward(self, x): - a = x + 1 - b = a * self.attr_1 - return x + 11 - - self._run_dce_and_test(TestModule(), expect_dce_changes=True) - - def test_dead_placeholder(self): - """ - Tests that a placeholder in the graph is not DCE'd, as that would change - the function signature. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x + 7 - - self._run_dce_and_test(TestModule(), expect_dce_changes=False) - - def test_dead_placeholder_with_user(self): - """ - Tests that a placeholder in the graph is not DCE'd, as that would change - the function signature. Also verifies that a dead node that uses the - placeholder is DCE'd. - - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - a = y + 2 - return x + 7 - - self._run_dce_and_test(TestModule(), expect_dce_changes=True) - - def test_keep_module_with_side_effects(self): - """ - Test that DCE doesn't remove a module if it's specified as having side effects. - """ - - class ReLUImpure(torch.nn.ReLU): - _is_impure = True - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = ReLUImpure() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - r = self.relu(a) - return a * 2 - - self._run_dce_and_test( - TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure} - ) - - def test_keep_torch_assert(self): - """ - Test that DCE doesn't remove torch._assert since it has side effects. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - torch._assert(torch.equal(a, a), "a must equal a") - return a * 2 - - # Note: Don't need to specify torch._assert as having side effects - # because it's known to. - self._run_dce_and_test(TestModule(), expect_dce_changes=False) diff --git a/test/fx/test_future.py b/test/fx/test_future.py deleted file mode 100644 index de9af1487..000000000 --- a/test/fx/test_future.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -from __future__ import annotations # type: ignore[attr-defined] -import torch -import typing -from pippy.fx import symbolic_trace - -class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - -# No forward references -class M1(torch.nn.Module): - def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: - return a(x) - -# Forward references -class M2(torch.nn.Module): - def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor': - return a(x) - -# Non-torch annotation with no internal forward references -class M3(torch.nn.Module): - def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: - return a(x[0]) - -# Non-torch annotation with internal forward references -class M4(torch.nn.Module): - def forward(self, x: typing.List['torch.Tensor'], a: A) -> 'torch.Tensor': - return a(x[0]) - -x = torch.rand(2, 3) - -ref = torch.add(x, x) - -traced1 = symbolic_trace(M1()) -res1 = traced1(x, A()) -assert torch.all(torch.eq(ref, res1)) - -traced2 = symbolic_trace(M2()) -res2 = traced2(x, A()) -assert torch.all(torch.eq(ref, res2)) - -traced3 = symbolic_trace(M3()) -res3 = traced3([x], A()) -assert torch.all(torch.eq(ref, res3)) - -traced4 = symbolic_trace(M4()) -res4 = traced4([x], A()) -assert torch.all(torch.eq(ref, res4)) diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py deleted file mode 100644 index b8207c2b5..000000000 --- a/test/fx/test_fx_const_fold.py +++ /dev/null @@ -1,712 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import operator - -import torch -import pippy.fx -from pippy.fx.experimental import const_fold -from pippy.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp -from torch.testing._internal.common_utils import TestCase - - -class TestConstFold(TestCase): - def _get_attr(self, node): - mod = node.graph.owning_module - target = str(node.target) - target_atoms = target.split(".") - curr_obj = mod - for i, atom in enumerate(target_atoms): - if not hasattr(curr_obj, atom): - raise RuntimeError( - f"Node referenced nonexistent target '{'.'.join(target_atoms[:i])}'; " - f" original whole target: '{target}'" - ) - curr_obj = getattr(curr_obj, atom) - return curr_obj - - def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule): - self.assertTrue(mod_folded.const_subgraph_module is not None) - - # Check that we don't have the const or non-const fold graphs in the gm, and - # that we do have the const folded get_attr. - found_folded_attrs = False - for n in mod_folded.graph.nodes: - if n.op == "get_attr" and n.target.startswith("_FX_CONST_FOLDED_ATTRS"): - found_folded_attrs = True - elif n.op == "call_module": - self.assertTrue(n.target not in {"submod_0", "submod_1"}) - self.assertTrue(found_folded_attrs) - - def test_const_fold_basic_one_attr_no_name_collision(self): - r""" - Perform constant folding conversion, from original mod to split constant folding - module with two split subgraphs, where there's a single attr to fold and - a single output attr result to replace. - - attr1 attr1 - | | | | - x add add - \ / | - sub y output (becomes attr add_1) - \ / ==> -------+------- (const/base subgraph split) - mul attr2 x / (input from previous subgraph - \ / \ / is attr) - add sub y - | \ / - output mul attr2 - \ / - add - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]])) - self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]])) - - def forward(self, x, y): - a = self.attr_1 + self.attr_1 - x = x - a - return x * y + self.attr_2 - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) - base_result = mod(in_x, in_y) - fold_result = mod_folded(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_basic_one_attr_name_collision(self): - r""" - Perform constant folding conversion, from original mod to split constant folding - module with two split subgraphs, where there's a single attr to fold and - a single output attr result to replace. Name the attrs such that they will - collide by name with folded attrs. - - add_1 add_1 - | | | | - x add add - \ / | - sub y output (becomes attr add_1) - \ / ==> -------+------- (const/base subgraph split) - mul add_2 x / (input from previous subgraph - \ / \ / is attr) - add sub y - | \ / - output mul add_2 - \ / - add - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - # Note: Named as such to result in name collision. - self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]])) - self.add_2__CF = torch.nn.Parameter(torch.tensor([[17.1]])) - - def forward(self, x, y): - a = self.add_1__CF + self.add_1__CF - x = x - a - return x * y + self.add_2__CF - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0]) - base_result = mod(in_x, in_y) - fold_result = mod_folded(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_basic_placeholder_reordered(self): - """ - Test code path where placeholder comes after normal op node in FX - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x * 2 + y - - mod = ConstFoldTestModule() - mod = pippy.fx.symbolic_trace(mod) - yy = None - for n in mod.graph.nodes: - if n.op == "placeholder" and n.target == "y": - yy = n - elif yy is not None and n.op == "call_function": - yy.prepend(n) - break - - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - - self.assertTrue(mod_folded.const_subgraph_module is None) - # Now run both folded and non-folded to check results equal. - in_x = torch.tensor([[-0.45]]) - in_y = torch.tensor([[0.45]]) - base_result = mod(in_x, in_y) - fold_result = mod_folded(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_noop(self): - r""" - Check that a graph with no constant folding is handled correctly. - - x attr1 - \ / - sub - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) - - def forward(self, x): - return x - self.attr1 - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - - # Check that the folded graph module is None, since there was no folding to do. - self.assertTrue(mod_folded.const_subgraph_module is None) - - # Now run both folded and non-folded to check results equal. - in_x = torch.tensor([[-0.45]]) - base_result = mod(in_x) - fold_result = mod_folded(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_basic_two_attr_three_input(self): - r""" - Perform constant folding conversion, from original mod to split constant - folding module with two split subgraphs, where there are two attrs to - fold into a single output, and there are three placeholder inputs. - - attr1 attr2 attr1 attr2 - \ / \ / - x add add - \ / | - sub y output (becomes attr add_1) - \ / ==> -------+------- (const/base subgraph split) - mul z x / (input from previous subgraph - \ / \ / is attr) - div sub y - | \ / - output mul z - \ / - div - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) - self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]])) - - def forward(self, x, y, z): - a = self.attr1 + self.attr1 - sub = x - a - mul = sub * y - return mul / z - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x, in_y, in_z = ( - torch.tensor([[-0.45]]), - torch.tensor([0.9]), - torch.tensor([1.1]), - ) - base_result = mod(in_x, in_y, in_z) - fold_result = mod_folded(in_x, in_y, in_z) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_basic_two_attr(self): - r""" - Perform constant folding conversion, from original mod to split constant - folding module with two split subgraphs, where there are two attrs to - fold into a single output. - - attr1 attr2 attr1 attr2 - \ / \ / - x add add (becomes attr add_1) - \ / ==> -------+------- (const/base subgraph split) - sub x | (input from previous subgraph is attr) - | \ / - output sub - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr1 = torch.nn.Parameter(torch.randn(2, 3)) - self.attr2 = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - y = self.attr1 + self.attr2 - return x + y - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = mod_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_multi_const_folded_attrs(self): - r""" - Perform constant folding conversion, from original mod to split constant - folding module with two split subgraphs, where there are two attrs to - fold into two new attrs. - - attr1 attr2 attr1 attr2 - / \ | / \ | - permute | sum permute | sum - \ / / \ / | - x add y / add | - \ / \ / | | - sub add output output (become attrs add_1 and mul_1) - \ / ==> --------+-------+------ (const/base subgraph split) - \ / x | y | (inputs from previous subgraph - add \ / \ / are attrs) - | sub add - linear \ / - | add - sigmoid | - | linear - output | - sigmoid - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr1 = torch.nn.Parameter(torch.randn(4, 4)) - self.attr2 = torch.nn.Parameter(torch.randn(4, 4)) - self.lin = torch.nn.Linear(4, 4) - - def forward(self, x, y): - a = self.attr1 + self.attr1.permute(1, 0) - x = x - a - amax = torch.sum(self.attr2, dim=1) - y = y + amax - return torch.sigmoid(self.lin(x + y)) - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x, in_y = torch.randn(4, 4), torch.randn(4) - fold_result = mod_folded(in_x, in_y) - base_result = mod(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_submod_hierarchy(self): - r""" - Perform constant folding conversion, from original mod to split constant folding - module where one of the folded attrs comes from a submod deeper in the hierarchy - of the base module. - """ - - class TracedThroughModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.internal_attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self): - return self.internal_attr - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.my_mod = TracedThroughModule() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - return self.attr + self.my_mod() + x - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = mod_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_retain_node_meta(self): - r""" - Perform constant folding conversion, and validate that node meta is retained. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.attr + self.attr - return x - a - - mod = ConstFoldTestModule() - gm = pippy.fx.symbolic_trace(mod) - - # Add a count for each node to check after we const fold. - for idx, node in enumerate(gm.graph.nodes): - if node.op != "output": - node.meta["meta_idx"] = idx - - # Pre-folding: - # idx 0: placeholder - # idx 1: get_attr (will no longer be used, hence removed) - # idx 2: add (will be folded into a get_attr) - # idx 3: sub - - gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) - self._verify_const_fold_mod(gm_folded) - - # Post-folding: - # idx 0: placeholder - # idx 2: get_attr (replaced original add; original get_attr was removed) - # idx 3: sub - - # Check the expected indices are still here. - for node in gm_folded.graph.nodes: - if node.op == "placeholder": - self.assertEqual(node.meta["meta_idx"], 0) - elif node.op == "get_attr": - self.assertEqual(node.meta["meta_idx"], 2) - elif node.op == "call_function" and node.target == operator.sub: - self.assertEqual(node.meta["meta_idx"], 3) - else: - self.assertEqual(node.op, "output") - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_has_inlined_call_module_node(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - self.mod = torch.nn.Identity() - self.mod.relu = torch.nn.ReLU() - - def forward(self, x): - a = self.attr + self.attr - return self.mod.relu(x - a) - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_module_attr(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - self.mod = torch.nn.Identity() - self.mod.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.const + self.mod.attr - x = x + a - return x + self.mod.attr - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_unused_placeholder(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x, y, z): - a = self.const + self.const - return y + a - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x, in_x, in_x) - base_result = mod(in_x, in_x, in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_dict_output(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.const + self.const - return {"result": x + a} - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result["result"], base_result["result"])) - - def test_two_outputs(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.const + self.const - return x, x + a - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result[0], base_result[0])) - self.assertTrue(torch.equal(fold_result[1], base_result[1])) - - def test_three_outputs(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.const + self.const - return x, x + a, x + a - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result[0], base_result[0])) - self.assertTrue(torch.equal(fold_result[1], base_result[1])) - self.assertTrue(torch.equal(fold_result[2], base_result[2])) - - def test_check_inline_non_const(self): - r""" - Perform constant folding conversion and check that the non-const module is inlined - correctly. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.attr + self.attr - return (x - a * x) / 2 - - mod = ConstFoldTestModule() - gm = pippy.fx.symbolic_trace(mod) - - gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) - self._verify_const_fold_mod(gm_folded) - - # Check there are no call modules, because they've been inlined or extracted for - # const folding. - for node in gm_folded.graph.nodes: - self.assertNotEqual(node.op, "call_module") - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_check_inline_non_const_mult_return(self): - r""" - Perform constant folding conversion and check that the non-const module is inlined - correctly. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.attr + self.attr - return x - a, x / 2 - - mod = ConstFoldTestModule() - gm = pippy.fx.symbolic_trace(mod) - - gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) - self._verify_const_fold_mod(gm_folded) - - # Check there are no call modules, because they've been inlined or extracted for - # const folding. - for node in gm_folded.graph.nodes: - self.assertNotEqual(node.op, "call_module") - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result[0], base_result[0])) - self.assertTrue(torch.equal(fold_result[1], base_result[1])) - - def test_check_skip_folding_quant_dequant_pattern(self): - r""" - Set up skip_folding_quant_dequant function to skip quant/dequant pattern. - This example shows how to use skip_folding_node_fn. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(4, 4)) - self.bias = torch.nn.Parameter(torch.randn(4)) - self.relu = torch.nn.ReLU() - - def forward(self, x): - quant_weight = torch.quantize_per_tensor( - self.weight, 0.5, 3, torch.quint8 - ) - dequant_weight = torch.dequantize(quant_weight) - output = torch.nn.functional.linear(x, dequant_weight, self.bias) - return self.relu(output) - - mod = ConstFoldTestModule() - in_x = torch.randn(2, 4) - gm = pippy.fx.symbolic_trace(mod) - - def skip_folding_quant_dequant(node: pippy.fx.Node): - if node.target != torch.quantize_per_tensor: - return False - # If quantize_per_node -> dequantize, then skip folding. - for user in node.users: - if user.target == torch.dequantize: - return True - return False - - gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( - gm, skip_folding_node_fn=skip_folding_quant_dequant - ) - - # Check that the folded graph module is None, since there was no folding to do. - self.assertTrue(gm_folded.const_subgraph_module is None) - - # Now run both folded and non-folded to check results equal. - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_fold_module(self): - r""" - Perform constant folding with a call_module node. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin_input = torch.nn.Parameter(torch.randn(4, 4)) - self.lin = torch.nn.Linear(4, 4) - - def forward(self, x): - return self.lin(self.lin_input) + x - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - inp = torch.randn(4, 4) - self.assertTrue(torch.equal(mod_folded(inp), mod(inp))) - - def test_const_fold_tensor_meta(self): - self._test_const_fold_tensor_meta(True) - self._test_const_fold_tensor_meta(False) - - def _test_const_fold_tensor_meta(self, requires_grad): - """ - Verify tensor_meta is handled correctly. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad) - self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad) - - def forward(self, x, y): - a = self.attr_1 + self.attr_1 - x = x - a - return x * y + self.attr_2 - - mod = ConstFoldTestModule() - gm = pippy.fx.symbolic_trace(mod) - in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) - ShapeProp(gm).propagate(in_x, in_y) - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( - gm, device_for_folded_attrs="cpu" - ) - self._verify_const_fold_mod(mod_folded) - - mod_folded.run_folding() - - for n in mod_folded.graph.nodes: - if n.op == "get_attr": - attr = self._get_attr(n) - self.assertEquals(_extract_tensor_metadata(attr), n.meta["tensor_meta"]) - - # Now run both folded and non-folded to check results equal. - base_result = mod(in_x, in_y) - fold_result = mod_folded(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) diff --git a/test/fx/test_fx_param_shape_control_flow.py b/test/fx/test_fx_param_shape_control_flow.py deleted file mode 100644 index 88f19642c..000000000 --- a/test/fx/test_fx_param_shape_control_flow.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import unittest -import torch -import pippy.fx - -from torch.testing._internal.common_utils import TestCase - - -class MyModuleBase(torch.nn.Module): - def forward(self, x): - matrx = self.get_mul_matrix() - if self.no_relu(): - return torch.mm(x, matrx) - else: - return torch.relu(torch.mm(x, matrx)) - - def get_mul_matrix(self): - return self.param - - def no_relu(self): - raise Exception("not implemented") - -class MyModuleParamShape(MyModuleBase): - def __init__(self, in_channels): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) - - def no_relu(self): - return self.param.shape[0] < 10 - - -class MyModuleParamSize(MyModuleBase): - def __init__(self, in_channels): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) - - def no_relu(self): - return self.param.size()[0] < 10 - - -class MyModuleParamDim(MyModuleBase): - def __init__(self, param): - super().__init__() - self.param = param - - def get_mul_matrix(self): - return self.param[0] if (self.param.dim() == 3) else self.param - - def no_relu(self): - return self.param.dim() == 3 - - -class MyModuleParamNDim(MyModuleBase): - def __init__(self, param): - super().__init__() - self.param = param - - def get_mul_matrix(self): - return self.param[0] if (self.param.ndim == 3) else self.param - - def no_relu(self): - return self.param.ndim == 3 - - -class MyModuleParamNumEl(MyModuleBase): - def __init__(self, in_channels): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) - - def no_relu(self): - return self.param.numel() < 10 * 3 - - - -class MyModuleParamNElement(MyModuleBase): - def __init__(self, in_channels): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) - - def no_relu(self): - return self.param.nelement() < 10 * 3 - - - -class TestConstParamShapeInControlFlow(TestCase): - - def verify_mm_relu_mods(self, mm_only_mod, relu_mod): - """ - Verify one module only does a mm op while the other - performs both mm and relu ops in cascade - """ - x = torch.randn(10, 5) - torch.testing.assert_allclose(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())) - tracer = pippy.fx.Tracer(param_shapes_constant=True) - traced_graph = tracer.trace(mm_only_mod) - - # verify the graph module calculates the same result - graph_mod_mm = pippy.fx.GraphModule(mm_only_mod, traced_graph) - torch.testing.assert_allclose(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())) - - - # Make a new module with different parameter shape to go down the different - # code path - x = torch.randn(10, 15) - torch.testing.assert_allclose(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))) - - tracer2 = pippy.fx.Tracer(param_shapes_constant=True) - traced_graph2 = tracer2.trace(relu_mod) - - # verify the graph module calculates the same result - graph_mod_relu = pippy.fx.GraphModule(relu_mod, traced_graph2) - torch.testing.assert_allclose(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))) - - - graph1_node_targets = [n.target for n in traced_graph.nodes] - graph2_node_targets = [n.target for n in traced_graph2.nodes] - - # the second graph has an exta relu function call node - assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets - assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets - - def test_param_shape_const(self): - mymod = MyModuleParamShape(in_channels=5) - mymod2 = MyModuleParamShape(in_channels=15) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_size_const(self): - mymod = MyModuleParamSize(in_channels=5) - mymod2 = MyModuleParamSize(in_channels=15) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_dim_const(self): - mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3))) - mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3))) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_ndim_const(self): - mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3))) - mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3))) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_numel_const(self): - mymod = MyModuleParamNumEl(in_channels=5) - mymod2 = MyModuleParamNumEl(in_channels=15) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_nelement_const(self): - mymod = MyModuleParamNElement(in_channels=5) - mymod2 = MyModuleParamNElement(in_channels=15) - self.verify_mm_relu_mods(mymod, mymod2) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/fx/test_gradual_type.py b/test/fx/test_gradual_type.py deleted file mode 100644 index 9f82f0810..000000000 --- a/test/fx/test_gradual_type.py +++ /dev/null @@ -1,1017 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import unittest -import torch -import pippy -import pippy.fx -from pippy.fx import symbolic_trace -from pippy.fx.experimental.unify_refinements import infer_symbolic_types -from pippy.fx.experimental.refinement_types import Equality -from pippy.fx.tensor_type import TensorType, Dyn, is_consistent, is_more_precise -from pippy.fx.annotate import annotate -from pippy.fx.experimental.graph_gradual_typechecker import GraphTypeChecker, broadcast_types, Refine -from pippy.fx.experimental.rewriter import RewritingTracer -from pippy.fx import GraphModule -from pippy.fx.passes.shape_prop import ShapeProp -from torch.testing._internal.common_utils import TestCase - - -try: - import sympy - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False -skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") - - -try: - from torchvision.models import resnet50 - - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - -class AnnotationsTest(TestCase): - - def test_annotations(self): - """ - Test type annotations in the forward function. - The annoation should appear in the n.graph - where n is the corresoinding node in the resulting graph. - """ - class M(torch.nn.Module): - def forward(self, - x: TensorType((1, 2, 3, Dyn)), - y: Dyn, - z: TensorType[Dyn, 3, Dyn]): - return torch.add(x, y) + z - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - - expected_ph_types = [TensorType((1, 2, 3, Dyn)), Dyn, TensorType((Dyn, 3, Dyn))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - assert n.type == next(expected_iter) - - def test_annotate(self): - class M(torch.nn.Module): - - def forward(self, x): - y = annotate(x, TensorType((1, 2, 3, Dyn))) - return torch.add(x, y) - - module = M() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(module) - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((1, 2, 3, Dyn)) - - def test_consistency(self): - """ - Test the consistency relation. - """ - self.assertTrue(is_consistent(TensorType((1, 2, 3)), TensorType((1, Dyn, 3)))) - self.assertTrue(is_consistent(int, Dyn)) - self.assertTrue(is_consistent(int, int)) - self.assertFalse(is_consistent(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5)))) - self.assertFalse(is_consistent(TensorType((1, 2, 3)), int)) - - def test_precision(self): - """ - Test the consistency relation. - """ - self.assertTrue(is_more_precise(TensorType((1, 2, 3)), TensorType((1, Dyn, 3)))) - self.assertTrue(is_more_precise(int, Dyn)) - self.assertTrue(is_more_precise(int, int)) - self.assertFalse(is_more_precise(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5)))) - self.assertFalse(is_more_precise(TensorType((1, 2, 3)), int)) - - def test_broadcasting1(self): - t1 = TensorType((1, 2, 3, 4)) - t2 = TensorType((1, 2, 1, 4)) - t3 = TensorType(()) - t4 = TensorType((4, 1)) - t5 = TensorType((4, 4, 4)) - # todo switch all code to use list instead of tuple - t6 = TensorType([1]) - assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, 4))) - assert broadcast_types(t3, t4) == (t4, t4) - assert broadcast_types(t5, t6) == (t5, t5) - - def test_broadcasting2(self): - t1 = TensorType((2, 3, 4)) - t2 = TensorType((1, 2, 1, 4)) - - assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, 4))) - - def test_broadcasting3(self): - t1 = TensorType((1, 2, 3, Dyn)) - t2 = TensorType((2, 3, 4)) - assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, Dyn)), TensorType((1, 2, 3, 4))) - -class TypeCheckerTest(TestCase): - - def test_type_check_add_with_broadcast(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - expected_ph_types = [TensorType((1, 2, 3, Dyn)), - TensorType((2, 3, 4)), - TensorType((1, 2, 3, Dyn)), - TensorType((1, 2, 3, Dyn))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - if n.op == 'call_function': - assert n.meta['broadcast'] - assert n.type == next(expected_iter) - - def test_type_check_add_with_scalar(self): - class M(torch.nn.Module): - def forward(self, x: int, y: TensorType((2, 3, 4))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - expected_ph_types = [int, - TensorType((2, 3, 4)), - TensorType((2, 3, 4)), - TensorType((2, 3, 4))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - assert n.type == next(expected_iter) - - def test_type_check_add_false(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((1, 2, 3))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_add_true(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, Dyn)), y: TensorType((1, 2, 3))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - self.assertTrue(tc.type_check()) - - expected_ph_types = [TensorType((1, 2, Dyn)), TensorType((1, 2, 3))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - assert n.type == next(expected_iter) - if n.op == 'output': - assert n.type == TensorType((1, 2, Dyn)) - - def test_type_check_reshape_true(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 6))): - return torch.reshape(x, [1, 2, 3]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - self.assertTrue(tc.type_check()) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((1, 6)) - - if n.op == 'call_function': - assert n.type == TensorType((1, 2, 3)) - - if n.op == 'output': - assert n.type == TensorType((1, 2, 3)) - - def test_type_check_reshape_false(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 5))): - return torch.reshape(x, [1, 2, 3]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_reshape_dyn_false(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 5))): - return torch.reshape(x, [1, 2, -1]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_reshape_dyn_true(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 15))): - return torch.reshape(x, [1, 5, -1]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - self.assertTrue(tc.type_check()) - - def test_type_check_reshape_dyn_true_param_false(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((Dyn, 5))): - return torch.reshape(x, [1, 2, -1]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_transpose_true(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, 5))): - return torch.transpose(x, 0, 1) - - module = M() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - self.assertTrue(tc.type_check()) - - for n in symbolic_traced.graph.nodes: - if n.op == 'call_function': - assert n.type == TensorType([2, 1, 3, 5]) - if n.op == 'output': - assert n.type == TensorType([2, 1, 3, 5]) - if n.op == 'x': - assert n.placeholder == TensorType([1, 2, 3, 5]) - - def test_type_check_transpose_False(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, 5))): - return torch.transpose(x, 0, 10) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_batch_norm_2D(self): - class BasicBlock(torch.nn.Module): - - def __init__(self, inplanes, planes): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.bn1 = norm_layer(planes) - - def forward(self, x: TensorType((2, 2, 5, 4))): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((2, 2, 5, 4)) - if n.op == 'output': - assert n.type == TensorType((2, 2, 5, 4)) - if n.op == 'call_module': - assert n.type == TensorType((2, 2, 5, 4)) - if n.op == 'call_function': - assert n.type == TensorType((2, 2, 5, 4)) - - def test_type_check_batch_norm_2D_false(self): - class BasicBlock(torch.nn.Module): - - def __init__(self, inplanes, planes): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.bn1 = norm_layer(planes) - - def forward(self, x: TensorType((2, 2, 5))): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_batch_norm_2D_broadcast(self): - class BasicBlock(torch.nn.Module): - - def __init__(self, inplanes, planes): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.bn1 = norm_layer(planes) - - def forward(self, x: Dyn): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - for n in graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'call_function': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'output': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'call_module': - assert n.type == TensorType((2, 2, Dyn, 4)) - - B = BasicBlock(1, 1) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_conv2D(self): - class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, stride=1): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - - def forward(self, x: Dyn): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - for n in graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'call_function': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'output': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'call_module': - assert n.type == TensorType((2, 2, Dyn, 4)) - - def test_type_check_conv2D_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, stride=1): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - - def forward(self, x: TensorType((5, 2, 3, 4))): - identity = x - out = self.conv1(x) - out += identity - return out - - B = BasicBlock(2, 2) - b = B.forward(torch.rand(5, 2, 3, 4)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - t = TensorType((5, 2, 3, 4)) - for n in graph.nodes: - if n.op == 'placeholder': - assert n.type == t - if n.op == 'call_function': - assert n.type == t - if n.op == 'output': - assert torch.Size(n.type.__args__) == b.shape - if n.op == 'call_module': - assert n.type == t - - B = BasicBlock(1, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_conv2D_2_fully_static(self): - annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)] - input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, 15, 13, 14), (1, 2, 2, 3)] - intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6), (10, 15, Dyn, 5), - (10, 15, 7, 7), (1, Dyn, Dyn, Dyn)] - in_planes_list = [2, 5, 15, 15, 2] - stride_list = [1, 2, 3, 2, 2] - out_planes_list = [2, 5, 15, 15, 2] - groups_list = [1, 5, 5, 5, 2] - dilation_list = [1, 2, 3, 3, 3] - padding_list = [1, 2, 3, 3, 3] - kernel_size_list = [1, 2, 3, 3, 3] - output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, 2, Dyn, Dyn)] - - for i in range(5): - annotation = annotation_list[i] - input = input_list[i] - in_planes = in_planes_list[i] - stride = stride_list[i] - out_planes = out_planes_list[i] - groups = groups_list[i] - dilation = dilation_list[i] - padding = padding_list[i] - kernel_size = kernel_size_list[i] - intermediate_type = intermediate_types[i] - - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x): - out = self.conv1(x) - return out - - B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # annotate our argument - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType(annotation) - - b = B.forward(torch.rand(input)) - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in graph.nodes: - if n.op == 'output': - assert is_consistent(n.type, TensorType(b.size())) - - # test with intermediate annotations - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x): - out = self.conv1(x) - return out - - B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # populate our intermediate notes - for n in traced.graph.nodes: - if n.op == 'call_module': - n.type = TensorType(intermediate_type) - - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType(output_types[i]) - assert is_consistent(n.type, TensorType(b.size())) - - def test_typecheck_basicblock(self): - class BasicBlock(torch.nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = torch.nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x: TensorType((2, 2, 4, 5))): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - B = BasicBlock(2, 2) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in traced.graph.nodes: - if n.target == 'output': - assert isinstance(n.type, TensorType) - assert torch.Size(n.type.__args__) == B.forward(torch.rand(2, 2, 4, 5)).size() - - def test_type_check_conv2D_maxpool2d_flatten(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((4, 3, 32, 32))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - - expected_ph_types = [TensorType((4, 3, 32, 32)), TensorType((4, 6, 28, 28)), - TensorType((4, 6, 14, 14)), TensorType((4, 16, 10, 10)), - TensorType((4, 16, 5, 5)), TensorType((4, 16, 5, 120)), - TensorType((4, 16, 6, 7)), TensorType((4, 672)), TensorType((4, 672))] - - expected_iter = iter(expected_ph_types) - traced.graph.eliminate_dead_code() - - for n in traced.graph.nodes: - assert n.type == next(expected_iter) - - def test_type_check_flatten(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, 5, Dyn))): - return torch.flatten(x, 1, 2) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - for n in symbolic_traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType((1, 6, 5, Dyn)) - - - def test_type_check_flatten_2(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, Dyn, 3, 5, Dyn))): - return torch.flatten(x, 1, 2) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - for n in symbolic_traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType((1, Dyn, 5, Dyn)) - - def test_type_check_flatten3(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((2, 3, 4, 5))): - return torch.flatten(x, start_dim=1, end_dim=3) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - for n in symbolic_traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType((2, 60)) - r = Refine(symbolic_traced) - r.refine() - c = r.constraints - assert c == [Equality(2, 2)] - - def test_type_typechecl_maxpool2d_3dinput(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.pool = torch.nn.MaxPool2d(5, 8) - - def forward(self, x : TensorType((64, 8, 8))): - out = self.pool(x) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in traced.graph.nodes: - if n.target == 'output': - assert n.type == TensorType((64, 1, 1)) - - def test_type_maxpool2d_fully_static(self): - annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)] - input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, 15, 13, 14), (2, 2, 10, 10)] - intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4), (10, 15, Dyn, 2), - (10, 15, 2, 3), (2, Dyn, Dyn, Dyn)] - stride_list = [1, 2, 3, 2, 1] - dilation_list = [1, 2, 3, 3, 2] - padding_list = [1, 2, 3, 3, 1] - kernel_size_list = [2, 4, 6, 6, 3] - output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2), (10, 15, 2, 3), (2, Dyn, Dyn, 8)] - - for i in range(5): - annotation = annotation_list[i] - input = input_list[i] - stride = stride_list[i] - dilation = dilation_list[i] - padding = padding_list[i] - kernel_size = kernel_size_list[i] - intermediate_type = intermediate_types[i] - - class BasicBlock(torch.nn.Module): - def __init__(self, kernel_size, stride, padding, dilation): - super(BasicBlock, self).__init__() - self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, - padding=padding, dilation=dilation, - return_indices=False, ceil_mode=False) - - def forward(self, x): - out = self.pool(x) - return out - - B = BasicBlock(kernel_size, stride, padding, dilation) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # annotate our argument - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType(annotation) - - b = B.forward(torch.rand(input)) - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in graph.nodes: - if n.op == 'output': - assert is_consistent(n.type, TensorType(b.size())) - - # test with intermediate annotations - class BasicBlock(torch.nn.Module): - def __init__(self, kernel_size, stride, padding, dilation): - super(BasicBlock, self).__init__() - self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, - padding=padding, dilation=dilation, - return_indices=False, ceil_mode=False) - - def forward(self, x): - out = self.pool(x) - return out - - B = BasicBlock(kernel_size, stride, padding, dilation) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # annotate our argument - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType(annotation) - - # populate our intermediate notes - for n in traced.graph.nodes: - if n.op == 'call_module': - n.type = TensorType(intermediate_type) - - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType(output_types[i]) - assert is_consistent(n.type, TensorType(b.size())) - - def test_flatten_fully_static(self): - annotation_list = [Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)), - TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10))] - input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, 15, 13, 14), (2, 2, 10, 10)] - - intermediate_list = [Dyn, (2, 5, 6, 9), (10, 15, 13, 14), - (10, 15, 13, 14), (2, 2, 10, 10)] - - start_dim = [1, 2, 1, 2, 0] - end_dim = [1, 3, 3, 3, -2] - - for i in range(5): - annotation = annotation_list[i] - input = input_list[i] - # intermediate_type = intermediate_list[i] - - class BasicBlock(torch.nn.Module): - def __init__(self, start, end): - super(BasicBlock, self).__init__() - self.start = start - self.end = end - - def forward(self, x): - out = torch.flatten(x, self.start, self.end) - return out - - B = BasicBlock(start_dim[i], end_dim[i]) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # annotate our argument - for n in graph.nodes: - if n.op == 'placeholder': - n.type = annotation - - b = B.forward(torch.rand(input)) - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in graph.nodes: - if n.op == 'output': - assert is_consistent(n.type, TensorType(b.size())) - - @skipIfNoSympy - @skipIfNoTorchVision - def test_resnet50(self): - gm_run = symbolic_trace(resnet50()) - sample_input = torch.randn(1, 3, 224, 224) - - # run our nodes - ShapeProp(gm_run).propagate(sample_input) - - gm_static = symbolic_trace(resnet50()) - - for n in gm_static.graph.nodes: - n.type = None - - g = GraphTypeChecker({}, gm_static) - g.type_check() - gm_static.graph.eliminate_dead_code() - gm_run.graph.eliminate_dead_code() - # here we are checking for consistency with fully dynamic nodes - for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes): - assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape)) - - # here we give the same input as to runtume - gm_static_with_types = symbolic_trace(resnet50()) - - # we initialize our placeholder - for n in gm_static_with_types.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType((1, 3, 224, 224)) - - g = GraphTypeChecker({}, gm_static_with_types) - g.type_check() - for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes): - assert n1.type == TensorType(n2.meta['tensor_meta'].shape) - - # apply shape inference to graph and check - # that the batch size is equal across all layers - infer_symbolic_types(gm_static) - - - batch_sizes = set() - gm_static.graph.eliminate_dead_code() - for n in gm_static.graph.nodes: - assert isinstance(n.type, TensorType) - batch_sizes.add(n.type.__args__[0]) - assert (len(batch_sizes) == 1) - - @skipIfNoSympy - def test_type_check_batch_norm_symbolic(self): - class BasicBlock(torch.nn.Module): - - def __init__(self, inplanes, planes): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.bn1 = norm_layer(planes) - - def forward(self, x: Dyn): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - - infer_symbolic_types(traced) - - my_types = iter([TensorType[(2, 2, sympy.symbols('~7'), 4)], - TensorType[(2, 2, sympy.symbols('~7'), 4)], - TensorType[(2, 2, sympy.symbols('~7'), 4)], - TensorType[(2, 2, sympy.symbols('~7'), 4)]]) - - for n in graph.nodes: - assert n.type == next(my_types) - - @skipIfNoSympy - def test_symbolic_add_with_broadcast(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - infer_symbolic_types(symbolic_traced) - r = Refine(symbolic_traced) - r.refine() - - assert r.constraints == [Equality(1, 1), Equality(2, 2), Equality(3, 3)] - # note that there is no equality constraint between dyn and 4 because - # dyn could be 4 or 1 - - infer_symbolic_types(symbolic_traced) - - expected_ph_types = [TensorType((1, 2, 3, sympy.symbols('~0'))), - TensorType((2, 3, 4)), - TensorType((1, 2, 3, sympy.symbols('~1'))), - TensorType((1, 2, 3, sympy.symbols('~1')))] - expected_iter = iter(expected_ph_types) - - - for n in symbolic_traced.graph.nodes: - assert n.type == next(expected_iter) - - @skipIfNoSympy - def test_symbolic_add_with_broadcast_2(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - infer_symbolic_types(symbolic_traced) - r = Refine(symbolic_traced) - r.refine() - - expected_ph_types = [TensorType((1, 2)), - TensorType((sympy.symbols('~1'), 2)), - TensorType((sympy.symbols('~1'), 2)), - TensorType((sympy.symbols('~1'), 2))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - assert n.type == next(expected_iter) - - @skipIfNoSympy - def test_type_check_conv2D_types(self): - class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, stride=1): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - - def forward(self, x: Dyn): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - infer_symbolic_types(traced) - - for n in traced.graph.nodes: - if n.op == 'call_module': - assert isinstance(n.type.__args__[2], sympy.floor) - assert isinstance(n.type.__args__[3], sympy.floor) - - @skipIfNoSympy - def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((4, 3, Dyn, Dyn))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - traced = symbolic_trace(B) - tc = GraphTypeChecker({}, traced) - tc.type_check() - infer_symbolic_types(traced) - - for n in traced.graph.nodes: - if n.target == 'conv1': - assert n.type == TensorType((4, 6, sympy.floor((sympy.symbols('~0') - 4)), - sympy.floor((sympy.symbols('~1') - 4)))) - - elif n.target == 'conv2': - assert n.type == TensorType((4, 16, sympy.floor((sympy.symbols('~4') - 4)), - sympy.floor((sympy.symbols('~5') - 4)))) - -if __name__ == '__main__': - unittest.main() diff --git a/test/fx/test_pass_infra.py b/test/fx/test_pass_infra.py deleted file mode 100644 index e41aaf0a6..000000000 --- a/test/fx/test_pass_infra.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import torch -from torch.testing._internal.common_utils import TestCase - -import pippy -import pippy.fx as fx -from pippy.fx.passes.infra.pass_base import PassResult -from pippy.fx.passes.infra.pass_manager import ( - PassManager, - this_before_that_pass_constraint, - _topological_sort_passes, -) - - -def replace_add_with_mul_pass(gm): - modified = False - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.add: - node.target = torch.mul - modified = True - return PassResult(gm, modified) - -def replace_mul_with_div_pass(gm): - modified = False - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.mul: - node.target = torch.div - modified = True - return PassResult(gm, modified) - -class AddModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - y = torch.add(x, x) - z = torch.add(y, x) - return z - - -class TestPassManager(TestCase): - def test_pass_manager(self): - """ - Tests that the pass manager runs the passes correctly. - """ - - m = AddModule() - traced_m = pippy.fx.symbolic_trace(m) - pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass], steps=5) - - pm.validate_constraints() - self.assertEqual(len(pm.passes), 2) - - res = pm(traced_m) - modified_m = res.graph_module - assert isinstance(modified_m, fx.GraphModule) - - # Check that all call_function nodes are divs - for node in modified_m.graph.nodes: - if node.op == "call_function": - self.assertEqual(node.target, torch.div) - - def test_this_before_that_pass_constraint(self): - """ - Tests the construction of constraints - """ - passes = [lambda x: 2 * x for _ in range(10)] - pm = PassManager(passes) - - # add unfulfillable constraint - pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) - - with self.assertRaises(RuntimeError): - pm.validate_constraints() - - - def test_pass_manager_checks(self): - """ - Tests that users can add in check functions correctly - """ - m = AddModule() - traced_m = fx.symbolic_trace(m) - pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass]) - - def check_div_target(graph_module): - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target != torch.div: - raise ValueError("Target should be div!") - pm.add_checks(check_div_target) - - with self.assertRaises(ValueError): - pm(traced_m) - - def test_pass_manager_bad_checks(self): - """ - Checks that we error if we pass in a check function with the wrong parameters - """ - def check_bad_args(graph_module, i): - pass - - pm = PassManager() - self.assertRaises(TypeError, pm.add_checks, check_bad_args) - - def test_topological_sort(self): - """ - Tests that passes are correctly ordered based on contraints. - """ - - def pass0(x): - return x - - def pass1(x): - return x + 1 - - def pass2(x): - return x + 2 - - def pass3(x): - return x + 3 - - def pass4(x): - return x + 4 - - def pass5(x): - return x + 5 - - # Not passing any constraints should keep the original order - passes = [pass0, pass1, pass2, pass3, pass4, pass5] - sorted = _topological_sort_passes(passes, []) - self.assertEqual(sorted, passes) - - # Graph that we are constructing: - # 5 ----> 0 <---- 4 - # | | - # +-> 2 -> 3 -> 1 <-+ - # Which has a possible topological order of: [4, 5, 0, 2, 3, 1] - passes = [pass0, pass1, pass2, pass3, pass4, pass5] - constraints = [ - this_before_that_pass_constraint(pass5, pass0), - this_before_that_pass_constraint(pass5, pass2), - this_before_that_pass_constraint(pass4, pass0), - this_before_that_pass_constraint(pass4, pass1), - this_before_that_pass_constraint(pass2, pass3), - this_before_that_pass_constraint(pass3, pass1), - ] - sorted = _topological_sort_passes(passes, constraints) - self.assertEqual(sorted, [pass4, pass5, pass0, pass2, pass3, pass1]) - - # Circular dependency should result in the circular_dep flag being set - passes = [pass0, pass1, pass2] - constraints = [ - this_before_that_pass_constraint(passes[0], passes[1]), - this_before_that_pass_constraint(passes[1], passes[2]), - this_before_that_pass_constraint(passes[2], passes[0]), - ] - with self.assertRaises(RuntimeError) as e: - _topological_sort_passes(passes, constraints) - expected_error_msg = f"Circular dependency detected within the following passes: {passes}" - self.assertEqual(e.exception.args[0], expected_error_msg) - - def test_pass_manager_error(self): - """ - Tests error catching + debug - """ - def pass_fail(graph_module): - raise RuntimeError("bad") - - m = AddModule() - traced_m = pippy.fx.symbolic_trace(m) - pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass, pass_fail]) - - # Comment out this line to see the actual error message - with self.assertRaises(RuntimeError): - pm(traced_m) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py deleted file mode 100644 index d8a1bc77d..000000000 --- a/test/fx/test_subgraph_rewriter.py +++ /dev/null @@ -1,777 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import os -import sys - -import torch - -import pippy -from pippy.fx import symbolic_trace, subgraph_rewriter -from pippy.fx.annotate import annotate -# Make the helper files in test/ importable -from pippy.fx.experimental.rewriter import RewritingTracer - -pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -sys.path.append(pytorch_test_dir) -from torch.testing._internal.jit_utils import JitTestCase - -if __name__ == '__main__': - raise RuntimeError("This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_fx.py TESTNAME\n\n" - "instead.") - -@pippy.fx.wrap -def wrapped_gemm_bias_mul(a, b, bias): - lin_res = torch.nn.functional.linear(a, b, bias=bias) - mul_res = lin_res * a - return lin_res, mul_res - -@pippy.fx.wrap -def wrapped_gemm_bias_mul_with_c(a, b, bias, c): - lin_res = torch.nn.functional.linear(a, b, bias=bias) - mul_res = lin_res * c - return lin_res, mul_res - -class TestSubgraphRewriter(JitTestCase): - - def test_subgraph_rewriter_preserves_logic(self): - class M(torch.nn.Module): - def forward(self, x): - val = torch.neg(x) + torch.relu(x) - return torch.add(val, val) - - def pattern(x): - return torch.neg(x) + torch.relu(x) - - def comparison(x): - val = torch.neg(x) + torch.relu(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.rand(1, 3) - - # Replace `pattern` with the same pattern (shouldn't change - # the underlying logic) - subgraph_rewriter.replace_pattern(traced, pattern, pattern) - - traced.graph.lint() - - ref_output = comparison_fn(x) - test_output = traced.forward(x) - self.assertEqual(ref_output, test_output) - - def test_subgraph_rewriter_with_oneliner_pattern(self): - class M(torch.nn.Module): - def forward(self, x): - val = torch.neg(x) - return torch.add(val, val) - - def pattern(x): - return torch.neg(x) - - def replacement(x): - return torch.relu(x) - - def comparison(x): - val = torch.relu(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.rand(1, 3) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_output = comparison_fn(x) - test_output = traced.forward(x) - self.assertEqual(ref_output, test_output) - - def test_subgraph_rewriter_single_pattern_match(self): - class M(torch.nn.Module): - def forward(self, x): - val = torch.neg(x) + torch.relu(x) - return torch.add(val, val) - - def pattern(x): - return torch.neg(x) + torch.relu(x) - - def replacement(x): - return torch.relu(x) - - def comparison(x): - val = torch.relu(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.rand(1, 3) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_output = comparison_fn(x) - test_output = traced.forward(x) - self.assertEqual(ref_output, test_output) - - def test_subgraph_rewriter_multiple_pattern_match(self): - class M(torch.nn.Module): - def forward(self, x, w1, w2): - m1 = torch.cat([w1, w2]).sum() - m2 = torch.cat([w1, w2]).sum() - return x + torch.max(m1) + torch.max(m2) - - def pattern(w1, w2): - return torch.cat([w1, w2]).sum() - - def replacement(w1, w2): - return torch.stack([w1, w2]) - - def comparison(x, w1, w2): - m1 = torch.stack([w1, w2]) - m2 = torch.stack([w1, w2]) - return x + torch.max(m1) + torch.max(m2) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.rand(1, 3) - w1 = torch.rand(1, 3) - w2 = torch.rand(1, 3) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x, w1, w2) - test_outs = traced.forward(x, w1, w2) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_graph_argument_order(self): - class M(torch.nn.Module): - def forward(self, x, y): - return torch.mm(x, y) - - def pattern(x, y): - return torch.mm(x, y) - - def comparison(x, y): - return torch.mm(x, y) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - y = torch.randn(4, 5) - - subgraph_rewriter.replace_pattern(traced, pattern, pattern) - - traced.graph.lint() - - ref_outs = comparison_fn(x, y) - test_outs = traced.forward(x, y) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_correct_output_replacement(self): - class M(torch.nn.Module): - def forward(self, x, y): - val = torch.neg(y) + torch.relu(x) - return torch.add(val, val) - - def pattern(x): - return torch.relu(x) - - def replacement(x): - return torch.neg(x) - - def comparison(x, y): - val = torch.neg(y) + torch.neg(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(4, 4) - y = torch.randn(4, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x, y) - test_outs = traced.forward(x, y) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_traced_as_callable(self): - class M(torch.nn.Module): - def forward(self, x): - val = torch.neg(x) + torch.relu(x) - return torch.add(val, val) - - class Pattern(torch.nn.Module): - def forward(self, x): - return torch.neg(x) + torch.relu(x) - - class Replacement(torch.nn.Module): - def forward(self, x): - return torch.sigmoid(x) - - def comparison(x): - val = torch.sigmoid(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - traced_pattern = symbolic_trace(Pattern()) - traced_replacement = symbolic_trace(Replacement()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_pattern_is_entire_graph(self): - class M(torch.nn.Module): - def forward(self, x): - a = torch.neg(x) - return torch.add(a, a) - - def pattern(x): - a = torch.neg(x) - return torch.add(a, a) - - def replacement(x): - a = torch.sigmoid(x) - return torch.cat([a, a]) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(replacement) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self): - class M(torch.nn.Module): - def forward(self, x): - y = torch.relu(x) - return torch.neg(y) - y - - def pattern(x): - return torch.relu(x) - - def replacement(x): - return torch.sigmoid(x) - - def comparison(x): - y = torch.sigmoid(x) - return torch.neg(y) - y - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(self): - class M(torch.nn.Module): - def forward(self, x, w1, w2, b1, b2): - m0 = torch.cat([w1, w2]) - m1 = torch.cat([w1, w2]) - m2 = torch.cat([x, b2]) - t0 = torch.addmm(b1, m1, m2.t()) - t1 = torch.sum(w1, 1) - t2 = torch.addmm(b1, m1, m2.t()) - return torch.sum(t1), torch.sum(t2) - - def pattern(x, w1, w2, b1, b2): - m1 = torch.cat([w1, w2]) - m2 = torch.cat([x, b2]) - return torch.addmm(b1, m1, m2.t()) - - def replacement(x, w1, w2, b1, b2): - return torch.cat([x, w1, w2]) - - traced = symbolic_trace(M()) - - # Result should be [] since no matches can be found - res = subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - self.assertEqual(res, []) - - def test_subgraph_rewriter_placeholder_matching(self): - """ - This tests that a placeholder Node can be matched to a Node with - a different number of input Nodes. In the example below, the - original traced Module looks like this: - - opcode target args kwargs - ------------- ---------------------------------------------------------- ------------------------ -------- - placeholder x () {} - call_function (x, 3) {} - call_method dequantize (add,) {} - call_function (dequantize,) {} - call_method to (sigmoid, torch.float16) {} - output output (to,) {} - - while the pattern we want to match looks like this: - - opcode target args kwargs - ------------- ---------------------------------------------------------- ------------------------ -------- - placeholder x () {} - call_method dequantize (x,) {} - call_function (dequantize,) {} - call_method to (sigmoid, torch.float16) {} - output output (to,) {} - - Here, we want to be able to match the original graph's - `call_function.add` Node with the pattern graph's - `plaeholder.x` Node. - - Credit to Jerry Zhang (GitHub: jerryzh168) for this test case - """ - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.dtype = torch.float16 - - def forward(self, x): - x += 3 - x = x.dequantize() - x = torch.sigmoid(x) - dtype = self.dtype - x = x.to(dtype) - return x - - def pattern(x): - x = x.dequantize() - x = torch.sigmoid(x) - x = x.to(torch.float16) - return x - - def replacement(x): - return x - - def comparison(x): - return x + 3 - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_replaces_referenced_submodules(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - self.submod = torch.nn.ReLU() - - def forward(self, x): - x = x + 1 - return self.submod(self.sigmoid(x)) - - class Pattern(torch.nn.Module): - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - self.submod = torch.nn.ReLU() - - def forward(self, x): - return self.submod(self.sigmoid(x)) - - class Replacement(torch.nn.Module): - def __init__(self): - super().__init__() - self.tanh = torch.nn.Tanh() - self.submod = torch.nn.ReLU() - - def forward(self, x): - return self.submod(self.tanh(x)) - - class Comparison(torch.nn.Module): - def __init__(self): - super().__init__() - self.tanh = torch.nn.Tanh() - self.submod = torch.nn.ReLU() - - def forward(self, x): - x = x + 1 - return self.submod(self.tanh(x)) - - traced = symbolic_trace(M()) - comparison = Comparison() - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement()) - - traced.graph.lint() - - ref_outs = comparison(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - traced.get_submodule("tanh") - with self.assertRaisesRegex(AttributeError, "has no attribute"): - traced.get_submodule("sigmoid") - - submod = traced.get_submodule("submod") - self.assertEqual(type(submod), torch.nn.ReLU) - - def test_subgraph_rewriter_annotations_int(self): - - class M1(torch.nn.Module): - def forward(self, x): - y: int = x - return torch.add(x, y) - - class M2(torch.nn.Module): - def forward(self, x): - y = annotate(x, int) - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M1()) - - module = M2() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - for n, m in zip(symbolic_traced.graph.nodes, graph.nodes): - if n.op == 'placeholder': - assert n.type == int - assert m.type == int - - def test_subgraph_rewriter_replace_consecutive_submodules(self): - - def f(x): - x = torch.sigmoid(x) - x = torch.sigmoid(x) - return torch.sigmoid(x) - - def pattern(x): - return torch.sigmoid(x) - - def replacement(x): - return torch.exp(x) - - def comparison(x): - x = torch.exp(x) - x = torch.exp(x) - return torch.exp(x) - - traced = symbolic_trace(f) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_with_overlapping_matches(self): - - def f(x): - x = torch.sigmoid(x) - x = torch.sigmoid(x) - x = torch.sigmoid(x) - return torch.sigmoid(x) - - def pattern(x): - x = torch.sigmoid(x) - x = torch.sigmoid(x) - return x - - def replacement(x): - return torch.neg(x) - - def comparison(x): - x = torch.neg(x) - return torch.neg(x) - - traced = symbolic_trace(f) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_replace_with_multiple_outputs(self): - - def f(x): - y = torch.sigmoid(x) - z = torch.relu(x) - return y + z - - def pattern(a): - b = torch.sigmoid(a) - c = torch.relu(a) - return b, c - - def replacement(x): - return torch.exp(x), torch.abs(x) - - def comparison(x): - y = torch.exp(x) - z = torch.abs(x) - return y + z - - traced = symbolic_trace(f) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_replace_with_duplicated_outputs(self): - - def f(x1, x2): - x = x1 - x2 - y = torch.sigmoid(x) - z = torch.relu(x) - return y + z - - def pattern(a1, a2): - a = a1 - a2 - b = torch.sigmoid(a) - c = torch.relu(a) - return b, c, a - - def replacement(x1, x2): - y1 = torch.exp(x1) - y2 = torch.abs(x2) - return y2, y2, y1 - - def comparison(x1, x2): - y2 = torch.abs(x2) - return y2 + y2 - - traced = symbolic_trace(f) - comparison_fn = symbolic_trace(comparison) - - x1 = torch.randn(3, 4) - x2 = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x1, x2) - test_outs = traced.forward(x1, x2) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_with_unused_args(self): - class M(torch.nn.Module): - def forward(self, x, y, z): - return x + y - - def pattern(x, y): - return x + y - - def replacement(x, y): - return x - y - - def comparison(x1, x2, x3): - return x1 - x2 - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x1 = torch.randn(3, 4) - x2 = torch.randn(3, 4) - x3 = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - placeholder_nodes = [n for n in traced.graph.nodes if n.op == "placeholder"] - assert len(placeholder_nodes) == 3 - - ref_outs = comparison_fn(x1, x2, x3) - test_outs = traced.forward(x1, x2, x3) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_call_method(self): - - class M(torch.nn.Module): - def forward(self, x): - x = x.dequantize() - x = x.sigmoid() - x = x.to(torch.float16) - return x - - def pattern(x): - x = x.dequantize() - x = x.sigmoid() - x = x.to(torch.float16) - return x - - def replacement(x): - return x - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(replacement) - - x1 = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x1) - test_outs = traced.forward(x1) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_nodes_with_kwargs(self): - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.w0 = torch.nn.Parameter(torch.empty([128, 128])) - self.b0 = torch.nn.Parameter(torch.empty([128])) - - def forward(self, in0): - lin_res = torch.nn.functional.linear(in0, self.w0, bias=self.b0) - mul_res = in0 * lin_res - sum_res = mul_res + in0 - return sum_res - - def pattern(a, b, bias): - lin_res = torch.nn.functional.linear(a, b, bias=bias) - mul_res = a * lin_res - return lin_res, mul_res - - def replacement(a, b, bias): - lin_res, mul_res = wrapped_gemm_bias_mul(a, b, bias) - return lin_res, mul_res - - traced = symbolic_trace(M()) - matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - self.assertEqual(len(matches), 1) - - found_repalcement_node = False - for node in traced.graph.nodes: - if node.target == wrapped_gemm_bias_mul: - found_repalcement_node = True - break - - self.assertTrue(found_repalcement_node) - - def test_subgraph_rewriter_local_revert(self): - - # Following model will have 3 anchors as the matching candidate with the given pattern - # Anchor 1 and 3 is a real match, but anchor 2 is not. - # The subgraph rewriter should be able to revert the changes made while matching anchor 2. - # Final match with anchor 3 should be successful. - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.w0 = torch.nn.Parameter(torch.empty([128, 128])) - self.b0 = torch.nn.Parameter(torch.empty([128])) - self.w1 = torch.nn.Parameter(torch.empty([128, 128])) - self.b1 = torch.nn.Parameter(torch.empty([128])) - self.w2 = torch.nn.Parameter(torch.empty([128, 128])) - self.b2 = torch.nn.Parameter(torch.empty([128])) - self.w3 = torch.nn.Parameter(torch.empty([128, 128])) - self.b3 = torch.nn.Parameter(torch.empty([128])) - self.w4 = torch.nn.Parameter(torch.empty([128, 128])) - self.b4 = torch.nn.Parameter(torch.empty([128])) - - def forward(self, in0, in1): - lin_res_1 = torch.nn.functional.linear(in1, self.w0, bias=self.b0) - lin_res_2 = torch.nn.functional.linear(lin_res_1, self.w1, bias=self.b1) - # potential match at anchor 1 - mul_res_1 = in1 * lin_res_2 - sum_res_1 = mul_res_1 + in1 - lin_res_3 = torch.nn.functional.linear( - sum_res_1, self.w2, bias=self.b2 - ) - sigmoid_res_1 = torch.sigmoid(lin_res_3) - # potential match at anchor 2 - mul_res_2 = lin_res_3 * sigmoid_res_1 - lin_res_4 = torch.nn.functional.linear(in0, self.w3, bias=self.b3) - lin_res_5 = torch.nn.functional.linear(lin_res_4, self.w4, bias=self.b4) - # potential match at anchor 3 - mul_res_3 = in0 * lin_res_5 - sum_res_2 = mul_res_3 + in0 - cat_res = torch.cat( - [mul_res_2, sum_res_2], - dim=1, - ) - return cat_res - - def gemm_bias_mul_pattern_with_c(a, b, bias, c): - lin_res = torch.nn.functional.linear(a, b, bias=bias) - mul_res = c * lin_res - return lin_res, mul_res - - def gemm_bias_mul_replacement_with_c(a, b, bias, c): - lin_res, mul_res = wrapped_gemm_bias_mul_with_c(a, b, bias, c) - return lin_res, mul_res - - traced = symbolic_trace(M()) - matches = subgraph_rewriter.replace_pattern( - traced, - gemm_bias_mul_pattern_with_c, - gemm_bias_mul_replacement_with_c) - - self.assertEqual(len(matches), 2) - - repalcement_node_found = 0 - for node in traced.graph.nodes: - if node.target == wrapped_gemm_bias_mul_with_c: - repalcement_node_found += 1 - - self.assertEqual(repalcement_node_found, 2) diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py deleted file mode 100644 index 567a31dfe..000000000 --- a/test/fx/test_z3_gradual_types.py +++ /dev/null @@ -1,2481 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] -import operator -import unittest -from pippy.fx import GraphModule, symbolic_trace -from pippy.fx.experimental.meta_tracer import symbolic_trace as meta_symbolic_trace -from pippy.fx.experimental.migrate_gradual_types.constraint import BinConstraintT, DVar, TVar, T -from pippy.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator -from pippy.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint -from pippy.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency -from pippy.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints,\ - evaluate_conditional_with_constraints -from pippy.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn -from pippy.fx.experimental.rewriter import RewritingTracer -from pippy.fx.tensor_type import Dyn, TensorType -import torch - - -try: - import z3 # type: ignore[import] - HAS_Z3 = True -except ImportError: - HAS_Z3 = False - - -try: - from torchvision import models - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") - -class TorchDynamoUseCases(unittest.TestCase): - - def test_dim(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: TensorType([1, 2])): - y = x.dim() - return y - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - y_res = z3.z3.Int(2) - self.assertEqual(s.model()[y_res], 2) - - - def test_reshape(self): - """ - In this example, we prove that some nodes must - always have a fixed shape regardless of the input - """ - - class BasicBlock(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: Dyn): - y = x.view(100) - tmp = y.size()[0] - return tmp - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - dim = z3.Int(4) - self.assertEqual(s.model()[dim], 100) - # print(s.model()[dim]) - - -class HFOperations(unittest.TestCase): - - def test_eq_dim(self): - """ - test dimensions and equalities - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([32, 4, 4])): - eq = x.dim() == 3 - return eq - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.eq: - node = n - - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.unsat) - - def test_conditional_ne_1(self): - """ - This test case is for the HFmodels interface. - A function takes a node and a graph and considers - the conditional the node represents and its negation - and solves each formula with the remaining sets of constraints - Returns: - - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([32, 4, 4]), y: TensorType([32, 4, 4])): - size_5 = x.size() - getitem_7 = size_5[0] - getitem_8 = size_5[1] - getitem_9 = size_5[2] - ne_1 = y != (getitem_7, getitem_8, getitem_9) - return ne_1 - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.ne: - node = n - - # since x and y are equal, the requirement that x != y cannot be true, so we should get unsat - # for the positive condition and sat for the negative condition - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - self.assertEqual(positive, z3.unsat) - self.assertEqual(negative, z3.sat) - - def test_bmm(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 2, 3]), y: TensorType([1, 3, 2])): - bmm = torch.bmm(x, y) - return bmm - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - b = BasicBlock().forward(torch.rand(1, 2, 3), torch.rand(1, 3, 2)) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - - output = z3.Const(3, tensor_type) - self.assertEqual(s.check(), z3.sat) - self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) - self.assertEqual(s.model()[output].arg(1).arg(1), b.shape[1]) - self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) - - - def test_bmm2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: TensorType([1, 3, 2])): - bmm = torch.bmm(x, y) - return bmm - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - b = BasicBlock().forward(torch.rand(1, 2, 3), torch.rand(1, 3, 2)) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - - output = z3.Const(3, tensor_type) - self.assertEqual(s.check(), z3.sat) - self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) - self.assertEqual(s.model()[output].arg(1).arg(0), 0) - self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) - - def test_bmm3(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 3, 3]), y: TensorType([1, 3, 2])): - bmm = torch.bmm(x, y) - return bmm - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.unsat) - - - def test_transpose(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([1, 2, 3, 4])): - transpose = x.transpose(0, 1) - return transpose - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - b = BasicBlock().forward(torch.rand(1, 2, 3, 4)) - - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - output = z3.Const(2, tensor_type) - self.assertEqual(s.check(), z3.sat) - self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) - self.assertEqual(s.model()[output].arg(1).arg(1), b.shape[1]) - self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) - self.assertEqual(s.model()[output].arg(3).arg(1), b.shape[3]) - - # change the annotation to Dyn - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - - def test_index_select(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2050, 1024]), y: Dyn): - index_select = x.index_select(0, y) - return index_select - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - # print(symbolic_traced) - b = BasicBlock().forward(torch.rand(2050, 1024), torch.ones(8).int()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - index_select = z3.Const(3, tensor_type) - - # the second dimension of the result should not be affected since - # the index is 0 - self.assertEqual(s.model()[index_select].arg(1).arg(1), b.shape[1]) - - replacement_vector = z3.Const(2, tensor_type) - - # we set the vector to Dyn - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - index_select = z3.Const(3, tensor_type) - s.add(replacement_vector == z3_dyn) - self.assertEqual(s.check(), z3.sat) - - # this implies that the index at 0 should be dyn - self.assertEqual(s.model()[index_select].arg(0).arg(0), 0) - - def test_get_attr(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([1, 2, 3])): - getattr = x.device - to = x.to(getattr) - return to - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - b = BasicBlock().forward(torch.rand(1, 2, 3)) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - attr_res = z3.Const(3, tensor_type) - assert s.model()[attr_res].arg(0).arg(1) == b.shape[0] - assert s.model()[attr_res].arg(1).arg(1) == b.shape[1] - assert s.model()[attr_res].arg(2).arg(1) == b.shape[2] - - - def test_expand(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([1, 4])): - size = x.size() - getitem = size[-1] - expand = x.expand(getitem, 4) - return expand - - b = BasicBlock().forward(torch.rand(1, 4)) - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - expand_res = z3.Const(4, tensor_type) - assert s.model()[expand_res].arg(0).arg(1) == b.shape[0] - assert s.model()[expand_res].arg(1).arg(1) == b.shape[1] - - # change the annotation on the input to Dyn. - # the last dimension should still be 4 - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - assert s.model()[expand_res].arg(1).arg(1) == b.shape[1] - - def test_getitem_tensor(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([4, 4])): - getitem = x[(None, None, slice(None, None, None), slice(None, None, None))] - return getitem - - B = BasicBlock() - b = B.forward(torch.rand(4, 4)) - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(B) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - get_item_res = z3.Const(2, tensor_type) - assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] - assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] - assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] - assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] - - # change the annotation on the input to make sure it propagates - # to the output - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, 4]) - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - # dyn check - assert s.model()[get_item_res].arg(2).arg(0) == 0 - - - def test_getitem_tensor2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([4, 4])): - getitem = x[(None, None)] - return getitem - - B = BasicBlock() - b = B.forward(torch.rand(4, 4)) - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(B) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - get_item_res = z3.Const(2, tensor_type) - assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] - assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] - assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] - assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] - - - def test_getitem_tensor_3(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([4, 4])): - getitem = x[(None, slice(None, None, None), None, slice(None, None, None))] - return getitem - - B = BasicBlock() - b = B.forward(torch.rand(4, 4)) - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(B) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - get_item_res = z3.Const(2, tensor_type) - assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] - assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] - assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] - assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] - - - - def test_layer_norm(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.l = torch.nn.LayerNorm((1024,)) - - def forward(self, x: Dyn): - return self.l(x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # make the output a size 1 tensor which should result - # in the migration of the input - - b = BasicBlock().forward(torch.rand(1024)) - input = z3.Const(1, tensor_type) - output = z3.Const(2, tensor_type) - s.add(output == tensor_type.tensor1(D(1, 1024))) - s.check() - self.assertEqual(s.model()[input], s.model()[output]) - # input shape = output shape - self.assertEqual(b.shape[0], s.model()[input].arg(0).arg(1)) - - # change annotation to the wrong shape - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([10, 10]) - - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.unsat) - - # fix the annotation - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([10, 1024]) - - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - s.check() - b = BasicBlock().forward(torch.rand(10, 1024)).shape - self.assertEqual(s.model()[output].arg(0).arg(1), b[0]) - self.assertEqual(s.model()[output].arg(1).arg(1), b[1]) - - - def test_layer_norm_functional(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - return torch.nn.functional.layer_norm(x, (1024,)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # make the output a size 1 tensor which should result - # in the migration of the input - - b = BasicBlock().forward(torch.rand(1024)) - input = z3.Const(1, tensor_type) - output = z3.Const(2, tensor_type) - s.add(output == tensor_type.tensor1(D(1, 1024))) - s.check() - self.assertEqual(s.model()[input], s.model()[output]) - # input shape = output shape - self.assertEqual(b.shape[0], s.model()[input].arg(0).arg(1)) - - def test_ne_int_long_type_as(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn, Dyn])): - ne_int = torch.ne(x, y).int() - type_as = ne_int.type_as(y) - long = type_as.long() - return long - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # migrate one of the parameters to a fully static shape so we can compare - - input = z3.Const(1, tensor_type) - input_2 = z3.Const(2, tensor_type) - s1, s2 = z3.Ints('s1 s2') - - output_long = z3.Const(8, tensor_type) - s.add(input == tensor_type.tensor2(D(1, 2), D(1, 4))) - s.add(input_2 == tensor_type.tensor2(D(1, s1), D(1, s2))) - - self.assertEquals(s.check(), z3.sat) - actual_shape = BasicBlock().forward(torch.rand(2, 4), torch.rand(2, 4)).shape - self.assertEqual(s.model()[output_long].arg(0).arg(1), actual_shape[0]) - self.assertEqual(s.model()[output_long].arg(1).arg(1), actual_shape[1]) - - - def test_ne(self): - s1, s2 = z3.Ints('s1 s2') - s11, s22 = z3.Ints('s11 s22') - d1, d2 = D(s11, s1), D(0, s2) - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.ne(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # change the annotations - for n in graph.nodes: - if n.name == 'x': - n.type = TensorType([1, 2]) - if n.name == 'y': - n.type = TensorType([2, Dyn]) - - # resulting type should be TensorType([2, 2]) - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # force the second dimension to be Dyn - # output should still be TensorType([2, 2]) - input = z3.Const(2, tensor_type) - s.add(input == tensor_type.tensor2(d1, d2)) - self.assertEqual(s.check(), z3.sat) - B = BasicBlock().forward(torch.rand(1, 2), torch.rand(2, 1)) - output = z3.Const(3, tensor_type) - self.assertEqual(s.model()[output].arg(0).arg(1), B.shape[0]) - self.assertEqual(s.model()[output].arg(1).arg(1), B.shape[0]) - - - def test_cumsum(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 4, 3])): - t = torch.cumsum(x, 3) - return t - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - - # should be unsat since the index is not valid for this annotation - self.assertEqual(s.check(), z3.unsat) - - # modify the annotation to Dyn which should give sat - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - # # modify the annotation to the right tensor size - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([1, 2, 3, 4]) - - # verify that the input is equal to the output - B = BasicBlock().forward(torch.rand(1, 2, 3, 4)) - res_shape = B.shape - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - # confirm the output matches the expected tensor - result = z3.Const(2, tensor_type) - self.assertEqual(s.model()[result].arg(0).arg(1), res_shape[0]) - self.assertEqual(s.model()[result].arg(1).arg(1), res_shape[1]) - self.assertEqual(s.model()[result].arg(2).arg(1), res_shape[2]) - self.assertEqual(s.model()[result].arg(3).arg(1), res_shape[3]) - - # confirm the output is not dyn - self.assertNotEqual(s.model()[result].arg(0).arg(0).as_long(), 0) - self.assertNotEqual(s.model()[result].arg(1).arg(0).as_long(), 0) - self.assertNotEqual(s.model()[result].arg(2).arg(0).as_long(), 0) - self.assertNotEqual(s.model()[result].arg(3).arg(0).as_long(), 0) - - - def test_cumsum_kwargs(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 4, 3])): - t = torch.cumsum(x, dim=3) - return t - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - - # should be unsat since the index is not valid for this annotation - self.assertEqual(s.check(), z3.unsat) - - # modify the annotation to Dyn which should give sat - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - - def test_arange(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - size = x.size() - getitem = size[-1] - arange = torch.arange(getitem) - return arange - - B = BasicBlock().forward(torch.rand(2, 4)) - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - arange_result = z3.Const(5, tensor_type) - self.assertNotEqual(s.model()[arange_result].arg(0).arg(0).as_long(), 0) - self.assertEqual(s.model()[arange_result].arg(0).arg(1).as_long(), B.size()[0]) - - # change the annotation to Dyn. This will migrate to an arbitirary type - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn, Dyn, Dyn]) - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - def test_scalar_add(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - size = x.size() - getitem = size[-1] - arange = torch.arange(getitem) - add = arange + 1 - return add - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - arange_result = z3.Const(5, tensor_type) - add_result = z3.Const(6, tensor_type) - self.assertEqual(s.model()[arange_result], s.model()[add_result]) - - - def test_regular_add_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - to = x.to() - size = to.size() - getitem = size[-1] - add = getitem + 1 - return add - - b = BasicBlock().forward(torch.rand(2, 4)) - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - res = z3.Int(5) - self.assertEqual(s.model()[res], b) - - - def test_regular_add_3(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - to = x.to() - size = to.size() - getitem = size[-1] - add = 1 + getitem - return add - - b = BasicBlock().forward(torch.rand(2, 4)) - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - res = z3.Int(5) - self.assertEqual(s.model()[res], b) - - def test_embedding(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.embedding = torch.nn.Embedding(256008, 1024, padding_idx=1) - - def forward(self, x: TensorType([2, 4])): - return self.embedding(x) - - B = BasicBlock().forward(torch.ones([2, 4], dtype=torch.long)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - embedding_result = z3.Const(2, tensor_type) - - assert s.model()[embedding_result].arg(0).arg(1) == B[0] - assert s.model()[embedding_result].arg(1).arg(1) == B[1] - assert s.model()[embedding_result].arg(2).arg(1) == B[2] - - # change the type. This should still be satisfiable - for n in traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn]) - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - assert s.model()[embedding_result].arg(0).arg(0) == 0 - assert s.model()[embedding_result].arg(1).arg(0) == 0 - assert s.model()[embedding_result].arg(2).arg(1) == B[2] - - # change the type to Dyn. Here, we will get an arbitirary migration - for n in traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - - - def test_embedding_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4]), y: TensorType([Dyn, 1024])): - return torch.nn.functional.embedding(x, y) - - B = BasicBlock().forward(torch.ones([2, 4], dtype=torch.long), torch.rand(256008, 1024)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - embedding_result = z3.Const(5, tensor_type) - - assert s.model()[embedding_result].arg(0).arg(1) == B[0] - assert s.model()[embedding_result].arg(1).arg(1) == B[1] - assert s.model()[embedding_result].arg(2).arg(1) == B[2] - - def test_size_two_args(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 2, Dyn])): - size = x.size(-1) - return size - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - d1, d2 = z3.Int(39), z3.Int(2) - d4, d5 = z3.Int('input_d1'), z3.Int('input_d2') - - # migrate the third dimension - s.add(d1 != 0) - - self.assertEqual(s.check(), z3.sat) - input = z3.Const(1, tensor_type) - s.add(input == tensor_type.tensor3(D(3, 39), D(1, 2), D(d4, d5))) - - # check if the item we got is the right one - self.assertEqual(s.check(), z3.sat) - self.assertEqual(s.model()[d5], s.model()[d2]) - self.assertEqual(s.model()[d1], s.model()[d4]) - - def test_size_getitem(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - size = x.size() - getitem = size[-1] - return getitem - - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - - # force the input to be of size 4 - - s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') - s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - input = z3.Const(1, tensor_type) - s.add(input == tensor_type.tensor4(d1, d2, d3, d4)) - - # check if the model is still SAT - self.assertEquals(s.check(), z3.sat) - - s1, s2 = z3.Int(23), z3.Int(3) - - # check that the item is correct - self.assertEquals(s.model()[s1], s.model()[s2]) - - # invalid index but should still be SAT because input will be Dyn - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - size = x.size() - getitem = size[-10] - return getitem - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - s.add(input != z3_dyn) - self.assertEqual(s.check(), z3.unsat) - - def test_view_mul(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) - - def forward(self, x: TensorType([2, 4])): - size = x.size() - getitem = size[-1] - view = x.view(-1, getitem) - embed_tokens = self.embed_tokens(view) - mul = embed_tokens * 32.0 - return mul - - - # print(B) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # print(traced) - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - # print(s.model()) - - embedding_result = z3.Const(6, tensor_type) - - # note that the view output will be: tensor3(dim(0, 0), dim(1, 4), dim(1, 1024)) - # this is due to the reshape constraints. This can be lifted - # but would require revising the type rules accordingly so we leave it for now - assert (s.model()[embedding_result].arg(1).arg(1)) == 4 - assert (s.model()[embedding_result].arg(2).arg(1)) == 1024 - - mul_result = z3.Const(13, tensor_type) - assert s.model()[mul_result] == s.model()[embedding_result] - - def test_gt(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 4])): - size = x.size() - getitem_1 = size[-1] - gt = getitem_1 > 1 - return gt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - res = z3.Bool(4) - self.assertEqual(s.model()[res], True) - - def test_view(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - view = x.view(-1, 8) - return view - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - def test_lt_tensor(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4]), y: Dyn): - lt = x > y - return lt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - - def test_conditional_wrong_assumption(self): - """ - Test condition after making the wrong assumption about the input - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - gt = x > 1 - return gt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.gt: - node = n - - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.sat) - - def test_conditional(self): - """ - This test case is for the HFmodels interface. - A function takes a node and a graph and considers - the conditional the node represents and its negation - and solves each formula with the remaining sets of constraints - Returns: - - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) - - def forward(self, x: TensorType([Dyn, 4])): - size = x.size() - getitem = size[-1] - view = x.view(-1, getitem) - embed_tokens = self.embed_tokens(view) - mul = embed_tokens * 32.0 - getitem_1 = size[-1] - gt = getitem_1 > 1 - return gt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.gt: - node = n - - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.unsat) - - # change the annotation to Dyn - for n in graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - # here, both should be SAT since the input is Dyn - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.sat) - - - # change the annotation to TensorType[Dyn, Dyn] - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn]) - - # here, both should be SAT as well - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.sat) - - - def test_conditional_2(self): - """ - This test case is for the HFmodels interface. - A function takes a node and a graph and considers - the conditional the node represents and its negation - and solves each formula with the remaining sets of constraints - Returns the opposite result of the above testcase - - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) - - def forward(self, x: TensorType([Dyn, 4])): - size = x.size() - getitem = size[-1] - view = x.view(-1, getitem) - embed_tokens = self.embed_tokens(view) - mul = embed_tokens * 32.0 - getitem_1 = size[-1] - lt = getitem_1 < 1 - return lt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.lt: - node = n - - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - self.assertEqual(positive, z3.unsat) - self.assertEqual(negative, z3.sat) - - -class ComposeOperationsGradualTypes(unittest.TestCase): - - def test_masked_fill(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - size = x.size() - getitem = size[-1] - arange = torch.arange(getitem) - view = x.view(-1, getitem) - lt = arange > view - masked_fill = x.masked_fill_(lt, 0) - return masked_fill - - B = BasicBlock().forward(torch.rand(2, 4)) - # print(B.shape) - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - # print(symbolic_traced) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - masked_fill_res = z3.Const(10, tensor_type) - self.assertEqual(s.model()[masked_fill_res].arg(0).arg(1).as_long(), B.size()[0]) - self.assertEqual(s.model()[masked_fill_res].arg(1).arg(1).as_long(), B.size()[1]) - - # change the annotation to Dyn. This will migrate to an arbitirary type - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn, Dyn, Dyn]) - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - def test_add_reshape_1(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.add(torch.reshape(x, (1, 2)), torch.reshape(y, (2, 2))) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - def test_add_reshape_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.add(torch.reshape(x, (-1, 2)), torch.reshape(y, (2, 2, 2))) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - def test_conv_reshape_add_0(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: Dyn): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - - - def test_conv_reshape_add_0_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: TensorType([4, 1])): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - - # 4,1 - # 1, 2, 4, 8 - res = B.forward(torch.rand(20, 20), torch.rand(1, 2, 4, 8)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - - - conv_result = z3.Const(4, tensor_type) - add_result = z3.Const(9, tensor_type) - input_2 = z3.Const(2, tensor_type) - - s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') - s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - - solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) - solver.check() - assert solver.model()[s1].as_long() == res[0] - assert solver.model()[s2].as_long() == res[1] - assert solver.model()[s3].as_long() == res[2] - assert solver.model()[s4].as_long() == res[3] - - solver.add(input_2 == tensor_type.tensor2(D(1, 4), D(1, 1))) - self.assertEquals(solver.check(), z3.sat) - solver.add(add_result == tensor_type.tensor4(d1, d2, d3, d4)) - self.assertEquals(solver.check(), z3.sat) - - # first dimension could be anything because we have broadcasting - assert solver.model()[s1] == res[0] - assert solver.model()[s2] == res[1] - assert solver.model()[s3] == res[2] - assert solver.model()[s4] == res[3] - - def test_conv_reshape_add_0_3(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: TensorType([11, 1])): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.unsat) - - - def test_conv_reshape_add_1(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: TensorType([1, 2, 10, 20])): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.unsat) - - -class GradualTypes(unittest.TestCase): - def test_conv_reshape_unsat(self): - - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn): - return self.conv1(torch.reshape(x, (1, 2, 10))) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.unsat) - - def test_conv_reshape0(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn): - return self.conv1(torch.reshape(x, (1, 2, 10, 20))) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - res = B.forward(torch.rand(20, 20)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - conv_result = z3.Const(3, tensor_type) - - s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') - s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) - solver.check() - # print(solver.model()) - # print(type(solver.model()[s1])) - assert solver.model()[s1].as_long() == res[0] - assert solver.model()[s2].as_long() == res[1] - assert solver.model()[s3].as_long() == res[2] - assert solver.model()[s4].as_long() == res[3] - - s1, s2, s3, s4 = z3.Ints('y1 y2 y3 y4') - s11, s22, s33, s44 = z3.Ints('y11 y22 y33 y44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - input = z3.Const(1, tensor_type) - solver.add(input == tensor_type.tensor4(d1, d2, d3, d4)) - - # assert solver.check() == sat - # solver.add(s11 == 1) - # solver.add(s22 == 1) - # solver.add(s33 == 1) - # solver.add(s44 == 1) - # - # print(solver.check()) - # print(solver.model()) - - - def test_conv_reshape1(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: TensorType([20, 20])): - return self.conv1(torch.reshape(x, (1, -1, 10, 20))) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - res = B.forward(torch.rand(20, 20)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - conv_result = z3.Const(3, tensor_type) - - s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') - s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) - solver.check() - # print(solver.model()) - assert solver.model()[s1].as_long() == res[0] - assert solver.model()[s2].as_long() == res[1] - assert solver.model()[s3].as_long() == res[2] - assert solver.model()[s4].as_long() == res[3] - - -class TestSingleOperation(unittest.TestCase): - - def test_conv_wrong_example(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=2, out_channels=2, - kernel_size=2, stride=2, - padding=2, groups=2, bias=False, dilation=2) - - self.conv2 = torch.nn.Conv2d(in_channels=4, out_channels=2, - kernel_size=2, stride=2, - padding=2, groups=2, bias=False, dilation=2) - - self.relu = torch.nn.ReLU(inplace=True) - - def forward(self, x: Dyn): - y = self.relu(self.conv1(x)) - z = self.relu(self.conv2(x)) - return z - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - - solver3 = z3.Solver() - solver3.add(transformed) - print(solver3.check()) - assert solver3.check() == z3.sat - - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - x = z3.Const(1, tensor_type) - solver3.add(x == tensor_type.tensor4(d1, d2, d3, d4)) - assert solver3.check() == z3.sat - - solver3.add(s22 != 0) - assert solver3.check() == z3.unsat - - def test_conv_dyn(self): - - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - e1, e2, e3, e4 = z3.Ints('e1 e2 e3 e4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - e11, e22, e33, e44 = z3.Ints('e11 e22 e33 e44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - b1, b2, b3, b4 = D(e11, e1), D(e22, e2), D(e33, e3), D(e44, e4) - - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn): - return self.conv1(x) - - BasicBlock(2, 2, 2, 2, 2, 2, 2).forward(torch.rand(4, 2, 3, 4)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock(2, 2, 2, 2, 2, 2, 2)) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - - solver3 = z3.Solver() - solver3.add(transformed) - assert solver3.check() == z3.sat - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - solver3.add(x == tensor_type.tensor4(d1, d2, d3, d4), - y == tensor_type.tensor4(b1, b2, b3, b4)) - - assert solver3.check() == z3.sat - assert solver3.model()[s1].as_long() == solver3.model()[e1].as_long() - assert solver3.model()[s11].as_long() == solver3.model()[e11].as_long() - - solver3.add(s2 != 2) - assert solver3.check() == z3.sat - assert solver3.model()[s22].as_long() == 0 - - solver3.add(s22 != 0) - self.assertEquals(solver3.check(), z3.unsat) - - solver2 = z3.Solver() - solver2.add(transformed) - assert solver2.check() == z3.sat - solver2.add(x == tensor_type.tensor3(d1, d2, d3)) - self.assertEquals(solver2.check(), z3.unsat) - - - def test_add(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # make the tensor be of size 1 - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - self.assertEquals(s.check(), z3.sat) - - y = z3.Const(2, tensor_type) - s.add(y == tensor_type.tensor1(D(1, s22))) - self.assertEquals(s.check(), z3.sat) - - s.add(s11 == 1) # tensor[1] - s.add(s22 == 2) # tensor[2] - self.assertEquals(s.check(), z3.sat) - - class BasicBlock2(torch.nn.Module): - def __init__(self): - super(BasicBlock2, self).__init__() - - def forward(self, x: TensorType((Dyn,)), y: Dyn): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock2()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - # make the tensor be of size 1 - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - self.assertEquals(s.check(), z3.sat) - y = z3.Const(2, tensor_type) - s.add(y == tensor_type.tensor1(D(1, s22))) - self.assertEquals(s.check(), z3.sat) - s.add(s11 == 4) # tensor[4] - s.add(s22 == 5) # tensor[5] - self.assertEquals(s.check(), z3.unsat) - - class BasicBlock3(torch.nn.Module): - def __init__(self): - super(BasicBlock3, self).__init__() - - def forward(self, x: TensorType((Dyn,)), y: Dyn): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock3()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor2(d1, d2)) - self.assertEquals(s.check(), z3.unsat) - - def test_add_padding(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType((Dyn,)), y: TensorType((Dyn, Dyn))): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s1))) - - self.assertEquals(s.check(), z3.sat) - - # print(s.model()) - - def test_add_padding_2(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - # print(s.model()) - - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor2(D(1, s1), D(1, s2))) - self.assertEquals(s.check(), z3.sat) - - y = z3.Const(2, tensor_type) - s.add(y == tensor_type.tensor1(D(0, s3))) - self.assertEquals(s.check(), z3.sat) - - add_result = z3.Const(3, tensor_type) - broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const(5, tensor_type) - - # print(s.model()) - - assert s.model()[broadcast_res1].decl() == tensor_type.tensor2 - assert s.model()[broadcast_res2].decl() == tensor_type.tensor2 - assert s.model()[add_result].decl() == tensor_type.tensor2 - assert s.model()[y].decl() == tensor_type.tensor1 - - # print(s.model()) - - # prevent broadcasting for that dimension - s.add(s2 > 1) - - assert s.check() - - # the second dimension of the result is a number, not Dyn. - # however if the first input dimension had been 1, we would - # have had dyn in the result, as seen in the next test case - assert s.model()[add_result].arg(1).arg(0).as_long() != 0 - - def test_add_padding_3(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 1]), y: TensorType([Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - # print(transformed) - self.assertEquals(s.check(), z3.sat) - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s.add(s2 != 0) - s.add(x == tensor_type.tensor2(D(0, s1), D(s2, 1))) - s.add(y == tensor_type.tensor1(D(0, s3))) - - self.assertEquals(s.check(), z3.sat) - - # print(s.model()) - - add_result = z3.Const(3, tensor_type) - assert s.model()[add_result].arg(0).arg(0).as_long() == 0 - assert s.model()[add_result].arg(1).arg(0).as_long() == 0 - - - def test_add_padding_4(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 1]), y: TensorType([3])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - - add_result = z3.Const(3, tensor_type) - assert s.model()[add_result] == tensor_type.tensor2(D(1, 2), D(1, 3)) - - def test_add_padding_5(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 2]), y: TensorType([3])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.unsat) - - def test_add_size_3(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, Dyn, Dyn]), y: TensorType([Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') - - s.add(x == tensor_type.tensor3(D(1, s1), D(1, 1), D(1, s2))) - s.add(y == tensor_type.tensor3(D(1, s3), D(1, s4), D(1, s5))) - - self.assertEquals(s.check(), z3.sat) - s.add(s2 == 5) - self.assertEquals(s.check(), z3.sat) - s.add(s5 == 6) - self.assertEquals(s.check(), z3.unsat) - - def test_add_padding_6(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') - - s.add(x == tensor_type.tensor1(D(1, s1))) - s.add(y == tensor_type.tensor3(D(1, s2), D(1, s3), D(1, s4))) - - self.assertEquals(s.check(), z3.sat) - - s.add(s1 == 4) - s.add(s4 == 5) - - self.assertEquals(s.check(), z3.unsat) - - def test_add_padding_7(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') - s.add(x == tensor_type.tensor2(D(s1, s2), D(s2, s3))) - self.assertEquals(s.check(), z3.unsat) - - - def test_add_padding_8(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') - s.add(x == tensor_type.tensor1(D(s1, 1))) - s.add(s1 >= 0) - - self.assertEquals(s.check(), z3.sat) - - s.add(y == tensor_type.tensor4(D(0, s2), D(0, s3), D(0, s4), D(0, s5))) - self.assertEquals(s.check(), z3.sat) - - def test_add_padding_9(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: TensorType([Dyn, Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s1, s2, s3, s4, s5, s6, s7 = z3.Ints('s1 s2 s3 s4 s5 s6 s7') - s.add(x == tensor_type.tensor1(D(s1, s7))) - s.add(s1 == 1) - self.assertEquals(s.check(), z3.sat) - - s.add(y == tensor_type.tensor4(D(0, s2), D(0, s3), D(0, s4), D(s6, s5))) - self.assertEquals(s.check(), z3.sat) - - s.add(s6 == 1) - - self.assertEquals(s.check(), z3.sat) - s.add(s5 != 1, s7 != 1) - assert s.check() - - assert s.model()[s5].as_long() == s.model()[s7].as_long() - - def test_conv_static(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - e1, e2, e3, e4 = z3.Ints('e1 e2 e3 e4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - e11, e22, e33, e44 = z3.Ints('e11 e22 e33 e44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - b1, b2, b3, b4 = D(e11, e1), D(e22, e2), D(e33, e3), D(e44, e4) - - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation) - - def forward(self, x: TensorType((1, 2, 10, 20))): - return self.conv1(x) - - ast_rewriter = RewritingTracer() - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - res = B.forward(torch.rand(1, 2, 10, 20)).size() - - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - solver.add(x == tensor_type.tensor4(d1, d2, d3, d4)) - solver.add(y == tensor_type.tensor4(b1, b2, b3, b4)) - self.assertEquals(solver.check(), z3.sat) - # print(solver.model()) - assert solver.model()[e3].as_long() == res[2] - assert solver.model()[e4].as_long() == res[3] - - B2 = BasicBlock(2, 4, 5, 2, 9, 2, 2) - res2 = B2.forward(torch.rand(1, 2, 10, 20)).size() - - graph2 = ast_rewriter.trace(B2) - traced2 = GraphModule(ast_rewriter.root, graph2, "gm") - new_transformed_c = transform_all_constraints(traced2) - solver = z3.Solver() - solver.add(new_transformed_c) - - solver.add(x == tensor_type.tensor4(d1, d2, d3, d4)) - solver.add(y == tensor_type.tensor4(b1, b2, b3, b4)) - - self.assertEquals(solver.check(), z3.sat) - assert solver.model()[e3].as_long() == res2[2] - assert solver.model()[e4].as_long() == res2[3] - - def test_reshape_dyn(self): - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - return torch.reshape(x, (2, -1)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - self.assertEquals(s.check(), z3.sat) - s.add(z3.Or([s11 == 2, s11 == 4, s11 == 9])) - self.assertEquals(s.check(), z3.sat) - s.add(s11 == 9) - self.assertEquals(s.check(), z3.unsat) - - - def test_reshape_annotated(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn])): - return torch.reshape(x, (2, -1)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor2(d1, d2)) - self.assertEquals(s.check(), z3.unsat) - - def test_reshape_static_target(self): - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn])): - return torch.reshape(x, (2, 3)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced) - # print(transformed) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - s.check() - assert s.model()[s11].as_long() == 6 - s.add(s11 != 6) - self.assertEquals(s.check(), z3.unsat) - - def test_reshape_static_target2(self): - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - return torch.reshape(x, (2, 3, 1, 1)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - s.check() - assert s.model()[s11].as_long() == 6 - s.add(s11 != 6) - self.assertEquals(s.check(), z3.unsat) - - - def test_conv2D_maxpool2d_flatten(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((4, 3, 32, 32))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - res = B.forward(torch.rand(4, 3, 32, 32)).shape - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - solver.check() - input = z3.Const(1, tensor_type) - solver.add(input == tensor_type.tensor4(D(1, 4), D(1, 3), D(1, 32), D(1, 32))) - solver.check() - output = z3.Const(48, tensor_type) - assert solver.model()[output].arg(0).arg(1) == res[0] - assert solver.model()[output].arg(1).arg(1) == res[1] - - def test_conv2D_maxpool2d_flatten_unsat(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((4, 3, 32, 32))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - solver.check() - input = z3.Const(1, tensor_type) - solver.add(input == tensor_type.tensor4(D(1, 4), D(1, 3), D(1, 32), D(1, 45))) - self.assertEquals(solver.check(), z3.unsat) - - def test_conv2D_maxpool2d_flatten_dyn(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((Dyn, 3, 32, 32))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - - def test_type_check_flatten(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - - class M(torch.nn.Module): - def forward(self, x: TensorType([2, 3, 4, 5])): - return torch.flatten(x, start_dim=1, end_dim=3) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - flatten = z3.Const(2, tensor_type) - - res = M().forward(torch.rand(2, 3, 4, 5)).size() - assert solver.model()[flatten].arg(0).arg(1) == res[0] - assert solver.model()[flatten].arg(1).arg(1) == res[1] - - class M(torch.nn.Module): - def forward(self, x: TensorType([2, 3, Dyn, 5])): - return torch.flatten(x, start_dim=1, end_dim=3) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - solver.add(x == tensor_type.tensor4(D(1, 2), D(1, 3), D(0, s1), D(1, 5))) - self.assertEquals(solver.check(), z3.sat) - assert solver.model()[y].arg(1).arg(0) == 0 - - - class M(torch.nn.Module): - def forward(self, x: TensorType([2, 3, Dyn])): - return torch.flatten(x, 10, 0) - - module = M() - # print(module.forward(torch.rand(2,3,5)).shape) - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.unsat) - -class ConstraintGeneration(unittest.TestCase): - - def test_add_reshape(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.add(torch.reshape(x, (1, 2)), torch.reshape(y, (2, 2))) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - generator = ConstraintGenerator(traced) - new_constraints, counter = generator.generate_constraints(0) - assert len(new_constraints.conjucts) == 11 - - - def test_conv_reshape_add(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: Dyn): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - generator = ConstraintGenerator(traced) - new_constraints, counter = generator.generate_constraints(0) - assert len(new_constraints.conjucts) == 16 - - -class TestInternalConstraints(unittest.TestCase): - def test_precision(self): - - c1 = BinConstraintT(Dyn, TVar('x'), op_precision) - transformed, _ = transform_constraint(c1, 0) - assert transformed == T() - - c2 = BinConstraintT(TensorType([1, Dyn, 3]), TVar('x'), op_precision) - transformed, counter = transform_constraint(c2, 0) - assert len(transformed.conjucts) == 7 - - def test_matching(self): - c1 = BinConstraintT(TVar('x'), - TensorType([DVar('a'), DVar('b'), DVar('c'), DVar('d')]), op_matching) - transformed, _ = transform_constraint(c1, 0) - assert len(transformed.disjuncts) == 2 - - def test_consistency(self): - c1 = BinConstraintT(TVar('x'), - TensorType([DVar('a'), DVar('b')]), op_consistency) - transformed, count = transform_constraint(c1, 0) - - assert len(transformed.disjuncts) == 5 - transformed, count = transform_constraint(transformed, count) - assert len(transformed.disjuncts) == 5 - - # def test_apply_broadcasting(self): - # c1 = ApplyBroadcasting(TVar(1), TVar(2), TVar(3), TVar(4)) - # transformed, count = transform_apply_broadcasting(c1, 5) - # assert len(transformed.conjucts) == 41 - -@skipIfNoTorchVision -class TestResNet(unittest.TestCase): - - def test_resnet50_unsat(self): - traced = symbolic_trace(models.resnet50()) - for n in traced.graph.nodes: - n.type = Dyn - - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - input = z3.Const(1, tensor_type) - # input with 3 dimensions - solver.add(input == tensor_type.tensor3(D(1, 1), D(1, 3), D(1, 224))) - self.assertEquals(solver.check(), z3.unsat) - - - - def test_resnet50(self): - traced = symbolic_trace(models.resnet50()) - for n in traced.graph.nodes: - n.type = Dyn - - sample_input = torch.randn(1, 3, 224, 224) - res = models.resnet50().forward(sample_input).size() - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - linear = z3.Const(650, tensor_type) - - input = z3.Const(1, tensor_type) - solver.add(input == tensor_type.tensor4(D(1, 1), D(1, 3), D(1, 224), D(1, 224))) - self.assertEquals(solver.check(), z3.sat) - assert solver.model()[linear] == tensor_type.tensor2(D(1, res[0]), D(1, res[1])) - - def test_resnet502(self): - traced = symbolic_trace(models.resnet50()) - for n in traced.graph.nodes: - n.type = Dyn - - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - linear = z3.Const(650, tensor_type) - input = z3.Const(1, tensor_type) - batch = z3.Int('b') - solver.add(input == tensor_type.tensor4(D(1, batch), D(1, 3), D(1, 224), D(1, 224))) - solver.add(batch > 4) - solver.check() - assert solver.model()[batch] == solver.model()[linear].arg(0).arg(1) - - def test_resnet503(self): - traced = symbolic_trace(models.resnet50()) - for n in traced.graph.nodes: - n.type = Dyn - - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - linear = z3.Const(650, tensor_type) - input = z3.Const(1, tensor_type) - batch, d1, d2 = z3.Ints('b d1 d2') - solver.add(input == tensor_type.tensor4(D(1, batch), D(1, 3), D(1, 224), D(1, 224))) - solver.add(linear == tensor_type.tensor2(D(1, d1), D(1, d2))) - self.assertEquals(solver.check(), z3.sat) - solver.add(batch != d1) - self.assertEquals(solver.check(), z3.unsat) - -@skipIfNoTorchVision -class TestAlexNet(unittest.TestCase): - def test_alexnet1(self): - - alexnet = models.alexnet() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(alexnet) - - for n in symbolic_traced.graph.nodes: - n.type = Dyn - - # print(symbolic_traced) - - res = alexnet.forward(torch.rand(10, 3, 227, 227)).size() - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - input = z3.Const(1, tensor_type) - conv = z3.Const(2, tensor_type) - solver.add(input == tensor_type.tensor4(D(1, 10), D(1, 3), D(1, 227), D(1, 227))) - self.assertEquals(solver.check(), z3.sat) - assert solver.model()[conv] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 56), D(1, 56)) - - relu = z3.Const(7, tensor_type) - assert solver.model()[relu] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 56), D(1, 56)) - - maxpool = z3.Const(8, tensor_type) - assert solver.model()[maxpool] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 27), D(1, 27)) - - maxpool2 = z3.Const(42, tensor_type) - assert solver.model()[maxpool2] == tensor_type.tensor4(D(1, 10), D(1, 256), D(1, 6), D(1, 6)) - - flatten = z3.Const(52, tensor_type) - assert solver.model()[flatten] == tensor_type.tensor2(D(1, 10), D(1, 9216)) - - linear = z3.Const(64, tensor_type) - assert solver.model()[linear] == tensor_type.tensor2(D(1, 10), D(1, 4096)) - - linear2 = z3.Const(109, tensor_type) - assert solver.model()[linear2] == tensor_type.tensor2(D(1, res[0]), D(1, res[1])) - - - def test_alexnet2(self): - alexnet = models.alexnet() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(alexnet) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, 4, 227, 227]) - - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.unsat) - - def test_alexnet3(self): - alexnet = models.alexnet() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(alexnet) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn, 227, 227]) - - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - - def test_alexnet4(self): - alexnet = models.alexnet() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(alexnet) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn, 227]) - - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.unsat) - - - -if __name__ == '__main__': - unittest.main() diff --git a/test/local_test_compile.py b/test/local_test_compile.py deleted file mode 100644 index 6ee079912..000000000 --- a/test/local_test_compile.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import unittest - -import pippy - -import torch -from pippy import run_pippy -from pippy.IR import pipe_split - -d_hid = 512 -bs = 256 - - -class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - pipe_split() - x = torch.mm(x, self.mm_param) - x = self.lin(x) - pipe_split() - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - pipe_split() - x = self.lin(x) - x = torch.relu(x) - return {"out": x} - - -def run_master(_, args): - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - - # Create pipeline model - pipe_ec = pippy.compile( - ec, - num_ranks=args.world_size, - num_chunks=4, - schedule=args.schedule, - checkpoint=bool(args.checkpoint), - _debug_mask_minibatches=True, # for numerical equivalence test only - ) - - # Warm up and correctness runs - out = pipe_ec(ec_input) - ref_out = ec(ec_input) - - # run with different chunk size to exercise microbatch and scheduling components - torch.testing.assert_close(out["out"], ref_out["out"]) - print( - f'equivalence test passed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default="FillDrain", - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestCompileTest(unittest.TestCase): - def test_compile(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_forward.py b/test/local_test_forward.py deleted file mode 100644 index 924759cba..000000000 --- a/test/local_test_forward.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import unittest - -import pippy.fx - -import torch -import torch.autograd.profiler_legacy -from pippy import run_pippy -from pippy.IR import MultiUseParameterConfig, Pipe, pipe_split -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - d_hid = 512 - bs = 503 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - pipe_split() - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - pipe_split() - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - pipe_split() - x = torch.relu(x) - return {"out": x} - - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - ec(ec_input) - - ec_pipe = Pipe.from_tracing(ec, MULTI_USE_PARAM_CONFIG) - print(ec_pipe.split_gm) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - 5, - args.world_size, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - # # Warm up and correctness runs - out = pipe_driver(ec_input) - ref_out = ec_pipe(ec_input) - - # run with different chunk size to exercise microbatch and scheduling components - pipe_driver.chunks = 1 - pipe_driver(ec_input) - pipe_driver.chunks = 100 - pipe_driver(ec_input) - - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close(out["out"], ref_out["out"]) - print( - f'equivalence test passed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - - # # Profiling runs - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_driver.chunks = 5 - out = pipe_driver(ec_input) - ref_out = ec_pipe(ec_input) - print( - f'profiling run completed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestForwardTest(unittest.TestCase): - def test_forward(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_visualizer.py b/test/local_test_visualizer.py deleted file mode 100644 index 95eacbe20..000000000 --- a/test/local_test_visualizer.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import time -import unittest -from collections import defaultdict -from functools import reduce -from typing import Any, Dict, List - -import pippy.fx - -import torch -import torch.nn as nn -from pippy import run_pippy -from pippy.events import Event -from pippy.IR import ( - MultiUseParameterConfig, - Pipe, - pipe_split, - TrivialLossWrapper, -) -from pippy.PipelineDriver import ( - EventsContext, - Phase, - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) -from pippy.visualizer import events_to_json -from torch.autograd import Function - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -@pippy.fx.wrap -def sleep(x, t=1.0): - time.sleep(t) - return x - - -class SlowMSELoss(nn.MSELoss): - def forward(self, input, target): - return super().forward(sleep(input, t=0.01), target) - - -# Inherit from Function -class MyLinearFunction(Function): - # Note that both forward and backward are @staticmethods - @staticmethod - # bias is an optional argument - def forward(ctx, input, weight, bias=None): - # print("my forward") - input = sleep(input, t=0.1) - ctx.save_for_backward(input, weight, bias) - output = input.mm(weight.t()) - if bias is not None: - output += bias.unsqueeze(0).expand_as(output) - return output - - # This function has only a single output, so it gets only one gradient - @staticmethod - def backward(ctx, grad_output): - # print("my backward") - grad_output = sleep(grad_output, t=0.3) - # This is a pattern that is very convenient - at the top of backward - # unpack saved_tensors and initialize all gradients w.r.t. inputs to - # None. Thanks to the fact that additional trailing Nones are - # ignored, the return statement is simple even when the function has - # optional inputs. - input, weight, bias = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - - # These needs_input_grad checks are optional and there only to - # improve efficiency. If you want to make your code simpler, you can - # skip them. Returning gradients for inputs that don't require it is - # not an error. - if ctx.needs_input_grad[0]: - grad_input = grad_output.mm(weight) - if ctx.needs_input_grad[1]: - grad_weight = grad_output.t().mm(input) - if bias is not None and ctx.needs_input_grad[2]: - grad_bias = grad_output.sum(0) - - return grad_input, grad_weight, grad_bias - - -@pippy.fx.wrap -def linear(input, weight, bias): - return MyLinearFunction.apply(input, weight, bias) - - -class MyLinear(nn.Module): - def __init__(self, input_features, output_features, bias=True): - super(MyLinear, self).__init__() - self.input_features = input_features - self.output_features = output_features - - # nn.Parameter is a special kind of Tensor, that will get - # automatically registered as Module's parameter once it's assigned - # as an attribute. Parameters and buffers need to be registered, or - # they won't appear in .parameters() (doesn't apply to buffers), and - # won't be converted when e.g. .cuda() is called. You can use - # .register_buffer() to register buffers. - # nn.Parameters require gradients by default. - self.weight = nn.Parameter(torch.empty(output_features, input_features)) - if bias: - self.bias = nn.Parameter(torch.empty(output_features)) - else: - # You should always register all possible parameters, but the - # optional ones can be None if you want. - self.register_parameter("bias", None) - - # Not a very smart way to initialize weights - nn.init.uniform_(self.weight, -0.1, 0.1) - if self.bias is not None: - nn.init.uniform_(self.bias, -0.1, 0.1) - - def forward(self, input): - # See the autograd section for explanation of what happens here. - return linear(input, self.weight, self.bias) - - def extra_repr(self): - # (Optional)Set the extra information about this module. You can test - # it by printing an object of this class. - return "input_features={}, output_features={}, bias={}".format( - self.input_features, self.output_features, self.bias is not None - ) - - -def run_master(_, args): - d_hid = 100 - bs = 400 - chunks = 4 - batches = 1 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - print("Using schedule:", args.schedule) - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.l1 = MyLinear(d_hid, d_hid) - self.l2 = MyLinear(d_hid, d_hid) - self.l3 = MyLinear(d_hid, d_hid) - self.l4 = MyLinear(d_hid, d_hid) - - def forward(self, x): - x = self.l1(x) - pipe_split() - x = self.l2(x) - pipe_split() - x = self.l3(x) - pipe_split() - x = self.l4(x) - return x - - ec = ExampleCode() - ec.to(args.device) - - mse_loss = SlowMSELoss(reduction="sum") - wrapper = TrivialLossWrapper(ec, mse_loss) - ec_pipe = Pipe.from_tracing(wrapper, MULTI_USE_PARAM_CONFIG) - - all_ranks = list(range(1, args.world_size)) # exclude master rank = 0 - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - chunks, - args.world_size - 1, - all_ranks=all_ranks, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - ec_input = torch.randn(bs, d_hid, device=args.device) - target = torch.randn(bs, d_hid, device=args.device) - - pipe_visualized_filename = "pipe_visualized.json" - batches_events_contexts = [] - for i in range(batches): - pipe_driver(ec_input, target) - batches_events_contexts.append(pipe_driver.retrieve_events()) - - # first: save file - all_events_contexts: EventsContext = reduce( - lambda c1, c2: EventsContext().update(c1).update(c2), - batches_events_contexts, - EventsContext(), - ) - with open(pipe_visualized_filename, "w") as f: - f.write(events_to_json(all_events_contexts)) - - # TODO: Investigate flakiness! TODO(https://github.com/pytorch/PiPPy/issues/136) - # # second: perform checks - # for events_context in batches_events_contexts: - # check_events_for_single_batch(events_context.events, all_ranks, chunks, pipe_visualized_filename) - - -def check_events_for_single_batch( - events: List[Event], - all_stages: List[int], - chunks: int, - pipe_visualized_filename: str, -): - events_by_type_by_rank_by_mbid: Dict[ - Any, Dict[Any, Dict[Any, Event]] - ] = defaultdict(lambda: defaultdict(lambda: dict())) - for event in events: - events_by_type_by_rank_by_mbid[event.type][event.rank][ - event.mbid - ] = event - - def start_ts(e: Event, eps=0.1): - return e.start_ts + (e.finish_ts - e.start_ts) * eps - - def finish_ts(e: Event, eps=0.1): - return e.finish_ts - (e.finish_ts - e.start_ts) * eps - - # Basic happens-before cross rank checks - for i in range(len(all_stages) - 1): - rank = all_stages[i] - next_rank = all_stages[i + 1] - for mbid in range(chunks): - rank_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][rank][ - mbid - ] - next_rank_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][ - next_rank - ][mbid] - # happens-before cross-rank forward check - assert start_ts(next_rank_forward) >= finish_ts(rank_forward), ( - f"{rank_forward.name}({rank_forward.finish_ts}) must happen before " - f"{next_rank_forward.name}({next_rank_forward.start_ts}), see {pipe_visualized_filename}" - ) - - rank_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - next_rank - ][mbid] - next_rank_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - rank - ][mbid] - # happens-before cross-rank backward check - assert start_ts(next_rank_backward) >= finish_ts(rank_backward), ( - f"{rank_backward.name}({rank_backward.finish_ts}) must happen before " - f"{next_rank_backward.name}({next_rank_backward.start_ts}), see {pipe_visualized_filename}" - ) - - # Basic happens-before cross-microbatch checks - for mbid in range(chunks - 1): - next_mbid = mbid + 1 - for i in range(len(all_stages) - 1): - rank = all_stages[i] - rank_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][rank][ - mbid - ] - next_mbid_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][ - rank - ][next_mbid] - # happens-before cross-microbatch forward check - assert start_ts(next_mbid_forward) >= finish_ts(rank_forward), ( - f"{rank_forward.name}({rank_forward.finish_ts}) must happen before " - f"{next_mbid_forward.name}({next_mbid_forward.start_ts}), see {pipe_visualized_filename}" - ) - - rank_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - rank - ][mbid] - next_mbid_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - rank - ][next_mbid] - # happens-before cross-microbatch backward check - assert start_ts(next_mbid_backward) >= finish_ts(rank_backward), ( - f"{rank_backward.name}({rank_backward.finish_ts}) must happen before " - f"{next_mbid_backward.name}({next_mbid_backward.start_ts}), see {pipe_visualized_filename}" - ) - - # Overlap checks - for mbid in range(chunks - 1): - next_mbid = mbid + 1 - last_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][ - all_stages[-1] - ][mbid] - first_next_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][ - all_stages[0] - ][next_mbid] - # cross-microbatch forward overlap check - assert ( - last_forward.finish_ts >= first_next_forward.start_ts - ), f"Forward microbatch {mbid} doesn't overlap with next microbatch {next_mbid}" - - last_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - all_stages[0] - ][mbid] - first_next_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - all_stages[-1] - ][next_mbid] - # cross-microbatch forward overlap check - assert ( - last_backward.finish_ts >= first_next_backward.start_ts - ), f"Backward microbatch {mbid} doesn't overlap with next microbatch {next_mbid}" - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 5)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestVisualizer(unittest.TestCase): - def test_visualizer(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_c10d_bwd.py b/test/test_bwd.py similarity index 72% rename from test/local_test_c10d_bwd.py rename to test/test_bwd.py index 21db17240..fb5f45aa0 100644 --- a/test/local_test_c10d_bwd.py +++ b/test/test_bwd.py @@ -3,12 +3,16 @@ import os import unittest +import pippy + import torch import torch.distributed as dist +from pippy.IR import Pipe, pipe_split +from pippy.microbatch import sum_reducer, TensorChunkSpec +from pippy.PipelineStage import PipelineStage -from pippy.compile import compile_stage -from pippy.IR import pipe_split +pippy.microbatch._debug_mask_minibatches = True schedules = [ "FillDrain", @@ -16,11 +20,12 @@ ] d_hid = 512 -chunk_size = 256 +batch_size = 256 torch.manual_seed(0) +# Basic example class ExampleCode(torch.nn.Module): def __init__(self): super().__init__() @@ -29,7 +34,7 @@ def __init__(self): self.lin = torch.nn.Linear(d_hid, d_hid) self.mse_loss = torch.nn.MSELoss(reduction="sum") - def forward(self, x, target): + def forward(self, x, y): x = torch.mm(x, self.mm_param) skip_connection = x x = torch.relu(x) @@ -42,35 +47,41 @@ def forward(self, x, target): x = torch.mm(x, self.mm_param2) pipe_split() x = self.lin(x) - x = torch.relu(x) - loss = self.mse_loss(x, target) - return {"logits": x, "loss": loss} + logits = torch.relu(x) + loss = self.mse_loss(x, y) + return logits, loss def run_worker(args): - ec = ExampleCode() - ec.to(args.device) - ec.train() + mod = ExampleCode() + mod.to(args.device) - ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) - target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + x = torch.randn(batch_size, d_hid, device=args.device) + y = torch.randn(batch_size, d_hid, device=args.device) - stage = compile_stage( - ec, - args.rank, - args.world_size, + output_chunk_spec = ( + TensorChunkSpec(0), # logits + sum_reducer, # loss + ) + + pipe = Pipe.from_tracing( + mod, args.chunks, - args.device, - None, - [ec_x, target], - schedule=args.schedule, + example_args=(x, y), + output_chunk_spec=output_chunk_spec, + ) + + stage = PipelineStage( + pipe, + args.rank, + device=args.device, ) # Run if args.rank == 0: - out = stage(ec_x) + out = stage(x) elif args.rank == args.world_size - 1: - out = stage(target) + out = stage(y) else: stage() @@ -79,11 +90,9 @@ def run_worker(args): # Last rank checks result if args.rank == args.world_size - 1: - ref_out = ec(ec_x, target) + ref_out = mod(x, y) torch.testing.assert_close(out, ref_out) - print( - f"equivalence test passed, loss = {out['loss']}, ref loss = {ref_out['loss']}" - ) + print(f"equivalence test passed loss={out[1]} ref_loss={ref_out[1]}") def main(args=None): @@ -135,8 +144,8 @@ def main(args=None): main() -class LocalTestC10DBwdTest(unittest.TestCase): - def test_c10d_bwd(self): +class TestBwd(unittest.TestCase): + def test_bwd(self): import random port = random.randint(29500, 30000) diff --git a/test/local_test_c10d.py b/test/test_fwd.py similarity index 82% rename from test/local_test_c10d.py rename to test/test_fwd.py index 5044c50d3..675844170 100644 --- a/test/local_test_c10d.py +++ b/test/test_fwd.py @@ -3,15 +3,18 @@ import os import unittest +import pippy + import torch import torch.distributed as dist +from pippy.IR import Pipe, pipe_split +from pippy.PipelineStage import PipelineStage -from pippy.compile import compile_stage -from pippy.IR import pipe_split +pippy.microbatch._debug_mask_minibatches = True d_hid = 512 -chunk_size = 256 +batch_size = 256 torch.manual_seed(0) @@ -42,25 +45,27 @@ def forward(self, x, y): def run_worker(args): - ec = ExampleCode() - ec.to(args.device) + mod = ExampleCode() + mod.to(args.device) - ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) - ec_y = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + x = torch.randn(batch_size, d_hid, device=args.device) + y = torch.randn(batch_size, d_hid, device=args.device) - stage = compile_stage( - ec, - args.rank, - args.world_size, + pipe = Pipe.from_tracing( + mod, args.chunks, - args.device, - None, - [ec_x, ec_y], + example_args=(x, y), + ) + + stage = PipelineStage( + pipe, + args.rank, + device=args.device, ) # Run if args.rank == 0: - out = stage(ec_x, ec_y) + stage(x, y) elif args.rank == args.world_size - 1: out = stage() else: @@ -71,7 +76,7 @@ def run_worker(args): # Last rank checks result if args.rank == args.world_size - 1: - ref_out = ec(ec_x, ec_y) + ref_out = mod(x, y) torch.testing.assert_close(out, ref_out) print( f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" @@ -121,8 +126,8 @@ def main(args=None): main() -class LocalTestC10DTest(unittest.TestCase): - def test_c10d(self): +class TestFwd(unittest.TestCase): + def test_fwd(self): import random port = random.randint(29500, 30000) diff --git a/test/test_fx.py b/test/test_fx.py deleted file mode 100644 index b366395ec..000000000 --- a/test/test_fx.py +++ /dev/null @@ -1,4658 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import builtins -import contextlib -import copy -import functools -import inspect -import io -import math -import numbers -import operator -import os -import pickle -import sys -import traceback -import types -import typing -import unittest -import warnings -from collections import namedtuple -from copy import deepcopy -from math import sqrt - -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union - -import pippy -import pippy.fx._pytree as fx_pytree - -import torch -import torch.nn.utils._stateless as _stateless -import torch.utils._pytree as pytree - -from fx.named_tup import MyNamedTup -from pippy.fx import ( - CodeGen, - Graph, - GraphModule, - Interpreter, - Node, - PH, - Proxy, - symbolic_trace, - Tracer, - Transformer, - wrap, -) -from pippy.fx._compatibility import ( - _BACK_COMPAT_OBJECTS, - _MARKED_WITH_COMATIBLITY, -) -from pippy.fx.experimental.rewriter import RewritingTracer -from pippy.fx.immutable_collections import immutable_dict, immutable_list -from pippy.fx.node import _format_arg, Argument, Target -from pippy.fx.operator_schemas import get_signature_for_torch_op -from pippy.fx.passes import shape_prop -from pippy.fx.proxy import TraceError -from torch.multiprocessing import Process -from torch.testing import FileCheck -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, - onlyCPU, - ops, -) -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_utils import ( - find_library_location, - IS_FBCODE, - IS_MACOS, - IS_WINDOWS, - run_tests, - skipIfSlowGradcheckEnv, -) -from torch.testing._internal.jit_utils import JitTestCase - -try: - from torchvision import models as torchvision_models - - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") - - -class SimpleTest(torch.nn.Module): - def forward(self, x): - return torch.relu(x + 3.0) - - -def a_non_torch_leaf(a, b): - return a + b - - -# Used for test_autowrap_function. Autowrapped functions need to be global -def fx_int(x: float) -> int: - return int(x) - - -def fx_int_x2(x: float) -> int: - return int(x) * 2 - - -# used in test_pytree. It's all the way out here because pickling a GraphModule -# that uses Point errors out if Point is local to the function -Point = namedtuple("Point", ["x", "y"]) - - -# Test wrap() passing both a function name as well as a function -# directly -def a_lifted_leaf(a, b): - return a[0] + a[1] + b - - -wrap("a_lifted_leaf") -# Test wrapping twice doesn't break anything -wrap("a_lifted_leaf") - - -def a_lifted_leaf2(a, b): - return a[0] + a[1] + b - - -wrap(a_lifted_leaf2) - -wrap("len") - -wrap("getattr") - - -def wrapped_named_tup(p1, *, p2): - return p1.x + p2.y - - -wrap(wrapped_named_tup) - - -@wrap -def wrapped_via_decorator(a): - return a + 1 - - -wrap("wrapped_with_submodule") - - -def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d): - return batchnorm1d(x) - - -def my_decorator(f): - @functools.wraps(f) - def wrapper_inside_decorator(*args, **kwargs): - return f(*args, **kwargs) - - return wrapper_inside_decorator - - -@wrap -@my_decorator -def wrapped_decorated_fn(x): - return x - - -real_wrapped_via_decorator = wrapped_via_decorator -real_a_lifed_leaf = a_lifted_leaf -real_a_lifed_leaf2 = a_lifted_leaf2 -_sqrt = sqrt - -wrap("wrapper_fn") - - -def wrapper_fn(x): - return torch.foo(x) - - -class Pair(NamedTuple): - x: torch.Tensor - y: torch.Tensor - - def _custom_fx_repr_fn(self) -> str: - return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})" - - -# for testing pytrees -class Foo(object): # noqa: B209 - def __init__(self, a, b): - self.a = a - self.b = b - - -class TestFX(JitTestCase): - def setUp(self): - super().setUp() - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - if not (IS_FBCODE or IS_WINDOWS or IS_MACOS): - lib_file_path = find_library_location("libtorchbind_test.so") - torch.ops.load_library(str(lib_file_path)) - - def tearDown(self): - super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): - """Check that an nn.Module's results match the GraphModule version - for a given set of args/kwargs. - """ - kwargs = kwargs if kwargs else {} - ref_outs = m(*args, **kwargs) - gm = symbolic_trace(m) - gm.graph.lint() - test_outs = gm(*args, **kwargs) - self.assertEqual(ref_outs, test_outs) - - def test_graph_module(self): - class MySub(torch.nn.Module): - def __init__(self): - super().__init__() - self.w = torch.nn.Parameter(torch.rand(4, 3)) - - def forward(self, x): - return self.w + x - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin = torch.nn.Linear(4, 3) - self.sub_mod = MySub() - self.w = torch.nn.Parameter(torch.rand(3)) - - def forward(self, A, B, c): - t = torch.sigmoid(A) + self.lin(c) - return self.sub_mod( - t.data - + self.w - + t - + 1 - - A - + B // A - + -A - + A.add(B, alpha=3) - ) - - m = MyModule() - gm = symbolic_trace(m) - - ms = torch.jit.script(gm) - - class M2(torch.nn.Module): - def forward(self, A): - m, idx = torch.max(A, 0) - return m + 1, idx + 1 - - m2 = M2() - gm2 = symbolic_trace(m2) - - class T(torch.nn.Module): - def forward(self, A, b=4, *args, c=5, **kwargs): - x = A + 1 + args[0] + kwargs["3"] - return x - - t = T() - symbolic_trace(t) - - # test for issue described at https://github.com/pytorch/pytorch/issues/63883 - class M3(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - m3 = M3() - gm3 = symbolic_trace(m3) - new_instance = gm3.__new__(type(gm3)) - new_instance.__init__(gm3, gm3.graph) - - x = torch.randn(5, 3) - torch.testing.assert_allclose(new_instance(x), torch.relu(x)) - - def test_custom_import(self): - graph = pippy.fx.Graph() - a = graph.placeholder("x") - b = graph.placeholder("y") - c = graph.call_function(a_non_torch_leaf, (a, b)) - d = graph.call_function(torch.sin, (c,)) - graph.output(d) - gm = GraphModule(torch.nn.Module(), graph) - x, y = torch.rand(1), torch.rand(1) - self.assertEqual(torch.sin(x + y), gm(x, y)) - - def test_args_kwargs(self): - class T(torch.nn.Module): - def forward(self, *args, **kwargs): - x = args[0] + kwargs["foo"] - return x - - t = T() - self.checkGraphModule( - t, (torch.rand(1), torch.rand(1)), {"foo": torch.rand(1)} - ) - - def test_args_kwargs_no_self(self): - class T(torch.nn.Module): - def forward(*args, **kwargs): # noqa: B902 - self = args[0] - return torch.relu(args[1]) - - t = T() - with self.assertRaisesRegex( - RuntimeError, r"cannot be part of \*args expansion" - ): - self.checkGraphModule( - t, (torch.rand(1), torch.rand(1)), {"foo": torch.rand(1)} - ) - - def test_fx_shifts(self): - class MyModule(torch.nn.Module): - def forward(self, x): - return x << 3, x >> 3 - - input = torch.LongTensor(10).random_(0, 1024) - - m = MyModule() - self.checkGraphModule(m, (input,)) - - def test_fx_and_or(self): - class MyModule(torch.nn.Module): - def forward(self, x): - return x & x, x | x - - input = torch.LongTensor(10).random_(0, 1024) - - m = MyModule() - self.checkGraphModule(m, (input,)) - - def test_dict(self): - class MyDictMod(torch.nn.Module): - def forward(self, d): - return d["3"].relu(), {"4": d["3"].neg()} - - input_dict = {"3": torch.rand(3, 4)} - m = MyDictMod() - - self.checkGraphModule(m, (input_dict,)) - - def test_matmul_tracing(self): - const = torch.randn(3) - - def matmul_f(x): - return x @ const - - mod = symbolic_trace(matmul_f) - inp = torch.randn(3) - self.assertEqual(mod(inp), matmul_f(inp)) - - def rmatmul_f(x): - return const @ x - - mod = symbolic_trace(rmatmul_f) - inp = torch.randn(3) - self.assertEqual(mod(inp), rmatmul_f(inp)) - - def test_disallow_override(self): - # Custom delegate to disallow in-place tensor operations - class NoMutableCallTracer(Tracer): - def create_node( - self, - kind: str, - target: Union[str, Callable], - args: Tuple[Argument, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - ) -> Node: - name = ( - target - if isinstance(target, str) - else torch.typename(target) - ) - if name[-1] == "_": - raise RuntimeError("In-place operations are not supported") - return super().create_node(kind, target, args, kwargs, name) - - # Test method - class MyInplaceMod(torch.nn.Module): - def forward(self, x): - x.add_(3.0) - return x - - m = MyInplaceMod() - - with self.assertRaisesRegex(RuntimeError, "In-place operations"): - NoMutableCallTracer().trace(m) - - # Test free function - class MyInplaceMod2(torch.nn.Module): - def forward(self, x): - torch.log_(x) - return x - - m2 = MyInplaceMod2() - with self.assertRaisesRegex(RuntimeError, "In-place operations"): - NoMutableCallTracer().trace(m2) - - # Test symbolic node as an arg - class MyInplaceMod3(torch.nn.Module): - def forward(self, x): - y = torch.ones(3, 4) - y.add_(x) - return x - - m3 = MyInplaceMod3() - with self.assertRaisesRegex(RuntimeError, "In-place operations"): - NoMutableCallTracer().trace(m3) - - def test_leaf_module(self): - # Custom delegate to make it so that there are no leaf modules, everything - # should get traced through - class NoLeafModulesTracer(Tracer): - def is_leaf_module(self, m, qualname): - return False - - class MyReluMod(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(x) - - mrm = MyReluMod() - sym = NoLeafModulesTracer().trace(mrm) - for node in sym.nodes: - self.assertNotEqual(node.op, "call_module") - sym.lint() - - def test_wrap(self): - self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) - - def to_trace(y): - return ( - a_lifted_leaf((4, y), 3) - + a_lifted_leaf((3, 4), 5) - + a_lifted_leaf((y, y), y) - ) - - m = symbolic_trace(to_trace) - self.assertIn("a_lifted_leaf", m.code) - self.assertEqual(27, m(2)) - self.assertIs(a_lifted_leaf, real_a_lifed_leaf) - - def test_wrap_fn_directly(self): - self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) - - def to_trace(y): - return ( - a_lifted_leaf2((4, y), 3) - + a_lifted_leaf2((3, 4), 5) - + a_lifted_leaf2((y, y), y) - ) - - m = symbolic_trace(to_trace) - self.assertIn("a_lifted_leaf2", m.code) - self.assertEqual(27, m(2)) - self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) - - def test_wrapped_via_decorator(self): - self.assertEqual(wrapped_via_decorator(0), 1) - - def to_trace(y): - return wrapped_via_decorator(y) - - m = symbolic_trace(to_trace) - self.assertIn("wrapped_via_decorator", m.code) - self.assertEqual(m(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - def test_wrapped_via_decorator_and_transformed(self): - self.assertEqual(wrapped_via_decorator(0), 1) - - def to_trace(y): - return wrapped_via_decorator(y) - - m = symbolic_trace(to_trace) - self.assertIn("wrapped_via_decorator", m.code) - self.assertEqual(m(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - transformed = pippy.fx.Transformer(m).transform() - self.assertIn("wrapped_via_decorator", transformed.code) - self.assertEqual(transformed(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - def test_wrap_with_submodule(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) - - def forward(self, x: torch.Tensor): - return wrapped_with_submodule(x, self.batchnorm1d) - - m = symbolic_trace(M()) - - self.assertIn("wrapped_with_submodule", m.code) - - input = torch.rand(3, 2) - ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) - self.assertEqual(ref_batchnorm1d(input), m(input)) - - def test_wrapped_retrace(self): - def to_trace(y): - return wrapped_via_decorator(y) - - m = symbolic_trace(to_trace) - self.assertIn("wrapped_via_decorator", m.code) - self.assertEqual(m(0), 1) - - retraced = symbolic_trace(m) - self.assertIn("wrapped_via_decorator", retraced.code) - self.assertEqual(retraced(0), 1) - - def test_wrap_decorated_function(self): - def to_trace(y): - return wrapped_decorated_fn(y) - - m = symbolic_trace(to_trace) - self.assertIn("wrapped_decorated_fn", m.code) - self.assertEqual(m(1), 1) - - def test_graph_edit_with_proxy(self): - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = M() - g = symbolic_trace(m).graph - new_g = pippy.fx.Graph() - val_map: Dict[Node, Node] = {} - output_val = new_g.graph_copy(g, val_map) - t = Proxy(output_val) - # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. - new_g.output((t + t).node) - gm = GraphModule(m, new_g) - gm.graph.lint() - self.assertEqual(gm(3, 4), 14) - - def test_concrete_arg_none_assert(self): - class Foo(torch.nn.Module): - def forward(self, x, val=None): - return x if val is None else x + val - - f = Foo() - traced = pippy.fx.symbolic_trace(f, concrete_args={"val": None}) - with self.assertRaisesRegex( - AssertionError, "val has been specialized to have value None" - ): - traced(torch.randn(5), torch.randn(5)) - - x = torch.randn(5) - torch.testing.assert_close(traced(x), f(x)) - - def test_trace_multiple_funcs(self): - class Foo(torch.nn.Module): - def forward(self, x, y): - return x + y - - def minus_forward(self, x, y): - return x - y - - def multiply_forward(self, x, y): - return x * y - - f = Foo() - x, y = torch.randn(5), torch.randn(5) - - print(torch.__version__) - - tracer = Tracer() - torch.testing.assert_close( - GraphModule(f, tracer.trace(f))(x, y), f(x, y) - ) - - tracer.traced_func_name = "minus_forward" - torch.testing.assert_close( - GraphModule(f, tracer.trace(f))(x, y), - f.minus_forward(x, y), - ) - - tracer.traced_func_name = "multiply_forward" - torch.testing.assert_close( - GraphModule(f, tracer.trace(f))(x, y), - f.multiply_forward(x, y), - ) - - tracer.traced_func_name = "add_forward" - with self.assertRaisesRegex(AssertionError, "doesn't exist in"): - tracer.trace(f) - - def test_graph_unique_names(self): - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = M() - g = symbolic_trace(m).graph - new_g = pippy.fx.Graph() - val_map: Dict[Node, Node] = {} - output_val = new_g.graph_copy(g, val_map) - t = Proxy(output_val) - # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. - new_g.output((t + t).node) - gm = GraphModule(m, new_g) - seen_names: Set[str] = set() - for node in gm.graph.nodes: - assert node.name not in seen_names - seen_names.add(node.name) - - def test_stack_traces(self): - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - tracer = pippy.fx.Tracer() - tracer.record_stack_traces = True - - graph = tracer.trace(M()) - # saving the original list because we will insert new nodes as a part of a test - orig_graph_nodes = list(graph.nodes) - for node in orig_graph_nodes: - if node.op == "output": - continue - self.assertTrue(node.stack_trace is not None) - assert "test_fx.py" in node.stack_trace - - # verify that copying the node does not lose the stack trace - new_node = graph.node_copy(node) - self.assertTrue(new_node.stack_trace is not None) - assert "test_fx.py" in new_node.stack_trace - - def test_stack_traces_with_transformer(self): - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - tracer = pippy.fx.Tracer() - tracer.record_stack_traces = True - - graph = tracer.trace(M()) - gm = GraphModule(tracer.root, graph) - new_gm = Transformer(gm).transform() - - # nodes after Transformer should still preserve the original node's stack trace - for node in new_gm.graph.nodes: - if node.op in {"placeholder", "output"}: - continue - self.assertTrue(node.stack_trace is not None) - assert "test_fx.py" in node.stack_trace - - def test_graph_unique_names_manual(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - a: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_module", "linear_mod", args=(a,), name="foo_1_1" - ) - c: pippy.fx.Node = graph.create_node("get_attr", "y_attr", name="foo_1") - d: pippy.fx.Node = graph.create_node( - "call_function", operator.add, args=(b, c) - ) - graph.output(d) - graph2 = pippy.fx.Graph() - val_map: Dict[Node, Node] = {} - graph2.graph_copy(graph, val_map) - seen_names: Set[str] = set() - for node in graph2.nodes: - assert node.name not in seen_names - seen_names.add(node.name) - - def test_unpack(self): - class M(torch.nn.Module): - def forward(self, a, b): - c, d = a - return c + d + b - - a = (torch.rand(1), torch.rand(1)) - b = torch.rand(1) - m = M() - self.checkGraphModule(m, (a, b)) - - def test_native_callable(self): - if IS_FBCODE or IS_WINDOWS or IS_MACOS: - raise unittest.SkipTest( - "non-portable load_library call used in test" - ) - # This test exercises the case where we use FX to translate from Python - # code to some native callable object - # - # For the purposes of testing, we use ElementwiseInterpreter defined - # in test_custom_class.cpp. - # - # We test that we can - # 1) Construct a native callable from FX IR - # 2) Construct a drop-in replacement module that delegates to the - # native callable rather than the original code - # 3) Run both the original code and native callable wrapper with - # equivalent results - # 4) TorchScript compile the native callable wrapper and confirm - # equivalent results with the reference - # 5) TorchScript serialize and deserialize the native callable - # and confirm equivalent results with the reference - - # We use this simple Module as a reference computation - class MySimpleMod(torch.nn.Module): - def forward(self, x): - return 3.0 * x + x - - msm = MySimpleMod() - - # This is what a lowering pass might look like: a function that takes - # a valid nn.Module, symbolically traces it, lowers the Module to some - # representation, and wraps that representation up into another - # nn.Module instance that handles dispatch to the compiled/lowered code. - def lower_to_elementwise_interpreter( - orig_mod: torch.nn.Module, - ) -> torch.nn.Module: - # ===== Stage 1: Symbolic trace the module ===== - mod = symbolic_trace(orig_mod) - - # ===== Stage 2: Lower GraphModule representation to the C++ - # interpreter's instruction format ====== - instructions = [] - constant_idx = 0 - constants = {} - fn_input_names = [] - - target_to_name = {operator.add: "add", operator.mul: "mul"} - - output_node: Optional[Node] = None - # For each instruction, create a triple - # (instruction_name : str, inputs : List[str], output : str) - # to feed into the C++ interpreter - for n in mod.graph.nodes: - target, args, out_name = n.target, n.args, n.name - assert len(n.kwargs) == 0, "kwargs currently not supported" - - if n.op == "placeholder": - # Placeholders specify function argument names. Save these - # for later when we generate the wrapper GraphModule - fn_input_names.append(target) - elif n.op == "call_function": - assert target in target_to_name, ( - "Unsupported call target " + target - ) - arg_names = [] - for arg in args: - if not isinstance(arg, Node): - # Pull out constants. These constants will later be - # fed to the interpreter C++ object via add_constant() - arg_name = f"constant_{constant_idx}" - constants[arg_name] = torch.tensor( - [arg] - if isinstance(arg, numbers.Number) - else arg - ) - arg_names.append(arg_name) - constant_idx += 1 - else: - arg_names.append(arg.name) - instructions.append( - (target_to_name[target], arg_names, out_name) - ) - elif n.op == "output": - if output_node is not None: - raise RuntimeError("Multiple output nodes!") - output_node = n - else: - raise RuntimeError("Unsupported opcode " + n.op) - - interpreter = ( - torch.classes._TorchScriptTesting._ElementwiseInterpreter() - ) - # Load constants - for k, v in constants.items(): - interpreter.add_constant(k, v) - # Specify names for positional input arguments - interpreter.set_input_names(fn_input_names) - # Load instructions - interpreter.set_instructions(instructions) - # Specify name for single output - assert isinstance(output_node.args[0], pippy.fx.Node) - interpreter.set_output_name(output_node.args[0].name) - - # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== - class WrapperModule(torch.nn.Module): - def __init__(self, interpreter): - super().__init__() - self.interpreter = interpreter - - wrapper = WrapperModule(interpreter) - - # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter - # 3) Returns the speficied return value - - # FIXME: The following code could be greatly simplified by symbolic_trace'ing - # the wrapper with a Tracer that considers the Wrapper instance a root - # module, however, I can't get `__call__` exposed on TorchBind classes - # without it messing up Python `hasattr` for some reason. More digging - # into CPython's implementation of hasattr is probably in order... - - graph = pippy.fx.Graph() - # Add placeholders for fn inputs - placeholder_nodes = [] - for name in fn_input_names: - placeholder_nodes.append(graph.create_node("placeholder", name)) - - # Get the interpreter object - interpreter_node = graph.create_node("get_attr", "interpreter") - - # Add a node to call the interpreter instance - output_node = graph.create_node( - op="call_method", - target="__call__", - args=(interpreter_node, placeholder_nodes), - ) - - # Register output - graph.output(output_node) - - graph.lint() - - # Return final GraphModule!!! - return GraphModule(wrapper, graph) - - # Lower GraphModule to C++ interpreter - lowered = lower_to_elementwise_interpreter(msm) - - # Compare correctness with original module - x = torch.rand(3, 4) - ref_out = msm(x) - test_out = lowered(x) - torch.testing.assert_close(test_out, ref_out) - - # Test TorchScript compilation - scripted_lowered = torch.jit.script(lowered) - script_out = scripted_lowered(x) - torch.testing.assert_close(script_out, ref_out) - - # Test TorchScript ser/de - import_copy = self.getExportImportCopy(scripted_lowered) - imported_out = import_copy(x) - torch.testing.assert_close(imported_out, ref_out) - - def test_reserved_getattr(self): - """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" - - class M(torch.nn.Module): - def forward(self, a): - return a.foo.bar.baz - - m = M() - m_g = symbolic_trace(m) - m_g.graph.lint() - for node in m_g.graph.nodes: - self.assertTrue(node.name != "getattr") - - @unittest.skip("Hotfix for SEV remediation") - def test_trace_buffer_slice(self): - bs, d_hid = 10, 23 - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - return x - - ec = ExampleCode() - - traced = pippy.fx.symbolic_trace(ec) - - x = torch.randn(bs, d_hid) - torch.testing.assert_allclose(ec(x), traced(x)) - - def test_node_tagging(self): - class TaggingTracer(Tracer): - def create_node( - self, - kind: str, - target: Union[str, Callable], - args: Tuple[Argument, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - ) -> Node: - n = super().create_node(kind, target, args, kwargs, name) - n.tag = "foo" - return n - - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = M() - g = TaggingTracer().trace(m) - g.lint() - for n in g.nodes: - self.assertTrue(hasattr(n, "tag")) - self.assertEqual(n.tag, "foo") - - def test_tensor_attribute(self): - class TensorAttribute(torch.nn.Module): - def __init__(self): - super().__init__() - self.tensor = torch.rand(3, 4) - - def forward(self, x): - return torch.nn.functional.linear(x, self.tensor) - - ta = TensorAttribute() - traced = symbolic_trace(ta) - traced(torch.rand(4, 4)) - - class WrapperForQualname(torch.nn.Module): - def __init__(self): - super().__init__() - self.ta = TensorAttribute() - - def forward(self, x): - return torch.nn.functional.linear(x, self.ta.tensor) - - wfq = WrapperForQualname() - traced2 = symbolic_trace(wfq) - traced2.graph.lint() - traced2(torch.rand(4, 4)) - - def test_tensor_attribute_coalseced(self): - def count_attrs(fx_module): - targets = set() - for node in traced.graph.nodes: - if node.op == "get_attr": - targets.add(node.target) - return len(targets) - - val = torch.tensor(5) - - def f(x): - return x + val + val - - traced = symbolic_trace(f) - traced.graph.lint() - self.assertEqual(count_attrs(traced), 1) - - val2 = torch.tensor(5) - - def f(x): - val = torch.tensor(5) - return x + val + val2 - - traced = symbolic_trace(f) - traced.graph.lint() - self.assertEqual(count_attrs(traced), 2) - - def test_symbolic_trace_sequential(self): - class Simple(torch.nn.Module): - def forward(self, x): - return torch.neg(x) - - seq = torch.nn.Sequential(Simple(), Simple(), Simple()) - traced = symbolic_trace(seq) - traced.graph.lint() - x = torch.rand(3, 4) - self.assertEqual(traced(x), seq(x)) - - def test_tensor_constant(self): - class ConstTensor(torch.nn.Module): - def forward(self, x): - return torch.nn.functional.linear(x, torch.zeros(3, 4)) - - ct = ConstTensor() - traced = symbolic_trace(ct) - traced.graph.lint() - traced(torch.rand(4, 4)) - - def test_pickle_graphmodule(self): - class Nested(torch.nn.Module): - def __init__(self): - super().__init__() - self.st = torch.nn.Linear(4, 4) - - def forward(self, x): - return self.st(x) - - n = Nested() - traced = symbolic_trace(n) - traced.graph.lint() - pickled = pickle.dumps(traced) - loaded = pickle.loads(pickled) - loaded.graph.lint() - x = torch.rand(3, 4) - self.assertEqual(loaded(x), traced(x)) - - def test_pickle_custom_import(self): - graph = pippy.fx.Graph() - a = graph.placeholder("x") - b = graph.placeholder("y") - c = graph.call_function(a_non_torch_leaf, (a, b)) - d = graph.call_function(torch.sin, (c,)) - graph.output(d) - gm = GraphModule(torch.nn.Module(), graph) - pickled = pickle.dumps(gm) - loaded = pickle.loads(pickled) - loaded.graph.lint() - x, y = torch.rand(1), torch.rand(1) - self.assertEqual(loaded(x, y), gm(x, y)) - - def test_all_input_nodes(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - a: pippy.fx.Node = graph.placeholder("x") - b: pippy.fx.Node = graph.call_module("linear_mod", args=(a,)) - c: pippy.fx.Node = graph.get_attr("y_attr") - d: pippy.fx.Node = graph.call_function(operator.add, args=(b, c)) - e: pippy.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) - graph.output(e) - graph.lint() - - self.assertEqual(b.all_input_nodes, [a]) - self.assertEqual(c.all_input_nodes, []) - self.assertEqual(d.all_input_nodes, [b, c]) - self.assertEqual(e.all_input_nodes, [d]) - - def test_deepcopy_graphmodule_with_transform(self): - st = SimpleTest() - traced = symbolic_trace(st) - traced.graph.lint() - - def transform(traced): - new_graph = pippy.fx.Graph() - val_map: Dict[Node, Node] = {} - output_value = new_graph.graph_copy(traced.graph, val_map) - relu_out = new_graph.create_node( - op="call_method", target="neg", args=(output_value,), kwargs={} - ) - new_graph.output(relu_out) - return GraphModule(traced, new_graph) - - transformed = transform(traced) - transformed.graph.lint() - copied = copy.deepcopy(transformed) - self.assertNotEqual(id(type(transformed)), id(type(copied))) - x = torch.randn(3, 4) - self.assertEqual(copied(x), transformed(x)) - - def test_deepcopy_with_submods_params(self): - class Bar(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - - def forward(self, x): - return torch.relu(x) + self.param - - class Baz(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.bar = Bar() - - def forward(self, x): - return self.bar(x) - self.param - - baz = Baz() - traced = symbolic_trace(baz) - traced.graph.lint() - copied = copy.deepcopy(traced) - copied.graph.lint() - - def test_deepcopy_graph_with_tracer_cls(self): - class TestTracer(Tracer): - def is_leaf_module(self, module, name): - return True - - g = Graph(tracer_cls=TestTracer) - x = g.placeholder("x") - g.output(x) - - h = copy.deepcopy(g) - self.assertIsNotNone(h._tracer_cls) - self.assertTrue(g._tracer_cls == h._tracer_cls) - - def test_unpack_list_better_error(self): - class SomeArgs(torch.nn.Module): - def forward(self, a, b): - return torch.rand(3, 4) - - class UnpacksList(torch.nn.Module): - def __init__(self): - super().__init__() - self.sa = SomeArgs() - - def forward(self, x: list): - return self.sa(*x) - - ul = UnpacksList() - with self.assertRaisesRegex( - TraceError, "Proxy object cannot be iterated." - ): - symbolic_trace(ul) - - def test_unpack_dict_better_error(self): - class SomeKwargs(torch.nn.Module): - def forward(self, x=3, y=4): - return torch.rand(3, 4) - - class UnpacksDict(torch.nn.Module): - def __init__(self): - super().__init__() - self.sk = SomeKwargs() - - def forward(self, x: dict): - return self.sk(**x) - - ud = UnpacksDict() - with self.assertRaisesRegex( - TraceError, "Proxy object cannot be iterated." - ): - symbolic_trace(ud) - - def test_pretty_print_targets(self): - # Test that Graph pretty-print prints friendly name for targets - # in `operator` and `builtins` - - class SomeMod(torch.nn.Module): - def forward(self, x): - return torch.add(x.foo + x.bar, 3.0) - - traced = symbolic_trace(SomeMod()) - graph_str = str(traced.graph) - self.assertIn("builtins.getattr", graph_str) - self.assertIn("operator.add", graph_str) - self.assertIn("torch.add", graph_str) - - def test_pretty_print_node(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.param: torch.nn.Parameter = torch.nn.Parameter( - torch.rand(3, 4) - ) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x: torch.Tensor, y: int = 2): - return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0) - - traced = symbolic_trace(M()) - - all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes]) - - FileCheck().check("x").check("placeholder").check("y").check( - "placeholder" - ).check("getitem").check("call_function").check("param").check( - "get_attr" - ).check( - "add" - ).check( - "call_function" - ).check( - "linear" - ).check( - "call_module" - ).check( - "clamp" - ).check( - "call_method" - ).run( - all_formatted - ) - - def test_script_tensor_constant(self): - # TorchScript seems to ignore attributes that start with `__`. - # We used to call anonymous Tensor values `__tensor_constant*`, but - # they were getting ignored by script. Now they're called - # `_tensor_constant*` - class IHaveATensorConstant(torch.nn.Module): - def forward(self, x): - return x + torch.rand(3, 4) - - traced = pippy.fx.symbolic_trace(IHaveATensorConstant()) - torch.jit.script(traced) - - def test_autowrap_functions(self): - class AutowrapFnTest(torch.nn.Module): - def forward(self, x): - return fx_int(x.shape[0] / 2) - - class AutowrapFnTest2(torch.nn.Module): - def forward(self, x): - return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2) - - # Check function(s) are wrapped - # `int` would normally throw a TypeError as argument can't be `Proxy` - tracer = Tracer(autowrap_functions=(fx_int,)) - graph = tracer.trace(AutowrapFnTest()) - traced = GraphModule(tracer.root, graph, "test") - tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2)) - tracer_2.trace(AutowrapFnTest2()) - - # Test scriptability - traced_scripted = torch.jit.script(traced) - self.assertEqual(traced_scripted(torch.rand(4)), 2) - - def test_tuple_no_subscript(self): - def foo(x: Tuple): - return x[0] - - traced = pippy.fx.symbolic_trace(foo) - x = (torch.randn(5, 3),) - torch.testing.assert_allclose(traced(x), x[0]) - - bio = io.BytesIO() - - torch.save(traced, bio) - - bio.seek(0) - - loaded = torch.load(bio) - - torch.testing.assert_allclose(loaded(x), x[0]) - - def test_torch_fx_len(self): - class FXLenTest(torch.nn.Module): - def forward(self, x): - return len(x) - - traced = symbolic_trace(FXLenTest()) - self.assertEqual(traced(torch.rand(3, 4)), 3) - - # Test scriptability - scripted = torch.jit.script(FXLenTest()) - self.assertEqual(scripted(torch.rand(3)), 3) - - traced_scripted = torch.jit.script(traced) - self.assertEqual(traced_scripted(torch.rand(3)), 3) - - # Test non-proxy len - class FXLenTest2(torch.nn.Module): - def __init__(self): - super().__init__() - self.l = [3, 4, 5] - - def forward(self, x): - return x + len(self.l) - - traced2 = symbolic_trace(FXLenTest2()) - inp = torch.rand(3, 4) - self.assertEqual(traced2(inp), inp + 3.0) - self.assertIs(len, builtins.len) - - def test_torch_fx_getattr(self): - class FXGetattrTest(torch.nn.Module): - def forward(self, x): - return getattr(x, "nonexistent_attr", torch.Tensor([2, 3])) - - traced = symbolic_trace(FXGetattrTest()) - self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3])) - - def test_sqrt(self): - class Sqrt1(torch.nn.Module): - def forward(self, x): - return sqrt(x.size(0)) - - class Sqrt2(torch.nn.Module): - def forward(self, x): - return math.sqrt(x.size(0)) - - class Sqrt3(torch.nn.Module): - def forward(self, x): - return x + math.sqrt(2) + sqrt(2) - - self.checkGraphModule(Sqrt1(), [torch.zeros(8)]) - self.checkGraphModule(Sqrt2(), [torch.zeros(8)]) - self.checkGraphModule(Sqrt3(), [torch.zeros(8)]) - self.assertIs(sqrt, _sqrt) - self.assertIs(math.sqrt, _sqrt) - - def test_torch_custom_ops(self): - class M(torch.nn.Module): - def forward(self, a): - b = torch.ops.aten.sigmoid(a) - c = torch.ops.aten.cat([a, b]) - return torch.ops.aten.cat((c, c)) - - m = M() - input = torch.randn(3) - ref_out = m(input) - gm = symbolic_trace(m) - gm.graph.lint() - out = gm(input) - self.assertEqual(out, ref_out) - - def test_torch_op_overloads(self): - class M(torch.nn.Module): - def forward(self, a): - b = torch.ops.aten.add.Tensor(a, a) - return b - - m = M() - input = torch.randn(3) - ref_out = m(input) - gm = symbolic_trace(m) - gm.graph.lint() - out = gm(input) - self.assertEqual(out, ref_out) - - for node in gm.graph.nodes: - if node.op == "call_function": - assert isinstance(node.target, torch._ops.OpOverload) - assert node.target.__name__ == "add.Tensor" - - def test_pickle_torch_custom_ops(self): - class M(torch.nn.Module): - def forward(self, a): - b = torch.ops.aten.sigmoid(a) - c = torch.ops.aten.cat([a, b]) - return torch.ops.aten.cat((c, c)) - - m = M() - input = torch.randn(3) - ref_out = m(input) - gm = symbolic_trace(m) - gm.graph.lint() - pickled = pickle.dumps(gm) - loaded = pickle.loads(pickled) - self.assertEqual(loaded(input), gm(input)) - - def test_pretty_print(self): - st = SimpleTest() - traced = symbolic_trace(st) - traced.graph.lint() - printed = str(traced) - assert "SimpleTest()" in printed - assert "torch.relu" in printed - - def test_pretty_print_graph(self): - class KwargPrintTest(torch.nn.Module): - def forward(self, x): - return torch.squeeze(x + 3.0, dim=2) - - st = KwargPrintTest() - traced = symbolic_trace(st) - traced.graph.lint() - stringed = str(traced.graph) - for s in ["args", "kwargs", "#users"]: - assert s in stringed - - def test_custom_proxy_type(self): - class TensorPair: - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair(x: TensorPair, y: TensorPair): - s = x.add(y) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - y = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - - ref_out = use_tensor_pair(x, y) - - traced = symbolic_trace(use_tensor_pair) - - traced_out = traced(x, y) - self.assertEqual(traced_out.left, ref_out.left) - self.assertEqual(traced_out.right, ref_out.right) - - def test_custom_proxy_type_literal(self): - class TensorPair(metaclass=pippy.fx.ProxyableClassMeta): - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair_literal(x: TensorPair): - s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3))) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - - ref_out = use_tensor_pair_literal(x) - - traced = symbolic_trace(use_tensor_pair_literal) - - traced_out = traced(x) - self.assertEqual(traced_out.left, ref_out.left) - self.assertEqual(traced_out.right, ref_out.right) - - def test_custom_proxy_dynamic_value(self): - class TensorPair(metaclass=pippy.fx.ProxyableClassMeta): - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor): - s = x.add(TensorPair(y, y)) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - y = torch.randn(5, 3) - ref_out = use_tensor_pair_ctor(x, y) - - traced = symbolic_trace(use_tensor_pair_ctor) - - traced_out = traced(x, y) - self.assertEqual(traced_out.left, ref_out.left) - self.assertEqual(traced_out.right, ref_out.right) - - def test_custom_proxy_input_dependent_control_flow(self): - class ZeroTensor(metaclass=pippy.fx.ProxyableClassMeta): - def __init__(self, inp): - if inp.sum() == 0: - self.is_zero = True - self.tensor = torch.tensor([]) - else: - self.is_zero = False - self.tensor = inp - - def add(self, other): - if self.is_zero: - return ZeroTensor(other.tensor) - elif other.is_zero: - return self - - def use_zero_tensor(x: torch.Tensor, y: torch.Tensor): - return ZeroTensor(x + y) - - x, y = torch.randn(5, 3), torch.randn(5, 3) - - ref_out = use_zero_tensor(x, y) - - traced = symbolic_trace(use_zero_tensor) - - traced_out = traced(x, y) - - self.assertEqual(traced_out.is_zero, ref_out.is_zero) - self.assertEqual(traced_out.tensor, ref_out.tensor) - - def test_graph_fns(self): - g = Graph() - a = g.placeholder("a") - b = g.call_module("linear", (a,)) - c = g.get_attr("bias") - d = g.call_method("add", (b, c)) - e = g.call_function(torch.sin, (d,)) - g.output(e) - mod = torch.nn.Module() - mod.linear = torch.nn.Linear(3, 4) - mod.bias = torch.rand(4) - gm = GraphModule(mod, g) - gm.graph.lint() - input = torch.rand(3) - r = gm(input) - ref = torch.sin(mod.linear(input) + mod.bias) - self.assertEqual(r, ref) - - def test_remove_uses(self): - g: pippy.fx.Graph = Graph() - x: pippy.fx.Node = g.placeholder("x") - relu: pippy.fx.Node = g.call_function(torch.relu, (x,)) - neg: pippy.fx.Node = g.call_function(torch.neg, (relu,)) - g.output(neg) - - neg.replace_all_uses_with(relu) - g.erase_node(neg) - - self.assertTrue(neg not in relu.users) - - def test_remove_uses_with_custom_filter(self): - g: pippy.fx.Graph = Graph() - x: pippy.fx.Node = g.placeholder("x") - relu: pippy.fx.Node = g.call_function(torch.relu, (x,)) - neg: pippy.fx.Node = g.call_function(torch.neg, (relu,)) - g.output(neg) - - neg.replace_all_uses_with(relu, lambda x: x != neg) - - self.assertTrue(neg in relu.users) - - def test_nonetype_annotation(self): - eb = torch.nn.EmbeddingBag(3, 4) - symbolic_trace(eb) - - def test_pickle_nonetype_annotation(self): - eb = torch.nn.EmbeddingBag(10, 3, mode="sum") - traced = symbolic_trace(eb) - pickled = pickle.dumps(traced) - loaded = pickle.loads(pickled) - loaded.graph.lint() - input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) - offsets = torch.LongTensor([0, 4]) - self.assertEqual(loaded(input, offsets), traced(input, offsets)) - - def test_return_tuple(self): - class M(torch.nn.Module): - def forward( - self, x: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - return (x, x + x) - - original = M() - traced = symbolic_trace(original) - self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1))) - - def test_construct_root_dict(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - a: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_module", "foo.bar.baz", args=(a,) - ) - c: pippy.fx.Node = graph.create_node("get_attr", "zip.zap.zam") - d: pippy.fx.Node = graph.create_node( - "call_function", operator.add, args=(b, c) - ) - graph.output(d) - - linear_mod: torch.nn.Module = torch.nn.Linear(3, 4) - add_param: torch.Tensor = torch.rand(3, 4) - gm: pippy.fx.GraphModule = pippy.fx.GraphModule( - {"foo.bar.baz": linear_mod, "zip.zap.zam": add_param}, graph - ) - gm.graph.lint() - - assert "self.foo.bar.baz" in gm.code - - x: torch.Tensor = torch.rand(3, 3) - out: torch.Tensor = gm(x) - ref_out: torch.Tensor = linear_mod(x) + add_param - self.assertEqual(out, ref_out) - - def test_symbolic_trace_assert(self): - class AssertsTensorShape(torch.nn.Module): - def forward(self, x): - torch._assert(x.shape[1] > 4, "assert_foobar") - return x - - m = AssertsTensorShape() - # verify traceability - traced = symbolic_trace(m) - # verify assertion on traced model works correctly at runtime - traced(torch.rand(4, 5)) - with self.assertRaisesRegex(AssertionError, "assert_foobar"): - traced(torch.rand(4, 3)) - # verify the symbolically traced module is scriptable - ms = torch.jit.script(m) - with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"): - ms(torch.rand(4, 3)) - - def test_fx_create_arg(self): - class CustomArgObject: - def __init__(self, x, y): - self.x = x - self.y = y - - def __fx_create_arg__(self, tracer: pippy.fx.Tracer): - return tracer.create_node( - "call_function", - CustomArgObject, - args=( - tracer.create_arg(self.x), - tracer.create_arg(self.y), - ), - kwargs={}, - ) - - class HasCustomArgObjectWhenLeaf(torch.nn.Module): - def forward(self, o: CustomArgObject): - # Not normally traceable; good reason to make - # this module a leaf. - for x in o.x: - o.y += x - return o.y - - class Root(torch.nn.Module): - def __init__(self): - super().__init__() - self.inner = HasCustomArgObjectWhenLeaf() - - def forward(self, x, y): - o = CustomArgObject(x, y) - return self.inner(o) - - class CreateArgTracer(pippy.fx.Tracer): - def is_leaf_module(self, m, module_qualified_name): - return type(m) is HasCustomArgObjectWhenLeaf - - m = Root() - graph = CreateArgTracer().trace(m) - gm = pippy.fx.GraphModule(m, graph) - assert "CustomArgObject(" in gm.code - - def test_trace_fn_constant(self): - some_constant = torch.rand(3, 4) - - def add_const(x): - return some_constant + x - - traced = symbolic_trace(add_const) - - input = torch.rand(3, 4) - self.assertEqual(traced(input), add_const(input)) - - def test_copy_no_remap(self): - traced = symbolic_trace(SimpleTest()) - g = traced.graph - copied = pippy.fx.Graph() - for node in g.nodes: - copied.node_copy(node) - with self.assertRaisesRegex( - RuntimeError, "does not belong to this Graph" - ): - copied.lint() - - def test_wrong_topo(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - a: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_module", "foo.bar.baz", args=(a,) - ) - c: pippy.fx.Node = graph.create_node("get_attr", "zip.zap.zam") - d: pippy.fx.Node = graph.create_node( - "call_function", operator.add, args=(b, c) - ) - graph.output(d) - nodes = list(graph.nodes) - nodes[3].append(nodes[2]) - with self.assertRaisesRegex( - RuntimeError, "was used before it has been defined" - ): - graph.lint() - - def test_wrong_target_type(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - with self.assertRaises(ValueError): - n = pippy.fx.Node( - graph=graph, - name="foo", - op="call_function", - target="foo", - args=(), - kwargs={}, - ) - - def test_example_shape_prop(self): - class TestCase(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.randn(3, 4) - self.submod = torch.nn.Linear(4, 4) - - def forward(self, x): - return torch.neg(self.submod(x.relu() + self.attr)) - - tc = TestCase() - tc_traced = symbolic_trace(tc) - ref_out = tc_traced(torch.rand(3, 4)) - shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4)) - - # Make sure we're testing all opcodes - opcodes = set() - output_shape: Optional[torch.Shape] = None - output_stride: Optional[Tuple[int]] = None - for node in tc_traced.graph.nodes: - opcodes.add(node.op) - if node.op == "output": - output_shape = node.args[0].meta["tensor_meta"].shape - output_stride = node.args[0].meta["tensor_meta"].stride - self.assertEqual( - opcodes, - set( - [ - "placeholder", - "get_attr", - "call_function", - "call_method", - "call_module", - "output", - ] - ), - ) - - # Test shape propagation and make sure results match actual - self.assertEqual(output_shape, ref_out.shape) - self.assertEqual(output_stride, ref_out.stride()) - - def test_shape_prop_layout(self): - class ConvTest(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv_mod = torch.nn.Conv2d(5, 5, 3) - - def forward(self, x): - return self.conv_mod(x) - - # contiguous layout - test_mod = ConvTest() - traced = symbolic_trace(test_mod) - x = torch.randn(5, 5, 224, 224) - shape_prop.ShapeProp(traced).propagate(x) - - assert all( - node.meta["tensor_meta"].memory_format is torch.contiguous_format - for node in traced.graph.nodes - ) - - x_channels_last = x.contiguous(memory_format=torch.channels_last) - traced.to(memory_format=torch.channels_last) - shape_prop.ShapeProp(traced).propagate(x_channels_last) - for node in traced.graph.nodes: - # NB: the implementation of conv may not preserve the memory format, - # unfortunately. The best we can do is just check that the placeholder - # node is channels-last - if node.op in {"placeholder"}: - self.assertEqual( - node.meta["tensor_meta"].memory_format, torch.channels_last - ) - - def test_shape_prop_aggregate(self): - class ReturnTwo(torch.nn.Module): - def forward(self, x): - return (3, torch.sum(x)) - - class UnderTest(torch.nn.Module): - def __init__(self): - super().__init__() - self.rt = ReturnTwo() - - def forward(self, x): - return self.rt(x) - - ut = UnderTest() - - class RTTracer(pippy.fx.Tracer): - def is_leaf_module(self, m, module_qualified_name): - return type(m) is ReturnTwo - - graph = RTTracer().trace(ut) - mod = pippy.fx.GraphModule(ut, graph) - - shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4)) - - for node in mod.graph.nodes: - if node.op == "call_module": - assert "tensor_meta" in node.meta - tensor_meta = node.meta["tensor_meta"] - assert tensor_meta[0] == 3 - assert tensor_meta[1].shape == torch.Size([]) - - def test_shape_prop_layout_3d(self): - class ConvTest3d(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv_mod = torch.nn.Conv3d(5, 5, 3) - - def forward(self, x): - return self.conv_mod(x) - - test_mod_3d = ConvTest3d() - traced_3d = symbolic_trace(test_mod_3d) - x_3d = torch.randn(5, 5, 224, 224, 15) - shape_prop.ShapeProp(traced_3d).propagate(x_3d) - assert all( - node.meta["tensor_meta"].memory_format is torch.contiguous_format - for node in traced_3d.graph.nodes - ) - - x_channels_last_3d = x_3d.contiguous( - memory_format=torch.channels_last_3d - ) - traced_3d.to(memory_format=torch.channels_last_3d) - shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d) - for node in traced_3d.graph.nodes: - # NB: the implementation of conv may not preserve the memory format, - # unfortunately. The best we can do is just check that the placeholder - # node is channels-last - if node.op in {"placeholder"}: - self.assertEqual( - node.meta["tensor_meta"].memory_format, - torch.channels_last_3d, - ) - - def test_interpreter(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - interpreter = Interpreter(gm) - input = torch.randn(3, 4) - self.assertEqual(interpreter.run(input), gm(input)) - self.assertEqual(interpreter.run(input), m(input)) - - def test_interpreter_run_node_override(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - class RunNodeInterpreter(Interpreter): - def __init__(self, module): - super().__init__(module) - - def run_node(self, n: Node) -> Any: - result = super().run_node(n) - n.cached_value = result - return result - - input = torch.randn(3, 4) - RunNodeInterpreter(gm).run(input) - for node in gm.graph.nodes: - assert hasattr(node, "cached_value") - - def test_interpreter_onthefly_swap(self): - def fn(x): - return torch.sigmoid(x).neg() - - gm = pippy.fx.symbolic_trace(fn) - - class NegSigmSwapInterpreter(Interpreter): - def call_function( - self, target: Target, args: Tuple, kwargs: Dict - ) -> Any: - if target == torch.sigmoid: - return torch.neg(*args, **kwargs) - return super().call_function(n) - - def call_method( - self, target: Target, args: Tuple, kwargs: Dict - ) -> Any: - if target == "neg": - call_self, *args_tail = args - return call_self.sigmoid(*args_tail, **kwargs) - return super().call_method(n) - - input = torch.randn(3, 4) - result = NegSigmSwapInterpreter(gm).run(input) - self.assertEqual(result, torch.neg(input).sigmoid()) - - def test_interpreter_partial_eval(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - gm = pippy.fx.symbolic_trace(MyModule()) - interp = Interpreter(gm) - env = {} - for node in gm.graph.nodes: - if node.op == "call_module" and node.target == "linear": - env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 - break - assert len(env) == 1 - x = torch.randn(3, 4) - result = interp.run(x, initial_env=env) - self.assertEqual( - result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0) - ) - - def test_interpreter_star_args(self): - def with_star_args(x, *args): - return x + args[0] - - gm = pippy.fx.symbolic_trace(with_star_args) - interp = Interpreter(gm) - result = interp.run( - torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4) - ) - self.assertEqual(result, torch.ones(3, 4) * 2.0) - - @skipIfNoTorchVision - def test_interpreter_noop_resnet18(self): - rn18 = torchvision_models.resnet18() - transformed = pippy.fx.Transformer(symbolic_trace(rn18)).transform() - inp = torch.randn(5, 3, 224, 224) - self.assertEqual(transformed(inp), rn18(inp)) - - @skipIfNoTorchVision - def test_interpreter_gc_values(self): - rn18 = torchvision_models.resnet18() - interp = Interpreter(symbolic_trace(rn18)) - inp = torch.rand(5, 3, 224, 224) - out = interp.run(inp) - env_key_names = set(n.name for n in interp.env.keys()) - self.assertEqual(env_key_names, set(["output"])) - - def test_interpreter_default_args(self): - class Model(torch.nn.Module): - def forward(self, x, y=3.14159): - return x + y - - model = Model() - gm = pippy.fx.symbolic_trace(model) - - interp = Interpreter(gm) - x = torch.randn(5, 3) - out = interp.run(x) - torch.testing.assert_allclose(out, x + 3.14159) - - def test_interpreter_not_enough_args(self): - class Model(torch.nn.Module): - def forward(self, x, y): - return x + y - - model = Model() - gm = pippy.fx.symbolic_trace(model) - - interp = Interpreter(gm) - x = torch.randn(5, 3) - with self.assertRaisesRegex( - RuntimeError, - "Expected positional argument for parameter y, but one was not passed in", - ): - out = interp.run(x) - - def test_transformer_noop(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - new_gm = Transformer(gm).transform() - - input = torch.randn(3, 4) - self.assertEqual(new_gm(input), gm(input)) - - def test_transformer_op_swap(self): - def fn(x): - return torch.sigmoid(x).neg() - - gm = pippy.fx.symbolic_trace(fn) - - class NegSigmSwapXformer(Transformer): - def call_function( - self, target: Target, args: Tuple, kwargs: Dict - ) -> Any: - if target == torch.sigmoid: - return torch.neg(*args, **kwargs) - return super().call_function(n) - - def call_method( - self, target: Target, args: Tuple, kwargs: Dict - ) -> Any: - if target == "neg": - call_self, *args_tail = args - return call_self.sigmoid(*args_tail, **kwargs) - return super().call_method(n) - - transformed = NegSigmSwapXformer(gm).transform() - input = torch.randn(3, 4) - self.assertEqual(transformed(input), torch.neg(input).sigmoid()) - - def test_transformer_multi_outputs(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - x = x + self.param - out = self.linear(x) - return x, out - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - new_gm = Transformer(gm).transform() - - input = torch.randn(3, 4) - self.assertEqual(new_gm(input), gm(input)) - - def test_fn_type_annotations(self): - class Foo(torch.nn.Module): - def forward( - self, p: Pair, z: torch.Tensor, i: int - ) -> Dict[str, torch.Tensor]: - return {"a": p.x + p.y + z + i} - - foo_scripted = torch.jit.script(Foo()) - foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) - - fxed = symbolic_trace(Foo()) - fxed_scripted = torch.jit.script(fxed) - fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) - - def test_fn_type_annotation_empty(self): - def forward(a: List[torch.Tensor]): - return a[0] - - torch.jit.script(symbolic_trace(forward)) - - def test_wrapped_method(self): - def wrap_with_relu(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - return torch.relu(fn(*args, **kwargs)) - - return wrapper - - class Foo(torch.nn.Module): - @wrap_with_relu - def forward(self, x, w): - return torch.matmul(x, w) - - f = Foo() - traced = symbolic_trace(f) - x, w = torch.rand(3, 4), torch.rand(4, 4) - self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) - - def test_empty_graph_codegen(self): - graph = pippy.fx.Graph() - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - self.assertEqual(gm(), None) - - def test_sequential(self): - m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) - gm = pippy.fx.symbolic_trace(m) - gm_copy = copy.deepcopy(gm) - - def test_ctx_mgr(self): - @contextlib.contextmanager - def do_nothing(): - yield - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - @do_nothing() - def forward(self, x): - return torch.relu(x) - - m = M() - self.checkGraphModule(m, (torch.rand(3, 4),)) - - def test_typename_print(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,), type_expr=List[float] - ) - output: pippy.fx.Node = graph.output(b) - - self.assertTrue("typing.List[float]" in str(graph)) - - def test_layout(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.empty_like( - x, layout=torch.strided, pin_memory=False - ).fill_(0) - - traced = symbolic_trace(M()) - x = torch.rand(5, 9, 3, 4) - self.assertEqual(traced(x), torch.zeros_like(x)) - - def test_ellipsis(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x + y[:, 1:10, ...] - - traced = symbolic_trace(M()) - x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4) - self.assertEqual(traced(x, y), x + y[:, 1:10, ...]) - - def test_inf_nan(self): - class FooMod(torch.nn.Module): - def forward(self, x): - return x + float("inf"), x + float("-inf"), x + float("nan") - - fm = FooMod() - self.checkGraphModule(fm, (torch.rand(3, 4),)) - - def test_inf_nan_kwds(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", operator.add, (x, float("inf")), {}, name="inf" - ) - c: pippy.fx.Node = graph.create_node( - "call_function", operator.add, (x, float("nan")), {}, name="nan" - ) - graph.output((b, c)) - - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - x = torch.rand(3, 4) - self.assertEqual(gm(x), (x + float("inf"), x + float("nan"))) - - def test_deepcopy_recursion_depth(self): - depth = sys.getrecursionlimit() + 20 - - g = pippy.fx.Graph() - x = g.placeholder("x") - for i in range(depth): - x = g.call_function(torch.relu, (x,)) - g.output(x) - - copied_graph = copy.deepcopy(g) - - val_map = {} - for orig_node, new_node in zip(g.nodes, copied_graph.nodes): - val_map[orig_node] = new_node - - for orig_node, new_node in zip(g.nodes, copied_graph.nodes): - orig_users = set(orig_node.users.keys()) - orig_users_equiv = set(val_map[u] for u in orig_users) - new_users = set(new_node.users.keys()) - self.assertEqual(orig_users_equiv, new_users) - - @skipIfNoTorchVision - def test_replace_uses(self): - rn18 = torchvision_models.resnet18() - - class LowerReluTracer(pippy.fx.Tracer): - def is_leaf_module(self, m: torch.nn.Module, qualname: str): - if isinstance(m, torch.nn.ReLU): - return False - return super().is_leaf_module(m, qualname) - - rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18)) - - to_erase = [] - for node in rn18_traced.graph.nodes: - if node.op == "call_function" and node.target in [ - torch.relu, - torch.nn.functional.relu, - ]: - kwargs = node.kwargs.copy() - # Neg doesn't have in-place - kwargs.pop("inplace") - with rn18_traced.graph.inserting_before(node): - new_node = rn18_traced.graph.call_function( - the_function=torch.neg, - args=node.args, - kwargs=node.kwargs, - ) - node.replace_all_uses_with(replace_with=new_node) - to_erase.append(node) - - for node in to_erase: - rn18_traced.graph.erase_node(node) - - def test_replace_input(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - y: pippy.fx.Node = graph.create_node("placeholder", "y") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - b.replace_input_with(x, y) - - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - - input_x = torch.randn(33, 44) - input_y = torch.randn(11, 22) - self.assertEqual(gm(input_x, input_y), torch.relu(input_y)) - - def test_insertion_point(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - with graph.inserting_before(b): - neg: pippy.fx.Node = graph.call_function( - the_function=torch.neg, args=(x,) - ) - _, *relu_args = b.args - b.args = (neg, *relu_args) - - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - - input = torch.randn(33, 44) - self.assertEqual(gm(input), torch.relu(torch.neg(input))) - - def test_update_args_api(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - y: pippy.fx.Node = graph.create_node("placeholder", "y") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - orig_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) - self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) - - b.update_arg(0, y) - new_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) - - def test_update_kwargs_api(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - y: pippy.fx.Node = graph.create_node("placeholder", "y") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, kwargs={"input": x} - ) - output: pippy.fx.Node = graph.output(b) - - orig_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) - self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) - - b.update_kwarg("input", y) - new_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) - - def test_immutable_list_pytree_ops(self): - rand_tensor = torch.randn(5, 3) - l = immutable_list([3, [rand_tensor, 42]]) - - flattened, spec = pytree.tree_flatten(l) - assert flattened == [3, rand_tensor, 42] - - unflattened = pytree.tree_unflatten(flattened, spec) - assert unflattened == l - assert isinstance(unflattened, immutable_list) - - def test_immutable_dict_pytree_ops(self): - rand_tensor = torch.randn(5, 3) - d = immutable_dict({"a": 3, "b": [rand_tensor, 42]}) - - flattened, spec = pytree.tree_flatten(d) - assert flattened == [3, rand_tensor, 42] - - unflattened = pytree.tree_unflatten(flattened, spec) - assert unflattened == d - assert isinstance(unflattened, immutable_dict) - - def test_move_before(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - neg: pippy.fx.Node = graph.call_function( - the_function=torch.neg, args=(x,) - ) - _, *relu_args = b.args - b.args = (neg, *relu_args) - b.prepend(neg) - - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - - input = torch.randn(33, 44) - self.assertEqual(gm(input), torch.relu(torch.neg(input))) - - def test_prepend_self(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - b.prepend(b) - x.append(b) - self.assertEqual(len(graph.nodes), 3) - - def test_erase_node_error(self): - st = SimpleTest() - traced = symbolic_trace(st) - - for node in traced.graph.nodes: - # Test deleting with uses both in another Node and at the output - if node.target in [operator.add, torch.relu]: - with self.assertRaisesRegex( - RuntimeError, "but it still had .* users in the graph" - ): - traced.graph.erase_node(node) - - def test_copy_it(self): - d = immutable_dict([(3, 4), (5, 6)]) - l = immutable_list([(3, 4), (5, 6)]) - - self.assertEqual(d, deepcopy(d)) - self.assertEqual(l, deepcopy(l)) - - def test_get_torch_func_signature(self): - for key in dir(torch): - obj = getattr(torch, key) - if callable(obj): - schemas = get_signature_for_torch_op(obj) - - def test_find_uses(self): - graph = pippy.fx.Graph() - x = pippy.fx.Proxy(graph.placeholder("x")) - - y = torch.relu(x) - z = x + x - u = torch.neg(x) - graph.output((y + z + u).node) - graph.lint() - - users_of_x = x.node.users - self.assertEqual(len(users_of_x), 3) - expected_ops = set(["relu", "add", "neg"]) - for use in users_of_x: - assert any(use.name.startswith(prefix) for prefix in expected_ops) - - def test_inline_graph(self): - class InlineInto(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - class ToInline(torch.nn.Module): - def forward(self, x): - return torch.neg(x) - - inline_into = symbolic_trace(InlineInto()) - to_inline = symbolic_trace(ToInline()) - - combined_graph = pippy.fx.Graph() - output_node = combined_graph.graph_copy(inline_into.graph, {}) - - input_node = list(to_inline.graph.nodes)[0] - assert input_node and input_node.op == "placeholder" - - val_map = {input_node: output_node} - output = combined_graph.graph_copy(to_inline.graph, val_map) - combined_graph.output(output) - - combined_module = pippy.fx.GraphModule( - torch.nn.Module(), combined_graph - ) - - input = torch.rand(3, 4) - self.assertEqual(combined_module(input), input.relu().neg()) - - def test_multi_insert_point(self): - graph = pippy.fx.Graph() - x = pippy.fx.Proxy(graph.placeholder("x")) - relu = torch.relu(x) - - with graph.inserting_before(relu.node): - y = torch.neg(x) - z = torch.tanh(y) - - graph.output((relu.node, z.node)) - graph.lint() - - expected_ops = ["x", "neg", "tanh", "relu"] - for node, expected in zip(graph.nodes, expected_ops): - assert expected in node.name - - def test_reassign_args_kwargs_uses(self): - graph = pippy.fx.Graph() - x, y = Proxy(graph.placeholder("x")), Proxy(graph.placeholder("y")) - z = x + y - zed = z + z + z - graph.output(zed.node) - graph.lint() - - # zed = z + z + z -> zed = z + z + x - zed.node.args = (zed.node.args[0], x.node) - self.assertEqual(list(x.node.users.keys()), [z.node, zed.node]) - - # z = x + y -> z = y + y - z.node.args = (y.node, y.node) - self.assertEqual(list(x.node.users.keys()), [zed.node]) - - def test_trace_function(self): - def foo(x, y): - return torch.relu(x) + y - - x, y = torch.randn(3, 4), torch.randn(3, 4) - self.checkGraphModule(foo, (x, y)) - - def test_trace_dict_int_keys(self): - class ModWithDictArg(torch.nn.Module): - def forward(self, d: Dict[int, torch.Tensor]): - return d[42] - - class CallsModWithDict(torch.nn.Module): - def __init__(self): - super().__init__() - self.m = ModWithDictArg() - - def forward(self, x): - return self.m({42: x}) - - class MyTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - return isinstance(m, ModWithDictArg) - - traced_graph = MyTracer().trace(CallsModWithDict()) - - def test_trace_dict_proxy_keys(self): - class ModWithDictArg(torch.nn.Module): - def forward(self, d: Dict[torch.Tensor, torch.Tensor]): - return d[42] - - class CallsModWithDict(torch.nn.Module): - def __init__(self): - super().__init__() - self.m = ModWithDictArg() - - def forward(self, x): - return self.m({x: x}) - - class MyTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - return isinstance(m, ModWithDictArg) - - with self.assertRaisesRegex(RuntimeError, "cannot contain a Node"): - traced_graph = MyTracer().trace(CallsModWithDict()) - - def test_module_deepcopy_edit_nodes(self): - class Foo(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - traced1 = symbolic_trace(Foo()) - copied = copy.deepcopy(traced1) - - for node in copied.graph.nodes: - if node.target == torch.relu: - node.target = torch.neg - - copied.recompile() - traced1.recompile() - - x = torch.randn(15, 15) - torch.testing.assert_allclose(traced1(x), torch.relu(x)) - torch.testing.assert_allclose(copied(x), torch.neg(x)) - - def test_direct_param_use(self): - class TransposeTest(torch.nn.Module): - def __init__(self): - super().__init__() - self.b = torch.nn.Parameter(torch.rand(4, 3)) - - def forward(self, x): - return self.b - - class Foo(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = TransposeTest() - - def forward(self, x): - return self.a.b, self.a.b.t(), self.a.b.view(12) - - traced = pippy.fx.symbolic_trace(Foo()) - assert all("constant" not in node.target for node in traced.graph.nodes) - - def test_single_default_arg(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, y=1): - return y - - m = M() - self.checkGraphModule(m, ()) - self.checkGraphModule(m, (3,)) - - def test_multiple_default_args(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, y=1, z=2): - return y + z - - m = M() - self.checkGraphModule(m, ()) - self.checkGraphModule(m, (3,)) - self.checkGraphModule(m, (3, 4)) - - def test_regular_and_default_args(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y=1): - return x + y - - m = M() - self.checkGraphModule(m, (2,)) - self.checkGraphModule(m, (2, 3)) - - def test_string_literal_return(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self): - return "foo" - - m = M() - self.checkGraphModule(m, ()) - - def test_namedtuple_return_qualname(self): - class NamedTupReturn(torch.nn.Module): - def forward(self, x): - return MyNamedTup(x, x) - - traced = symbolic_trace(NamedTupReturn()) - input = torch.rand(3, 4) - self.assertEqual(traced(input), MyNamedTup(input, input)) - - def test_update_args_kwargs_yells_at_you(self): - symtraced = symbolic_trace(SimpleTest()) - node = next(iter(symtraced.graph.nodes)) - with self.assertRaisesRegex(AttributeError, "__update_args_kwargs"): - node.__update_args_kwargs((), {}) - - def test_torchbind_class_attribute_in_fx(self): - if IS_FBCODE or IS_WINDOWS or IS_MACOS: - self.skipTest( - "torch.classes._TorchScriptTesting._StackString is registered, skipping" - ) - - class FooBar1234(torch.nn.Module): - def __init__(self): - super(FooBar1234, self).__init__() - self.f = torch.classes._TorchScriptTesting._StackString( - ["3", "4"] - ) - - def forward(self): - return self.f.top() - - m = FooBar1234() - self.checkGraphModule(m, ()) - - def test_torchbind_class_attribute_in_fx_tensor_arg(self): - if IS_FBCODE or IS_WINDOWS or IS_MACOS: - self.skipTest( - "torch.classes._TorchScriptTesting._ReLUClass is registered, skipping" - ) - - class FooBar2341(torch.nn.Module): - def __init__(self): - super(FooBar2341, self).__init__() - self.f = torch.classes._TorchScriptTesting._ReLUClass() - - def forward(self, x): - return self.f.run(x) - - m = FooBar2341() - - traced = symbolic_trace(m) - input = torch.randn(3, 4) - self.assertEqual(traced(input), m(input)) - - self.assertTrue(any(n.op == "call_method" for n in traced.graph.nodes)) - - def test_script_method_trace(self): - class Scripted(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - class Holder(torch.nn.Module): - def __init__(self): - super().__init__() - self.s = torch.jit.script(Scripted()) - - def forward(self, x): - return self.s(x) - - h = Holder() - traced = symbolic_trace(h) - input = torch.randn(3, 4) - self.assertEqual(traced(input), h(input)) - - self.assertTrue(any(n.op == "call_method" for n in traced.graph.nodes)) - - def test_namedtuple_return_trace(self): - class NamedTupReturn(torch.nn.Module): - def forward(self, x): - return Pair(x, x) - - traced = symbolic_trace(NamedTupReturn()) - input = torch.rand(3, 4) - self.assertEqual(traced(input), Pair(input, input)) - - def test_named_tuple_inlined(self): - class NamedTupMod(torch.nn.Module): - def forward(self, inp): - return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp)) - - m = NamedTupMod() - input = torch.rand(3, 4) - ref = m(input) - traced = symbolic_trace(m) - - res = traced(input) - self.assertEqual(ref, res) - - # Check Pair NamedTuple works when inlined into the function call. - ph = call_func = None - for node in traced.graph.nodes: - if node.op == "placeholder": - ph = node - elif ( - node.op == "call_function" and node.target == wrapped_named_tup - ): - node.update_arg(0, Pair(ph, 1.2)) - node.update_kwarg("p2", Pair(3.4, ph)) - call_func = node - break - self.assertTrue(call_func is not None) - self.assertTrue(isinstance(call_func.args[0], Pair)) - self.assertTrue(isinstance(call_func.kwargs["p2"], Pair)) - self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)") - self.assertEqual( - _format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)" - ) - - traced.graph.eliminate_dead_code() - traced.recompile() - res = traced(input) - self.assertEqual(ref, res) - - def test_return_type_exists(self): - class ReturnTypeModule(torch.nn.Module): - def other(self, x: List[str]) -> List[str]: - return x - - def forward(self, x: List[str]) -> List[str]: - return self.other(x) - - traced = symbolic_trace(ReturnTypeModule()) - self.assertIn("-> typing_List[str]", traced._code) - scripted = torch.jit.script(traced) - self.assertIn("-> List[str]", scripted.code) - - def getitem_inner(self): - class GetItemBase(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("pe", torch.randn(8, 8)) - - class GetItem1(GetItemBase): - def forward(self, x): - return self.pe[:, : x.size(0)] - - class GetItem2(GetItemBase): - def forward(self, x): - return self.pe[x.size(0)] - - class GetItem3(GetItemBase): - def forward(self, x): - return self.pe[4] # fx creates `self._tensor_constant0` here - - self.checkGraphModule(GetItem1(), [torch.zeros(4)]) - self.checkGraphModule(GetItem2(), [torch.zeros(4)]) - self.checkGraphModule(GetItem3(), [torch.zeros(4)]) - - @unittest.skipUnless( - os.environ.get("FX_PATCH_GETITEM") == "1", - "Will be checked in test_getitem_subproc", - ) - def test_getitem(self): - self.getitem_inner() - - def test_getitem_subproc(self): - # need to run this test in a subproc to work around: - # https://github.com/pytorch/pytorch/issues/50710 - proc = Process(target=run_getitem_target) - proc.start() - proc.join() - self.assertEqual(proc.exitcode, 0) - - def test_user_friendly_call_provenance_with_function(self): - def fn(x): - return wrapper_fn(x) - - traced = pippy.fx.symbolic_trace(fn) - - with self.assertRaisesRegex( - RuntimeError, - "'wrapper_fn' is " - "being compiled since it was called" - " from 'fn.forward'", - ): - scripted = torch.jit.script(traced) - - def test_user_friendly_call_provenance_with_module(self): - class M(torch.nn.Module): - def forward(self, x): - return wrapper_fn(x) - - traced = pippy.fx.symbolic_trace(M()) - - with self.assertRaisesRegex( - RuntimeError, - "'wrapper_fn' is " - "being compiled since it was called" - " from 'M.forward'", - ): - scripted = torch.jit.script(traced) - - def test_snake_case(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.activations = torch.nn.ModuleDict( - [ - ["snake_case", torch.nn.ReLU()], - ["PascalCase", torch.nn.LeakyReLU()], - ["ALL_CAPS", torch.nn.PReLU()], - ] - ) - - def forward(self, x): - a = self.activations["snake_case"](x) - b = self.activations["PascalCase"](x) - c = self.activations["ALL_CAPS"](x) - return a, b, c - - traced = symbolic_trace(M()) - - check = [ - ("activations_snake_case", "activations.snake_case"), - ("activations_pascal_case", "activations.PascalCase"), - ("activations_all_caps", "activations.ALL_CAPS"), - ] - - i = 0 - for node in traced.graph.nodes: - if node.op == "placeholder" or node.op == "output": - continue - name = check[i][0] - target = check[i][1] - self.assertEqual(name, node.name) - self.assertEqual(target, node.target) - i += 1 - self.assertEqual(i, 3) - - def test_no_mutation(self): - from pippy.fx.immutable_collections import immutable_list - - x = immutable_list([3, 4]) - with self.assertRaisesRegex(NotImplementedError, "new_args"): - x[0] = 4 - - def test_partial_trace(self): - class Foo(torch.nn.Module): - def forward(self, x, y): - if y: - return 2 * x - else: - return x - - mod = Foo() - mod_true = symbolic_trace(mod, concrete_args={"y": True}) - mod_false = symbolic_trace(mod, concrete_args={"y": False}) - self.assertEqual(mod_true(3, True), 6) - print(mod_true.code) - assert any([i.target == torch._assert for i in mod_true.graph.nodes]) - with self.assertRaises(AssertionError): - mod_true(3, False) - self.assertEqual(mod_false(3, False), 3) - with self.assertRaises(AssertionError): - mod_false(3, True) - - def f_higher(a, f): - return f(a) - - nf = symbolic_trace(f_higher, concrete_args={"f": lambda x: x * 2}) - self.assertEqual(nf(3, lambda x: x * 2), 6) - - def test_custom_traceback_raised_when_exception_source_is_graphmodule(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.W = torch.nn.Parameter(torch.randn(5)) - - def forward(self, x): - return torch.dot(self.W, x) - - traced = pippy.fx.symbolic_trace(M()) - - out = [n for n in traced.graph.nodes if n.op == "output"][-1] - with traced.graph.inserting_before(out): - relu_out = traced.graph.call_method( - method_name="relu", args=(out.args[0],) - ) - out.args = (relu_out,) - - traced.recompile() - - with self.capture_stderr() as captured: - with self.assertRaises(TypeError): - traced(5) - - self.assertRegex( - captured[0], - r"Call using an FX-traced Module, line .* of the " - r"traced Module's generated forward function:", - ) - - def test_custom_traceback_not_raised_when_exception_source_is_submodule( - self, - ): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 4) - - def forward(self, x): - return self.linear(x) - - traced = pippy.fx.symbolic_trace(M()) - - # Do not change this to `capture_stderr` or another context - # manager without ensuring that the output is as expected - try: - traced(torch.rand(5, 5)) - except RuntimeError: - captured = traceback.format_exc() - - self.assertNotRegex( - captured, - r"Call using an FX-traced Module, line .* of the " - r"traced Module's generated forward function:", - ) - - def test_graph_module_replicate_for_dp(self): - class Foo(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - gm = pippy.fx.symbolic_trace(Foo()) - - x = torch.randn(5, 3) - out = gm(x) - - replica = gm._replicate_for_data_parallel() - out_replica = replica(x) - - torch.testing.assert_allclose(out_replica, out) - - def test_ast_rewriter_rewrites_assert(self): - class M(torch.nn.Module): - def forward(self, x: torch.Tensor, y: int, z: int): - assert y == z - return torch.add(x, x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - traced.graph.lint() - - def test_ast_rewriter_rewrites_assert_with_message(self): - class M(torch.nn.Module): - def forward(self, x: torch.Tensor, y: int, z: int): - assert y == z, "msg" - return torch.add(x, x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - traced.graph.lint() - - def test_throw_out_variant(self): - def foo(x): - y = torch.rand_like(x) - torch.sigmoid(x, out=y) - return y - - class MyTracer(pippy.fx.Tracer): - check_mutable_operations = True - - tracer = MyTracer() - with self.assertRaisesRegex( - RuntimeError, "mutable operation aten::sigmoid.out" - ): - traced_graph = tracer.trace(foo) - - def test_ast_rewriter_reassigns_submodules(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(100) - - def forward(self, x: torch.Tensor): - return torch.add(x, x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - traced.graph.lint() - - def test_ast_rewriter_wrap(self): - self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) - - def to_trace(y): - return ( - a_lifted_leaf((4, y), 3) - + a_lifted_leaf((3, 4), 5) - + a_lifted_leaf((y, y), y) - ) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(to_trace) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("a_lifted_leaf", traced.code) - self.assertEqual(27, traced(2)) - self.assertIs(a_lifted_leaf, real_a_lifed_leaf) - - def test_ast_rewriter_wrap_fn_directly(self): - self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) - - def to_trace(y): - return ( - a_lifted_leaf2((4, y), 3) - + a_lifted_leaf2((3, 4), 5) - + a_lifted_leaf2((y, y), y) - ) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(to_trace) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("a_lifted_leaf2", traced.code) - self.assertEqual(27, traced(2)) - self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) - - def test_profiler_ranges_side_effect(self): - g = pippy.fx.Graph() - handle = g.call_function( - torch.ops.profiler._record_function_enter, ("test_range",) - ) - g.call_function(torch.ops.profiler._record_function_exit, (handle,)) - g.output(None) - - found_targets = {} - for node in g.nodes: - if node.op == "call_function": - found_targets.setdefault(node.target) - self.assertEqual( - list(found_targets.keys()), - [ - torch.ops.profiler._record_function_enter, - torch.ops.profiler._record_function_exit, - ], - ) - - g.eliminate_dead_code() - found_targets = {} - for node in g.nodes: - if node.op == "call_function": - found_targets.setdefault(node.target) - self.assertEqual( - list(found_targets.keys()), - [ - torch.ops.profiler._record_function_enter, - torch.ops.profiler._record_function_exit, - ], - ) - - def test_ast_rewriter_wrapped_via_decorator(self): - class F(torch.nn.Module): - def forward(self, x): - return wrapped_via_decorator(x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(F()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("wrapped_via_decorator", traced.code) - self.assertEqual(traced(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - def test_ast_rewriter_wrapped_via_decorator_and_transformed(self): - self.assertEqual(wrapped_via_decorator(0), 1) - - def to_trace(y): - return wrapped_via_decorator(y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(to_trace) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("wrapped_via_decorator", traced.code) - self.assertEqual(traced(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - transformed = pippy.fx.Transformer(traced).transform() - self.assertIn("wrapped_via_decorator", transformed.code) - self.assertEqual(transformed(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - def test_ast_rewriter_wrap_with_submodule(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) - - def forward(self, x: torch.Tensor): - return wrapped_with_submodule(x, self.batchnorm1d) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("wrapped_with_submodule", traced.code) - - input = torch.rand(3, 2) - ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) - self.assertEqual(ref_batchnorm1d(input), traced(input)) - - def test_submodule_manipulation_API(self): - class C(torch.nn.Module): - def __init__(self): - super(C, self).__init__() - self.conv = torch.nn.Conv2d(16, 33, 3, stride=2) - self.param = torch.nn.Parameter(torch.rand(2, 3)) - - def forward(self, x): - return self.conv(torch.cat([self.param, x])) - - class B(torch.nn.Module): - def __init__(self): - super(B, self).__init__() - self.linear = torch.nn.Linear(100, 200) - self.register_buffer("buf", torch.randn(2, 3)) - self.net_c = C() - - def forward(self, x): - return self.linear(torch.cat([self.buf, self.net_c(x)])) - - class A(torch.nn.Module): - def __init__(self): - super(A, self).__init__() - self.net_b = B() - self.param = torch.nn.Parameter(torch.rand(2, 3)) - - def forward(self, x): - return self.net_b(x) + self.param - - a = symbolic_trace(A()) - - a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2)) - - conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1] - with a.graph.inserting_before(conv): - with warnings.catch_warnings(record=True) as w: - dropout = a.graph.call_module( - module_name="net_b.net_c.dropout", args=conv.args - ) - self.assertEqual(len(w), 0) - - conv.replace_all_uses_with(dropout) - a.graph.erase_node(conv) - a.recompile() - - def module_exists(gm: GraphModule, path: str) -> bool: - return any(path == name for name, _ in gm.named_modules()) - - def parameter_exists(gm: GraphModule, path: str) -> bool: - return any( - path == name for name, _ in gm.named_parameters() - ) and any(path == name for name in gm.state_dict().keys()) - - def buffer_exists(gm: GraphModule, path: str) -> bool: - return any(path == name for name, _ in gm.named_buffers()) and any( - path == name for name in gm.state_dict().keys() - ) - - # Test that we added the "dropout" submodule - self.assertTrue(module_exists(a, "net_b.net_c.dropout")) - - # Test `get_submodule` with an added submodule - self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout")) - - # Test that the "conv" submodule is still there - self.assertTrue(module_exists(a, "net_b.net_c.conv")) - - # Test `get_submodule` with an original module - self.assertIsNotNone(a.get_submodule("net_b.net_c.conv")) - - # Test that the "conv" node is NOT still there - conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"] - self.assertEqual(conv, []) - - a.delete_submodule("net_b.net_c.conv") - - # Test that the "conv" submodule is now gone - self.assertFalse(module_exists(a, "net_b.net_c.conv")) - - # Test `get_submodule` with a deleted submodule - with self.assertRaisesRegex( - AttributeError, "has no attribute " "`conv`" - ): - self.assertIsNone(a.get_submodule("net_b.net_c.conv")) - - # Test `get_attr` warnings - cat = [n for n in a.graph.nodes if n.target == torch.cat][-1] - - with a.graph.inserting_before(cat): - with warnings.catch_warnings(record=True) as w: - param = a.graph.get_attr(qualified_name="net_b.net_c.param") - self.assertEqual(len(w), 0) - - with self.assertWarnsRegex( - UserWarning, - "Attempted to " - "insert a get_attr Node with no " - "underlying reference in the " - "owning GraphModule", - ): - bad_param = a.graph.get_attr(qualified_name="net_b.param") - a.graph.erase_node(bad_param) - - cat.args = (*cat.args, param) - - a.recompile() - - a.graph.lint() - - # Test `get_parameter` - a.get_parameter("net_b.net_c.param") - with self.assertRaisesRegex( - AttributeError, "is not an " "nn.Parameter" - ): - a.get_parameter("net_b.buf") - with self.assertRaisesRegex( - AttributeError, "has no attribute " "`param`" - ): - a.get_parameter("net_b.param") - - # Test `get_buffer` - a.get_buffer("net_b.buf") - with self.assertRaisesRegex(AttributeError, "is not a " "buffer"): - a.get_buffer("net_b.net_c.param") - with self.assertRaisesRegex( - AttributeError, "has no attribute " "`buf`" - ): - a.get_buffer("net_b.net_c.buf") - - # Test non-nested attributes - a.get_submodule("") - a.get_parameter("param") - - # Insert some unused submodules - a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3)) - a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3)) - a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2)) - a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100)) - - # Garbage collection - a.delete_all_unused_submodules() - - # Test that all the unused submodules are gone - self.assertFalse(module_exists(a, "net_b.embedding")) - self.assertFalse(module_exists(a, "net_b.net_c.embedding")) - self.assertFalse(module_exists(a, "net_b.net_c.rnn")) - self.assertFalse(module_exists(a, "batch_norm_2d")) - - # Test that we didn't delete any unused Parameters or buffers - self.assertTrue(parameter_exists(a, "net_b.net_c.param")) - self.assertTrue(buffer_exists(a, "net_b.buf")) - - a.graph.lint() - - def test_delete_unused_submodules_leaf(self): - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 10) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.linear(x) - x = self.relu(x) - return x - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.submod = SubModule() - - def forward(self, x): - x = self.submod(x) - return x - - model = Model() - - class MyCustomTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - return module_qualified_name == "submod" - - inputs = torch.randn(1, 10) - traced_graph = MyCustomTracer().trace(model) - gm2 = pippy.fx.GraphModule(model, traced_graph) - gm2.delete_all_unused_submodules() - torch.testing.assert_allclose(gm2(inputs), model(inputs)) - - def test_fx_stateless(self): - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(1, 1) - self.register_buffer("buffer", torch.ones(1)) - - def forward(self, x): - return self.l1(x) + self.buffer - - module = MockModule() - x = torch.rand((1, 1)) - weight = torch.tensor([[1.0]], requires_grad=True) - bias = torch.tensor([0.0], requires_grad=True) - buffer = torch.tensor([0.0]) - parameters = {"l1.weight": weight, "l1.bias": bias, "buffer": buffer} - fx_module = pippy.fx.symbolic_trace(module) - res = _stateless.functional_call(fx_module, parameters, x) - res.backward() - self.assertIsNotNone(weight.grad) - self.assertIsNotNone(bias.grad) - self.assertIsNone(buffer.grad) - # Gradient was not calculated for the module stated and buffers - self.assertIsNone(module.l1.weight.grad) - self.assertIsNone(module.l1.bias.grad) - self.assertIsNone(module.buffer.grad) - - def test_tracing_graphmodules_as_leaf_submodules(self): - class A(torch.nn.Module): - def forward(self, t): - return t + t - - class B(torch.nn.Module): - def __init__(self): - super(type(self), self).__init__() - self.calling = False - self.called = False - - def forward(self, t): - if self.calling: - return t - t - else: - return t + t - - def __call__(self, *args): - self.called = True - self.calling = True - return super(type(self), self).__call__(*args) - self.calling = False - - class M(torch.nn.Module): - def __init__(self, a, b): - super().__init__() - self.a = a - self.b = b - - def forward(self, t): - x = self.a(t) - y = self.b(t) - return x + y - - class LeafTracer(Tracer): - def is_leaf_module(self, module, name): - return True - - class LeafTracerNotB(Tracer): - def is_leaf_module(self, module, name): - return False if "b" in name else True - - # Recompile calls added "for fun", since they - # chain __call__ wrappers. - - # - # Test: B as a regular, non-leaf module - # - a = symbolic_trace(A()) - a.recompile() - m = M(a, B()) - graph = LeafTracerNotB().trace(m) - gm = GraphModule(m, graph) - gm.recompile() - - # Test graphmodule/submodule a is not inlined. - self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "a" - ] - self.assertTrue(len(match) == 1) - - # Test submodule b is not treated as leaf. - self.assertFalse(hasattr(gm, "b")) - - # Test assert custom __call__ on submodule b was honored. - match = [ - n - for n in gm.graph.nodes - if n.op == "call_function" and n.target == operator.sub - ] - self.assertTrue(len(match) == 1) - - # - # Test: B as a regular, leaf module - # symbolic_trace should only patch torch.nn.Module.__call__, - # which means B.__call__ should still execute - # - a = symbolic_trace(A()) - a.recompile() - b = B() - m = M(a, b) - graph = LeafTracer().trace(m) - gm = GraphModule(m, graph) - gm.recompile() - - # Test graphmodule/submodule a is not inlined. - self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "a" - ] - self.assertTrue(len(match) == 1) - - # Test submodule b is leaf: - self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "b" - ] - self.assertTrue(len(match) == 1) - - # Test b.__call__ was run - self.assertTrue(b.called) - self.assertTrue(gm.get_submodule("b").called) - - # - # Test: B as GraphModule leaf - # __call__ not honored since symbolic_trace directly invokes forward() - # - a = symbolic_trace(A()) - a.recompile() - b = symbolic_trace(B()) - b.recompile() - m = M(a, b) - graph = LeafTracer().trace(m) - gm = GraphModule(m, graph) - gm.recompile() - - self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "a" - ] - self.assertTrue(len(match) == 1) - - self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "b" - ] - self.assertTrue(len(match) == 1) - - def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("my_buff", torch.rand(3, 4)) - self.register_parameter( - "my_param", torch.nn.Parameter(torch.rand(3, 4)) - ) - - def forward(self, x): - return x + self.my_buff + self.my_param - - mod = MyModule() - mod_traced = symbolic_trace(mod) - - # Create new GraphModule based on original, either w/ dict or root module. - orig_buff = mod_traced.get_buffer("my_buff") - orig_param = mod_traced.get_parameter("my_param") - mod_traced_new = GraphModule( - {"my_buff": orig_buff, "my_param": orig_param} - if use_dict_init - else mod, - mod_traced.graph, - ) - - # Check that both my_buff and my_param are found and the same. - try: - new_buff = mod_traced_new.get_buffer("my_buff") - except Exception: - self.fail("Did not find my_buff") - self.assertEqual(orig_buff, new_buff) - - try: - new_param = mod_traced_new.get_parameter("my_param") - except Exception: - self.fail("Did not find my_param") - self.assertEqual(orig_param, new_param) - - x = torch.rand(3, 4) - orig_out = mod_traced(x) - submodules_out = mod_traced_new(x) - - self.assertEqual(orig_out, submodules_out) - - def test_graph_module_init_buffer_param_copied_dict_init(self): - self._test_graph_module_init_buffer_param_copied(use_dict_init=True) - - def test_graph_module_init_buffer_param_copied_mod_init(self): - self._test_graph_module_init_buffer_param_copied(use_dict_init=False) - - def test_annotations_with_no_forward_references(self): - class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - - class M(torch.nn.Module): - def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: - return a(x) - - self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - - def test_annotations_with_forward_references(self): - class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - - class M(torch.nn.Module): - def forward(self, x: "torch.Tensor", a: "A") -> "torch.Tensor": - return a(x) - - self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - - def test_annotations_with_non_torch_reference_and_no_internal_forward_references( - self, - ): - class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - - class M(torch.nn.Module): - def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor: - return a(x[0]) - - self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - - def test_annotations_with_non_torch_reference_and_internal_forward_references( - self, - ): - class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - - class M(torch.nn.Module): - def forward(self, x: List["torch.Tensor"], a: A) -> "torch.Tensor": - return a(x)[0] - - self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - - @unittest.skipIf( - sys.version_info < (3, 7), - "`__future__` feature " "`annotations` is not defined in Python <3.7", - ) - def test_annotation_with_future(self): - try: - import fx.test_future # noqa: F401 - finally: - del sys.modules["__future__"] - - def test_annotations_empty_tuple(self): - class Foo(torch.nn.Module): - def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]): - return "foo" - - traced = pippy.fx.symbolic_trace(Foo()) - - x = () - y = ("bar", ()) - - traced(x, y) - - FileCheck().check("_Tuple[()]").check( - "typing_Tuple[str,typing_Tuple[()]]" - ).run(traced.code) - - scripted = torch.jit.script(traced) - - scripted(x, y) - - FileCheck().check("Tuple[()]").check("Tuple[str, Tuple[()]]").run( - scripted.code - ) - - @unittest.skipIf( - IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108" - ) - @unittest.skipIf( - sys.version_info >= (3, 10), "Does not work on Python-3.10" - ) - def test_assert(self): - def f(x): - assert x > 1 - return x + 1 - - try: - pippy.fx.proxy.TracerBase.trace_asserts = True - traced = symbolic_trace(f) - finally: - pippy.fx.proxy.TracerBase.trace_asserts = False - - self.assertEqual(f(2), traced(2)) - with self.assertRaises(AssertionError): - traced(0) - - def test_pytree(self): - def f_sum(x): - return sum(x) - - def f_sum_dict(x): - out = 0 - for k, v in x.items(): - out += v - return out - - def f_dict_list_map(x): - new_dict = {} - for k, v in x.items(): - new_dict[k] = [i + 1 for i in v] - return new_dict - - def f_dict_add(x): - return x["a"] + sum(x["z"]) - - def f_namedtuple_add(x): - return x.x + x.y - - pytree._register_pytree_node( - Foo, - lambda x: ([x.a, x.b], None), - lambda x, _: Foo(x[0], x[1]), - ) - fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b]) - - def f_custom(x): - return x.a + x.b - - def f_custom_dict(x): - return f_sum_dict(x.a) + x.b - - def f_return_custom(x): - return Foo(x.b, x.a) - - tests = [ - (f_sum, [PH, PH, PH]), - (f_sum, []), - (f_sum_dict, {"a": PH, "b": PH, "c": PH}), - (f_dict_list_map, {"a": (PH, PH), "b": [PH], "c": []}), - (f_dict_list_map, {5: (PH, PH, PH)}), - (f_dict_add, {"a": PH, "z": (PH, PH, PH)}), - (f_dict_add, {"a": PH, "z": []}), - (f_custom, Foo(PH, PH)), - (f_custom, Foo(PH, 3)), - (f_custom_dict, Foo({"a": PH, "b": PH}, PH)), - # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees - (f_namedtuple_add, Point(PH, PH)), - ] - - def verify_pytree(f, inp): - val = pytree.tree_map( - lambda x: torch.randn(3) if x == PH else x, inp - ) - num_flat_args = len([i == PH for i in pytree.tree_flatten(inp)[0]]) - orig_out = f(val) - nf = symbolic_trace(f, concrete_args={"x": inp}) - self.assertEqual(nf(val), orig_out) - - bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) - bare_fx.graph.set_codegen(CodeGen()) - bare_fx.recompile() - self.assertEqual( - nf.graph.process_outputs( - bare_fx(*nf.graph.process_inputs(val)) - ), - orig_out, - ) - - assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert ( - sum([i.op == "placeholder" for i in nf.graph.nodes]) - == num_flat_args - ) - - nf = symbolic_trace(nf) - self.assertEqual(nf(val), orig_out) - assert "tree_flatten_spec" not in nf.code - assert sum([i.op == "placeholder" for i in nf.graph.nodes]) == 1 - - nf = symbolic_trace(nf, concrete_args={"x": inp}) - self.assertEqual(nf(val), orig_out) - assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert ( - sum([i.op == "placeholder" for i in nf.graph.nodes]) - == num_flat_args - ) - - pickled = pickle.dumps(nf) - nf = pickle.loads(pickled) - self.assertEqual(nf(val), orig_out) - - for f, inp in tests: - verify_pytree(f, inp) - - def test_pytree_concrete(self): - def f(b, a): - if b: - return a["a"] - else: - return a["z"] - - inp = {"a": {"a": PH, "z": PH}, "b": True} - nf = symbolic_trace(f, concrete_args=inp) - val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp) - self.assertEqual(nf(**val), f(**val)) - - nf = symbolic_trace(nf) - self.assertEqual(nf(**val), f(**val)) - - def test_custom_codegen(self): - class ListCodeGen(CodeGen): - def gen_fn_def(self, free_vars, maybe_return_annotation): - lst_unpack = f""" -def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: - {', '.join(free_vars)} = args_list""" - return lst_unpack - - def additional_globals(self): - return [("List", typing.List)] - - def process_inputs(self, *inputs): - assert len(inputs) == 1 - return inputs[0] - - def f(a, b): - return a + b - - nf = symbolic_trace(f) - vals = [torch.randn(3), torch.randn(3)] - self.assertEqual(nf(*vals), f(*vals)) - - nf.graph.set_codegen(ListCodeGen()) - nf.recompile() - - bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) - bare_fx.graph.set_codegen(CodeGen()) - bare_fx.recompile() - - self.assertEqual(nf(vals), f(*vals)) - self.assertEqual( - nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), - f(*vals), - ) - - ts_f = torch.jit.script(nf) - self.assertEqual(nf(vals), ts_f(vals)) - - def test_custom_codegen_with_transformer(self): - class ListCodeGen(CodeGen): - def gen_fn_def(self, free_vars, maybe_return_annotation): - lst_unpack = f""" -def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: - {', '.join(free_vars)} = args_list""" - return lst_unpack - - def additional_globals(self): - return [("List", typing.List)] - - def process_inputs(self, *inputs): - assert len(inputs) == 1 - return inputs[0] - - def f(a, b): - return a + b - - nf = symbolic_trace(f) - vals = [torch.randn(3), torch.randn(3)] - self.assertEqual(nf(*vals), f(*vals)) - - nf.graph.set_codegen(ListCodeGen()) - nf.recompile() - self.assertEqual(nf(vals), f(*vals)) - - transformed_gm = Transformer(nf).transform() - self.assertEqual(nf(vals), transformed_gm(vals)) - - def test_interpreter_with_codegen(self): - class ListCodeGen(CodeGen): - def gen_fn_def(self, free_vars, maybe_return_annotation): - lst_unpack = f""" -def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: - {', '.join(free_vars)} = args_list""" - return lst_unpack - - def additional_globals(self): - return [("List", typing.List)] - - def process_inputs(self, *inputs): - assert len(inputs) == 1 - return inputs[0] - - def generate_output(self, output_args): - return f"return list({repr(output_args)})" - - def process_outputs(self, outputs): - return list(outputs) - - def f(a, b): - a = a + b - b = a + b - return a, b - - nf = symbolic_trace(f) - vals = [torch.randn(3), torch.randn(3)] - nf.graph.set_codegen(ListCodeGen()) - nf.recompile() - self.assertEqual(Interpreter(nf).run(vals), nf(vals)) - - def test_imul_code_print(self): - graph = pippy.fx.Graph() - a = graph.placeholder("a") - b = graph.placeholder("b") - graph.call_function(operator.imul, (a, b), {}) - graph.output(a) - gm = pippy.fx.GraphModule({}, graph) - gm.recompile() - self.assertEqual(gm(2, 3), 6) - self.assertIn("a *= b", gm.code) - - def test_deepcopy_tracer(self): - def fn(x, y): - return (x + y).relu().sin() - - tracer = Tracer() - tracer_before = copy.deepcopy(tracer) - tracer.trace(fn) - tracer_after = copy.deepcopy(tracer) - - self.assertEqual(str(tracer.graph), str(tracer_after.graph)) - self.assertTrue( - not hasattr(tracer_before, "graph") - or str(tracer.graph) != str(tracer_before.graph) - ) - - -def run_getitem_target(): - from pippy.fx._symbolic_trace import _wrapped_methods_to_patch - - _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) - try: - TestFX().getitem_inner() - finally: - _wrapped_methods_to_patch.pop() - - -class TestOperatorSignatures(JitTestCase): - def setUp(self): - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - @onlyCPU - @ops(op_db, allowed_dtypes=(torch.float,)) - def test_get_torch_func_signature_exhaustive(self, device, dtype, op): - if not isinstance(op.op, types.BuiltinFunctionType): - raise unittest.SkipTest( - "This path doesn't work on Python functions" - ) - sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) - schemas = get_signature_for_torch_op(op.op) - if not schemas: - raise RuntimeError("No Schemas Returned") - for sample_input in sample_inputs_itr: - # Iterate through overloads until we hit a match. If we exit this - # loop via `else`, we haven't found a match - for schema in schemas: - try: - bound_args = schema.bind( - sample_input.input, - *sample_input.args, - **sample_input.kwargs, - ) - bound_args.apply_defaults() - op(*bound_args.args, **bound_args.kwargs) - break - except TypeError as e: - pass - else: - raise RuntimeError( - f"Did not match any schemas for op {op.name}!" - ) - - -class TestFXAPIBackwardCompatibility(JitTestCase): - def setUp(self): - super().setUp() - self.maxDiff = None - - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - def _fn_to_stable_annotation_str(self, obj): - """ - Unfortunately we have to serialize function signatures manually since - serialization for `inspect.Signature` objects is not stable across - python versions - """ - fn_name = torch.typename(obj) - - signature = inspect.signature(obj) - - sig_str = f"{fn_name}{signature}" - - arg_strs = [] - for k, v in signature.parameters.items(): - maybe_type_annotation = ( - f": {self._annotation_type_to_stable_str(v.annotation, sig_str)}" - if v.annotation is not inspect.Signature.empty - else "" - ) - - def default_val_str(val): - if isinstance(val, (tuple, list)): - str_pieces = ["(" if isinstance(val, tuple) else "["] - str_pieces.append( - ", ".join(default_val_str(v) for v in val) - ) - if isinstance(val, tuple) and len(str_pieces) == 2: - str_pieces.append(",") - str_pieces.append(")" if isinstance(val, tuple) else "]") - return "".join(str_pieces) - - # Need to fix up some default value strings. - # First case: modules. Default module `repr` contains the FS path of the module. - # Don't leak that - if isinstance(val, types.ModuleType): - return f"" - - # Second case: callables. Callables (such as lambdas) encode their address in - # their string repr. Don't do that - if callable(val): - return f"" - - return str(val) - - if v.default is not inspect.Signature.empty: - default_val_str = ( - default_val_str(v.default) - if not isinstance(v.default, str) - else f"'{v.default}'" - ) - maybe_default = f" = {default_val_str}" - else: - maybe_default = "" - maybe_stars = "" - if v.kind == inspect.Parameter.VAR_POSITIONAL: - maybe_stars = "*" - elif v.kind == inspect.Parameter.VAR_KEYWORD: - maybe_stars = "**" - arg_strs.append( - f"{maybe_stars}{k}{maybe_type_annotation}{maybe_default}" - ) - - return_annot = ( - f" -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}" - if signature.return_annotation is not inspect.Signature.empty - else "" - ) - - return f'{fn_name}({", ".join(arg_strs)}){return_annot}' - - def _annotation_type_to_stable_str(self, t, sig_str): - if t is inspect.Signature.empty: - return "" - - # Forward ref - if isinstance(t, str): - return f"'{t}'" - if hasattr(typing, "ForwardRef") and isinstance(t, typing.ForwardRef): - return t.__forward_arg__ - if hasattr(typing, "_ForwardRef") and isinstance(t, typing._ForwardRef): - return t.__forward_arg__ - - trivial_mappings = { - str: "str", - int: "int", - float: "float", - bool: "bool", - torch.dtype: "torch.dtype", - torch.Tensor: "torch.Tensor", - torch.device: "torch.device", - torch.memory_format: "torch.memory_format", - slice: "slice", - torch.nn.Module: "torch.nn.modules.module.Module", - pippy.fx.Graph: "pippy.fx.graph.Graph", - pippy.fx.Node: "pippy.fx.node.Node", - pippy.fx.Proxy: "pippy.fx.proxy.Proxy", - pippy.fx.node.Target: "pippy.fx.node.Target", - pippy.fx.node.Argument: "pippy.fx.node.Argument", - pippy.fx.graph.PythonCode: "pippy.fx.graph.PythonCode", - pippy.fx.graph_module.GraphModule: "pippy.fx.graph_module.GraphModule", - pippy.fx.subgraph_rewriter.Match: "pippy.fx.subgraph_rewriter.Match", - Ellipsis: "...", - typing.Any: "Any", - type(None): "NoneType", - None: "None", - typing.Iterator: "Iterator", - } - - mapping = trivial_mappings.get(t, None) - if mapping: - return mapping - - # Handle types with contained types - contained = getattr(t, "__args__", None) or [] - - # Callables contain a bare List for arguments - contained = t if isinstance(t, list) else contained - - # Python 3.8 puts type vars into __args__ for unbound types such as Dict - if all(isinstance(ct, typing.TypeVar) for ct in contained): - contained = [] - - contained_type_annots = [ - self._annotation_type_to_stable_str(ct, sig_str) for ct in contained - ] - contained_type_str = ( - f'[{", ".join(contained_type_annots)}]' - if len(contained_type_annots) > 0 - else "" - ) - - origin = getattr(t, "__origin__", None) - if origin is None: - # Unbound types don't have `__origin__` in some Python versions, so fix that up here. - origin = ( - t - if t - in { - typing.Tuple, - typing.Union, - typing.Dict, - typing.List, - typing.Type, - typing.Callable, - } - else origin - ) - - if origin in {tuple, typing.Tuple}: - return f"Tuple{contained_type_str}" - if origin in {typing.Union}: - # Annoying hack to detect Optional - if len(contained) == 2 and (contained[0] is type(None)) ^ ( - contained[1] is type(None) - ): - not_none_param = ( - contained[0] - if contained[0] is not type(None) - else contained[1] - ) - return f"Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]" - return f"Union{contained_type_str}" - if origin in {dict, typing.Dict}: - return f"Dict{contained_type_str}" - if origin in {list, typing.List}: - return f"List{contained_type_str}" - if origin in {type, typing.Type}: - return f"Type{contained_type_str}" - if isinstance(t, typing.Callable): - if len(contained) > 0 and contained[0] is not Ellipsis: - return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]' - else: - return f"Callable{contained_type_str}" - - raise RuntimeError( - f"Unrecognized type {t} used in BC-compatible type signature {sig_str}." - f"Please add support for this type and confirm with the " - f"FX team that your signature change is valid." - ) - - def test_function_back_compat(self): - """ - Test backward compatibility for function signatures with - @compatibility(is_backward_compatible=True). Currently this checks for - exact signature matches, which may lead to false positives. If this - becomes too annoying, we can refine this check to actually parse out - the saved schema strings and check if the change is truly backward- - incompatible. - """ - signature_strs = [] - - for obj in _BACK_COMPAT_OBJECTS: - if not isinstance(obj, type): - signature_strs.append(self._fn_to_stable_annotation_str(obj)) - - signature_strs.sort() - - try: - self.assertExpected( - "\n".join(signature_strs) + "\n", - "fx_backcompat_function_signatures", - ) - except AssertionError as e: - msg = ( - f"{e}\n****** ERROR ******\nAn FX function that has been marked " - f"as backwards-compatible has experienced a signature change. See the " - f"above exception context for more information. If this change was " - f"unintended, please revert it. If it was intended, check with the FX " - f"team to ensure that the proper deprecation protocols have been followed " - f"and subsequently --accept the change." - ) - raise AssertionError(msg) - - def test_class_member_back_compat(self): - """ - Test backward compatibility for members of classes with - @compatibility(is_backward_compatible=True). Currently this checks for - exact matches on the publicly visible members of the class. - """ - class_method_strs = [] - - for obj in _BACK_COMPAT_OBJECTS: - if isinstance(obj, type): - public_members = [ - name for name in obj.__dict__ if not name.startswith("_") - ] - class_method_strs.append( - f"{torch.typename(obj)} {sorted(public_members)}" - ) - - class_method_strs.sort() - - try: - self.assertExpected( - "\n".join(class_method_strs), "fx_backcompat_class_members" - ) - except AssertionError as e: - msg = ( - f"{e}\n****** ERROR ******\nAn FX class that has been marked " - f"as backwards-compatible has experienced change in its public members. See the " - f"above exception context for more information. If this change was " - f"unintended, please revert it. If it was intended, check with the FX " - f"team to ensure that the proper deprecation protocols have been followed " - f"and subsequently --accept the change." - ) - raise AssertionError(msg) - - def test_public_api_surface(self): - non_back_compat_objects = {} - - def check_symbols_have_bc_designation(m, prefix): - if not m.__name__.startswith("pippy.fx"): - return - if m.__name__.startswith("pippy.fx.experimental"): - return - for k, v in m.__dict__.items(): - if v is m: - continue - if k.startswith("_"): - continue - if isinstance(v, types.ModuleType): - check_symbols_have_bc_designation(v, prefix + [k]) - elif isinstance(v, type) or isinstance(v, types.FunctionType): - if v not in _MARKED_WITH_COMATIBLITY: - non_back_compat_objects.setdefault(v) - - check_symbols_have_bc_designation(pippy.fx, ["torch", "fx"]) - check_symbols_have_bc_designation( - pippy.fx.passes, ["torch", "fx", "passes"] - ) - - non_back_compat_strs = [ - torch.typename(obj) for obj in non_back_compat_objects.keys() - ] - # Only want objects in pippy.fx - non_back_compat_strs = [ - s - for s in non_back_compat_strs - if s.startswith("pippy.fx") - and not s.startswith("pippy.fx.experimental") - ] - # Only want objects in public namespaces - non_back_compat_strs = [ - s - for s in non_back_compat_strs - if all(not atom.startswith("_") for atom in s.split(".")) - ] - non_back_compat_strs.sort() - - if len(non_back_compat_strs) != 0: - raise AssertionError( - f"Public FX API(s) {non_back_compat_strs} introduced but not given a " - f"backwards-compatibility classification! Please decorate these " - f"API(s) with `@pippy.fx._compatibility.compatibility` to specify " - f"BC guarantees." - ) - - -class TestFunctionalTracing(JitTestCase): - def setUp(self): - super().setUp() - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - IGNORE_FUNCS = ( - "has_torch_function", - "has_torch_function_unary", - "has_torch_function_variadic", - "handle_torch_function", - "boolean_dispatch", - ) - TO_PATCH = { - "has_torch_function": None, - "has_torch_function_unary": None, - "has_torch_function_variadic": None, - } - - BUILT_IN_FUNC = (AssertionError, "") - PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable") - PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") - LEN_ERROR = ( - RuntimeError, - r"'len' is not supported in symbolic tracing by default", - ) - ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$") - CONTROL_FLOW = ( - TraceError, - r"symbolically traced variables cannot be used as inputs to control flow", - ) - INTERPOLATE_ARGS_CONFLICT = ( - ValueError, - r"only one of size or scale_factor should be defined", - ) - MUTABLE = (RuntimeError, r"Tried to trace mutable operation") - - UNTRACEABLE_FUNCTIONALS = { - "adaptive_avg_pool1d": BUILT_IN_FUNC, - "avg_pool1d": BUILT_IN_FUNC, - "avg_pool2d": BUILT_IN_FUNC, - "avg_pool3d": BUILT_IN_FUNC, - "bilinear": BUILT_IN_FUNC, - "celu_": BUILT_IN_FUNC, - "channel_shuffle": BUILT_IN_FUNC, - "native_channel_shuffle": BUILT_IN_FUNC, - "conv1d": BUILT_IN_FUNC, - "conv2d": BUILT_IN_FUNC, - "conv3d": BUILT_IN_FUNC, - "conv_tbc": BUILT_IN_FUNC, - "conv_transpose1d": BUILT_IN_FUNC, - "conv_transpose2d": BUILT_IN_FUNC, - "conv_transpose3d": BUILT_IN_FUNC, - "cosine_similarity": BUILT_IN_FUNC, - "elu_": BUILT_IN_FUNC, - "gelu": BUILT_IN_FUNC, - "hardshrink": BUILT_IN_FUNC, - "hardtanh_": BUILT_IN_FUNC, - "leaky_relu_": BUILT_IN_FUNC, - "linear": BUILT_IN_FUNC, - "logsigmoid": BUILT_IN_FUNC, - "one_hot": BUILT_IN_FUNC, - "pad": BUILT_IN_FUNC, - "pairwise_distance": BUILT_IN_FUNC, - "pdist": BUILT_IN_FUNC, - "pixel_shuffle": BUILT_IN_FUNC, - "pixel_unshuffle": BUILT_IN_FUNC, - "prelu": BUILT_IN_FUNC, - "relu_": BUILT_IN_FUNC, - "rrelu_": BUILT_IN_FUNC, - "selu_": BUILT_IN_FUNC, - "softplus": BUILT_IN_FUNC, - "softshrink": BUILT_IN_FUNC, - "threshold_": BUILT_IN_FUNC, - "adaptive_avg_pool2d": LEN_ERROR, - "adaptive_avg_pool3d": LEN_ERROR, - "adaptive_max_pool2d_with_indices": LEN_ERROR, - "adaptive_max_pool3d_with_indices": LEN_ERROR, - "instance_norm": CONTROL_FLOW, - "adaptive_max_pool1d": PROXY_ITERABLE, - "adaptive_max_pool2d": PROXY_ITERABLE, - "adaptive_max_pool3d": PROXY_ITERABLE, - "fractional_max_pool2d": PROXY_ITERABLE, - "fractional_max_pool3d": PROXY_ITERABLE, - "max_pool1d": PROXY_ITERABLE, - "max_pool2d": PROXY_ITERABLE, - "max_pool3d": PROXY_ITERABLE, - "group_norm": PROXY_ITERATED, - "lp_pool2d": PROXY_ITERATED, - "max_unpool1d": PROXY_ITERATED, - "max_unpool2d": PROXY_ITERATED, - "max_unpool3d": PROXY_ITERATED, - "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, - "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, - "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, - "layer_norm": ARG_TYPE_MISMATCH, - "lp_pool1d": ARG_TYPE_MISMATCH, - "affine_grid": CONTROL_FLOW, - "alpha_dropout": CONTROL_FLOW, - "batch_norm": CONTROL_FLOW, - "binary_cross_entropy": CONTROL_FLOW, - "binary_cross_entropy_with_logits": CONTROL_FLOW, - "celu": CONTROL_FLOW, - "cosine_embedding_loss": CONTROL_FLOW, - "cross_entropy": CONTROL_FLOW, - "ctc_loss": CONTROL_FLOW, - "dropout": CONTROL_FLOW, - "dropout1d": CONTROL_FLOW, - "dropout2d": CONTROL_FLOW, - "dropout3d": CONTROL_FLOW, - "elu": CONTROL_FLOW, - "embedding": CONTROL_FLOW, - "embedding_bag": CONTROL_FLOW, - "feature_alpha_dropout": CONTROL_FLOW, - "fold": CONTROL_FLOW, - "gaussian_nll_loss": CONTROL_FLOW, - "glu": CONTROL_FLOW, - "grid_sample": CONTROL_FLOW, - "gumbel_softmax": CONTROL_FLOW, - "hardsigmoid": CONTROL_FLOW, - "hardswish": CONTROL_FLOW, - "hardtanh": CONTROL_FLOW, - "hinge_embedding_loss": CONTROL_FLOW, - "huber_loss": CONTROL_FLOW, - "interpolate": CONTROL_FLOW, - "kl_div": CONTROL_FLOW, - "l1_loss": CONTROL_FLOW, - "leaky_relu": CONTROL_FLOW, - "local_response_norm": CONTROL_FLOW, - "margin_ranking_loss": CONTROL_FLOW, - "max_pool1d_with_indices": ARG_TYPE_MISMATCH, - "max_pool2d_with_indices": ARG_TYPE_MISMATCH, - "max_pool3d_with_indices": ARG_TYPE_MISMATCH, - "mse_loss": CONTROL_FLOW, - "multi_head_attention_forward": CONTROL_FLOW, - "multi_margin_loss": CONTROL_FLOW, - "multilabel_margin_loss": CONTROL_FLOW, - "multilabel_soft_margin_loss": CONTROL_FLOW, - "nll_loss": CONTROL_FLOW, - "poisson_nll_loss": CONTROL_FLOW, - "relu": CONTROL_FLOW, - "relu6": CONTROL_FLOW, - "rrelu": CONTROL_FLOW, - "selu": CONTROL_FLOW, - "silu": CONTROL_FLOW, - "mish": CONTROL_FLOW, - "smooth_l1_loss": CONTROL_FLOW, - "soft_margin_loss": CONTROL_FLOW, - "threshold": CONTROL_FLOW, - "triplet_margin_loss": CONTROL_FLOW, - "triplet_margin_with_distance_loss": CONTROL_FLOW, - "unfold": CONTROL_FLOW, - "upsample": CONTROL_FLOW, - "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, - "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, - } - - # List of nn.functionals with Tensor inputs but not with type annotation - FUNCTIONALS_WITHOUT_ANNOTATION = ( - "adaptive_max_pool1d", - "adaptive_max_pool2d", - "adaptive_max_pool3d", - "fractional_max_pool2d", - "fractional_max_pool3d", - "max_pool1d", - "max_pool2d", - "max_pool3d", - "gaussian_nll_loss", - "upsample", - "upsample_bilinear", - "upsample_nearest", - ) - - # Inconsistent behavior between Python 3.8 and other Python versions: - # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED` - # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same - # internal exception above - # Use the following map to override the expected exception for Python 3.8 - UNTRACEABLE_FUNCTIONALS_PY38 = { - "adaptive_max_pool1d": PROXY_ITERATED, - "adaptive_max_pool2d": PROXY_ITERATED, - "adaptive_max_pool3d": PROXY_ITERATED, - "fractional_max_pool2d": PROXY_ITERATED, - "fractional_max_pool3d": PROXY_ITERATED, - "max_pool1d": PROXY_ITERATED, - "max_pool2d": PROXY_ITERATED, - "max_pool3d": PROXY_ITERATED, - "group_norm": LEN_ERROR, - } - - @classmethod - def _get_functional(cls): - functional_list = [] - for f in dir(torch.nn.functional): - if not f.islower(): - continue - # Ignore internal functions - if f.startswith("_"): - continue - # Ignore supporting functions - if f in cls.IGNORE_FUNCS: - continue - fn = getattr(torch.nn.functional, f) - # Ignore non-callable object like modules - if not isinstance(fn, Callable): - continue - if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION: - try: - sig = inspect.signature(fn) - has_tensor_arg = False - for arg, param in sig.parameters.items(): - if isinstance(param.annotation, type) and issubclass( - param.annotation, torch.Tensor - ): - has_tensor_arg = True - if not has_tensor_arg: - continue - # No signature or Object is not supported - except ValueError: - pass - functional_list.append((f, fn)) - return functional_list - - @classmethod - def generate_test_func(cls, func_name, fn): - def functional_test(self): - if ( - func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 - and sys.version_info >= (3, 8) - and sys.version_info < (3, 11) - ): - exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name] - with self.assertRaisesRegex(exc, err): - symbolic_trace(fn) - elif func_name in self.UNTRACEABLE_FUNCTIONALS: - exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name] - with self.assertRaisesRegex(exc, err): - symbolic_trace(fn) - else: - symbolic_trace(fn) - - return functional_test - - @classmethod - def generate_tests(cls): - functional_list = cls._get_functional() - for func_name, fn in functional_list: - test_name = "test_nn_functional_" + func_name - functional_test = cls.generate_test_func(func_name, fn) - setattr(cls, test_name, functional_test) - - @classmethod - def setUpClass(cls): - def no(*args, **kwargs): - return False - - for name in cls.TO_PATCH.keys(): - cls.TO_PATCH[name] = getattr(torch.nn.functional, name) - setattr(torch.nn.functional, name, no) - - @classmethod - def tearDownClass(cls): - for name in cls.TO_PATCH.keys(): - setattr(torch.nn.functional, name, cls.TO_PATCH[name]) - - -TestFunctionalTracing.generate_tests() - - -instantiate_device_type_tests(TestOperatorSignatures, globals()) - - -@skipIfNoTorchVision -@skipIfSlowGradcheckEnv -class TestVisionTracing(JitTestCase): - def setUp(self): - # Checking for mutable operations while tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") - INCONSISTENT_TYPE = ( - RuntimeError, - r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor", - ) - - UNTRACEABLE_MODELS = { - "fasterrcnn_resnet50_fpn": PROXY_ITERATED, - "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED, - "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED, - "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED, - "maskrcnn_resnet50_fpn": PROXY_ITERATED, - "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED, - "keypointrcnn_resnet50_fpn": PROXY_ITERATED, - "retinanet_resnet50_fpn": PROXY_ITERATED, - "retinanet_resnet50_fpn_v2": PROXY_ITERATED, - "ssd300_vgg16": PROXY_ITERATED, - "fcos_resnet50_fpn": PROXY_ITERATED, - "ssdlite320_mobilenet_v3_large": PROXY_ITERATED, - } - UNSCRIPTABLE_MODELS = { - "googlenet": INCONSISTENT_TYPE, - "inception_v3": INCONSISTENT_TYPE, - } - - output_transform = { - "fcn_resnet50": lambda x: x["out"], - "fcn_resnet101": lambda x: x["out"], - "deeplabv3_resnet50": lambda x: x["out"], - "deeplabv3_resnet101": lambda x: x["out"], - "deeplabv3_mobilenet_v3_large": lambda x: x["out"], - "lraspp_mobilenet_v3_large": lambda x: x["out"], - "fasterrcnn_resnet50_fpn": lambda x: x[1], - "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], - "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], - "maskrcnn_resnet50_fpn": lambda x: x[1], - "keypointrcnn_resnet50_fpn": lambda x: x[1], - "retinanet_resnet50_fpn": lambda x: x[1], - } - - @classmethod - def generate_test_fn(cls, name, x, kwargs): - def run_test(self): - model = torchvision_models.get_model(name, **kwargs) - model = model.eval() - if name in self.UNTRACEABLE_MODELS: - err, exc = self.UNTRACEABLE_MODELS[name] - with self.assertRaisesRegex(err, exc): - graph = symbolic_trace(model) - else: - out_transform = self.output_transform.get(name, lambda x: x) - graph: pippy.fx.GraphModule = symbolic_trace(model) - a = out_transform(model(x)) - b = out_transform(graph(x)) - self.assertEqual(a, b) - - if name in self.UNSCRIPTABLE_MODELS: - err, exc = self.UNSCRIPTABLE_MODELS[name] - with self.assertRaisesRegex(err, exc): - script = torch.jit.script(graph) - else: - script = torch.jit.script(graph) - c = out_transform(script(x)) - self.assertEqual(a, c) - - return run_test - - @classmethod - def generate_classification_tests(cls): - for k in torchvision_models.list_models(module=torchvision_models): - test_name = "test_torchvision_models_" + k - x = ( - torch.rand(1, 3, 299, 299) - if k in ["inception_v3"] - else torch.rand(1, 3, 224, 224) - ) - kwargs = dict(num_classes=50) - model_test = cls.generate_test_fn(k, x, kwargs) - setattr(cls, test_name, model_test) - - @classmethod - def generate_segmentation_tests(cls): - for k in torchvision_models.list_models( - module=torchvision_models.segmentation - ): - test_name = "test_torchvision_models_segmentation_" + k - x = torch.rand(1, 3, 32, 32) - kwargs = dict(num_classes=10, pretrained_backbone=False) - model_test = cls.generate_test_fn(k, x, kwargs) - setattr(cls, test_name, model_test) - - @classmethod - def generate_detection_tests(cls): - for k in torchvision_models.list_models( - module=torchvision_models.detection - ): - test_name = "test_torchvision_models_detection_" + k - x = [torch.rand(3, 300, 300)] - kwargs = dict(num_classes=10, pretrained_backbone=False) - model_test = cls.generate_test_fn(k, x, kwargs) - setattr(cls, test_name, model_test) - - @classmethod - def generate_video_tests(cls): - for k in torchvision_models.list_models( - module=torchvision_models.video - ): - test_name = "test_torchvision_models_video_" + k - x = ( - torch.rand(1, 3, 4, 112, 112) - if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"} - else torch.rand(1, 3, 16, 224, 224) - ) - kwargs = dict(num_classes=50) - model_test = cls.generate_test_fn(k, x, kwargs) - setattr(cls, test_name, model_test) - - @classmethod - def generate_tests(cls): - cls.generate_classification_tests() - cls.generate_detection_tests() - cls.generate_segmentation_tests() - cls.generate_video_tests() - - -if HAS_TORCHVISION: - TestVisionTracing.generate_tests() - -if __name__ == "__main__": - run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py deleted file mode 100644 index 0ba3f2b8e..000000000 --- a/test/test_fx_experimental.py +++ /dev/null @@ -1,1717 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import math -import numbers -import operator -import pickle -import sys -import tempfile -import unittest -from types import BuiltinFunctionType -from typing import Callable, Dict, List, Optional, Union - -import pippy.fx.experimental.meta_tracer -import pippy.fx.experimental.optimization as optimization - -import torch -from pippy.fx._symbolic_trace import symbolic_trace -from pippy.fx.experimental import merge_matmul -from pippy.fx.experimental.accelerator_partitioner import Partitioner -from pippy.fx.experimental.normalize import NormalizeArgs, NormalizeOperators -from pippy.fx.experimental.partitioner_utils import ( - Device, - get_latency_of_partitioned_graph, - get_partition_to_latency_mapping, - NodeLatency, - PartitionerConfig, - PartitionMode, -) -from pippy.fx.experimental.rewriter import RewritingTracer -from pippy.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema -from pippy.fx.graph_module import GraphModule -from pippy.fx.node import Node -from pippy.fx.operator_schemas import ( - _torchscript_type_to_python_type, - create_type_hint, - normalize_function, - normalize_module, - type_matches, -) -from pippy.fx.passes import graph_manipulation -from pippy.fx.passes.param_fetch import lift_lowering_attrs_to_nodes -from pippy.fx.passes.shape_prop import ShapeProp -from pippy.fx.passes.split_module import split_module -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, - onlyCPU, - ops, -) -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_nn import module_tests, new_module_tests -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.jit_utils import JitTestCase - -try: - import torchvision.models - from torchvision.models import resnet18 - - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") -skipIfNoMkldnn = unittest.skipIf( - not ( - torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available() - ), - "no MKLDNN", -) - - -def symbolic_trace_with_rewrite( - root: Union[torch.nn.Module, Callable] -) -> GraphModule: - return GraphModule( - root if isinstance(root, torch.nn.Module) else torch.nn.Module(), - RewritingTracer().trace(root), - ) - - -class TestFXExperimental(JitTestCase): - def test_find_single_partition(self): - class TestModule(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(1) - b = torch.rand(1) - graph_manipulation.get_size_of_all_nodes(traced, [a, b]) - partitioner = Partitioner() - devices = [ - Device("dev_0", 125, 0), - Device("dev_1", 150, 1), - Device("dev_2", 125, 2), - ] - partitioner_config = PartitionerConfig(devices) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(traced(a, b), module_with_submodules(a, b)) - assert dag.nodes[0].logical_device_ids == [1] - - def test_lack_of_devices(self): - class TestModule(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - b = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a, b]) - partitioner = Partitioner() - devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] - partitioner_config = PartitionerConfig( - devices, PartitionMode.size_based - ) - catch_runtime_error = False - try: - ret = partitioner.partition_graph(traced, m, partitioner_config) - except RuntimeError: - catch_runtime_error = True - assert catch_runtime_error - - def test_large_node_error(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - linear = self.linear(a) - add = linear + a - return add - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - partitioner = Partitioner() - devices = [ - Device("dev_0", 40, 0), - Device("dev_1", 40, 0), - Device("dev_2", 40, 0), - Device("dev_3", 40, 0), - Device("dev_4", 40, 0), - ] - partitioner_config = PartitionerConfig( - devices, PartitionMode.size_based - ) - catch_runtime_error = False - try: - ret = partitioner.partition_graph(traced, m, partitioner_config) - except RuntimeError: - catch_runtime_error = True - assert catch_runtime_error - - def test_partition_node_manipulation(self): - class TestModule(torch.nn.Module): - def forward(self, a, b): - add_1 = a + b - add_2 = add_1 + torch.rand(4) - add_3 = add_2 + torch.rand(4) - return add_3 - - m = TestModule() - traced = symbolic_trace(m) - a, b = torch.rand(4), torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a, b]) - partitioner = Partitioner() - devices = [Device("dev_0", 1000, 0)] - partitioner_config = PartitionerConfig(devices) - ret = partitioner.partition_graph(traced, m, partitioner_config) - partition = partitioner.partitions[0] - assert partition.used_mem_bytes == 112 - # Select add_2 node to remove - selected_node = None - for node in partition.nodes: - if node.name == "add_2": - selected_node = node - partition.remove_node(selected_node) - assert partition.used_mem_bytes == 80 - - def test_size_based_partition(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - self.c = torch.rand(4) - - def forward(self, a, b): - add_1 = a + b - linear = self.linear(add_1) - add_2 = linear + self.c - return add_2 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - b = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a, b]) - partitioner = Partitioner() - devices = [ - Device("dev_0", 125, 0), - Device("dev_1", 125, 1), - Device("dev_2", 125, 2), - ] - partitioner_config = PartitionerConfig( - devices, PartitionMode.size_based - ) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(traced(a, b), module_with_submodules(a, b)) - for i, node in enumerate(dag.nodes): - assert node.logical_device_ids == [i] - - def test_partition_device_mapping(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - b = torch.rand(4) - add_1 = a + b - linear_1 = self.linear(add_1) - add_2 = torch.rand(4) + a - add_3 = add_2 + linear_1 - return add_3 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - partitioner = Partitioner() - devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] - partitioner_config = PartitionerConfig( - devices, PartitionMode.size_based - ) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(traced(a), module_with_submodules(a)) - for i, node in enumerate(dag.nodes): - if i == 1: - assert node.logical_device_ids == [1] - else: - assert node.logical_device_ids == [0] - - def test_sparse_nn_partition(self): - class MyRecommendationModule(torch.nn.Module): - def create_mlp( - self, num_of_layers: int, input_size: int, output_size: int - ): - layers = torch.nn.ModuleList() - for _ in range(num_of_layers): - ll = torch.nn.Linear(input_size, output_size) - layers.append(ll) - layers.append(torch.nn.ReLU()) - return layers - - def __init__(self): - super(MyRecommendationModule, self).__init__() - layers = self.create_mlp(4, 4, 4) - self.bottom_layers = torch.nn.Sequential(*layers) - layers = self.create_mlp(3, 24, 24) - self.top_layers = torch.nn.Sequential(*layers) - self.embedding_layers = torch.nn.ModuleList() - el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) - self.embedding_layers.append(el) - for i in range(3): - el = torch.nn.EmbeddingBag( - 1000000, 4, mode="sum", sparse=True - ) - self.embedding_layers.append(el) - el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) - self.embedding_layers.append(el) - - def forward(self, a, b, offset): - x = self.bottom_layers(a) - y = [] - c = [] - for i in range(len(self.embedding_layers)): - temp = torch.randint(10, (8,)) - c.append(temp + b) - for i in range(len(self.embedding_layers)): - if i % 2 == 0: - y.append(self.embedding_layers[i](c[i], offset)) - else: - y.append( - self.embedding_layers[i]( - torch.randint(10, (8,)), offset - ) - ) - z = torch.cat([x] + y, dim=1) - p = self.top_layers(z) - return p - - m = MyRecommendationModule() - a = torch.rand(2, 4) - b = torch.randint(10, (8,)) - offset = torch.randint(1, (2,)) - traced = symbolic_trace(m) - graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset]) - devices = [ - Device("dev_0", 33000000, 0), - Device("dev_1", 33000000, 1), - Device("dev_2", 33000000, 2), - ] - partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) - partitioner = Partitioner() - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual( - traced(a, b, offset), module_with_submodules(a, b, offset) - ) - assert len(module_with_submodules.graph.nodes) == 24 - - def test_partition_latency(self): - class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - add_1 = a + torch.rand(4) - add_2 = add_1 + torch.rand(4) - linear_1 = self.linear(add_1) - add_3 = add_2 + linear_1 - add_4 = add_2 + add_3 - return add_4 - - def get_node_to_latency_mapping(fx_module: GraphModule): - """Given a fx module, generate node latency for each node - based on the size of each node - """ - node_to_latency_mapping: Dict[Node, NodeLatency] = {} - for node in fx_module.graph.nodes: - if node.op not in {"output", "placeholder", "get_attr"}: - if ( - node.size_bytes.total_size - == node.size_bytes.output_size - ): - node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, - 2.0 * node.size_bytes.total_size, - ) - else: - node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, - node.size_bytes.output_size, - ) - return node_to_latency_mapping - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - node_to_latency_mapping = get_node_to_latency_mapping(traced) - devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] - partitioner = Partitioner() - partitioner_config = PartitionerConfig(devices) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - self.assertEqual(traced(a), module_with_submodules(a)) - partitions = partitioner.partitions - partition_to_latency_mapping = get_partition_to_latency_mapping( - partitions, node_to_latency_mapping - ) - for p in partition_to_latency_mapping: - if p.partition_id == 0: - assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) - else: - assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) - transfer_rate_bytes_per_sec = 2 - critical_path_latency_sec = get_latency_of_partitioned_graph( - partitions, - partition_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - assert critical_path_latency_sec == 208.0 - - def test_cost_aware_partition(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - add_1 = a + torch.rand(4) - add_2 = add_1 + torch.rand(4) - linear_1 = self.linear(add_1) - add_3 = add_2 + torch.rand(4) - add_4 = add_2 + linear_1 - add_5 = add_3 + add_4 - return add_5 - - def get_node_to_latency_mapping(fx_module: GraphModule): - node_to_latency_mapping: Dict[Node, NodeLatency] = {} - for node in fx_module.graph.nodes: - if node.op not in {"output", "placeholder", "get_attr"}: - if ( - node.size_bytes.total_size - == node.size_bytes.output_size - ): - node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, 1 - ) - else: - node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, - node.size_bytes.output_size, - ) - return node_to_latency_mapping - - m = MyModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - devices = [ - Device("dev_0", 125, 0), - Device("dev_1", 125, 1), - Device("dev_2", 125, 2), - Device("dev_3", 125, 3), - ] - node_to_latency_mapping = get_node_to_latency_mapping(traced) - partitioner_config = PartitionerConfig( - devices, - mode=PartitionMode.cost_aware, - transfer_rate_bytes_per_sec=2, - node_to_latency_mapping=node_to_latency_mapping, - ) - partitioner = Partitioner() - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(traced(a), module_with_submodules(a)) - partitions = partitioner.partitions - partition_to_latency_mapping = get_partition_to_latency_mapping( - partitions, node_to_latency_mapping - ) - critical_path_latency_sec = get_latency_of_partitioned_graph( - partitions, - partition_to_latency_mapping, - partitioner_config.transfer_rate_bytes_per_sec, - ) - assert critical_path_latency_sec == 160.0 - - def test_aot_based_partition(self): - class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - self.b = torch.rand(4) - self.c = torch.rand(4) - - def forward(self, a): - add_1 = a + self.b - add_2 = self.c + add_1 - return add_2 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - node_to_partition_id = {} - partition_to_logical_devices = {} - count = 0 - graph_manipulation.get_size_of_all_nodes(traced, [a]) - for node in traced.graph.nodes: - if node.op not in {"placeholder", "get_attr", "output"}: - node_to_partition_id[node] = count - partition_to_logical_devices[count] = [0] - count += 1 - devices = [Device("dev_0", 200, 0)] - partitioner_config = PartitionerConfig( - devices=devices, - mode=PartitionMode.aot_based, - node_to_partition_mapping=node_to_partition_id, - partition_to_logical_device_mapping=partition_to_logical_devices, - ) - partitioner = Partitioner() - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(module_with_submodules(a), traced(a)) - for node in dag.nodes: - assert node.size_bytes == 48 - assert node.logical_device_ids == [0] - - def test_replace_target_nodes_with(self): - class testModule(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = testModule() - traced = symbolic_trace(m) - input1 = torch.randn(1) - input2 = torch.randn(1) - assert (input1 + input2) == traced(input1, input2) - graph_manipulation.replace_target_nodes_with( - fx_module=traced, - old_op="call_function", - old_target=operator.add, - new_op="call_function", - new_target=operator.mul, - ) - assert (input1 * input2) == traced(input1, input2) - - def test_saturate_host(self): - class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - add_1 = a + torch.rand(4) - add_2 = add_1 + torch.rand(4) - linear_1 = self.linear(add_1) - add_3 = add_2 + linear_1 - add_4 = add_2 + add_3 - return add_4 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - devices = [ - Device("dev_0", 200, 0), - Device("dev_1", 200, 1), - Device("dev_2", 100, 2), - Device("dev_3", 100, 3), - Device("dev_4", 200, 4), - Device("dev_5", 100, 5), - ] - partitioner = Partitioner() - # Without host saturation, the model will be split into two partitions. - # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes. - partitioner_config = PartitionerConfig(devices, saturate_host=True) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - self.assertEqual(traced(a), module_with_submodules(a)) - - partitions = partitioner.partitions - self.assertEqual(len(partitions), 2) - # With host saturation, partition 1 will be replicated to dev_4, and partition 2 - # will be replicated to dev_2. - self.assertEqual(partitions[0].logical_device_ids, [0, 4]) - self.assertEqual(partitions[1].logical_device_ids, [1, 2]) - - @skipIfNoTorchVision - def test_conv_bn_fusion(self): - rn18 = resnet18().eval() - traced = symbolic_trace(rn18) - fused = optimization.fuse(traced) - - self.assertTrue( - all( - not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules() - ) - ) - - N, C, H, W = 20, 3, 224, 224 - inp = torch.randn(N, C, H, W) - - self.assertEqual(fused(inp), rn18(inp)) - - def test_conv_bn_fusion_not_running_state(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) - self.bn = torch.nn.BatchNorm2d( - 64, - eps=1e-05, - momentum=0.1, - affine=True, - track_running_stats=False, - ) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return x - - model = M().eval() - - traced = symbolic_trace(model) - fused = optimization.fuse(traced) - inp = torch.randn([1, 32, 50, 50]) - - # bn need not be folded in conv - self.assertTrue( - any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) - ) - self.assertEqual(fused(inp), model(inp)) - - def test_call_to_assert_no_msg(self): - class M(torch.nn.Module): - def forward(self, a, b): - assert a == b - return a + b - - m = M() - traced = symbolic_trace_with_rewrite(m) - - # Make sure the graph is well-formed - traced.graph.lint() - - # Check the IR to make sure there's a call_function node with target == "Assert" - self.assertTrue( - any( - node.op == "call_function" and node.target == torch._assert - for node in traced.graph.nodes - ) - ) - - # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to - traced(3, 3) - with self.assertRaisesRegex(AssertionError, ""): - traced(3, 5) - - # Confirm that the output is correct - self.assertEqual(traced(3, 3), m(3, 3)) - - def test_meta_tracer(self): - class MetaTracerTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = torch.nn.Embedding( - num_embeddings=42, embedding_dim=16 - ) - self.layernorm = torch.nn.LayerNorm(16) - - def forward(self, x): - emb = self.emb(x) - emb = emb + torch.arange( - emb.shape[-1], dtype=torch.float, device=emb.device - ) - lol = self.layernorm(emb) - return ( - torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) - ) - - mttm = MetaTracerTestModule() - for BS in [15, 35]: - x = torch.zeros(BS, dtype=torch.long).random_(42) - meta_args = {"x": x.to(device="meta")} - gm = pippy.fx.experimental.meta_tracer.symbolic_trace( - mttm, meta_args=meta_args - ) - torch.testing.assert_close(gm(x), mttm(x)) - - # Test serialization/deserialization - with tempfile.TemporaryDirectory() as tmp_dir: - with open(f"{tmp_dir}/meta_module.pkl", "wb") as f: - pickle.dump(gm, f) - - with open(f"{tmp_dir}/meta_module.pkl", "rb") as f: - loaded = pickle.load(f) - - torch.testing.assert_close(loaded(x), mttm(x)) - - def test_call_to_assert_with_msg(self): - class M(torch.nn.Module): - def forward(self, a, b): - assert a == b, "test message" - return a + b - - m = M() - traced = symbolic_trace_with_rewrite(m) - - # Make sure the graph is well-formed - traced.graph.lint() - - # Check the IR to make sure there's a call_function node with target == "Assert" - self.assertTrue( - any( - node.op == "call_function" and node.target == torch._assert - for node in traced.graph.nodes - ) - ) - - # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to - traced(3, 3) - with self.assertRaisesRegex(AssertionError, "test message"): - traced(3, 5) - - # Confirm that the output is correct - self.assertEqual(traced(3, 3), m(3, 3)) - - def test_call_to_assert_with_empty_msg(self): - class M(torch.nn.Module): - def forward(self, a, b): - assert a == b, "" - return a + b - - m = M() - traced = symbolic_trace_with_rewrite(m) - - # Make sure the graph is well-formed - traced.graph.lint() - - # Check the IR to make sure there's a call_function node with target == "Assert" - self.assertTrue( - any( - node.op == "call_function" and node.target == torch._assert - for node in traced.graph.nodes - ) - ) - - # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to - traced(3, 3) - with self.assertRaisesRegex(AssertionError, ""): - traced(3, 5) - - # Confirm that the output is correct - self.assertEqual(traced(3, 3), m(3, 3)) - - def test_call_to_assert_with_multiline_message(self): - class M(torch.nn.Module): - def forward(self, a, b): - error_msg = """ -An error message with -terrible spacing - """ - assert a == b, error_msg - return a + b - - m = M() - traced = symbolic_trace_with_rewrite(m) - - # Make sure the graph is well-formed - traced.graph.lint() - - # Check the IR to make sure there's a call_function node with target == "Assert" - self.assertTrue( - any( - node.op == "call_function" and node.target == torch._assert - for node in traced.graph.nodes - ) - ) - - # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to - error_msg = """ -An error message with -terrible spacing - """ - traced(3, 3) - with self.assertRaisesRegex(AssertionError, error_msg): - traced(3, 5) - - # Confirm that the output is correct - self.assertEqual(traced(3, 3), m(3, 3)) - - def test_subgraph_creation(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x, y): - z = self.linear(x + self.param).clamp(min=0.0, max=1.0) - w = self.linear(y).clamp(min=0.0, max=1.0) - return z + w - - # symbolically trace model - my_module = MyModule() - my_module_traced = symbolic_trace(my_module) - - # random mod partitioning - partition_counter = 0 - NPARTITIONS = 3 - - # Add some random meta info to make sure it is kept around. - for node in my_module_traced.graph.nodes: - if node.op != "output": - node.meta["test_meta_info"] = True - - def mod_partition(node: Node): - nonlocal partition_counter - partition = partition_counter % NPARTITIONS - partition_counter = (partition_counter + 1) % NPARTITIONS - return partition - - # split module in module with submodules - module_with_submodules = split_module( - my_module_traced, my_module, mod_partition - ) - - # Check that test_meta_info was still on all nodes. - submodules = dict(module_with_submodules.named_modules()) - for node in module_with_submodules.graph.nodes: - if node.op == "call_module": - submod = submodules[node.target] - self.assertTrue(isinstance(submod, pippy.fx.GraphModule)) - for submod_node in submod.graph.nodes: - if submod_node.op != "output": - stored_op = submod_node.meta.get("test_meta_info") - self.assertTrue(stored_op is not None and stored_op) - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - - orig_out = my_module_traced(x, y) - submodules_out = module_with_submodules(x, y) - - self.assertEqual(orig_out, submodules_out) - - def test_split_module_kwargs_expansion(self): - class ModuleWithKwargsExpansion(torch.nn.Module): - def forward(self, x, **kwargs): - return x + kwargs["foo"] - - mod = ModuleWithKwargsExpansion() - traced = pippy.fx.symbolic_trace(mod) - - seen_getitem = False - - def split_callback(n): - nonlocal seen_getitem - split_idx = int(seen_getitem) - if n.target == operator.getitem: - seen_getitem = True - return split_idx - - split = split_module(traced, mod, split_callback) - - x = torch.randn(5, 3) - foo = torch.randn(5, 3) - torch.testing.assert_allclose(split(x, foo=foo), traced(x, foo=foo)) - - @skipIfNoTorchVision - def test_subgraph_trivial_resnet(self): - # Smoke test trivially splitting resnet into 1 partition works - # There was an issue before causing submodule names to be aliased - m = resnet18() - traced = symbolic_trace(m) - a = torch.rand(64, 3, 7, 7) - module_with_submodules = split_module(traced, m, lambda node: 0) - module_with_submodules(a) - - def test_split_module_default_arg(self): - class ModelToTrace(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin = torch.nn.Linear(512, 512) - - def forward(self, x, targets=None): - x = self.lin(x) - - if targets is not None: - x = x + targets - - return x - - mtt = ModelToTrace() - traced = pippy.fx.symbolic_trace(mtt, concrete_args={"targets": None}) - - split = split_module(traced, mtt, lambda node: 0) - - x = torch.randn(50, 512) - torch.testing.assert_allclose(split(x), traced(x)) - - def test_normalize_binary_operators(self): - ops_to_test = { - torch.add, - torch.mul, - torch.sub, - torch.div, - torch.floor_divide, - torch.remainder, - torch.eq, - torch.ne, - torch.lt, - torch.le, - torch.gt, - torch.ge, - } - - # Test Tensor/Tensor callsite - for op in ops_to_test: - - class WrapperMod(torch.nn.Module): - def forward(self, x, y): - return op(x, y) - - traced = symbolic_trace(WrapperMod()) - normalized = NormalizeOperators(traced).transform() - x, y = torch.randn(3, 4), torch.randn(3, 4) - torch.testing.assert_close(traced(x, y), normalized(x, y)) - self.assertFalse( - any(n.target in ops_to_test for n in normalized.graph.nodes) - ) - - # Test Tensor/scalar callsite - for op in ops_to_test: - - class WrapperMod(torch.nn.Module): - def forward(self, x): - return op(x, 42) - - traced = symbolic_trace(WrapperMod()) - normalized = NormalizeOperators(traced).transform() - x = torch.randn(3, 4) - torch.testing.assert_close(traced(x), normalized(x)) - self.assertFalse( - any(n.target in ops_to_test for n in normalized.graph.nodes) - ) - - @skipIfNoTorchVision - def test_normalize_args(self): - m = resnet18() - - class FunctionalTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - # `leaves` contains the set of standard `nn.Modules` that are not - # currently symbolically traceable. Ideally this set would be empty - leaves = set([torch.nn.BatchNorm2d]) - return type(m) in leaves - - traced = pippy.fx.GraphModule(m, FunctionalTracer().trace(m)) - - input = torch.randn(5, 3, 224, 224) - ref_outs = traced(input) - - ShapeProp(traced).propagate(input) - traced = NormalizeArgs(traced).transform() - - modules = dict(traced.named_modules()) - - for node in traced.graph.nodes: - if node.op == "call_function" and node.target != operator.add: - self.assertEqual(len(node.args), 0) - elif node.op == "call_module": - submod_class = modules[node.target].__class__ - nn_class = getattr(torch.nn, submod_class.__name__) - if submod_class == nn_class: - self.assertEqual(len(node.args), 0) - traced(input) - self.assertEqual(traced(input), ref_outs) - - def test_normalize_modules_exhaustive(self): - """ - Exhaustively test `Node.normalized_arguments` on all standard - torch.nn Module classes - """ - for test_params in module_tests + new_module_tests: - if "constructor" not in test_params: - constructor = getattr(torch.nn, test_params["module_name"]) - else: - constructor = test_params["constructor"] - - if "constructor_args" not in test_params: - args = () - else: - args = test_params["constructor_args"] - - mod = constructor(*args) - # Skip modules that are not standard `torch.nn` - # instances, including functionals. (functionals - # are tested in test_normalize_args) - if mod.__class__.__name__ not in dir(torch.nn): - continue - - if "input_fn" not in test_params: - inputs = torch.randn(test_params["input_size"]) - else: - inputs = test_params["input_fn"]() - - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - - params = ", ".join(f"v{i}" for i in range(len(inputs))) - - # Generate a class to wrap this standard `nn.Module` instance - test_classname = f"Test{mod.__class__.__name__}" - test_mod_code = f""" -class {test_classname}(torch.nn.Module): - def __init__(self, mod): - super().__init__() - self.mod = mod - - def forward(self, {params}): - return self.mod({params}) - """ - - gbls = {"torch": torch} - exec(test_mod_code, gbls) - - test_instance = gbls[test_classname](mod) - traced = symbolic_trace(test_instance) - - # Use `Node.normalized_arguments` to get a new set of arguments - # to feed to the Module. Then, rewrite the node to only take - # in those arguments as kwargs - modules = dict(traced.named_modules()) - for node in traced.graph.nodes: - if node.op == "call_module": - submod_class = modules[node.target].__class__ - nn_class = getattr(torch.nn, submod_class.__name__) - if submod_class == nn_class: - normalized_args = node.normalized_arguments(traced) - normalized_args2 = normalize_module( - traced, node.target, node.args, node.kwargs - ) - assert normalized_args == normalized_args2 - assert normalized_args - node.args = normalized_args.args - node.kwargs = normalized_args.kwargs - - traced.recompile() - - # These Modules have an RNG in their forward, so testing - # correctness by comparing outputs is not correct. Skip that - # check for these - stochastic_modules = { - "FractionalMaxPool2d", - "FractionalMaxPool3d", - "RReLU", - } - - if mod.__class__.__name__ not in stochastic_modules: - self.assertEqual(traced(*inputs), mod(*inputs)) - - traced = NormalizeArgs(symbolic_trace(test_instance)).transform() - modules = dict(traced.named_modules()) - for node in traced.graph.nodes: - if node.op == "call_module": - submod_class = modules[node.target].__class__ - nn_class = getattr(torch.nn, submod_class.__name__) - if submod_class == nn_class: - self.assertEqual(len(node.args), 0) - - def test_normalize_args_preserve_meta(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a): - return torch.add(a, 3) - - m = MyModule() - traced = symbolic_trace(m) - - for node in traced.graph.nodes: - if node.op == "call_function" and node.target == torch.add: - node.meta["my_key"] = 7 - break - else: - self.fail("Didn't find call_function torch.add") - - input = torch.randn(2, 3) - ShapeProp(traced).propagate(input) - traced = NormalizeArgs(traced).transform() - - for node in traced.graph.nodes: - if node.op == "call_function" and node.target == torch.add: - self.assertTrue("my_key" in node.meta) - self.assertEqual(node.meta["my_key"], 7) - break - else: - self.fail("Didn't find call_function torch.add") - - def test_normalize_args_perserve_type(self): - class MyModule(torch.nn.Module): - def forward(self, a: List[torch.Tensor]): - return torch.add(a[0], a[1]) - - m = MyModule() - traced = symbolic_trace(m) - traced = NormalizeArgs(traced).transform() - - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(node.type, List[torch.Tensor]) - - @skipIfNoTorchVision - def test_annotate_returns_with_schema(self): - m = resnet18() - - traced_modules = symbolic_trace(m) - traced_modules_annotated = AnnotateTypesWithSchema( - traced_modules - ).transform() - for node in traced_modules_annotated.graph.nodes: - if node.type is None: - check = (node.op, node.target) - self.assertIn( - check, - { - ("placeholder", "x"), - ("call_module", "maxpool"), - ("call_function", operator.add), - ("call_function", torch.flatten), - ("output", "output"), - }, - ) - - # Smoke test torchscript compilation since now we're emitting type annotations - torch.jit.script(traced_modules_annotated) - - class FunctionalTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - # `leaves` contains the set of standard `nn.Modules` that are not - # currently symbolically traceable. Ideally this set would be empty - leaves = set([torch.nn.BatchNorm2d]) - return type(m) in leaves - - traced_functionals = pippy.fx.GraphModule( - m, FunctionalTracer().trace(m) - ) - - traced_functionals_annotated = AnnotateTypesWithSchema( - traced_functionals - ).transform() - for node in traced_functionals_annotated.graph.nodes: - if node.type is None: - check = (node.op, node.target) - excluded_nodes = { - ("placeholder", "x"), - # Return type differs based on boolean dispatch :( - ("call_function", torch.nn.functional.max_pool2d), - ("output", "output"), - } - # AnnotateTypesWithSchema doesn't work with bound C++ functions - if not isinstance(node.target, BuiltinFunctionType): - self.assertIn(check, excluded_nodes) - - # Smoke test torchscript compilation since now we're emitting type annotations - torch.jit.script(traced_functionals_annotated) - - def test_subgraph_uniquename(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a, b, c, d): - add_1 = a + b - add_2 = add_1 + c - linear_1 = self.linear(add_1) - add_3 = add_2 + d - add_4 = add_2 + linear_1 - add_5 = add_3 + add_4 - return add_5 - - a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) - mm = MyModule() - traced = symbolic_trace(mm) - - def split_cb(node: pippy.fx.Node): - if node.name == "a" or node.name == "b" or node.name == "add": - return 0 - else: - return 1 - - module_with_submodule = split_module(traced, mm, split_cb) - self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d)) - - def test_split_qualname_mapping(self): - d_hid = 4 - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - x = torch.relu(x) - x = torch.mm(x, self.mm_param) - x = self.lin(x) - x = torch.relu(x) - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - return x - - my_module = ExampleCode() - my_module_traced = symbolic_trace(my_module) - - part_idx = 0 - - def split_callback(n: pippy.fx.Node): - nonlocal part_idx - if (n.op, n.target) == ("call_module", "lin"): - part_idx += 1 - return part_idx - - # split module in module with submodules - qualname_map: Dict[str, str] = {} - module_with_submodules = split_module( - my_module_traced, my_module, split_callback, qualname_map - ) - expected_qualname_map = {"submod_1.lin": "lin", "submod_2.lin": "lin"} - self.assertEqual(qualname_map, expected_qualname_map) - - def test_traceable_function_with_nonstandard_name(self): - def foo(x): - return torch.relu(x) - - traced = symbolic_trace_with_rewrite(foo) - - def test_to_folder(self): - class Test(torch.nn.Module): - def __init__(self): - super(Test, self).__init__() - self.W = torch.nn.Parameter(torch.randn(2)) - self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) - self.linear = torch.nn.Linear(2, 2) - self.attr = torch.randn(2) - self.register_buffer("attr2", torch.randn(2)) - self.register_buffer("attr3", torch.ones(2, dtype=torch.int32)) - - def forward(self, x): - return self.linear( - self.seq(self.W + self.attr + self.attr2 + self.attr3 + x) - ) - - mod = symbolic_trace(Test()) - module_name = "Foo" - import tempfile - from pathlib import Path - - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = Path(tmp_dir) - mod.to_folder(tmp_dir, module_name) - # Recipe taken from here: - # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - import importlib.util - - spec = importlib.util.spec_from_file_location( - module_name, tmp_dir / "__init__.py" - ) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - t = torch.randn(2, 2) - self.assertEqual(module.Foo()(t), mod(t)) - - def test_fetch(self): - attrs_for_lowering: Dict[str, List[str]] = { - "torch.nn.modules.conv.Conv2d": [ - "weight", - "bias", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "padding_mode", - ], - "torch.nn.modules.batchnorm.BatchNorm2d": [ - "weight", - "bias", - "running_mean", - "running_var", - "eps", - ], - } - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 2) - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, a): - a = self.conv(a) - a += a - return self.bn(a) - - mod = TestModule() - traced = symbolic_trace(mod) - lift_lowering_attrs_to_nodes(traced) - - for node in traced.graph.nodes: - if node.op == "call_module": - assert hasattr(node, "attrs_for_lowering") - para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] - - # node.attrs_for_lowering has an addition field of class name - assert len(para_list) + 1 == len(node.attrs_for_lowering) - for p_name in para_list: - assert p_name in node.attrs_for_lowering - - def test_merge_matmuls(self): - """ - A collection of test cases for pippy.fx.experimental.merge_matmul, - a graph transformation that merges matrix multiplication operations. - """ - - # Utility function for counting matmuls for test assertions. - def _count_matmuls(mod): - gm = pippy.fx.symbolic_trace(mod) - - num_matmuls = 0 - for node in gm.graph.nodes: - if node.target == torch.matmul: - num_matmuls += 1 - - return num_matmuls - - # Simple test case in which there are two matmuls of the same size to merge. - class SimpleMergeMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, x, y): - a = torch.matmul(x, self.rhs) - b = torch.matmul(y, self.rhs) - return a + b - - # Initialize inputs. - a = torch.randn(3, 3) - b = torch.randn(3, 3) - - # Initialize RHS for matmuls. - rhs = torch.randn(3, 4) - - # Construct SimpleMergeMatmulModule and call merge_matmul on it. - module = SimpleMergeMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(a, b) - after = opt_module(a, b) - before.allclose(after) - - # Basic graph structure check; original module should have 2 matmuls - # and optimized module should have 1. - self.assertEqual(_count_matmuls(module), 2) - self.assertEqual(_count_matmuls(opt_module), 1) - - # Test case in which there are multiple matmuls of different sizes to merge. - class FiveMergeMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, a, b, c, d, e): - s = torch.tensor([]) - matmuls = [] - - # For some reason using a list comprehension or for-loop for this - # doesn't work. - matmuls.append(torch.matmul(a, self.rhs)) - matmuls.append(torch.matmul(b, self.rhs)) - matmuls.append(torch.matmul(c, self.rhs)) - matmuls.append(torch.matmul(d, self.rhs)) - matmuls.append(torch.matmul(e, self.rhs)) - - for m in matmuls: - s += torch.sum(m) - - return s - - # Initialize inputs. - inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] - - # Initialize RHS. - rhs = torch.randn(5, 4) - - # Construct FiveMergeMatmulModule and call merge_matmul on it. - module = FiveMergeMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(*inputs) - after = opt_module(*inputs) - before.allclose(after) - - # Basic graph structure check; original module should have len(inputs) matmuls - # and optimized module should have 1. - self.assertEqual(_count_matmuls(module), len(inputs)) - self.assertEqual(_count_matmuls(opt_module), 1) - - # Simple test case in which two matmuls cannot be merged due to a data dependency between - # the LHS operands. - class UnmergeableMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, x): - a = torch.matmul(x, self.rhs) - a_abs = torch.abs(a) - b = torch.matmul(a_abs.transpose(1, 0), self.rhs) - return b - - # Initialize inputs. - a = torch.randn(3, 3) - - # Initialize RHS for matmuls. - rhs = torch.randn(3, 4) - - # Construct UnmergeableMatmulModule and call merge_matmul on it. - module = UnmergeableMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(a) - after = opt_module(a) - before.allclose(after) - - # Basic graph structure check; the number of matrix multiplcations should not have changed. - self.assertEqual(_count_matmuls(module), 2) - self.assertEqual(_count_matmuls(opt_module), 2) - - def test_type_matches(self): - should_be_equal = [ - (int, type(5)), - (numbers.Number, type(5)), - (numbers.Number, type(5.0)), - (int, type(torch.float)), - (Union[int, float], type(5)), - (Union[int, float], type(5.0)), - (List[int], type(5)), - (List[int], create_type_hint([int, int])), - (List[int], create_type_hint((int, int))), - ( - List[torch.Tensor], - create_type_hint([torch.Tensor, torch.Tensor]), - ), - ( - List[torch.Tensor], - create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), - ), - (torch.Tensor, torch.nn.Parameter), - ( - List[torch.Tensor], - create_type_hint([torch.nn.Parameter, torch.Tensor]), - ), - ( - List[torch.Tensor], - create_type_hint([torch.Tensor, torch.nn.Parameter]), - ), - ( - List[torch.Tensor], - create_type_hint((torch.Tensor, torch.Tensor)), - ), - ( - List[torch.Tensor], - create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), - ), - (torch.Tensor, torch.nn.Parameter), - ( - List[torch.Tensor], - create_type_hint((torch.nn.Parameter, torch.Tensor)), - ), - ( - List[torch.Tensor], - create_type_hint((torch.Tensor, torch.nn.Parameter)), - ), - (Optional[List[torch.Tensor]], List[torch.Tensor]), - (Optional[List[int]], List[int]), - ] - for sig_type, arg_type in should_be_equal: - self.assertTrue(type_matches(sig_type, arg_type)) - - should_fail = [ - (int, float), - (Union[int, float], str), - (List[torch.Tensor], List[int]), - ] - - for sig_type, arg_type in should_fail: - self.assertFalse(type_matches(sig_type, arg_type)) - - @skipIfNoMkldnn - def test_optimize_for_inference_cpu(self): - import torch.nn as nn - - class Foo(nn.Module): - def __init__(self): - super().__init__() - layers = [] - layers2 = [] - for _ in range(10): - layers.append(nn.Conv2d(3, 3, 1)) - layers.append(nn.BatchNorm2d(3)) - layers.append(nn.ReLU()) - - layers2.append(nn.Conv2d(3, 3, 1)) - layers2.append(nn.BatchNorm2d(3)) - layers2.append(nn.ReLU()) - self.model = nn.Sequential(*layers) - self.model2 = nn.Sequential(*layers2) - - def forward(self, x): - return self.model(x) + self.model2(x) - - (N, C, H, W) = ( - 1, - 3, - 224, - 224, - ) - inp = torch.randn(N, C, H, W) - with torch.no_grad(): - model = Foo().eval() - optimized_model = optimization.optimize_for_inference(model) - torch.testing.assert_close(model(inp), optimized_model(inp)) - - optimized_model2 = optimization.optimize_for_inference( - model, pass_config={"remove_dropout": False} - ) - torch.testing.assert_close(model(inp), optimized_model2(inp)) - - @skipIfNoTorchVision - @skipIfNoMkldnn - def test_optimize_for_inference_cpu_torchvision(self): - models = [ - torchvision.models.resnet18, - torchvision.models.resnet50, - torchvision.models.densenet121, - torchvision.models.shufflenet_v2_x1_0, - torchvision.models.vgg16, - torchvision.models.mobilenet_v2, - torchvision.models.mnasnet1_0, - torchvision.models.resnext50_32x4d, - ] - with torch.no_grad(): - for model_type in models: - model = model_type() - (C, H, W) = ( - 3, - 224, - 224, - ) - inp = torch.randn(3, C, H, W) - model(inp) - model.eval() - inp = torch.randn(1, C, H, W) - heuristic = optimization.gen_mkl_autotuner( - inp, iters=0, warmup=0 - ) - optimized_model = optimization.optimize_for_inference(model) - - orig_out = model(inp) - new_out = optimized_model(inp) - torch.testing.assert_close(orig_out, new_out) - - -class TestNormalizeOperators(JitTestCase): - @onlyCPU - @ops(op_db, allowed_dtypes=(torch.float,)) - def test_normalize_operator_exhaustive(self, device, dtype, op): - # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors) - fx_fail = { - "cat", - "stack", - "hstack", - "vstack", - "dstack", - "linalg.multi_dot", - } - sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) - if isinstance(op.op, torch._ops.OpOverload): - self.skipTest("normalize operator doesn't work on torch.ops") - for sample_input in sample_inputs_itr: - unsupported_arg_type = False - arg_values = [sample_input.input] + list(sample_input.args) - kwarg_values = sample_input.kwargs - arg_types = [] - kwarg_types = {} - - def jit_infer_type(v): - inferred_arg_type = torch._C._jit_try_infer_type(v) - assert inferred_arg_type.success() - t = _torchscript_type_to_python_type(inferred_arg_type.type()) - return t - - for v in arg_values: - if isinstance(v, torch.Tensor): - arg_types.append(type(v)) - else: - if isinstance(v, complex): - # Complex type not supported in FX - unsupported_arg_type = True - arg_types.append(jit_infer_type(v)) - - for k, v in kwarg_values.items(): - if isinstance(v, torch.Tensor): - kwarg_types[k] = type(v) - else: - if isinstance(v, complex): - # Complex type not supported in FX - unsupported_arg_type = True - kwarg_types[k] = jit_infer_type(v) - - if unsupported_arg_type: - continue - # Test normalize_function by itself - ref_out = op.op(*arg_values, **kwarg_values) - norm_args_and_kwargs = normalize_function( - op.op, arg_values, kwarg_values, arg_types, kwarg_types - ) - if norm_args_and_kwargs is None: - raise RuntimeError( - """ - FX failed to normalize op - add the op to the op_skip list. - A common reason is if your OpInfo was implemented with a lambda - - otherwise, file an issue - """ - ) - test_out = op.op( - *norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs - ) - self.assertEqual(test_out, ref_out) - - # Test normalized_arguments as part of FX - if op.name in fx_fail: - continue - param_names = [] - param_values = [] - fx_args = [] - for idx, v in enumerate(arg_values): - if isinstance(v, torch.Tensor): - param_names.append(f"arg_{idx}") - param_values.append(v) - fx_args.append(param_names[-1]) - else: - fx_args.append(f"{repr(v)}") - - for k, v in kwarg_values.items(): - if isinstance(v, torch.Tensor): - param_names.append(k) - param_values.append(v) - fx_args.append(f"{k} = {k}") - else: - fx_args.append(f"{k} = {repr(v)}") - - code = f""" -class TestModule(torch.nn.Module): - def forward(self, {', '.join(param_names)}): - return torch.{op.name}({', '.join(fx_args)}) - """ - - g = {"torch": torch, "inf": math.inf} - exec(code, g) - TestModule = g["TestModule"] - - m = TestModule() - traced = pippy.fx.symbolic_trace(m) - ref_out = traced(*param_values) - - for node in traced.graph.nodes: - if node.op == "call_function": - normalized_args = node.normalized_arguments( - traced, arg_types, kwarg_types - ) - assert normalized_args - node.args = normalized_args.args - node.kwargs = normalized_args.kwargs - traced.recompile() - - test_out = traced(*param_values) - self.assertEqual(test_out, ref_out) - - def test_normalize_quantized_eb(self): - target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets - args = ( - torch.empty((2, 3), dtype=torch.uint8), - torch.empty((2,), dtype=torch.int64), - torch.empty((2,), dtype=torch.int64), - ) - norm_args_and_kwargs = normalize_function( - target, args, normalize_to_only_use_kwargs=True - ) - self.assertTrue(norm_args_and_kwargs is not None) - self.assertEqual( - set(norm_args_and_kwargs.kwargs.keys()), - { - "weight", - "indices", - "offsets", - "scale_grad_by_freq", - "mode", - "pruned_weights", - "per_sample_weights", - "compressed_indices_mapping", - "include_last_offset", - }, - ) - self.assertEqual(norm_args_and_kwargs.args, tuple()) - - def test_normalize_args_op_overload(self): - for target in [ - torch.ops.aten.resize_as_.default, - torch.ops.aten.resize_as_, - ]: - inp1 = torch.rand([1]) - inp2 = torch.rand([4]) - args, kwargs = normalize_function( - target, - (inp1,), - {"the_template": inp2}, - normalize_to_only_use_kwargs=True, - ) - self.assertIs(kwargs["input"], inp1) - self.assertIs(kwargs["the_template"], inp2) - - -instantiate_device_type_tests(TestNormalizeOperators, globals()) - -if __name__ == "__main__": - run_tests() diff --git a/test/test_pipe.py b/test/test_pipe.py new file mode 100644 index 000000000..6259953e7 --- /dev/null +++ b/test/test_pipe.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import unittest + +import torch + +from pippy.IR import Pipe, pipe_split + + +d_hid = 512 +batch_size = 256 + +torch.manual_seed(0) + + +# Basic example +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x, y): + x = torch.mm(x, self.mm_param) + skip_connection = x + x = x + y + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param) + x = self.lin(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin(x) + x = torch.relu(x) + return x + + +# MLP example +class MLPModule(torch.nn.Module): + def __init__(self, d_hid): + super(MLPModule, self).__init__() + self.net1 = torch.nn.Linear(d_hid, d_hid) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + x = self.net1(x) + x = self.relu(x) + x = self.net2(x) + return x + + +class MultiMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.mlp0 = MLPModule(d_hid) + self.mlp1 = MLPModule(d_hid) + self.mlp2 = MLPModule(d_hid) + self.mlp3 = MLPModule(d_hid) + + def forward(self, x, y): + x = self.mlp0(x) + pipe_split() + x = self.mlp1(x) + pipe_split() + x = self.mlp2(x) + pipe_split() + x = self.mlp3(x) + return x - y + + +def run_worker(args, model_class): + mod = model_class() + x = torch.randn(batch_size, d_hid) + y = torch.randn(batch_size, d_hid) + + pipe = Pipe.from_tracing( + mod, + args.chunks, + example_args=(x, y), + ) + + ref_out = mod(x, y) + out = pipe(x, y)[0] + torch.testing.assert_close(out, ref_out) + print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}") + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + for model_class in [ExampleCode, MultiMLP]: + print("Testing ", model_class.__name__) + run_worker(args, model_class) + + +if __name__ == "__main__": + main() + + +class TestPipe(unittest.TestCase): + def test_pipe(self): + main(args) diff --git a/test/test_pipe_bwd.py b/test/test_pipe_bwd.py new file mode 100644 index 000000000..e9657621d --- /dev/null +++ b/test/test_pipe_bwd.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import unittest + +import torch +from pippy.IR import Pipe, pipe_split + +from pippy.microbatch import sum_reducer, TensorChunkSpec + + +d_hid = 512 +batch_size = 256 + +torch.manual_seed(0) + + +# Basic example +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin = torch.nn.Linear(d_hid, d_hid) + self.mse_loss = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, y): + x = torch.mm(x, self.mm_param) + skip_connection = x + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param) + x = self.lin(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin(x) + logits = torch.relu(x) + loss = self.mse_loss(x, y) + return logits, loss + + +# MLP example +class MLPModule(torch.nn.Module): + def __init__(self, d_hid): + super(MLPModule, self).__init__() + self.net1 = torch.nn.Linear(d_hid, d_hid) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + x = self.net1(x) + x = self.relu(x) + x = self.net2(x) + return x + + +class MultiMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.mlp0 = MLPModule(d_hid) + self.mlp1 = MLPModule(d_hid) + self.mlp2 = MLPModule(d_hid) + self.mlp3 = MLPModule(d_hid) + self.mse_loss = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, y): + x = self.mlp0(x) + pipe_split() + x = self.mlp1(x) + pipe_split() + x = self.mlp2(x) + pipe_split() + x = self.mlp3(x) + loss = self.mse_loss(x, y) + return x, loss + + +def run_worker(args, model_class): + mod = model_class() + x = torch.randn(batch_size, d_hid) + y = torch.randn(batch_size, d_hid) + + output_chunk_spec = ( + TensorChunkSpec(0), # logits + sum_reducer, # loss + ) + + pipe = Pipe.from_tracing( + mod, + args.chunks, + example_args=(x, y), + output_chunk_spec=output_chunk_spec, + ) + + ref_out = mod(x, y) + out = pipe(x, y) + torch.testing.assert_close(out, ref_out) + print(f"equivalence test passed loss={out[1]} ref_loss={ref_out[1]}") + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + for model_class in [ExampleCode, MultiMLP]: + print("Testing ", model_class.__name__) + run_worker(args, model_class) + + +if __name__ == "__main__": + main() + + +class TestPipeBwd(unittest.TestCase): + def test_pipe_bwd(self): + main(args) From 2146783ad84d9418e2de36308af3fb114d750537 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 29 Nov 2023 17:36:23 -0500 Subject: [PATCH 45/96] Use module based logger (#877) ## Description Use module based logger to enjoy automatic logging hierarchy. --- pippy/IR.py | 2 +- pippy/LoadModule.py | 4 +++- pippy/ModelSplit.py | 7 +++++-- pippy/SaveModule.py | 5 ++++- pippy/debug.py | 22 ++++++++++------------ pippy/hf/utils.py | 2 +- 6 files changed, 24 insertions(+), 18 deletions(-) diff --git a/pippy/IR.py b/pippy/IR.py index 62333c560..bb98d2827 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -1017,7 +1017,7 @@ def _trace_with_export( Callable[[torch.fx.GraphModule], torch.fx.GraphModule] ] = None, ) -> torch.fx.GraphModule: - logger.info("[PiPPy] Tracing model ...") + logger.info("Tracing model ...") try: torch._dynamo.allow_in_graph(pipe_split) traced: torch.fx.GraphModule = torch._export._export_to_torch_ir( diff --git a/pippy/LoadModule.py b/pippy/LoadModule.py index 95b4d3c84..bac0295b5 100644 --- a/pippy/LoadModule.py +++ b/pippy/LoadModule.py @@ -11,6 +11,8 @@ from pippy.utils import _get_binary_filename +logger = logging.getLogger(__name__) + TYPICAL_PREFIXES = [ "model", # facebook/opt-6.7b "transformer", # bigscience/bloom-7b1 @@ -58,7 +60,7 @@ def load_checkpoint( used_files = file_to_weights.keys() import time - logging.info( + logger.info( f"Timestamp {time.time():.2f} " f"Opening checkpoint: {used_files}" ) diff --git a/pippy/ModelSplit.py b/pippy/ModelSplit.py index 1fd5e32fe..f52353b40 100644 --- a/pippy/ModelSplit.py +++ b/pippy/ModelSplit.py @@ -7,6 +7,9 @@ from pippy.IR import pipe_split + +logger = logging.getLogger(__name__) + """ Analyze size of parameters/buffers used by each node in the graph Here node can be a `call_function` or a `call_module` @@ -45,7 +48,7 @@ def _analyze_node_size( node_param_sizes.setdefault(node, mod_param_sizes) for node, param_sizes in node_param_sizes.items(): - logging.debug(f"{node} has params: {param_sizes}") + logger.debug(f"{node} has params: {param_sizes}") return node_param_sizes @@ -185,7 +188,7 @@ def _split_into_nstages_equal_size( total_size = param_size + buffer_size per_stage_size = total_size // nstages - logging.debug( + logger.debug( f"Total model size: {total_size}, " f"per stage size: {per_stage_size}" ) diff --git a/pippy/SaveModule.py b/pippy/SaveModule.py index 966cc0b4a..158234b37 100644 --- a/pippy/SaveModule.py +++ b/pippy/SaveModule.py @@ -13,6 +13,9 @@ from pippy.IR import Pipe from pippy.utils import _get_binary_filename + +logger = logging.getLogger(__name__) + CKPT_INDEX_JSON_FILENAME = "pytorch_model.bin.index.json" DTYPE_SIZES = { @@ -105,7 +108,7 @@ def _save_index( # write index file atomically to avoid partial/corrupted writes _atomic_write(json_str, filepath) - logging.info(f"Saved index file to {filepath}") + logger.info(f"Saved index file to {filepath}") def _save_params(submod: torch.nn.Module, checkpoint_dir: str) -> None: diff --git a/pippy/debug.py b/pippy/debug.py index 4e96cf7d1..a1f95a35c 100644 --- a/pippy/debug.py +++ b/pippy/debug.py @@ -5,18 +5,16 @@ import torch -PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "OFF") - -if PIPPY_VERBOSITY == "DEBUG": - logging.getLogger("pippy").setLevel(logging.DEBUG) -elif PIPPY_VERBOSITY == "INFO": - logging.getLogger("pippy").setLevel(logging.INFO) -elif PIPPY_VERBOSITY == "OFF": - pass -else: - print(f"[PiPPy] Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") - -print(f"[PiPPy] Setting logging level to: {PIPPY_VERBOSITY}") +PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY") +if PIPPY_VERBOSITY not in [None, "WARNING", "INFO", "DEBUG"]: + logging.warning(f"Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") + PIPPY_VERBOSITY = None + +if PIPPY_VERBOSITY: + logging.getLogger("pippy").setLevel(PIPPY_VERBOSITY) + # It seems we need to print something to make the level setting effective + # for child loggers. Doing it here. + logging.warning(f"Setting PiPPy logging level to: {PIPPY_VERBOSITY}") def friendly_debug_info(v): diff --git a/pippy/hf/utils.py b/pippy/hf/utils.py index aa5072d79..7a8053f78 100644 --- a/pippy/hf/utils.py +++ b/pippy/hf/utils.py @@ -305,7 +305,7 @@ def inject_pipeline_forward( model: torch.nn.Module, pipe_driver: PipelineDriverBase, ): - logging.info( + logger.info( f"Inserting PiPPy pipeline forward into model {model._get_name()}" ) # Inject pipeline driver as a member object of original model From 92038fb2b624e439bce4530c466198c64b8276e6 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 30 Nov 2023 16:57:32 -0500 Subject: [PATCH 46/96] Migrate GPT-2 to new tracer (#875) ## Description Migrated GPT-2 example to work with new tracer based pippy. examples/hf/hf_utils.py contains utility to generate inputs for HuggingFace models. Model architecture: ``` GPT2ForSequenceClassification( (transformer): GPT2Model( (wte): Embedding(50257, 768) (wpe): Embedding(1024, 768) (drop): Dropout(p=0.1, inplace=False) (h): ModuleList( (0-11): 12 x GPT2Block( (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (attn): GPT2Attention( (c_attn): Conv1D() (c_proj): Conv1D() (attn_dropout): Dropout(p=0.1, inplace=False) (resid_dropout): Dropout(p=0.1, inplace=False) ) (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (mlp): GPT2MLP( (c_fc): Conv1D() (c_proj): Conv1D() (act): NewGELUActivation() (dropout): Dropout(p=0.1, inplace=False) ) ) ) (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) (score): Linear(in_features=768, out_features=2, bias=False) ) ``` ## Run ``` $ torchrun --nproc-per-node 4 pippy_gpt2.py ``` ## Output https://gist.github.com/kwen2501/b9ed6158d8d0dc90b16824aa6abd8d72 --- examples/hf/gpt2/README.md | 49 ------ examples/hf/gpt2/pippy_gpt2.py | 211 ------------------------- examples/hf/gpt2/pippy_sbatch.sh | 20 --- examples/hf/gpt2/pippy_wrapper.sh | 11 -- examples/hf/hf_utils.py | 250 ++++++++++++++++++++++++++++++ examples/hf/pippy_gpt2.py | 123 +++++++++++++++ 6 files changed, 373 insertions(+), 291 deletions(-) delete mode 100644 examples/hf/gpt2/README.md delete mode 100644 examples/hf/gpt2/pippy_gpt2.py delete mode 100755 examples/hf/gpt2/pippy_sbatch.sh delete mode 100755 examples/hf/gpt2/pippy_wrapper.sh create mode 100755 examples/hf/hf_utils.py create mode 100644 examples/hf/pippy_gpt2.py diff --git a/examples/hf/gpt2/README.md b/examples/hf/gpt2/README.md deleted file mode 100644 index 96137a994..000000000 --- a/examples/hf/gpt2/README.md +++ /dev/null @@ -1,49 +0,0 @@ -# PiPPy CUDA RPC Demo - -Requires 2 nodes with 8 GPUs on slurm partition `train`. - -Splits Huggingface GPT2 model into 14 submodules and runs it on 2 nodes with 8 GPUs each(16 workers total, 2 unused). - -Run command: -```commandline -sbatch pippy_sbatch.sh -``` -Sample output: -``` -10: rank = 10 host/pid = train-st-p3dn24xlarge-3/85935 - 8: rank = 8 host/pid = train-st-p3dn24xlarge-3/85939 -14: rank = 14 host/pid = train-st-p3dn24xlarge-3/85937 -13: rank = 13 host/pid = train-st-p3dn24xlarge-3/85932 -11: rank = 11 host/pid = train-st-p3dn24xlarge-3/85940 - 0: rank = 0 host/pid = train-st-p3dn24xlarge-2/66387 - 9: rank = 9 host/pid = train-st-p3dn24xlarge-3/85934 - 7: rank = 7 host/pid = train-st-p3dn24xlarge-2/66384 -12: rank = 12 host/pid = train-st-p3dn24xlarge-3/85933 - 2: rank = 2 host/pid = train-st-p3dn24xlarge-2/66386 - 6: rank = 6 host/pid = train-st-p3dn24xlarge-2/66385 -15: rank = 15 host/pid = train-st-p3dn24xlarge-3/85938 - 4: rank = 4 host/pid = train-st-p3dn24xlarge-2/66391 - 5: rank = 5 host/pid = train-st-p3dn24xlarge-2/66389 - 1: rank = 1 host/pid = train-st-p3dn24xlarge-2/66388 - 3: rank = 3 host/pid = train-st-p3dn24xlarge-2/66390 - 0: REPLICATE config: False -> MultiUseParameterConfig.TRANSMIT - 0: Instantiating GPT2 Pipeline - 0: /fsx/users/pbelevich/miniconda/envs/PiPPy39/lib/python3.9/site-packages/transformers/modeling_utils.py:2327: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:225.) - 0: x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - 0: /fsx/users/pbelevich/miniconda/envs/PiPPy39/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:193: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:225.) - 0: attn_weights = torch.matmul(query, key.transpose(-1, -2)) - 0: /fsx/users/pbelevich/miniconda/envs/PiPPy39/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:206: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:225.) - 0: attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) - 0: /fsx/users/pbelevich/miniconda/envs/PiPPy39/lib/python3.9/site-packages/transformers/models/gpt2/modeling_gpt2.py:222: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:225.) - 0: attn_output = torch.matmul(attn_weights, value) - 0: /fsx/users/pbelevich/miniconda/envs/PiPPy39/lib/python3.9/site-packages/transformers/activations.py:34: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:225.) - 0: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) - 0: /fsx/users/pbelevich/PiPPy/pippy/PipelineDriver.py:471: UserWarning: Running pipeline with 14 stages on world_size of 16. Remaining ranks will be idle. - 0: warnings.warn(f'Running pipeline with {len(executor_descriptors)} stages on world_size of {self.world_size}. ' - 0: Running GPT2 pipeline. NB: if this is too slow, set OMP_NUM_THREADS to a higher value - 0: Running reference pipeline - 0: /fsx/users/pbelevich/miniconda/envs/PiPPy39/lib/python3.9/site-packages/torch/testing/_deprecated.py:35: FutureWarning: torch.testing.assert_allclose() is deprecated since 1.12 and will be removed in 1.14. Use torch.testing.assert_close() instead. For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844. - 0: warnings.warn(msg, FutureWarning) - 0: equivalence test passed -0.00017547607421875 ref -0.00017547607421875 - 0: profiling run completed -0.00072479248046875 ref -0.00017547607421875 -``` diff --git a/examples/hf/gpt2/pippy_gpt2.py b/examples/hf/gpt2/pippy_gpt2.py deleted file mode 100644 index 67dc2fd75..000000000 --- a/examples/hf/gpt2/pippy_gpt2.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import inspect -import os -import time -from functools import reduce - -import torch -from transformers import GPT2LMHeadModel, GPT2Config - -import pippy.fx -from pippy import run_pippy -from pippy.IR import MultiUseParameterConfig, Pipe, PipeSplitWrapper, annotate_split_points -from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ - PipelineDriverBase -from pippy.events import EventsContext -from pippy.hf import PiPPyHFTracer -from pippy.microbatch import sum_reducer, TensorChunkSpec, Replicate -from pippy.visualizer import events_to_json - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - 'FillDrain': PipelineDriverFillDrain, - '1F1B': PipelineDriver1F1B, - 'Interleaved1F1B': PipelineDriverInterleaved1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def get_number_of_params(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def add_split_points(gpt2, decoders_per_rank): - for i in range(0, gpt2.config.n_layer // decoders_per_rank): - annotate_split_points(gpt2, {f'transformer.h.{i * decoders_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}) - annotate_split_points(gpt2, {f'transformer.ln_f': PipeSplitWrapper.SplitPoint.BEGINNING}) - return gpt2.config.n_layer // decoders_per_rank + 2 - - -def calc_flop(args, conf): - # https://arxiv.org/pdf/2104.04473.pdf page 8, formula 3 - B = args.batch_size - s = args.seq_length - l = conf.n_layer - h = conf.n_embd - V = conf.vocab_size - return 96 * B * s * l * h * h * (1 + s/6/h + V/16/l/h) - - -def run_gspmd(pp_ranks, args): - print(args) - MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if args.replicate else MultiUseParameterConfig.TRANSMIT - assert args.world_size >= 4, "This program requires at least 3 workers + 1 master" - - config = GPT2Config() - config.n_embd = args.n_embd or config.n_embd - config.n_layer = args.n_layer or config.n_layer - config.n_head = args.n_head or config.n_head - print("GPT-2 model instantiation started") - start = time.time() - gpt2 = GPT2LMHeadModel(config) - finish = time.time() - print(f"GPT-2 model instantiation finished in {(finish - start) / 60:1.2f} minutes") - gpt2.eval() - if args.rank == 0: - print(gpt2.config) - print(f"GPT-2 total number of params = {get_number_of_params(gpt2) // 10 ** 6}M") - print(gpt2) - - emb_head = 2 # embeddings + head - master_emb_head = 1 + emb_head # master + embeddings + head - decoders_per_rank = (gpt2.config.n_layer + (args.world_size - master_emb_head) - 1) // ( - args.world_size - master_emb_head) # a divider of gpt2.config.n_layer: [1, 2, 3, 4, 6, 12] - print(f"decoders_per_rank = {decoders_per_rank}") - number_of_workers = emb_head + gpt2.config.n_layer // decoders_per_rank # 3 + a divider of gpt2.config.n_layer: [4, 5, 6, 7, 9, 15] - print(f"number_of_workers = {number_of_workers}") - - all_worker_ranks = pp_ranks[ - pippy.utils.exclude_master : pippy.utils.exclude_master - + number_of_workers - ] - chunks = len(all_worker_ranks) - seq_length = args.seq_length - batch_size = args.batch_size * chunks - vocab_size = gpt2.config.vocab_size - - device = args.device - print("Using device:", device) - - gpt2_input_dict = { - 'input_ids': torch.empty(batch_size, seq_length, dtype=torch.long, device=device).random_(vocab_size), - 'labels': torch.empty(batch_size, seq_length, dtype=torch.long, device=device).random_(vocab_size), - 'position_ids': torch.arange(0, seq_length, dtype=torch.long, device=device)} - - sm_cnt = add_split_points(gpt2, decoders_per_rank) - assert sm_cnt == len(all_worker_ranks), f"sm_cnt = {sm_cnt} all_worker_ranks = {all_worker_ranks}" - - if args.rank == 0: - print(gpt2) - - input_names = gpt2_input_dict.keys() - sig = inspect.signature(gpt2.forward) - concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - - print('Instantiating GPT-2 Pipeline') - output_loss_value_spec = {'loss': True, 'logits': False, - 'past_key_values': [[False for _ in range(2)] for _ in range(12)]} - gpt2_pipe = Pipe.from_tracing(gpt2, MULTI_USE_PARAM_CONFIG, tracer=PiPPyHFTracer(), concrete_args=concrete_args, - output_loss_value_spec=output_loss_value_spec, deep_copy_module=False) - if args.rank == 0: - print(gpt2_pipe.split_gm) - - assert sm_cnt == len(list(gpt2_pipe.split_gm.children())) - - # Materialize model differently depending on run mode - if args.gspmd == 1: - print(f"Deferring stage init on device {device}") - gpt2_pipe.defer_stage_init(device) - # Make sure every rank has deferred its stage init before master creates the driver - pippy.utils.pp_group_barrier() - else: - gpt2_pipe.to(device) - - if args.rank != 0: - # Workers return here - return - - # gpt2_pipe(**gpt2_input_dict) - - for i, sm in enumerate(gpt2_pipe.split_gm.children()): - print(f"submod_{i} {get_number_of_params(sm) // 10 ** 6}M params") - - kwargs_chunk_spec = {'input_ids': TensorChunkSpec(0), 'labels': TensorChunkSpec(0), 'position_ids': Replicate} - output_chunk_spec = {'loss': sum_reducer, 'logits': TensorChunkSpec(0), - 'past_key_values': [[TensorChunkSpec(0) for _ in range(2)] for _ in range(config.n_layer)]} - pipe_driver: PipelineDriverBase = schedules[args.schedule](gpt2_pipe, chunks, - len(all_worker_ranks), - all_ranks=all_worker_ranks, - kwargs_chunk_spec = kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - this_file_name = os.path.splitext(os.path.basename(__file__))[0] - - if args.warmup_batches > 0: - print(f'Running {args.warmup_batches} warm-up batches') - for i in range(args.warmup_batches): - pipe_driver(**gpt2_input_dict) - - FLOP = calc_flop(args, config) - print(f"FLOP per iteration {FLOP}") - print(f'Running GPT-2 pipeline {args.batches} batches for TFLOP/s/GPU measurement.') - - start = time.time() - for i in range(args.batches): - pipe_driver(**gpt2_input_dict) - finish = time.time() - total_latency = finish - start - print(f"TFLOP/s/GPU: {FLOP/1e12/total_latency}") - - print(f'Running GPT-2 pipeline {args.batches} batches for visualization.') - batches_events_contexts = [] - for i in range(args.batches): - pipe_driver(**gpt2_input_dict) - batches_events_contexts.append(pipe_driver.retrieve_events()) - all_events_contexts: EventsContext = reduce(lambda c1, c2: EventsContext().update(c1).update(c2), - batches_events_contexts, EventsContext()) - pipe_visualized_filename = f"{this_file_name}_visualized.json" - with open(pipe_visualized_filename, "w") as f: - f.write(events_to_json(all_events_contexts)) - print(f"Saved {pipe_visualized_filename}") - print('Finished.') - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 16))) - parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) - parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) - parser.add_argument('-s', '--schedule', type=str, default=list(schedules.keys())[0], choices=schedules.keys()) - parser.add_argument('--replicate', type=int, default=int(os.getenv("REPLICATE", '0'))) - parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) - parser.add_argument('--exclude_master', type=int, default=0, choices=[0, 1]) - - parser.add_argument('--record_mem_dumps', type=int, default=0, choices=[0, 1]) - parser.add_argument('--checkpoint', type=int, default=0, choices=[0, 1]) - - parser.add_argument('--rpc_timeout', type=int, default=1800) - parser.add_argument('--num_worker_threads', type=int, default=512) - - parser.add_argument('--batch_size', type=int, default=1) - parser.add_argument('--warmup_batches', type=int, default=0) - parser.add_argument('--batches', type=int, default=1) - parser.add_argument('--seq_length', type=int, default=16) - - parser.add_argument('--n_embd', type=int, default=None) - parser.add_argument('--n_layer', type=int, default=None) - parser.add_argument('--n_head', type=int, default=None) - - parser.add_argument('--gspmd', type=int, default=0, choices=[0, 1]) - - args = parser.parse_args() - - run_pippy(run_gspmd, args) diff --git a/examples/hf/gpt2/pippy_sbatch.sh b/examples/hf/gpt2/pippy_sbatch.sh deleted file mode 100755 index 521c72a14..000000000 --- a/examples/hf/gpt2/pippy_sbatch.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -#SBATCH --job-name=gpt2_pippy - -#SBATCH --open-mode=append - -#SBATCH --partition=train - -#SBATCH --nodes=2 - -#SBATCH --ntasks-per-node=8 - -#SBATCH --cpus-per-task=12 - -#SBATCH --gpus-per-node=8 - -#SBATCH --time=1:00:00 - -srun --label pippy_wrapper.sh diff --git a/examples/hf/gpt2/pippy_wrapper.sh b/examples/hf/gpt2/pippy_wrapper.sh deleted file mode 100755 index c9d6fb2bb..000000000 --- a/examples/hf/gpt2/pippy_wrapper.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -export MASTER_PORT=29500 -export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) -export LOCAL_RANK=${SLURM_LOCALID} -export CUDA_VISIBLE_DEVICES=${SLURM_LOCALID} -export WORLD_SIZE=${SLURM_NTASKS} -export RANK=${SLURM_PROCID} - -python -u pippy_gpt2.py --record_mem_dumps=0 --checkpoint=0 diff --git a/examples/hf/hf_utils.py b/examples/hf/hf_utils.py new file mode 100755 index 000000000..1e4ee7924 --- /dev/null +++ b/examples/hf/hf_utils.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# This file contains code to generate inputs for all of the models in the +# support list. The `generate_inputs_for_model` function is extracted from +# pytorch/benchmarks/dynamo/huggingface.py + +#!/usr/bin/env python3 +import importlib +import logging +import subprocess +import sys + +import torch + + +log = logging.getLogger(__name__) + + +def pip_install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + +# Disable the flake warnings for the imports. Flake8 does not provide a way to +# disable just warning for the entire file. Disabling flake8 entirely. +# flake8: noqa +imports = [ + "AlbertForPreTraining", + "AutoConfig", + "AutoModelForCausalLM", + "AutoModelForMaskedLM", + "AutoModelForSeq2SeqLM", + "BigBirdConfig", + "BlenderbotForConditionalGeneration", + "BlenderbotModel", + "BlenderbotSmallForConditionalGeneration", + "BlenderbotSmallModel", + "CLIPModel", + "CLIPVisionModel", + "ElectraForPreTraining", + "GPT2ForSequenceClassification", + "GPTJForSequenceClassification", + "GPTNeoForSequenceClassification", + "HubertForSequenceClassification", + "LxmertForPreTraining", + "LxmertForQuestionAnswering", + "MarianForCausalLM", + "MarianModel", + "MarianMTModel", + "PegasusForConditionalGeneration", + "PegasusModel", + "ReformerConfig", + "ViTForImageClassification", + "ViTForMaskedImageModeling", + "ViTModel", +] + + +try: + mod = importlib.import_module("transformers") + for cls in imports: + if not hasattr(mod, cls): + raise ModuleNotFoundError +except ModuleNotFoundError: + print("Installing HuggingFace Transformers...") + pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers") +finally: + for cls in imports: + exec(f"from transformers import {cls}") + + +def get_sequence_length(model_cls, model_name): + if model_name.startswith(("Blenderbot",)): + seq_length = 128 + elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")): + seq_length = 1024 + elif model_name in ("AllenaiLongformerBase", "BigBird"): + seq_length = 1024 + elif model_name.startswith("OPT"): + seq_length = 2048 + elif "Reformer" in model_name: + seq_length = 4096 + elif model_name.startswith( + ( + "Albert", + "Deberta", + "Layout", + "Electra", + "XLNet", + "MegatronBert", + "Bert", + "Roberta", + ) + ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"): + seq_length = 512 + elif model_name in ("TrOCRForCausalLM"): + seq_length = 256 + elif model_name.startswith("MobileBert"): + seq_length = 128 + elif model_name.startswith("Wav2Vec2"): + # If too short, will fail with something like + # ValueError: `mask_length` has to be smaller than `sequence_length`, + # but got `mask_length`: 10 and `sequence_length`: 9` + seq_length = 10000 # NB: a more realistic size is 155136 + else: + log.info( + f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" + ) + seq_length = 128 + return seq_length + + +def generate_inputs_for_model( + model_cls, model, model_name, bs, device, include_loss_args=False +): + # TODO - Check if following values are representative + num_choices = 3 + num_visual_features = 42 + seq_length = get_sequence_length(model_cls, model_name) + vocab_size = model.config.vocab_size + + if model_name.startswith("Wav2Vec2"): + # TODO: If we add more input_values style models, try to work this + # into the overall control flow + target_length = 100 + return { + "input_values": torch.randn((bs, seq_length), device=device), + # Added because that's what the example training script has + "attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)), + "labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)), + } + + if model_name.endswith("MultipleChoice"): + input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length)) + elif model_name.startswith("Roberta"): + input = rand_int_tensor(device, 0, 1, (bs, seq_length)) + else: + input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length)) + + if "Bart" in model_name: + input[:, -1] = model.config.eos_token_id + + input_dict = {"input_ids": input} + + if ( + model_name.startswith("T5") + or model_name.startswith("M2M100") + or model_name.startswith("MT5") + or model_cls + in [ + BlenderbotModel, + BlenderbotSmallModel, + BlenderbotForConditionalGeneration, + BlenderbotSmallForConditionalGeneration, + PegasusModel, + PegasusForConditionalGeneration, + MarianModel, + MarianMTModel, + ] + ): + input_dict["decoder_input_ids"] = input + + if model_name.startswith("Lxmert"): + visual_feat_dim, visual_pos_dim = ( + model.config.visual_feat_dim, + model.config.visual_pos_dim, + ) + input_dict["visual_feats"] = torch.randn( + bs, num_visual_features, visual_feat_dim + ) + input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim) + + if include_loss_args: + if model_name.endswith("PreTraining"): + if model_cls in [ElectraForPreTraining, LxmertForPreTraining]: + input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length)) + else: + label_name = ( + "sentence_order_label" + if model_cls in [AlbertForPreTraining] + else "next_sentence_label" + ) + input_dict["labels"] = ( + rand_int_tensor(device, 0, vocab_size, (bs, seq_length)), + ) + input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,)) + elif model_name.endswith("QuestionAnswering"): + input_dict["start_positions"] = rand_int_tensor( + device, 0, seq_length, (bs,) + ) + input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,)) + elif ( + model_name.endswith("MaskedLM") + or model_name.endswith("HeadModel") + or model_name.endswith("CausalLM") + or model_name.endswith("DoubleHeadsModel") + ): + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size, (bs, seq_length) + ) + elif model_name.endswith("TokenClassification"): + input_dict["labels"] = rand_int_tensor( + device, 0, model.config.num_labels - 1, (bs, seq_length) + ) + elif model_name.endswith("MultipleChoice"): + input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,)) + elif model_name.endswith("SequenceClassification"): + input_dict["labels"] = rand_int_tensor( + device, 0, model.config.num_labels - 1, (bs,) + ) + elif model_name.endswith("NextSentencePrediction"): + input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,)) + elif model_name.endswith("ForConditionalGeneration"): + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size - 1, (bs, seq_length) + ) + elif model_name in EXTRA_MODELS: + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size, (bs, seq_length) + ) + else: + raise NotImplementedError( + f"Class {model_name} unsupported for training test " + ) + + return input_dict + + +def rand_int_tensor(device, low, high, shape): + return torch.randint( + low, + high, + shape, + device=device, + dtype=torch.int64, + requires_grad=False, + ) + + +def get_number_of_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def calc_flop(args, conf): + # https://arxiv.org/pdf/2104.04473.pdf page 8, formula 3 + B = args.batch_size + s = args.seq_length + l = conf.n_layer + h = conf.n_embd + V = conf.vocab_size + return 96 * B * s * l * h * h * (1 + s/6/h + V/16/l/h) diff --git a/examples/hf/pippy_gpt2.py b/examples/hf/pippy_gpt2.py new file mode 100644 index 000000000..060212277 --- /dev/null +++ b/examples/hf/pippy_gpt2.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_gpt2.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import GPT2ForSequenceClassification, GPT2Config + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(gpt2, nranks): + decoders_per_rank = (gpt2.config.n_layer + nranks - 1) // nranks + print(f"decoders_per_rank = {decoders_per_rank}") + nstages = 1 + for i in range(1, gpt2.config.n_layer // decoders_per_rank): + annotate_split_points( + gpt2, + {f'transformer.h.{i * decoders_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}, + ) + nstages += 1 + assert nstages == nranks, f"nstages = {nstages} nranks = {nranks}" + + +def run(args): + # Model configs + config = GPT2Config() + config.n_embd = args.n_embd or config.n_embd + config.n_layer = args.n_layer or config.n_layer + config.n_head = args.n_head or config.n_head + print("Using device:", args.device) + + # Create model + model_class = GPT2ForSequenceClassification + model_name = "GPT2ForSequenceClassification" + gpt2 = model_class(config) + gpt2.to(args.device) + gpt2.eval() + if args.rank == 0: + print(gpt2.config) + print(f"GPT-2 total number of params = {get_number_of_params(gpt2) // 10 ** 6}M") + print(gpt2) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, gpt2, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(gpt2, args.world_size) + + # Create pipeline + gpt2_pipe = Pipe.from_tracing( + gpt2, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + assert len(list(gpt2_pipe.split_gm.children())) == args.world_size + if args.rank == 0: + for i, sm in enumerate(gpt2_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + gpt2_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + # Note: this specific example requires: 1) a batch size that is divisible by + # the number of chunks; 2) the division result (i.e. chunk size) must be 1, + # otherwise padding token must be provided too (see GPT-2's forward function) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + parser.add_argument('--n_embd', type=int, default=None) + parser.add_argument('--n_layer', type=int, default=None) + parser.add_argument('--n_head', type=int, default=None) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From d4de1d76b6e973a58677b4f95e73919efb829b7f Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 30 Nov 2023 17:06:45 -0500 Subject: [PATCH 47/96] Add pipeline schedules script (#876) Adding `PipelineStage` and pipeline schedules to source control under PiPPy for the OSS use. File can be run as-is with `python test_pipeline_schedule.py` or with torchrun Files: - The Schedule and PipelineStage components are in the PipelineSchedule.py file - The main in the test file (does not conform to python unit tests or run in CI yet). --- pippy/PipelineSchedule.py | 460 +++++++++++++++++++++++++++++++++ test/test_pipeline_schedule.py | 241 +++++++++++++++++ 2 files changed, 701 insertions(+) create mode 100644 pippy/PipelineSchedule.py create mode 100644 test/test_pipeline_schedule.py diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py new file mode 100644 index 000000000..094bfdcbb --- /dev/null +++ b/pippy/PipelineSchedule.py @@ -0,0 +1,460 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import logging +from collections import deque +from typing import Deque, List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.profiler import record_function + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +logger.addHandler(handler) + + +class PipelineStage(nn.Module): + def __init__( + self, + module: nn.Module, + stage_id: int, + num_stages: int, + rank: int, + world_size: int, + meta_input: torch.Tensor, + device: torch.device, + ): + super().__init__() + self.rank = rank + self.stage_id = stage_id + self.is_first_stage = stage_id == 0 + self.is_last_stage = stage_id == num_stages - 1 + self.num_stages = num_stages + # When we materialize the model partition on cuda, we call reset_parameters() if it is available + self.module = module.to(device) + + meta_output = self.module(meta_input) + self.fwd_input = torch.empty_like(meta_input, device=device) + self.fwd_output = None + self.fwd_output_grads = torch.empty_like(meta_output, device=device) + self.fwd_outputs_for_backward: Deque[ + Tuple[torch.tensor, torch.tensor] + ] = deque() + + self.prev_stage = (rank - 1) % world_size + self.next_stage = (rank + 1) % world_size + + self.fwd_recv_queue = None + self.bwd_recv_queue = None + + self.requests: List[dist.P2POp] = [] + logger.info( + f"finished pipeline stage init, {self.stage_id=}, {self.is_first_stage=}, {self.is_last_stage=}, {self.num_stages=}, {self.fwd_input.shape=}, {self.fwd_output_grads.shape=}" + ) + + def init_p2p_neighbors(self): + """ + Set up p2p communitors between previous and next stages + by sending a dummy tensor. + + If this is used, must be called for all pipeline stages. + """ + ops = [] + recv_tensor = torch.zeros(1, device="cuda") + send_tensor = torch.ones(1, device="cuda") + # forward + if not self.is_first_stage: + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_stage)) + if not self.is_last_stage: + ops.append(dist.P2POp(dist.isend, send_tensor, self.next_stage)) + + # backward + if not self.is_first_stage: + ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_stage)) + if not self.is_last_stage: + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_stage)) + + return True + + def get_fwd_recv_ops(self) -> List[dist.P2POp]: + if self.is_first_stage: + return [] + return [dist.P2POp(dist.irecv, self.fwd_input, self.prev_stage)] + + def get_fwd_send_ops(self) -> List[dist.P2POp]: + if self.is_last_stage: + return [] + return [dist.P2POp(dist.isend, self.fwd_output, self.next_stage)] + + def forward(self, input_data, is_first_mb, is_last_mb): + logger.info( + f"[{self.rank} FORWARD {self.stage_id}] is_first_mb {is_first_mb} is_last_mb {is_last_mb}" + ) + if self.is_first_stage: + self.fwd_input = input_data + + # this is needed when we access the gradients for this in backward() + self.fwd_input.requires_grad = True + self.fwd_input.retain_grad() + + # perform forward pass on module + self.fwd_output = self.module(self.fwd_input) + + output_for_backward = ( + self.compute_loss() if self.is_last_stage else self.fwd_output + ) + + # we store a ref to the input/output pair for this forward to be later used by the corresponding backward + self.fwd_outputs_for_backward.append( + (self.fwd_input, output_for_backward) + ) + + return self.fwd_output + + def get_bwd_send_ops(self) -> List[dist.P2POp]: + if self.is_first_stage: + return [] + assert self.fwd_input.grad is not None, "grad must be valid" + return [dist.P2POp(dist.isend, self.fwd_input.grad, self.prev_stage)] + + def get_bwd_recv_ops(self) -> List[dist.P2POp]: + if self.is_last_stage: + return [] + return [dist.P2POp(dist.irecv, self.fwd_output_grads, self.next_stage)] + + def sync_recv_backward_inputs(self) -> None: + ops = self.get_bwd_recv_ops() + if ops: + dist.batch_isend_irecv(ops).pop().wait() + + def _wait_backward_inputs(self): + assert ( + self.bwd_recv_queue is not None + ), "Waiting for backward input without enqueueing one" + self.bwd_recv_queue.wait() + self.bwd_recv_queue = None + return self.fwd_output_grads + + def backward(self, is_first_mb, is_last_mb): + logger.info( + f"[{self.rank} BACKWARD {self.stage_id}] is_first_mb {is_first_mb} is_last_mb {is_last_mb}" + ) + + if self.is_last_stage: + fwd_inputs, loss = self.fwd_outputs_for_backward.popleft() + else: + fwd_inputs, fwd_outputs = self.fwd_outputs_for_backward.popleft() + + # Compute gradients + if self.is_last_stage: + torch.autograd.backward(loss, retain_graph=True) + else: + torch.autograd.backward( + fwd_outputs, self.fwd_output_grads, retain_graph=True + ) + + return fwd_inputs + + def compute_loss(self): + if self.fwd_output is None: + raise RuntimeError("forward() must be called before compute_loss()") + # TODO: use a real loss function passed in + return self.fwd_output.mean() + + +class PipelineScheduleGPipe: + def __init__(self, stage: PipelineStage): + self._stage = stage + + def step(self, microbatches): + for i, mb in enumerate(microbatches): + with record_function(f"Forward {i}"): + is_last_mb = i == len(microbatches) - 1 + + ops = self._stage.get_fwd_recv_ops() + if ops: + dist.batch_isend_irecv(ops).pop().wait() + + self._stage.forward( + mb, is_first_mb=i == 0, is_last_mb=is_last_mb + ) + + ops = self._stage.get_fwd_send_ops() + if ops: + dist.batch_isend_irecv(ops) + + logger.info( + f"{self._stage.stage_id} forward {i} finished, microbatch: {mb.shape}" + ) + + for i, _ in enumerate(microbatches): + with record_function(f"Backward {i}"): + ops = self._stage.get_bwd_recv_ops() + if ops: + dist.batch_isend_irecv(ops).pop().wait() + + self._stage.backward( + is_first_mb=i == 0, + is_last_mb=i == len(microbatches) - 1, + ) + + ops = self._stage.get_bwd_send_ops() + if ops: + dist.batch_isend_irecv(ops) + + logger.info(f"{self._stage.stage_id} backward {i} finished") + + +class PipelineScheduleLoopedBFS: + def __init__(self, stages: List[PipelineStage]): + self._stages = stages + + def step(self, microbatches): + for s, stage in enumerate(self._stages): + for i, mb in enumerate(microbatches): + with record_function(f"Stage {s} Forward"): + is_last_mb = i == len(microbatches) - 1 + + ops = stage.get_fwd_recv_ops() + if ops: + dist.batch_isend_irecv(ops).pop().wait() + + stage.forward(mb, is_first_mb=i == 0, is_last_mb=is_last_mb) + + ops = stage.get_fwd_send_ops() + if ops: + dist.batch_isend_irecv(ops) + + for stage in reversed(self._stages): + for i in range(len(microbatches)): + with record_function(f"Stage {stage.stage_id} Backward"): + ops = stage.get_bwd_recv_ops() + if ops: + dist.batch_isend_irecv(ops).pop().wait() + + stage.backward( + is_first_mb=i == 0, + is_last_mb=i == len(microbatches) - 1, + ) + + ops = stage.get_bwd_send_ops() + if ops: + dist.batch_isend_irecv(ops) + + +class PipelineScheduleLoopedDFS: + def __init__(self, stages: List[PipelineStage], n_microbatch, pp_id, n_pp): + assert ( + n_microbatch % n_pp == 0 + ), f"Looped DFS schedule requires microbatch_size ({n_microbatch}) to be a multiple of n_pp ({n_pp})" + + self.stages = stages + self.n_microbatch = n_microbatch + + self.n_local_stages = len(stages) + self.total_stages = self.n_local_stages * n_pp + # world_size + self.n_pp = n_pp + + self.stage_id_to_global_stage_id = [ + (i * n_pp) + pp_id for i in range(self.n_local_stages) + ] + + # pp_id is the same as local rank within the PP dimension + self.pp_id = pp_id + + # number of sequences (chunks) + self.seq_size = n_pp + + # warmup steps for latest pp stage is trivial to compute + # increment warmup_steps by 2 for each hop away + self.warmup_steps = (len(stages) - 1) * self.seq_size + self.warmup_steps += 2 * ((n_pp - 1) - pp_id) + self.forward_steps = len(stages) * n_microbatch + self.total_steps = self.warmup_steps + (len(stages) * n_microbatch) + logger.info( + f"pp_id {pp_id} warmup_steps {self.warmup_steps} forward_steps {self.forward_steps} total_steps {self.total_steps}" + ) + + def step(self, microbatches): + """ + # n_loop = n_stage / n_pp + # run microbatches in sequences of NPp + + schedule operates at the rank level + + highest rank has a warmup (F only) count of [len(stages) - 1] * seq_size + each hop away from highest rank adds 2 warmup stages + - one happened before highest rank's warmup started, + - one waiting for backward result to trickle down from highest rank + dist_from_highest = (worldsize - 1) - rank + + total_steps = warmup_steps + (num_stages * num_microbatch) + + + Rank 0: 0F 0F 0F 0F 2F 2F 2F 2F + Rank 1: 1F 1F 1F 1F 3F3B 3F 3F 3F + """ + + def minibatch_index(step): + # Given the step index, find the corresponding minibatch index. + + # equivalent to a triple nested loop like this + # for sequence_id in range(self.seq_size): + # for stage in self.stages: + # for microbatch_within_sequence: + # ... + # step: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + # index:0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7 + return (step % self.seq_size) + self.seq_size * int( + step / (self.seq_size * self.n_local_stages) + ) + + def stage_index(step): + # step: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + # index:0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1 + return int((step / self.seq_size) % self.n_local_stages) + + """ + + my theory was that the hang could be fixed if I orchestrate the recvs after the sends from the schedule side, but i should probably + see if i can prove what caused the hang before i work on it further + """ + logger.info( + f"rank {self.pp_id} - minibatch_index {[minibatch_index(step) for step in range(self.total_steps)]}" + ) + logger.info( + f"rank {self.pp_id} - stage_index {[stage_index(step) for step in range(self.total_steps)]}" + ) + + forward_batched_op_handle: Optional[dist.Work] = None + backward_batched_op_handle: Optional[dist.Work] = None + + # edge case for first stage on each rank we need to call receive, recv for future microbatches will be fetched after fwd + # TODO: move this to its own class? `OneTimeUseRecv`? + forward_first_recv: Optional[List[dist.P2POp]] = self.stages[ + 0 + ].get_fwd_recv_ops() + + # edge case for the last stage on each rank we need to call receive, recv for future microbatches will be fetched after bwd + backward_first_recv: Optional[List[dist.P2POp]] = self.stages[ + -1 + ].get_bwd_recv_ops() + + backward_stages = list(reversed(self.stages)) + for step in range(self.total_steps): + mb_id_fwd = minibatch_index(step) + fwd_stage_id = stage_index(step) + forward_stage = self.stages[fwd_stage_id] + fwd_stage_id_next = None + forward_stage_next = None + + backward_step = step - self.warmup_steps + mb_id_bwd = minibatch_index(backward_step) + bwd_stage_id = stage_index(backward_step) + bwd_stage_id_next = None + backward_stage_next = None + backward_stage = backward_stages[bwd_stage_id] + + # info for next stages + if step < self.total_steps: + fwd_stage_id_next = stage_index(step + 1) + forward_stage_next = self.stages[fwd_stage_id_next] + bwd_stage_id_next = stage_index(backward_step + 1) + backward_stage_next = backward_stages[bwd_stage_id_next] + + if step < self.forward_steps: + if forward_first_recv: + logger.info( + f"rank {self.pp_id} - forward edge case for first stage" + ) + dist.batch_isend_irecv(forward_first_recv).pop().wait() + forward_first_recv = None + + if forward_batched_op_handle: + logger.info( + f"rank: {self.pp_id} - waiting on batched_op_handle before fwd" + ) + forward_batched_op_handle.wait() + forward_batched_op_handle = None + + with record_function(f"Stage {forward_stage.stage_id} Forward"): + logger.info( + f"pp_id {self.pp_id} step {step} forward_stage {forward_stage.stage_id} mb_id {mb_id_fwd}" + ) + forward_stage.forward( + microbatches[mb_id_fwd], + is_first_mb=mb_id_fwd == 0, + is_last_mb=mb_id_fwd == len(microbatches) - 1, + ) + + requests: List[dist.P2POp] = [] + + # send output activations if this is not the last stage + ops = forward_stage.get_fwd_send_ops() + requests.extend(ops) + + # add recv for the NEXT stage, do not do this for last stage + if forward_stage_next is not None: + ops = forward_stage_next.get_fwd_recv_ops() + if mb_id_fwd != len(microbatches) - 1: + requests.extend(ops) + + if requests: + logger.info( + f"rank: {self.pp_id}, current stage_id {self.stage_id_to_global_stage_id[fwd_stage_id]}, - {[(req.op, req.peer) for req in requests]}" + ) + forward_batched_op_handle = dist.batch_isend_irecv( + requests + ).pop() + + if step >= self.warmup_steps: + if backward_first_recv: + logger.info( + f"rank {self.pp_id} - backward edge case for last stage" + ) + dist.batch_isend_irecv(backward_first_recv).pop().wait() + backward_first_recv = None + + if backward_batched_op_handle: + logger.info( + f"rank: {self.pp_id} - waiting on batched_op_handles before bwd" + ) + backward_batched_op_handle.wait() + backward_batched_op_handle = None + + with record_function( + f"Stage {backward_stage.stage_id} Backward" + ): + logger.info( + f"pp_id {self.pp_id} step {step}/{self.total_steps} backward_step {backward_step} backward_stage_id {backward_stage.stage_id} mb_id {mb_id_bwd}" + ) + backward_stage.backward( + is_first_mb=mb_id_bwd == 0, + is_last_mb=mb_id_bwd == len(microbatches) - 1, + ) + + requests = [] + + # send bwd grad if this is not the first stage + ops = backward_stage.get_bwd_send_ops() + requests.extend(ops) + + # add recv for the NEXT stage, do not do this for first stage + if backward_stage_next is not None: + ops = backward_stage_next.get_bwd_recv_ops() + if mb_id_bwd != len(microbatches) - 1: + requests.extend(ops) + + if requests: + logger.info( + f"rank: {self.pp_id} - {[(req.op, req.peer) for req in requests]}" + ) + backward_batched_op_handle = dist.batch_isend_irecv( + requests + ).pop() + + logger.info("Step exiting") diff --git a/test/test_pipeline_schedule.py b/test/test_pipeline_schedule.py new file mode 100644 index 000000000..9b18c0c07 --- /dev/null +++ b/test/test_pipeline_schedule.py @@ -0,0 +1,241 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +""" +SINGLE HOST: + +python test_pipeline_schedule.py + +or + +with torchrun (1x2, 1 host with 2 processes): +torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=2 test_pipeline_schedule.py + +MULTIPLE HOSTS: + +torchrun --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR --nnodes=$NUM_NODES --nproc-per-node=$NUM_TRAINERS test_pipeline_schedule.py + +e.g. (2x2, 2 hosts with 2 processes) +torchrun --rdzv-backend=c10d --rdzv-endpoint=node1.example.com:29400 --nnodes=2 --nproc-per-node=2 test_pipeline_schedule.py +""" + +import argparse +import logging +import os + +from datetime import timedelta + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from pippy.PipelineSchedule import ( + PipelineScheduleGPipe, + PipelineScheduleLoopedBFS, + PipelineScheduleLoopedDFS, + PipelineStage, +) + +logger = logging.getLogger(__name__) + + +class MLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + out_dim: int, + ): + super().__init__() + self.wi = nn.Linear(dim, hidden_dim, bias=False) + self.wh1 = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.wh2 = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.wh3 = nn.Linear(hidden_dim, hidden_dim, bias=False) + self.wo = nn.Linear(hidden_dim, out_dim, bias=False) + self.gelu_act = nn.GELU(approximate="tanh") + + def forward(self, x): + a = self.wi(x) + a = self.wh1(a) + a = self.wh2(a) + a = self.wh3(a) + b = self.gelu_act(a) + c = self.wo(b) + return c + + +def setup(local_rank, world_size): + # If this is a child process (i.e., its PID is not the same as the PID of the process that started this script) + if os.getppid() != os.getpid(): + set_up_logging(local_rank) + + # initialize the process group + logger.info(f"init for rank {local_rank}") + dist.init_process_group("nccl", timeout=timedelta(seconds=20)) + if torch.distributed.is_initialized(): + torch.cuda.set_device(local_rank) + + logger.info(f"finish init for rank {local_rank}") + + +def main(**kwargs): + torch.manual_seed(42) + print(f"MY KWARGS ARE {kwargs}") + rank = kwargs["rank"] + local_rank = kwargs["local_rank"] + world_size = kwargs["world_size"] + device = torch.device(kwargs["device"]) + + setup(local_rank, world_size) + logger.info( + f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size}, Device {device} main ======" + ) + + input_dim = 4000 + hidden_dim = 8000 + output_dim = 4000 + + module_list = torch.nn.ModuleList( + modules=[ + MLP(input_dim, hidden_dim, output_dim) for i in range(world_size) + ] + ) + microbatch_size = 8 + global_batch_size = 64 + assert global_batch_size % microbatch_size == 0 + n_microbatches = int(global_batch_size / microbatch_size) + n_pp = world_size + + x = torch.randn([microbatch_size, input_dim]).to("meta") + + stage_model = PipelineStage( + module_list[rank], rank, world_size, rank, world_size, x, device + ) + stage_model.init_p2p_neighbors() + + stage_model_looped = [ + PipelineStage( + module_list[rank], + stage_id=(world_size * i) + rank, + num_stages=world_size * world_size, + rank=rank, + world_size=world_size, + meta_input=x, + device=device, + ) + for i in range(world_size) + ] + x_cuda_empty = torch.empty_like(x, device="cuda") + microbatches = [ + torch.randn_like(x_cuda_empty) for _ in range(n_microbatches) + ] + + for schedule in kwargs["schedules"]: + logger.info(f"====== Rank {rank} running schedule {schedule} ======") + if schedule == "gpipe": + pipeline = PipelineScheduleGPipe(stage_model) + elif schedule == "looped_bfs": + pipeline = PipelineScheduleLoopedBFS(stage_model_looped) + elif schedule == "looped_dfs": + pipeline = PipelineScheduleLoopedDFS( + stage_model_looped, + n_microbatch=n_microbatches, + pp_id=rank, + n_pp=n_pp, + ) + + logger.info(f"====== Rank {rank} profile ======") + + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + # with record_function(schedule): + pipeline.step(microbatches) + + # TODO - default should be no profiling. + """if not kwargs["no_trace"]: + trace_dir = kwargs["trace_dir"] + if not os.path.exists(trace_dir): + os.mkdir(trace_dir) + prof.export_chrome_trace(f"{trace_dir}/{schedule}_rank{rank}_trace.json") + """ + logger.info(f"====== Rank {rank} finished {schedule} ======") + + +def main_wrapper(rank, local_rank, world_size, kwargs): + rank = int(rank) + world_size = int(world_size) + if local_rank is None: + local_rank = rank + local_rank = int(local_rank) + + os.environ["RANK"] = str(rank) + main(rank=rank, local_rank=local_rank, world_size=world_size, **kwargs) + + +def set_up_logging(rank, log_level=logging.INFO): + """Set up logging""" + logger.setLevel(log_level) + handler = logging.StreamHandler() + handler.setLevel(log_level) + + # TODO: seeing double logging due to global logging setup in + # - fx/passes/utils/matcher_utils.py + + # class FstringFormatter(logging.Formatter): + # def format(self, record): + # return f"[{rank}][{record.levelname}][{self.formatTime(record)}][{os.path.basename(__file__)}:{record.lineno}]:{record.getMessage()}" + + # formatter = FstringFormatter() + # handler.setFormatter(formatter) + # logger.addHandler(handler) + + +if __name__ == "__main__": + rank = os.environ.get("RANK", None) + local_rank = os.environ.get("LOCAL_RANK", None) + world_size = os.environ.get("WORLD_SIZE", None) + master_addr = os.environ.get("MASTER_ADDR", None) + master_port = os.environ.get("MASTER_PORT", None) + + parser = argparse.ArgumentParser(description="Pipeline Stages Runner") + parser.add_argument("--no_trace", action="store_true") + parser.add_argument("--trace_dir", type=str, default="./traces") + parser.add_argument( + "--schedules", + type=str, + nargs="+", + choices=["gpipe", "looped_bfs", "looped_dfs"], + default=["gpipe", "looped_bfs", "looped_dfs"], + ) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + kwargs = vars(args) + print(kwargs) + + if ( + rank is None + or local_rank is None + or world_size is None + or master_addr is None + ): + # single host code path + master_port = "23456" + master_addr = "localhost" + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + n_gpus = 4 + world_size = n_gpus + os.environ["WORLD_SIZE"] = str(world_size) + print( + f"Torchrun was not used. Spawning {world_size} processes on {master_addr}:{master_port}" + ) + mp.spawn( + main_wrapper, + args=( + None, + world_size, + kwargs, + ), + nprocs=world_size, + ) + else: + # multihost code path (ran with torchrun) + main_wrapper(rank, local_rank, world_size, kwargs) From 852f91ca11a4148fde0f882798b77b7848e835b1 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 30 Nov 2023 23:15:57 -0500 Subject: [PATCH 48/96] Add Albert example (#879) ## Description Add Albert example. Model description: https://huggingface.co/docs/transformers/model_doc/albert Model architecture: ``` AlbertForMaskedLM( (albert): AlbertModel( (embeddings): AlbertEmbeddings( (word_embeddings): Embedding(30000, 128, padding_idx=0) (position_embeddings): Embedding(512, 128) (token_type_embeddings): Embedding(2, 128) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0, inplace=False) ) (encoder): AlbertTransformer( (embedding_hidden_mapping_in): Linear(in_features=128, out_features=4096, bias=True) (albert_layer_groups): ModuleList( (0): AlbertLayerGroup( (albert_layers): ModuleList( (0): AlbertLayer( (full_layer_layer_norm): LayerNorm((4096,), eps=1e-12, elementwise_affine=True) (attention): AlbertAttention( (query): Linear(in_features=4096, out_features=4096, bias=True) (key): Linear(in_features=4096, out_features=4096, bias=True) (value): Linear(in_features=4096, out_features=4096, bias=True) (attention_dropout): Dropout(p=0, inplace=False) (output_dropout): Dropout(p=0, inplace=False) (dense): Linear(in_features=4096, out_features=4096, bias=True) (LayerNorm): LayerNorm((4096,), eps=1e-12, elementwise_affine=True) ) (ffn): Linear(in_features=4096, out_features=16384, bias=True) (ffn_output): Linear(in_features=16384, out_features=4096, bias=True) (activation): NewGELUActivation() (dropout): Dropout(p=0, inplace=False) ) ) ) ) ) ) (predictions): AlbertMLMHead( (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dense): Linear(in_features=4096, out_features=128, bias=True) (decoder): Linear(in_features=128, out_features=30000, bias=True) (activation): NewGELUActivation() ) ) ``` ## Run ``` torchrun --nproc-per-node 4 pippy_albert.py ``` ## Output https://gist.github.com/kwen2501/3fd89d3f3f0c743d1e726c71c32a35e7 --- examples/hf/pippy_albert.py | 112 ++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 examples/hf/pippy_albert.py diff --git a/examples/hf/pippy_albert.py b/examples/hf/pippy_albert.py new file mode 100644 index 000000000..816ece388 --- /dev/null +++ b/examples/hf/pippy_albert.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_albert.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import AlbertForMaskedLM, AlbertConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(albert, nranks): + albert_layer_fqn = "albert.encoder.albert_layer_groups.0.albert_layers.0" + annotate_split_points( + albert, {albert_layer_fqn: PipeSplitWrapper.SplitPoint.BEGINNING}) + # Because of the for loop structure in albert's forward method, this will + # split the albert model `num_hidden_layers` times (hence + # `num_hidden_layers`+1 stages) + + +def run(args): + # Model configs + config = AlbertConfig() + config.num_hidden_layers = 3 + print("Using device:", args.device) + + # Create model + model_class = AlbertForMaskedLM + model_name = "AlbertForMaskedLM" + albert = model_class(config) + albert.to(args.device) + albert.eval() + if args.rank == 0: + print(albert.config) + print(f"Total number of params = {get_number_of_params(albert) // 10 ** 6}M") + print(albert) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, albert, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(albert, args.world_size) + + # Create pipeline + albert_pipe = Pipe.from_tracing( + albert, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(albert_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(albert_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + albert_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 9b82ff9383180d9eebaa167c86a8fefdb4b87f46 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 1 Dec 2023 17:51:20 -0500 Subject: [PATCH 49/96] Migrate BERT example (#880) ## Description BERT architecture: ``` BertForMaskedLM( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0-11): 12 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) ) (cls): BertOnlyMLMHead( (predictions): BertLMPredictionHead( (transform): BertPredictionHeadTransform( (dense): Linear(in_features=768, out_features=768, bias=True) (transform_act_fn): GELUActivation() (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (decoder): Linear(in_features=768, out_features=30522, bias=True) ) ) ) ``` ## Output https://gist.github.com/kwen2501/e99608e9a410c6ff15923a4989c26fc8 --- examples/hf/bert/pippy_bert.py | 107 ----------------------------- examples/hf/bert/pippy_sbatch.sh | 20 ------ examples/hf/bert/pippy_wrapper.sh | 11 --- examples/hf/pippy_bert.py | 109 ++++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 138 deletions(-) delete mode 100644 examples/hf/bert/pippy_bert.py delete mode 100755 examples/hf/bert/pippy_sbatch.sh delete mode 100755 examples/hf/bert/pippy_wrapper.sh create mode 100644 examples/hf/pippy_bert.py diff --git a/examples/hf/bert/pippy_bert.py b/examples/hf/bert/pippy_bert.py deleted file mode 100644 index 59e95f319..000000000 --- a/examples/hf/bert/pippy_bert.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -from functools import reduce - -import torch -from transformers import BertLMHeadModel, BertConfig - -import pippy -import pippy.fx -from pippy import run_pippy -from pippy.microbatch import sum_reducer, TensorChunkSpec -from pippy.events import EventsContext -from pippy.hf import PiPPyHFTracer -from pippy.visualizer import events_to_json - - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def get_number_of_params(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def print_submod_sizes(model_pipe): - total_params = 0 - for i, sm in enumerate(model_pipe.split_gm.children()): - params = get_number_of_params(sm) - print(f"submod_{i} {params // 10 ** 6}M params") - total_params += params - print(f"total {total_params // 10 ** 6}M params") - - -def run_master(_, args): - print("Using schedule:", args.schedule) - - bert = BertLMHeadModel(BertConfig(is_decoder=True)) - print(bert.config) - print(f"BERT total number of params = {get_number_of_params(bert) // 10 ** 6}M") - # print(bert) - - chunks = args.chunks or args.world_size - batches = 1 - bs = 1 * chunks - seq_length = 16 - - device = args.device - bert.to(device) - - bert_input_dict = { - 'input_ids': torch.zeros(bs, seq_length, dtype=torch.long, device=device).random_(bert.config.vocab_size), - 'labels': torch.zeros(bs, seq_length, dtype=torch.long, device=device).random_(bert.config.vocab_size), - 'attention_mask': torch.ones(bs, seq_length, device=device)} - # bert(**bert_input_dict) - - concrete_args = pippy.create_default_args(bert, - except_keys=bert_input_dict.keys()) - - output_chunk_spec = {"loss": sum_reducer, - "logits": TensorChunkSpec(0)} - - split_policy = pippy.split_into_equal_size(args.world_size) - - print('Instantiating BERT Pipeline') - pipe_driver = pippy.compile( - bert, - num_ranks=args.world_size, - num_chunks=chunks, - schedule=args.schedule, - split_policy=split_policy, - tracer=PiPPyHFTracer(), - checkpoint=bool(args.checkpoint), - output_chunk_spec=output_chunk_spec, - concrete_args=concrete_args, - ) - print_submod_sizes(pipe_driver.pipe) - - this_file_name = os.path.splitext(os.path.basename(__file__))[0] - - print('Running BERT pipeline.') - pipe_visualized_filename = f"{this_file_name}_visualized.json" - batches_events_contexts = [] - for i in range(batches): - pipe_driver(**bert_input_dict) - batches_events_contexts.append(pipe_driver.retrieve_events()) - - all_events_contexts: EventsContext = reduce(lambda c1, c2: EventsContext().update(c1).update(c2), - batches_events_contexts, EventsContext()) - with open(pipe_visualized_filename, "w") as f: - f.write(events_to_json(all_events_contexts)) - print(f"Saved {pipe_visualized_filename}") - print('Finished') - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) - parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) - parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) - parser.add_argument('-s', '--schedule', type=str, default="FillDrain") - parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) - parser.add_argument('--checkpoint', type=int, default=0, choices=[0, 1]) - parser.add_argument("--chunks", type=int, default=None) - args = parser.parse_args() - - run_pippy(run_master, args) diff --git a/examples/hf/bert/pippy_sbatch.sh b/examples/hf/bert/pippy_sbatch.sh deleted file mode 100755 index a0fdda65b..000000000 --- a/examples/hf/bert/pippy_sbatch.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -#SBATCH --job-name=bert_pippy - -#SBATCH --open-mode=append - -#SBATCH --partition=train - -#SBATCH --nodes=2 - -#SBATCH --ntasks-per-node=8 - -#SBATCH --cpus-per-task=12 - -#SBATCH --gpus-per-node=8 - -#SBATCH --time=1:00:00 - -srun --label pippy_wrapper.sh diff --git a/examples/hf/bert/pippy_wrapper.sh b/examples/hf/bert/pippy_wrapper.sh deleted file mode 100755 index 6a73e36e0..000000000 --- a/examples/hf/bert/pippy_wrapper.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -export MASTER_PORT=29500 -export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) -export LOCAL_RANK=${SLURM_LOCALID} -export CUDA_VISIBLE_DEVICES=${SLURM_LOCALID} -export WORLD_SIZE=${SLURM_NTASKS} -export RANK=${SLURM_PROCID} - -python -u pippy_bert.py --record_mem_dumps=0 --checkpoint=0 diff --git a/examples/hf/pippy_bert.py b/examples/hf/pippy_bert.py new file mode 100644 index 000000000..7d67d6077 --- /dev/null +++ b/examples/hf/pippy_bert.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_bert.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import BertForMaskedLM, BertConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(bert, nranks): + layers_per_rank = bert.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + bert, {f"bert.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = BertConfig() + print("Using device:", args.device) + + # Create model + model_class = BertForMaskedLM + model_name = "BertForMaskedLM" + bert = model_class(config) + bert.to(args.device) + bert.eval() + if args.rank == 0: + print(bert.config) + print(f"Total number of params = {get_number_of_params(bert) // 10 ** 6}M") + print(bert) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, bert, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(bert, args.world_size) + + # Create pipeline + bert_pipe = Pipe.from_tracing( + bert, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(bert_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(bert_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + bert_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 1164df7cdd006b270af29485ae037edbca69e380 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 1 Dec 2023 23:10:23 -0500 Subject: [PATCH 50/96] Remove stale examples (#882) Remove gspmd, TorchDynamo, resnet and some old impl for HF examples --- examples/TorchDynamo/pippy_dynamo.py | 186 ----- examples/gspmd/pippy_gspmd.py | 125 --- examples/hf/language-modeling/README.md | 188 ----- .../hf/language-modeling/requirements.txt | 5 - examples/hf/language-modeling/run_clm.py | 633 --------------- .../language-modeling/run_clm_no_trainer.py | 658 --------------- examples/hf/language-modeling/run_mlm.py | 661 ---------------- .../language-modeling/run_mlm_no_trainer.py | 703 ---------------- examples/hf/text-classification/README.md | 203 ----- .../hf/text-classification/requirements.txt | 7 - examples/hf/text-classification/run_glue.py | 670 ---------------- .../run_glue_no_trainer.py | 635 --------------- examples/hf/text-classification/run_xnli.py | 437 ---------- examples/hf/translation/README.md | 211 ----- examples/hf/translation/requirements.txt | 7 - examples/hf/translation/run_translation.py | 749 ------------------ .../translation/run_translation_no_trainer.py | 743 ----------------- examples/hf/translation/t5-3b_config.json | 58 -- examples/resnet/.gitignore | 1 - examples/resnet/local_resnet.py | 69 -- examples/resnet/pippy_resnet.py | 161 ---- examples/resnet/pippy_sbatch.sh | 20 - examples/resnet/pippy_wrapper.sh | 13 - examples/resnet/resnet.py | 150 ---- 24 files changed, 7293 deletions(-) delete mode 100644 examples/TorchDynamo/pippy_dynamo.py delete mode 100644 examples/gspmd/pippy_gspmd.py delete mode 100644 examples/hf/language-modeling/README.md delete mode 100644 examples/hf/language-modeling/requirements.txt delete mode 100755 examples/hf/language-modeling/run_clm.py delete mode 100755 examples/hf/language-modeling/run_clm_no_trainer.py delete mode 100755 examples/hf/language-modeling/run_mlm.py delete mode 100755 examples/hf/language-modeling/run_mlm_no_trainer.py delete mode 100644 examples/hf/text-classification/README.md delete mode 100644 examples/hf/text-classification/requirements.txt delete mode 100755 examples/hf/text-classification/run_glue.py delete mode 100644 examples/hf/text-classification/run_glue_no_trainer.py delete mode 100755 examples/hf/text-classification/run_xnli.py delete mode 100644 examples/hf/translation/README.md delete mode 100644 examples/hf/translation/requirements.txt delete mode 100755 examples/hf/translation/run_translation.py delete mode 100644 examples/hf/translation/run_translation_no_trainer.py delete mode 100644 examples/hf/translation/t5-3b_config.json delete mode 100644 examples/resnet/.gitignore delete mode 100644 examples/resnet/local_resnet.py delete mode 100644 examples/resnet/pippy_resnet.py delete mode 100755 examples/resnet/pippy_sbatch.sh delete mode 100755 examples/resnet/pippy_wrapper.sh delete mode 100644 examples/resnet/resnet.py diff --git a/examples/TorchDynamo/pippy_dynamo.py b/examples/TorchDynamo/pippy_dynamo.py deleted file mode 100644 index 45700ed7f..000000000 --- a/examples/TorchDynamo/pippy_dynamo.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import unittest - -import torch -import torch.autograd.profiler_legacy -import torch.fx -# TorchDynamo is moved into PyTorch as of PyTorch 2.0 -import torch._dynamo as dynamo - -import pippy -import pippy.fx -from pippy.IR import Pipe, pipe_split -from pippy import run_pippy - -PROFILING_ENABLED = True - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def inspect_split_module( - pipe: Pipe, - expected_stages: int = -1, -): - gm: pippy.fx.GraphModule = pipe.split_gm - # Check returned number of stages - nstages = len(list(gm.children())) - if expected_stages > 0: - assert ( - nstages == expected_stages - ), f"Model is split into {nstages} instead of {expected_stages} stages" - - print(f"\n======= GraphModule after Auto-split =======") - print(gm) - - for i, submod in enumerate(gm.children()): - print(f"\n======= Child module {i} =======") - print(submod) - - -def run_master(_, args): - print("Using schedule:", args.schedule) - - # Ask Dynamo to let PiPPy annotation stay in graph - dynamo.allow_in_graph(pipe_split) - - # Define a compiler backend made by PiPPy for use by Dynamo - # The backend comprising: - # - pippy.compile - # The driver is return as a compiled runtime callable, which will be used in the actual data execution - def my_pippy_compiler(gm: torch.fx.GraphModule, example_inputs, **kwargs): - print("\n============= my_pippy_compiler() called with FX graph =============") - gm.graph.print_tabular() - - # Create PipelineDriver - pipe_driver = pippy.compile( - gm, - args.world_size, - num_chunks=1, - schedule=args.schedule, - checkpoint=bool(args.checkpoint), - ) - - inspect_split_module(pipe_driver.pipe, args.world_size) - - # Return a runtime Callable - # This PipelineDriver is a distributed runtime - return pipe_driver - - # Model parameters - d_hid = 512 - bs = 503 - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) - - # Decorate with Dynamo here, or - # explicitly call optimize in the main code. - # We do the latter for zero change on the model, hence commenting out the decoration here - # @dynamo.optimize(my_pippy_compiler) - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - pipe_split() - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - pipe_split() - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - pipe_split() - x = self.lin(x) - x = torch.relu(x) - return x - - # Create model as usual - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - ref_out = ec(ec_input) - - # Optimize and distribute model using Dynamo + PiPPy - ec_pipe = dynamo.optimize(my_pippy_compiler)(ec) - - print(f"\n======= Runtime tests =======") - # This would already be output returned by PiPPy's distributed pipeline - pipe_out = ec_pipe(ec_input) - - # Check correctness - torch.testing.assert_close(pipe_out, ref_out) - print( - f'equivalence test passed {torch.sum(pipe_out)} ref {torch.sum(ref_out)}' - ) - - # Profiling run - # This run would not trigger compilation - # We can also change the size to test dynamic shape support - # ec_input = torch.randn(bs + 10, d_hid, device=args.device) - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_out = ec_pipe(ec_input) - print( - f'profiling run completed {torch.sum(pipe_out)} ref {torch.sum(ref_out)}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default="FillDrain", - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses fewer ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestDynamoTest(unittest.TestCase): - def test_forward(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/examples/gspmd/pippy_gspmd.py b/examples/gspmd/pippy_gspmd.py deleted file mode 100644 index 48478c4f0..000000000 --- a/examples/gspmd/pippy_gspmd.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import unittest - -import torch -import torch.autograd.profiler_legacy - -import pippy -import pippy.fx -from pippy import run_pippy -from pippy.IR import pipe_split - - -pippy.fx.Tracer.proxy_buffer_attributes = True - -d_hid = 512 -bs = 500 - -class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - pipe_split() - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - pipe_split() - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - pipe_split() - x = self.lin(x) - x = torch.relu(x) - return {"out": x} - - -def run_gspmd(pp_ranks, args): - # Make sure all ranks have the same seed - torch.manual_seed(5) - ec = ExampleCode() - ec_input = torch.randn(bs, d_hid, device=args.device) - chunks = 5 - - pipe_driver, stage_mod = pippy.all_compile( - ec, - args.world_size, - chunks, - schedule=args.schedule, - _debug_mask_minibatches=True, # For numeric check only - ) - print( - f"Rank {args.rank}: {stage_mod}" - ) - - # PiPPy run - if pipe_driver: - out = pipe_driver(ec_input) - - # Reference run - ec.to(args.device) - ref_out = ec(ec_input) - - # Numeric check - if pipe_driver: - torch.testing.assert_close(out["out"], ref_out["out"]) - print( - f'equivalence test passed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - - print(f"Rank {args.rank} completed") - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default="FillDrain", - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - args.gspmd = 1 - run_pippy(run_gspmd, args) - - -if __name__ == "__main__": - main() - - -class LocalTestGspmdTest(unittest.TestCase): - def test_gspmd(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/examples/hf/language-modeling/README.md b/examples/hf/language-modeling/README.md deleted file mode 100644 index 035a6dd69..000000000 --- a/examples/hf/language-modeling/README.md +++ /dev/null @@ -1,188 +0,0 @@ - - -## Language model training - -Fine-tuning (or training from scratch) the library models for language modeling on a text dataset for GPT, GPT-2, -ALBERT, BERT, DistilBERT, RoBERTa, XLNet... GPT and GPT-2 are trained or fine-tuned using a causal language modeling -(CLM) loss while ALBERT, BERT, DistilBERT and RoBERTa are trained or fine-tuned using a masked language modeling (MLM) -loss. XLNet uses permutation language modeling (PLM), you can find more information about the differences between those -objectives in our [model summary](https://huggingface.co/transformers/model_summary.html). - -There are two sets of scripts provided. The first set leverages the Trainer API. The second set with `no_trainer` in the suffix uses a custom training loop and leverages the 🤗 Accelerate library . Both sets use the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. - -**Note:** The old script `run_language_modeling.py` is still available [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/run_language_modeling.py). - -The following examples, will run on datasets hosted on our [hub](https://huggingface.co/datasets) or with your own -text files for training and validation. We give examples of both below. - -### GPT-2/GPT and causal language modeling - -The following example fine-tunes GPT-2 on WikiText-2. We're using the raw WikiText-2 (no tokens were replaced before -the tokenization). The loss here is that of causal language modeling. - -```bash -python run_clm.py \ - --model_name_or_path gpt2 \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --per_device_train_batch_size 8 \ - --per_device_eval_batch_size 8 \ - --do_train \ - --do_eval \ - --output_dir /tmp/test-clm -``` - -This takes about half an hour to train on a single K80 GPU and about one minute for the evaluation to run. It reaches -a score of ~20 perplexity once fine-tuned on the dataset. - -To run on your own training and validation files, use the following command: - -```bash -python run_clm.py \ - --model_name_or_path gpt2 \ - --train_file path_to_train_file \ - --validation_file path_to_validation_file \ - --per_device_train_batch_size 8 \ - --per_device_eval_batch_size 8 \ - --do_train \ - --do_eval \ - --output_dir /tmp/test-clm -``` - -This uses the built in HuggingFace `Trainer` for training. If you want to use a custom training loop, you can utilize or adapt the `run_clm_no_trainer.py` script. Take a look at the script for a list of supported arguments. An example is shown below: - -```bash -python run_clm_no_trainer.py \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --model_name_or_path gpt2 \ - --output_dir /tmp/test-clm -``` - -### RoBERTa/BERT/DistilBERT and masked language modeling - -The following example fine-tunes RoBERTa on WikiText-2. Here too, we're using the raw WikiText-2. The loss is different -as BERT/RoBERTa have a bidirectional mechanism; we're therefore using the same loss that was used during their -pre-training: masked language modeling. - -In accordance to the RoBERTa paper, we use dynamic masking rather than static masking. The model may, therefore, -converge slightly slower (over-fitting takes more epochs). - -```bash -python run_mlm.py \ - --model_name_or_path roberta-base \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --per_device_train_batch_size 8 \ - --per_device_eval_batch_size 8 \ - --do_train \ - --do_eval \ - --output_dir /tmp/test-mlm -``` - -To run on your own training and validation files, use the following command: - -```bash -python run_mlm.py \ - --model_name_or_path roberta-base \ - --train_file path_to_train_file \ - --validation_file path_to_validation_file \ - --per_device_train_batch_size 8 \ - --per_device_eval_batch_size 8 \ - --do_train \ - --do_eval \ - --output_dir /tmp/test-mlm -``` - -If your dataset is organized with one sample per line, you can use the `--line_by_line` flag (otherwise the script -concatenates all texts and then splits them in blocks of the same length). - -This uses the built in HuggingFace `Trainer` for training. If you want to use a custom training loop, you can utilize or adapt the `run_mlm_no_trainer.py` script. Take a look at the script for a list of supported arguments. An example is shown below: - -```bash -python run_mlm_no_trainer.py \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --model_name_or_path roberta-base \ - --output_dir /tmp/test-mlm -``` - -**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make -sure all your batches have the same length. - -### Whole word masking - -This part was moved to `examples/research_projects/mlm_wwm`. - -### XLNet and permutation language modeling - -XLNet uses a different training objective, which is permutation language modeling. It is an autoregressive method -to learn bidirectional contexts by maximizing the expected likelihood over all permutations of the input -sequence factorization order. - -We use the `--plm_probability` flag to define the ratio of length of a span of masked tokens to surrounding -context length for permutation language modeling. - -The `--max_span_length` flag may also be used to limit the length of a span of masked tokens used -for permutation language modeling. - -Here is how to fine-tune XLNet on wikitext-2: - -```bash -python run_plm.py \ - --model_name_or_path=xlnet-base-cased \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --per_device_train_batch_size 8 \ - --per_device_eval_batch_size 8 \ - --do_train \ - --do_eval \ - --output_dir /tmp/test-plm -``` - -To fine-tune it on your own training and validation file, run: - -```bash -python run_plm.py \ - --model_name_or_path=xlnet-base-cased \ - --train_file path_to_train_file \ - --validation_file path_to_validation_file \ - --per_device_train_batch_size 8 \ - --per_device_eval_batch_size 8 \ - --do_train \ - --do_eval \ - --output_dir /tmp/test-plm -``` - -If your dataset is organized with one sample per line, you can use the `--line_by_line` flag (otherwise the script -concatenates all texts and then splits them in blocks of the same length). - -**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make -sure all your batches have the same length. - - -## Creating a model on the fly - -When training a model from scratch, configuration values may be overridden with the help of `--config_overrides`: - - -```bash -python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="n_embd=1024,n_head=16,n_layer=48,n_positions=102" \ -[...] -``` - -This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`. diff --git a/examples/hf/language-modeling/requirements.txt b/examples/hf/language-modeling/requirements.txt deleted file mode 100644 index bec267b98..000000000 --- a/examples/hf/language-modeling/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -accelerate -torch >= 1.3 -datasets >= 1.8.0 -sentencepiece != 0.1.92 -protobuf diff --git a/examples/hf/language-modeling/run_clm.py b/examples/hf/language-modeling/run_clm.py deleted file mode 100755 index b878c3d28..000000000 --- a/examples/hf/language-modeling/run_clm.py +++ /dev/null @@ -1,633 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=text-generation -""" -# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. - -import logging -import math -import os -import sys -from dataclasses import dataclass, field -from itertools import chain -from typing import Optional - -import datasets -from datasets import load_dataset - -import evaluate -import pippy -import transformers -from transformers import ( - CONFIG_MAPPING, - MODEL_FOR_CAUSAL_LM_MAPPING, - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - HfArgumentParser, - default_data_collator, - is_torch_tpu_available, - set_seed, -) -from transformers.testing_utils import CaptureLogger -from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import check_min_version, send_example_telemetry -from transformers.utils.versions import require_version - -from pippy import run_pippy -from pippy.hf import PiPPyTrainingArguments, PiPPyTrainer, gpt2, PiPPyHFTracer -from pippy.microbatch import TensorChunkSpec, sum_reducer - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") - -logger = logging.getLogger(__name__) - - -MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. - """ - - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." - ) - }, - ) - model_type: Optional[str] = field( - default=None, - metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, - ) - config_overrides: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Override some existing default config settings when a model is trained from scratch. Example: " - "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" - ) - }, - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - use_auth_token: bool = field( - default=False, - metadata={ - "help": ( - "Will use the token generated when running `transformers-cli login` (necessary to use this script " - "with private models)." - ) - }, - ) - - def __post_init__(self): - if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): - raise ValueError( - "--config_overrides can't be used in combination with --config_name or --model_name_or_path" - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) - validation_file: Optional[str] = field( - default=None, - metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ) - }, - ) - - block_size: Optional[int] = field( - default=None, - metadata={ - "help": ( - "Optional input sequence length after tokenization. " - "The training dataset will be truncated in block of this size for training. " - "Default to the model max input length for single sentence inputs (take into account special tokens)." - ) - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) - validation_split_percentage: Optional[int] = field( - default=5, - metadata={ - "help": "The percentage of the train set used as validation set in case there's no validation split" - }, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - keep_linebreaks: bool = field( - default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} - ) - - def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - - # =============================================== PiPPy change start =============================================== - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PiPPyTrainingArguments)) - # ================================================ PiPPy change end ================================================ - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - # =============================================== PiPPy change start =============================================== - run_pippy(run_master, training_args, model_args, data_args) - - -def run_master(pp_ranks, training_args, model_args, data_args): - # ================================================ PiPPy change end ================================================ - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_clm", model_args, data_args) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Training/evaluation parameters {training_args}") - - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) - - # Set seed before initializing model. - set_seed(training_args.seed) - - # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called - # 'text' is found. You can easily tweak this behavior (see below). - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - raw_datasets["train"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - data_files = {} - dataset_args = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = ( - data_args.train_file.split(".")[-1] - if data_args.train_file is not None - else data_args.validation_file.split(".")[-1] - ) - if extension == "txt": - extension = "text" - dataset_args["keep_linebreaks"] = data_args.keep_linebreaks - raw_datasets = load_dataset( - extension, - data_files=data_files, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - **dataset_args, - ) - # If no validation data is there, validation_split_percentage will be used to divide the dataset. - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - extension, - data_files=data_files, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - **dataset_args, - ) - raw_datasets["train"] = load_dataset( - extension, - data_files=data_files, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - **dataset_args, - ) - - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model and tokenizer - # - # Distributed training: - # The .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - - config_kwargs = { - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - if model_args.config_name: - config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) - elif model_args.model_name_or_path: - config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) - else: - config = CONFIG_MAPPING[model_args.model_type]() - logger.warning("You are instantiating a new config instance from scratch.") - if model_args.config_overrides is not None: - logger.info(f"Overriding config: {model_args.config_overrides}") - config.update_from_string(model_args.config_overrides) - logger.info(f"New config: {config}") - - tokenizer_kwargs = { - "cache_dir": model_args.cache_dir, - "use_fast": model_args.use_fast_tokenizer, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - if model_args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) - elif model_args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - - if model_args.model_name_or_path: - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - model = AutoModelForCausalLM.from_config(config) - n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) - logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") - - model.resize_token_embeddings(len(tokenizer)) - - model.to(training_args.device) - - training_args.label_names = ['labels'] # https://github.com/huggingface/transformers/blob/c8b6ae858d61e5bc10e388d095aa74f7690d1021/src/transformers/trainer.py#L629-L630 - - # =============================================== PiPPy change start =============================================== - # Setting model to training mode so that PiPPy would automatically look for "loss" and generate backward pass - model.train() - - logger.info("[PiPPy] Splitting model ...") - num_ranks=len(pp_ranks) - gpt2.split(model, num_ranks) - - # Prepare concrete args in case FX tracing needs them - inputs = ['input_ids', 'labels', 'attention_mask'] - concrete_args = pippy.create_default_args(model, - except_keys=inputs) - output_chunk_spec = {'loss': sum_reducer, 'logits': TensorChunkSpec(0), - 'past_key_values': [[TensorChunkSpec(0) for _ in range(2)] for _ in - range(model.config.n_layer)]} - - # Compile into pipeline parallel, distributed model - pipe_mod = pippy.compile( - model, - num_ranks, - num_chunks=training_args.chunks or num_ranks, - ranks=pp_ranks, - tracer=PiPPyHFTracer(), - output_chunk_spec=output_chunk_spec, - checkpoint=bool(training_args.checkpoint), - concrete_args=concrete_args, - ) - pipe_mod.init_data_parallel(dp_group_size=training_args.dp_group_size) - pipe_mod.config = model.config - model = pipe_mod - - # ================================================ PiPPy change end ================================================ - - # Preprocessing the datasets. - # First we tokenize all the texts. - if training_args.do_train: - column_names = raw_datasets["train"].column_names - else: - column_names = raw_datasets["validation"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - - # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function - tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") - - def tokenize_function(examples): - with CaptureLogger(tok_logger) as cl: - output = tokenizer(examples[text_column_name]) - # clm input could be much much longer than block_size - if "Token indices sequence length is longer than the" in cl.out: - tok_logger.warning( - "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" - " before being passed to the model." - ) - return output - - with training_args.main_process_first(desc="dataset map tokenization"): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on dataset", - ) - - if data_args.block_size is None: - block_size = tokenizer.model_max_length - if block_size > 1024: - logger.warning( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --block_size xxx." - ) - block_size = 1024 - else: - if data_args.block_size > tokenizer.model_max_length: - logger.warning( - f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" - f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." - ) - block_size = min(data_args.block_size, tokenizer.model_max_length) - - # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - if total_length >= block_size: - total_length = (total_length // block_size) * block_size - # Split by chunks of max_len. - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - result["labels"] = result["input_ids"].copy() - return result - - # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder - # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower - # to preprocess. - # - # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - - with training_args.main_process_first(desc="grouping texts together"): - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - desc=f"Grouping texts in chunks of {block_size}", - ) - - if training_args.do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_dataset = lm_datasets["train"] - if data_args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) - - if training_args.do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = lm_datasets["validation"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) - - def preprocess_logits_for_metrics(logits, labels): - if isinstance(logits, tuple): - # Depending on the model and config, logits may contain extra tensors, - # like past_key_values, but logits always come first - logits = logits[0] - return logits.argmax(dim=-1) - - metric = evaluate.load("accuracy") - - def compute_metrics(eval_preds): - preds, labels = eval_preds - # preds have the same shape as the labels, after the argmax(-1) has been calculated - # by preprocess_logits_for_metrics but we need to shift the labels - labels = labels[:, 1:].reshape(-1) - preds = preds[:, :-1].reshape(-1) - return metric.compute(predictions=preds, references=labels) - - # Initialize our Trainer - # =============================================== PiPPy change start =============================================== - trainer = PiPPyTrainer( - # ================================================ PiPPy change end ================================================ - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - tokenizer=tokenizer, - # Data collator will default to DataCollatorWithPadding, so we change it. - data_collator=default_data_collator, - compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics - if training_args.do_eval and not is_torch_tpu_available() - else None, - ) - # =============================================== PiPPy change start =============================================== - trainer._signature_columns = inputs - # ================================================ PiPPy change end ================================================ - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - # TODO: overwrite save_model method so that it does not pickle process group - #trainer.save_model() # Saves the tokenizer too for easy upload - - metrics = train_result.metrics - - max_train_samples = ( - data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) - ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - # Evaluation - if training_args.do_eval: - logger.info("*** Evaluate ***") - - metrics = trainer.evaluate() - - max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["perplexity"] = perplexity - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} - if data_args.dataset_name is not None: - kwargs["dataset_tags"] = data_args.dataset_name - if data_args.dataset_config_name is not None: - kwargs["dataset_args"] = data_args.dataset_config_name - kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" - else: - kwargs["dataset"] = data_args.dataset_name - - if training_args.push_to_hub: - trainer.push_to_hub(**kwargs) - else: - trainer.create_model_card(**kwargs) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/examples/hf/language-modeling/run_clm_no_trainer.py b/examples/hf/language-modeling/run_clm_no_trainer.py deleted file mode 100755 index 183ac24a0..000000000 --- a/examples/hf/language-modeling/run_clm_no_trainer.py +++ /dev/null @@ -1,658 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) -on a text file or a dataset without using HuggingFace Trainer. - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=text-generation -""" -# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. - -import argparse -import json -import logging -import math -import os -import random -from itertools import chain -from pathlib import Path - -import datasets -import torch -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm - -import transformers -from accelerate import Accelerator, DistributedType -from accelerate.logging import get_logger -from accelerate.utils import set_seed -from huggingface_hub import Repository -from transformers import ( - CONFIG_MAPPING, - MODEL_MAPPING, - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - SchedulerType, - default_data_collator, - get_scheduler, -) -from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry -from transformers.utils.versions import require_version - - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -logger = get_logger(__name__) - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") - -MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -def parse_args(): - parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help="The name of the dataset to use (via the datasets library).", - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The configuration name of the dataset to use (via the datasets library).", - ) - parser.add_argument( - "--train_file", type=str, default=None, help="A csv or a json file containing the training data." - ) - parser.add_argument( - "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." - ) - parser.add_argument( - "--validation_split_percentage", - default=5, - help="The percentage of the train set used as validation set in case there's no validation split", - ) - parser.add_argument( - "--model_name_or_path", - type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=False, - ) - parser.add_argument( - "--config_name", - type=str, - default=None, - help="Pretrained config name or path if not the same as model_name", - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--use_slow_tokenizer", - action="store_true", - help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", - ) - parser.add_argument( - "--per_device_train_batch_size", - type=int, - default=8, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--per_device_eval_batch_size", - type=int, - default=8, - help="Batch size (per device) for the evaluation dataloader.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--lr_scheduler_type", - type=SchedulerType, - default="linear", - help="The scheduler type to use.", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], - ) - parser.add_argument( - "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--model_type", - type=str, - default=None, - help="Model type to use if training from scratch.", - choices=MODEL_TYPES, - ) - parser.add_argument( - "--block_size", - type=int, - default=None, - help=( - "Optional input sequence length after tokenization. The training dataset will be truncated in block of" - " this size for training. Default to the model max input length for single sentence inputs (take into" - " account special tokens)." - ), - ) - parser.add_argument( - "--preprocessing_num_workers", - type=int, - default=None, - help="The number of processes to use for the preprocessing.", - ) - parser.add_argument( - "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" - ) - parser.add_argument( - "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." - ) - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument( - "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." - ) - parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--checkpointing_steps", - type=str, - default=None, - help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help="If the training should continue from a checkpoint folder.", - ) - parser.add_argument( - "--with_tracking", - action="store_true", - help="Whether to enable experiment trackers for logging.", - ) - parser.add_argument( - "--report_to", - type=str, - default="all", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' - "Only applicable when `--with_tracking` is passed." - ), - ) - args = parser.parse_args() - - # Sanity checks - if args.dataset_name is None and args.train_file is None and args.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - else: - if args.train_file is not None: - extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." - if args.validation_file is not None: - extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." - - if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." - - return args - - -def main(): - args = parse_args() - - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_clm_no_trainer", args) - - # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. - # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers - # in the environment - accelerator = ( - Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() - ) - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - accelerator.wait_for_everyone() - - # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called - # 'text' is found. You can easily tweak this behavior (see below). - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - args.dataset_name, - args.dataset_config_name, - split=f"train[:{args.validation_split_percentage}%]", - ) - raw_datasets["train"] = load_dataset( - args.dataset_name, - args.dataset_config_name, - split=f"train[{args.validation_split_percentage}%:]", - ) - else: - data_files = {} - dataset_args = {} - if args.train_file is not None: - data_files["train"] = args.train_file - if args.validation_file is not None: - data_files["validation"] = args.validation_file - extension = args.train_file.split(".")[-1] - if extension == "txt": - extension = "text" - dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks - raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) - # If no validation data is there, validation_split_percentage will be used to divide the dataset. - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - extension, - data_files=data_files, - split=f"train[:{args.validation_split_percentage}%]", - **dataset_args, - ) - raw_datasets["train"] = load_dataset( - extension, - data_files=data_files, - split=f"train[{args.validation_split_percentage}%:]", - **dataset_args, - ) - - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model and tokenizer - # - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - if args.config_name: - config = AutoConfig.from_pretrained(args.config_name) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained(args.model_name_or_path) - else: - config = CONFIG_MAPPING[args.model_type]() - logger.warning("You are instantiating a new config instance from scratch.") - - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) - elif args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - - if args.model_name_or_path: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - ) - else: - logger.info("Training new model from scratch") - model = AutoModelForCausalLM.from_config(config) - - model.resize_token_embeddings(len(tokenizer)) - - # Preprocessing the datasets. - # First we tokenize all the texts. - column_names = raw_datasets["train"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - - def tokenize_function(examples): - return tokenizer(examples[text_column_name]) - - with accelerator.main_process_first(): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not args.overwrite_cache, - desc="Running tokenizer on dataset", - ) - - if args.block_size is None: - block_size = tokenizer.model_max_length - if block_size > 1024: - logger.warning( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --block_size xxx." - ) - block_size = 1024 - else: - if args.block_size > tokenizer.model_max_length: - logger.warning( - f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" - f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." - ) - block_size = min(args.block_size, tokenizer.model_max_length) - - # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - if total_length >= block_size: - total_length = (total_length // block_size) * block_size - # Split by chunks of max_len. - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - result["labels"] = result["input_ids"].copy() - return result - - # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder - # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower - # to preprocess. - # - # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - - with accelerator.main_process_first(): - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=args.preprocessing_num_workers, - load_from_cache_file=not args.overwrite_cache, - desc=f"Grouping texts in chunks of {block_size}", - ) - - train_dataset = lm_datasets["train"] - eval_dataset = lm_datasets["validation"] - - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # DataLoaders creation: - train_dataloader = DataLoader( - train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size - ) - eval_dataloader = DataLoader( - eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size - ) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) - - # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. - if accelerator.distributed_type == DistributedType.TPU: - model.tie_weights() - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.max_train_steps, - ) - - # Prepare everything with our `accelerator`. - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # Figure out how many steps we should save the Accelerator states - if hasattr(args.checkpointing_steps, "isdigit"): - checkpointing_steps = args.checkpointing_steps - if args.checkpointing_steps.isdigit(): - checkpointing_steps = int(args.checkpointing_steps) - else: - checkpointing_steps = None - - # We need to initialize the trackers we use, and also store our configuration. - # We initialize the trackers only on main process because `accelerator.log` - # only logs on main process and we don't want empty logs/runs on other processes. - if args.with_tracking: - if accelerator.is_main_process: - experiment_config = vars(args) - # TensorBoard cannot log Enums, need the raw value - experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value - accelerator.init_trackers("clm_no_trainer", experiment_config) - - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - completed_steps = 0 - starting_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - accelerator.load_state(args.resume_from_checkpoint) - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] - dirs.sort(key=os.path.getctime) - path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last - # Extract `epoch_{i}` or `step_{i}` - training_difference = os.path.splitext(path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - else: - resume_step = int(training_difference.replace("step_", "")) - starting_epoch = resume_step // len(train_dataloader) - resume_step -= starting_epoch * len(train_dataloader) - - for epoch in range(starting_epoch, args.num_train_epochs): - model.train() - if args.with_tracking: - total_loss = 0 - for step, batch in enumerate(train_dataloader): - # We need to skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == starting_epoch: - if resume_step is not None and step < resume_step: - completed_steps += 1 - continue - outputs = model(**batch) - loss = outputs.loss - # We keep track of the loss at each epoch - if args.with_tracking: - total_loss += loss.detach().float() - loss = loss / args.gradient_accumulation_steps - accelerator.backward(loss) - if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - progress_bar.update(1) - completed_steps += 1 - - if isinstance(checkpointing_steps, int): - if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps }" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - if completed_steps >= args.max_train_steps: - break - - model.eval() - losses = [] - for step, batch in enumerate(eval_dataloader): - with torch.no_grad(): - outputs = model(**batch) - - loss = outputs.loss - losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) - - losses = torch.cat(losses) - losses = losses[: len(eval_dataset)] - try: - eval_loss = torch.mean(losses) - perplexity = math.exp(eval_loss) - except OverflowError: - perplexity = float("inf") - - logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") - - if args.with_tracking: - accelerator.log( - { - "perplexity": perplexity, - "eval_loss": eval_loss, - "train_loss": total_loss.item() / len(train_dataloader), - "epoch": epoch, - "step": completed_steps, - }, - step=completed_steps, - ) - - if args.push_to_hub and epoch < args.num_train_epochs - 1: - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - repo.push_to_hub( - commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True - ) - - if args.checkpointing_steps == "epoch": - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - - if args.output_dir is not None: - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) - - with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: - json.dump({"perplexity": perplexity}, f) - - -if __name__ == "__main__": - main() diff --git a/examples/hf/language-modeling/run_mlm.py b/examples/hf/language-modeling/run_mlm.py deleted file mode 100755 index ad34fe8d8..000000000 --- a/examples/hf/language-modeling/run_mlm.py +++ /dev/null @@ -1,661 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2020 The HuggingFace Team All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset. - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=fill-mask -""" -# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. - -import logging -import math -import os -import sys -from dataclasses import dataclass, field -from itertools import chain -from typing import Optional - -import datasets -from datasets import load_dataset - -import evaluate -import pippy -import transformers -from transformers import ( - CONFIG_MAPPING, - MODEL_FOR_MASKED_LM_MAPPING, - AutoConfig, - AutoModelForMaskedLM, - AutoTokenizer, - DataCollatorForLanguageModeling, - HfArgumentParser, - is_torch_tpu_available, - set_seed, -) -from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import check_min_version, send_example_telemetry -from transformers.utils.versions import require_version - -from pippy import run_pippy -from pippy.hf import PiPPyTrainingArguments, PiPPyTrainer, roberta, PiPPyHFTracer - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") - -logger = logging.getLogger(__name__) -MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. - """ - - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." - ) - }, - ) - model_type: Optional[str] = field( - default=None, - metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, - ) - config_overrides: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Override some existing default config settings when a model is trained from scratch. Example: " - "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" - ) - }, - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - use_auth_token: bool = field( - default=False, - metadata={ - "help": ( - "Will use the token generated when running `transformers-cli login` (necessary to use this script " - "with private models)." - ) - }, - ) - - def __post_init__(self): - if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): - raise ValueError( - "--config_overrides can't be used in combination with --config_name or --model_name_or_path" - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) - validation_file: Optional[str] = field( - default=None, - metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) - validation_split_percentage: Optional[int] = field( - default=5, - metadata={ - "help": "The percentage of the train set used as validation set in case there's no validation split" - }, - ) - max_seq_length: Optional[int] = field( - default=None, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated." - ) - }, - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - mlm_probability: float = field( - default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} - ) - line_by_line: bool = field( - default=False, - metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, - ) - pad_to_max_length: bool = field( - default=False, - metadata={ - "help": ( - "Whether to pad all samples to `max_seq_length`. " - "If False, will pad the samples dynamically when batching to the maximum length in the batch." - ) - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ) - }, - ) - - def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - if extension not in ["csv", "json", "txt"]: - raise ValueError("`train_file` should be a csv, a json or a txt file.") - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - if extension not in ["csv", "json", "txt"]: - raise ValueError("`validation_file` should be a csv, a json or a txt file.") - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - - # =============================================== PiPPy change start =============================================== - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PiPPyTrainingArguments)) - # ================================================ PiPPy change end ================================================ - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - # =============================================== PiPPy change start =============================================== - run_pippy(run_master, training_args, model_args, data_args) - - -def run_master(pp_ranks, training_args, model_args, data_args): - # ================================================ PiPPy change end ================================================ - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_mlm", model_args, data_args) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - # Set the verbosity to info of the Transformers logger (on main process only): - logger.info(f"Training/evaluation parameters {training_args}") - - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) - - # Set seed before initializing model. - set_seed(training_args.seed) - - # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub - # - # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this - # behavior (see below) - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - raw_datasets["train"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - data_files = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - extension = data_args.train_file.split(".")[-1] - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = data_args.validation_file.split(".")[-1] - if extension == "txt": - extension = "text" - raw_datasets = load_dataset( - extension, - data_files=data_files, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - - # If no validation data is there, validation_split_percentage will be used to divide the dataset. - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - extension, - data_files=data_files, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - raw_datasets["train"] = load_dataset( - extension, - data_files=data_files, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model and tokenizer - # - # Distributed training: - # The .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - config_kwargs = { - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - if model_args.config_name: - config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) - elif model_args.model_name_or_path: - config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) - else: - config = CONFIG_MAPPING[model_args.model_type]() - logger.warning("You are instantiating a new config instance from scratch.") - if model_args.config_overrides is not None: - logger.info(f"Overriding config: {model_args.config_overrides}") - config.update_from_string(model_args.config_overrides) - logger.info(f"New config: {config}") - - tokenizer_kwargs = { - "cache_dir": model_args.cache_dir, - "use_fast": model_args.use_fast_tokenizer, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - } - if model_args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) - elif model_args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - - if model_args.model_name_or_path: - model = AutoModelForMaskedLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - logger.info("Training new model from scratch") - model = AutoModelForMaskedLM.from_config(config) - - model.resize_token_embeddings(len(tokenizer)) - - model.to(training_args.device) - - training_args.label_names = ['labels'] # https://github.com/huggingface/transformers/blob/c8b6ae858d61e5bc10e388d095aa74f7690d1021/src/transformers/trainer.py#L629-L630 - - # =============================================== PiPPy change start =============================================== - # Setting model to training mode so that PiPPy would automatically look for "loss" and generate backward pass - model.train() - - logger.info("[PiPPy] Splitting model ...") - num_ranks=len(pp_ranks) - roberta.split(model, num_ranks) - - # Prepare concrete args in case FX tracing needs them - inputs = ['input_ids', 'labels', 'attention_mask'] - concrete_args = pippy.create_default_args(model, - except_keys=inputs) - - # Compile into pipeline parallel, distributed model - pipe_mod = pippy.compile( - model, - num_ranks, - num_chunks=training_args.chunks or num_ranks, - ranks=pp_ranks, - tracer=PiPPyHFTracer(), - checkpoint=bool(training_args.checkpoint), - concrete_args=concrete_args, - ) - pipe_mod.init_data_parallel(dp_group_size=training_args.dp_group_size) - pipe_mod.config = model.config - model = pipe_mod - - # ================================================ PiPPy change end ================================================ - - # Preprocessing the datasets. - # First we tokenize all the texts. - if training_args.do_train: - column_names = raw_datasets["train"].column_names - else: - column_names = raw_datasets["validation"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - - if data_args.max_seq_length is None: - max_seq_length = tokenizer.model_max_length - if max_seq_length > 1024: - logger.warning( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." - ) - max_seq_length = 1024 - else: - if data_args.max_seq_length > tokenizer.model_max_length: - logger.warning( - f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - - if data_args.line_by_line: - # When using line_by_line, we just tokenize each nonempty line. - padding = "max_length" if data_args.pad_to_max_length else False - - def tokenize_function(examples): - # Remove empty lines - examples[text_column_name] = [ - line for line in examples[text_column_name] if len(line) > 0 and not line.isspace() - ] - return tokenizer( - examples[text_column_name], - padding=padding, - truncation=True, - max_length=max_seq_length, - # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it - # receives the `special_tokens_mask`. - return_special_tokens_mask=True, - ) - - with training_args.main_process_first(desc="dataset map tokenization"): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=[text_column_name], - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on dataset line_by_line", - ) - else: - # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. - # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more - # efficient when it receives the `special_tokens_mask`. - def tokenize_function(examples): - return tokenizer(examples[text_column_name], return_special_tokens_mask=True) - - with training_args.main_process_first(desc="dataset map tokenization"): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on every text in dataset", - ) - - # Main data processing function that will concatenate all texts from our dataset and generate chunks of - # max_seq_length. - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length - # Split by chunks of max_len. - result = { - k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] - for k, t in concatenated_examples.items() - } - return result - - # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a - # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value - # might be slower to preprocess. - # - # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - - with training_args.main_process_first(desc="grouping texts together"): - tokenized_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - desc=f"Grouping texts in chunks of {max_seq_length}", - ) - - if training_args.do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_dataset = tokenized_datasets["train"] - if data_args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) - - if training_args.do_eval: - if "validation" not in tokenized_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = tokenized_datasets["validation"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) - - def preprocess_logits_for_metrics(logits, labels): - if isinstance(logits, tuple): - # Depending on the model and config, logits may contain extra tensors, - # like past_key_values, but logits always come first - logits = logits[0] - return logits.argmax(dim=-1) - - metric = evaluate.load("accuracy") - - def compute_metrics(eval_preds): - preds, labels = eval_preds - # preds have the same shape as the labels, after the argmax(-1) has been calculated - # by preprocess_logits_for_metrics - labels = labels.reshape(-1) - preds = preds.reshape(-1) - mask = labels != -100 - labels = labels[mask] - preds = preds[mask] - return metric.compute(predictions=preds, references=labels) - - # Data collator - # This one will take care of randomly masking the tokens. - pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length - data_collator = DataCollatorForLanguageModeling( - tokenizer=tokenizer, - mlm_probability=data_args.mlm_probability, - pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, - ) - - # Initialize our Trainer - # =============================================== PiPPy change start =============================================== - trainer = PiPPyTrainer( - # ================================================ PiPPy change end ================================================ - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - tokenizer=tokenizer, - data_collator=data_collator, - compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics - if training_args.do_eval and not is_torch_tpu_available() - else None, - ) - # =============================================== PiPPy change start =============================================== - trainer._signature_columns = inputs - # ================================================ PiPPy change end ================================================ - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - # TODO: overwrite save_model method so that it does not pickle process group - #trainer.save_model() # Saves the tokenizer too for easy upload - metrics = train_result.metrics - - max_train_samples = ( - data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) - ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - # Evaluation - if training_args.do_eval: - logger.info("*** Evaluate ***") - - metrics = trainer.evaluate() - - max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["perplexity"] = perplexity - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"} - if data_args.dataset_name is not None: - kwargs["dataset_tags"] = data_args.dataset_name - if data_args.dataset_config_name is not None: - kwargs["dataset_args"] = data_args.dataset_config_name - kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" - else: - kwargs["dataset"] = data_args.dataset_name - - if training_args.push_to_hub: - trainer.push_to_hub(**kwargs) - else: - trainer.create_model_card(**kwargs) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/examples/hf/language-modeling/run_mlm_no_trainer.py b/examples/hf/language-modeling/run_mlm_no_trainer.py deleted file mode 100755 index 07299d47e..000000000 --- a/examples/hf/language-modeling/run_mlm_no_trainer.py +++ /dev/null @@ -1,703 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) -on a text file or a dataset without using HuggingFace Trainer. - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=fill-mask -""" -# You can also adapt this script on your own mlm task. Pointers for this are left as comments. - -import argparse -import json -import logging -import math -import os -import random -from itertools import chain -from pathlib import Path - -import datasets -import torch -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm - -import transformers -from accelerate import Accelerator, DistributedType -from accelerate.logging import get_logger -from accelerate.utils import set_seed -from huggingface_hub import Repository -from transformers import ( - CONFIG_MAPPING, - MODEL_MAPPING, - AutoConfig, - AutoModelForMaskedLM, - AutoTokenizer, - DataCollatorForLanguageModeling, - SchedulerType, - get_scheduler, -) -from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry -from transformers.utils.versions import require_version - - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -logger = get_logger(__name__) -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") -MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -def parse_args(): - parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task") - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help="The name of the dataset to use (via the datasets library).", - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The configuration name of the dataset to use (via the datasets library).", - ) - parser.add_argument( - "--train_file", type=str, default=None, help="A csv or a json file containing the training data." - ) - parser.add_argument( - "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." - ) - parser.add_argument( - "--validation_split_percentage", - default=5, - help="The percentage of the train set used as validation set in case there's no validation split", - ) - parser.add_argument( - "--pad_to_max_length", - action="store_true", - help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", - ) - parser.add_argument( - "--model_name_or_path", - type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=False, - ) - parser.add_argument( - "--config_name", - type=str, - default=None, - help="Pretrained config name or path if not the same as model_name", - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--use_slow_tokenizer", - action="store_true", - help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", - ) - parser.add_argument( - "--per_device_train_batch_size", - type=int, - default=8, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--per_device_eval_batch_size", - type=int, - default=8, - help="Batch size (per device) for the evaluation dataloader.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--lr_scheduler_type", - type=SchedulerType, - default="linear", - help="The scheduler type to use.", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], - ) - parser.add_argument( - "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--model_type", - type=str, - default=None, - help="Model type to use if training from scratch.", - choices=MODEL_TYPES, - ) - parser.add_argument( - "--max_seq_length", - type=int, - default=None, - help=( - "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated." - ), - ) - parser.add_argument( - "--line_by_line", - type=bool, - default=False, - help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", - ) - parser.add_argument( - "--preprocessing_num_workers", - type=int, - default=None, - help="The number of processes to use for the preprocessing.", - ) - parser.add_argument( - "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" - ) - parser.add_argument( - "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" - ) - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument( - "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." - ) - parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--checkpointing_steps", - type=str, - default=None, - help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help="If the training should continue from a checkpoint folder.", - ) - parser.add_argument( - "--with_tracking", - action="store_true", - help="Whether to enable experiment trackers for logging.", - ) - parser.add_argument( - "--report_to", - type=str, - default="all", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' - "Only applicable when `--with_tracking` is passed." - ), - ) - args = parser.parse_args() - - # Sanity checks - if args.dataset_name is None and args.train_file is None and args.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - else: - if args.train_file is not None: - extension = args.train_file.split(".")[-1] - if extension not in ["csv", "json", "txt"]: - raise ValueError("`train_file` should be a csv, json or txt file.") - if args.validation_file is not None: - extension = args.validation_file.split(".")[-1] - if extension not in ["csv", "json", "txt"]: - raise ValueError("`validation_file` should be a csv, json or txt file.") - - if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." - - return args - - -def main(): - args = parse_args() - - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_mlm_no_trainer", args) - - # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. - # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers - # in the environment - accelerator = ( - Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() - ) - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - accelerator.wait_for_everyone() - - # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called - # 'text' is found. You can easily tweak this behavior (see below). - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - args.dataset_name, - args.dataset_config_name, - split=f"train[:{args.validation_split_percentage}%]", - ) - raw_datasets["train"] = load_dataset( - args.dataset_name, - args.dataset_config_name, - split=f"train[{args.validation_split_percentage}%:]", - ) - else: - data_files = {} - if args.train_file is not None: - data_files["train"] = args.train_file - if args.validation_file is not None: - data_files["validation"] = args.validation_file - extension = args.train_file.split(".")[-1] - if extension == "txt": - extension = "text" - raw_datasets = load_dataset(extension, data_files=data_files) - # If no validation data is there, validation_split_percentage will be used to divide the dataset. - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - extension, - data_files=data_files, - split=f"train[:{args.validation_split_percentage}%]", - ) - raw_datasets["train"] = load_dataset( - extension, - data_files=data_files, - split=f"train[{args.validation_split_percentage}%:]", - ) - - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model and tokenizer - # - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - if args.config_name: - config = AutoConfig.from_pretrained(args.config_name) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained(args.model_name_or_path) - else: - config = CONFIG_MAPPING[args.model_type]() - logger.warning("You are instantiating a new config instance from scratch.") - - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) - elif args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - - if args.model_name_or_path: - model = AutoModelForMaskedLM.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - ) - else: - logger.info("Training new model from scratch") - model = AutoModelForMaskedLM.from_config(config) - - model.resize_token_embeddings(len(tokenizer)) - - # Preprocessing the datasets. - # First we tokenize all the texts. - column_names = raw_datasets["train"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - - if args.max_seq_length is None: - max_seq_length = tokenizer.model_max_length - if max_seq_length > 1024: - logger.warning( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." - ) - max_seq_length = 1024 - else: - if args.max_seq_length > tokenizer.model_max_length: - logger.warning( - f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) - - if args.line_by_line: - # When using line_by_line, we just tokenize each nonempty line. - padding = "max_length" if args.pad_to_max_length else False - - def tokenize_function(examples): - # Remove empty lines - examples[text_column_name] = [ - line for line in examples[text_column_name] if len(line) > 0 and not line.isspace() - ] - return tokenizer( - examples[text_column_name], - padding=padding, - truncation=True, - max_length=max_seq_length, - # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it - # receives the `special_tokens_mask`. - return_special_tokens_mask=True, - ) - - with accelerator.main_process_first(): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=args.preprocessing_num_workers, - remove_columns=[text_column_name], - load_from_cache_file=not args.overwrite_cache, - desc="Running tokenizer on dataset line_by_line", - ) - else: - # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. - # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more - # efficient when it receives the `special_tokens_mask`. - def tokenize_function(examples): - return tokenizer(examples[text_column_name], return_special_tokens_mask=True) - - with accelerator.main_process_first(): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not args.overwrite_cache, - desc="Running tokenizer on every text in dataset", - ) - - # Main data processing function that will concatenate all texts from our dataset and generate chunks of - # max_seq_length. - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - if total_length >= max_seq_length: - total_length = (total_length // max_seq_length) * max_seq_length - # Split by chunks of max_len. - result = { - k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] - for k, t in concatenated_examples.items() - } - return result - - # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a - # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value - # might be slower to preprocess. - # - # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - - with accelerator.main_process_first(): - tokenized_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=args.preprocessing_num_workers, - load_from_cache_file=not args.overwrite_cache, - desc=f"Grouping texts in chunks of {max_seq_length}", - ) - - train_dataset = tokenized_datasets["train"] - eval_dataset = tokenized_datasets["validation"] - - # Conditional for small test subsets - if len(train_dataset) > 3: - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # Data collator - # This one will take care of randomly masking the tokens. - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=args.mlm_probability) - - # DataLoaders creation: - train_dataloader = DataLoader( - train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size - ) - eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) - - # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. - if accelerator.distributed_type == DistributedType.TPU: - model.tie_weights() - - # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be - # shorter in multiprocess) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.max_train_steps, - ) - - # Prepare everything with our `accelerator`. - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # Figure out how many steps we should save the Accelerator states - if hasattr(args.checkpointing_steps, "isdigit"): - checkpointing_steps = args.checkpointing_steps - if args.checkpointing_steps.isdigit(): - checkpointing_steps = int(args.checkpointing_steps) - else: - checkpointing_steps = None - - # We need to initialize the trackers we use, and also store our configuration. - # We initialize the trackers only on main process because `accelerator.log` - # only logs on main process and we don't want empty logs/runs on other processes. - if args.with_tracking: - if accelerator.is_main_process: - experiment_config = vars(args) - # TensorBoard cannot log Enums, need the raw value - experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value - accelerator.init_trackers("mlm_no_trainer", experiment_config) - - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - completed_steps = 0 - starting_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - accelerator.load_state(args.resume_from_checkpoint) - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] - dirs.sort(key=os.path.getctime) - path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last - # Extract `epoch_{i}` or `step_{i}` - training_difference = os.path.splitext(path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - else: - resume_step = int(training_difference.replace("step_", "")) - starting_epoch = resume_step // len(train_dataloader) - resume_step -= starting_epoch * len(train_dataloader) - - for epoch in range(starting_epoch, args.num_train_epochs): - model.train() - if args.with_tracking: - total_loss = 0 - for step, batch in enumerate(train_dataloader): - # We need to skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == starting_epoch: - if resume_step is not None and step < resume_step: - completed_steps += 1 - continue - outputs = model(**batch) - loss = outputs.loss - # We keep track of the loss at each epoch - if args.with_tracking: - total_loss += loss.detach().float() - loss = loss / args.gradient_accumulation_steps - accelerator.backward(loss) - if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - progress_bar.update(1) - completed_steps += 1 - - if isinstance(checkpointing_steps, int): - if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps }" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - - if completed_steps >= args.max_train_steps: - break - - model.eval() - losses = [] - for step, batch in enumerate(eval_dataloader): - with torch.no_grad(): - outputs = model(**batch) - - loss = outputs.loss - losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) - - losses = torch.cat(losses) - losses = losses[: len(eval_dataset)] - try: - eval_loss = torch.mean(losses) - perplexity = math.exp(eval_loss) - except OverflowError: - perplexity = float("inf") - - logger.info(f"epoch {epoch}: perplexity: {perplexity}") - - if args.with_tracking: - accelerator.log( - { - "perplexity": perplexity, - "eval_loss": eval_loss, - "train_loss": total_loss.item() / len(train_dataloader), - "epoch": epoch, - "step": completed_steps, - }, - step=completed_steps, - ) - - if args.push_to_hub and epoch < args.num_train_epochs - 1: - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - repo.push_to_hub( - commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True - ) - - if args.checkpointing_steps == "epoch": - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - - if args.output_dir is not None: - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) - - with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: - json.dump({"perplexity": perplexity}, f) - - -if __name__ == "__main__": - main() diff --git a/examples/hf/text-classification/README.md b/examples/hf/text-classification/README.md deleted file mode 100644 index 391aaf4d3..000000000 --- a/examples/hf/text-classification/README.md +++ /dev/null @@ -1,203 +0,0 @@ - - -# Text classification examples - -## GLUE tasks - -Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py). - -Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding -Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) -and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file -(the script might need some tweaks in that case, refer to the comments inside for help). - -GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them: - -```bash -export TASK_NAME=mrpc - -python run_glue.py \ - --model_name_or_path bert-base-cased \ - --task_name $TASK_NAME \ - --do_train \ - --do_eval \ - --max_seq_length 128 \ - --per_device_train_batch_size 32 \ - --learning_rate 2e-5 \ - --num_train_epochs 3 \ - --output_dir /tmp/$TASK_NAME/ -``` - -where task name can be one of cola, sst2, mrpc, stsb, qqp, mnli, qnli, rte, wnli. - -We get the following results on the dev set of the benchmark with the previous commands (with an exception for MRPC and -WNLI which are tiny and where we used 5 epochs instead of 3). Trainings are seeded so you should obtain the same -results with PyTorch 1.6.0 (and close results with different versions), training times are given for information (a -single Titan RTX was used): - -| Task | Metric | Result | Training time | -|-------|------------------------------|-------------|---------------| -| CoLA | Matthews corr | 56.53 | 3:17 | -| SST-2 | Accuracy | 92.32 | 26:06 | -| MRPC | F1/Accuracy | 88.85/84.07 | 2:21 | -| STS-B | Pearson/Spearman corr. | 88.64/88.48 | 2:13 | -| QQP | Accuracy/F1 | 90.71/87.49 | 2:22:26 | -| MNLI | Matched acc./Mismatched acc. | 83.91/84.10 | 2:35:23 | -| QNLI | Accuracy | 90.66 | 40:57 | -| RTE | Accuracy | 65.70 | 57 | -| WNLI | Accuracy | 56.34 | 24 | - -Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the -website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. - -The following example fine-tunes BERT on the `imdb` dataset hosted on our [hub](https://huggingface.co/datasets): - -```bash -python run_glue.py \ - --model_name_or_path bert-base-cased \ - --dataset_name imdb \ - --do_train \ - --do_predict \ - --max_seq_length 128 \ - --per_device_train_batch_size 32 \ - --learning_rate 2e-5 \ - --num_train_epochs 3 \ - --output_dir /tmp/imdb/ -``` - -> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it. - - -### Mixed precision training - -If you have a GPU with mixed precision capabilities (architecture Pascal or more recent), you can use mixed precision -training with PyTorch 1.6.0 or latest, or by installing the [Apex](https://github.com/NVIDIA/apex) library for previous -versions. Just add the flag `--fp16` to your command launching one of the scripts mentioned above! - -Using mixed precision training usually results in 2x-speedup for training with the same final results: - -| Task | Metric | Result | Training time | Result (FP16) | Training time (FP16) | -|-------|------------------------------|-------------|---------------|---------------|----------------------| -| CoLA | Matthews corr | 56.53 | 3:17 | 56.78 | 1:41 | -| SST-2 | Accuracy | 92.32 | 26:06 | 91.74 | 13:11 | -| MRPC | F1/Accuracy | 88.85/84.07 | 2:21 | 88.12/83.58 | 1:10 | -| STS-B | Pearson/Spearman corr. | 88.64/88.48 | 2:13 | 88.71/88.55 | 1:08 | -| QQP | Accuracy/F1 | 90.71/87.49 | 2:22:26 | 90.67/87.43 | 1:11:54 | -| MNLI | Matched acc./Mismatched acc. | 83.91/84.10 | 2:35:23 | 84.04/84.06 | 1:17:06 | -| QNLI | Accuracy | 90.66 | 40:57 | 90.96 | 20:16 | -| RTE | Accuracy | 65.70 | 57 | 65.34 | 29 | -| WNLI | Accuracy | 56.34 | 24 | 56.34 | 12 | - - -## PyTorch version, no Trainer - -Based on the script [`run_glue_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue_no_trainer.py). - -Like `run_glue.py`, this script allows you to fine-tune any of the models on the [hub](https://huggingface.co/models) on a -text classification task, either a GLUE task or your own data in a csv or a JSON file. The main difference is that this -script exposes the bare training loop, to allow you to quickly experiment and add any customization you would like. - -It offers less options than the script with `Trainer` (for instance you can easily change the options for the optimizer -or the dataloaders directly in the script) but still run in a distributed setup, on TPU and supports mixed precision by -the mean of the [🤗 `Accelerate`](https://github.com/huggingface/accelerate) library. You can use the script normally -after installing it: - -```bash -pip install git+https://github.com/huggingface/accelerate -``` - -then - -```bash -export TASK_NAME=mrpc - -python run_glue_no_trainer.py \ - --model_name_or_path bert-base-cased \ - --task_name $TASK_NAME \ - --max_length 128 \ - --per_device_train_batch_size 32 \ - --learning_rate 2e-5 \ - --num_train_epochs 3 \ - --output_dir /tmp/$TASK_NAME/ -``` - -You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run - -```bash -accelerate config -``` - -and reply to the questions asked. Then - -```bash -accelerate test -``` - -that will check everything is ready for training. Finally, you can launch training with - -```bash -export TASK_NAME=mrpc - -accelerate launch run_glue_no_trainer.py \ - --model_name_or_path bert-base-cased \ - --task_name $TASK_NAME \ - --max_length 128 \ - --per_device_train_batch_size 32 \ - --learning_rate 2e-5 \ - --num_train_epochs 3 \ - --output_dir /tmp/$TASK_NAME/ -``` - -This command is the same and will work for: - -- a CPU-only setup -- a setup with one GPU -- a distributed training with several GPUs (single or multi node) -- a training on TPUs - -Note that this library is in alpha release so your feedback is more than welcome if you encounter any problem using it. - -## XNLI - -Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/examples/pytorch/text-classification/run_xnli.py). - -[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is a crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili). - -#### Fine-tuning on XNLI - -This example code fine-tunes mBERT (multi-lingual BERT) on the XNLI dataset. It runs in 106 mins on a single tesla V100 16GB. - -```bash -python run_xnli.py \ - --model_name_or_path bert-base-multilingual-cased \ - --language de \ - --train_language en \ - --do_train \ - --do_eval \ - --per_device_train_batch_size 32 \ - --learning_rate 5e-5 \ - --num_train_epochs 2.0 \ - --max_seq_length 128 \ - --output_dir /tmp/debug_xnli/ \ - --save_steps -1 -``` - -Training with the previously defined hyper-parameters yields the following results on the **test** set: - -```bash -acc = 0.7093812375249501 -``` diff --git a/examples/hf/text-classification/requirements.txt b/examples/hf/text-classification/requirements.txt deleted file mode 100644 index 2a0e0d7de..000000000 --- a/examples/hf/text-classification/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -accelerate -datasets >= 1.8.0 -sentencepiece != 0.1.92 -scipy -scikit-learn -protobuf -torch >= 1.3 diff --git a/examples/hf/text-classification/run_glue.py b/examples/hf/text-classification/run_glue.py deleted file mode 100755 index 3d2108e99..000000000 --- a/examples/hf/text-classification/run_glue.py +++ /dev/null @@ -1,670 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Finetuning the library models for sequence classification on GLUE.""" -# You can also adapt this script on your own text classification task. Pointers for this are left as comments. - -import logging -import os -import random -import sys -from dataclasses import dataclass, field -from typing import Optional - -import datasets -import numpy as np -from datasets import load_dataset - -import evaluate -import transformers -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - DataCollatorWithPadding, - EvalPrediction, - HfArgumentParser, - PretrainedConfig, - default_data_collator, - set_seed, -) -from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import check_min_version, send_example_telemetry -from transformers.utils.versions import require_version - -import pippy -from pippy import run_pippy -from pippy.hf import PiPPyTrainingArguments, PiPPyTrainer, bert, PiPPyHFTracer - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") - -task_to_keys = { - "cola": ("sentence", None), - "mnli": ("premise", "hypothesis"), - "mrpc": ("sentence1", "sentence2"), - "qnli": ("question", "sentence"), - "qqp": ("question1", "question2"), - "rte": ("sentence1", "sentence2"), - "sst2": ("sentence", None), - "stsb": ("sentence1", "sentence2"), - "wnli": ("sentence1", "sentence2"), -} - -logger = logging.getLogger(__name__) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - - Using `HfArgumentParser` we can turn this class - into argparse arguments to be able to specify them on - the command line. - """ - - task_name: Optional[str] = field( - default=None, - metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, - ) - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - max_seq_length: int = field( - default=128, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - ) - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} - ) - pad_to_max_length: bool = field( - default=True, - metadata={ - "help": ( - "Whether to pad all samples to `max_seq_length`. " - "If False, will pad the samples dynamically when batching to the maximum length in the batch." - ) - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ) - }, - ) - max_predict_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of prediction examples to this " - "value if set." - ) - }, - ) - train_file: Optional[str] = field( - default=None, metadata={"help": "A csv or a json file containing the training data."} - ) - validation_file: Optional[str] = field( - default=None, metadata={"help": "A csv or a json file containing the validation data."} - ) - test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) - - def __post_init__(self): - if self.task_name is not None: - self.task_name = self.task_name.lower() - if self.task_name not in task_to_keys.keys(): - raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) - elif self.dataset_name is not None: - pass - elif self.train_file is None or self.validation_file is None: - raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") - else: - train_extension = self.train_file.split(".")[-1] - assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." - validation_extension = self.validation_file.split(".")[-1] - assert ( - validation_extension == train_extension - ), "`validation_file` should have the same extension (csv or json) as `train_file`." - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field( - metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - use_auth_token: bool = field( - default=False, - metadata={ - "help": ( - "Will use the token generated when running `transformers-cli login` (necessary to use this script " - "with private models)." - ) - }, - ) - ignore_mismatched_sizes: bool = field( - default=False, - metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, - ) - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - - # =============================================== PiPPy change start =============================================== - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PiPPyTrainingArguments)) - # ================================================ PiPPy change end ================================================ - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - # =============================================== PiPPy change start =============================================== - run_pippy(run_master, training_args, model_args, data_args) - - -def run_master(pp_ranks, training_args, model_args, data_args): - # ================================================ PiPPy change end ================================================ - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_glue", model_args, data_args) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Training/evaluation parameters {training_args}") - - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) - - # Set seed before initializing model. - set_seed(training_args.seed) - - # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) - # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the - # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named - # label if at least two columns are provided. - # - # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this - # single column. You can easily tweak this behavior (see below) - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if data_args.task_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - "glue", - data_args.task_name, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - elif data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - # Loading a dataset from your local files. - # CSV/JSON training and evaluation files are needed. - data_files = {"train": data_args.train_file, "validation": data_args.validation_file} - - # Get the test dataset: you can provide your own CSV/JSON test file (see below) - # when you use `do_predict` without specifying a GLUE benchmark task. - if training_args.do_predict: - if data_args.test_file is not None: - train_extension = data_args.train_file.split(".")[-1] - test_extension = data_args.test_file.split(".")[-1] - assert ( - test_extension == train_extension - ), "`test_file` should have the same extension (csv or json) as `train_file`." - data_files["test"] = data_args.test_file - else: - raise ValueError("Need either a GLUE task or a test file for `do_predict`.") - - for key in data_files.keys(): - logger.info(f"load a local file for {key}: {data_files[key]}") - - if data_args.train_file.endswith(".csv"): - # Loading a dataset from local csv files - raw_datasets = load_dataset( - "csv", - data_files=data_files, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - # Loading a dataset from local json files - raw_datasets = load_dataset( - "json", - data_files=data_files, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - # See more about loading any type of standard or custom dataset at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Labels - if data_args.task_name is not None: - is_regression = data_args.task_name == "stsb" - if not is_regression: - label_list = raw_datasets["train"].features["label"].names - num_labels = len(label_list) - else: - num_labels = 1 - else: - # Trying to have good defaults here, don't hesitate to tweak to your needs. - is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] - if is_regression: - num_labels = 1 - else: - # A useful fast method: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique - label_list = raw_datasets["train"].unique("label") - label_list.sort() # Let's sort it for determinism - num_labels = len(label_list) - - # Load pretrained model and tokenizer - # - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast_tokenizer, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - model = AutoModelForSequenceClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, - ) - - model.to(training_args.device) - - model.config.problem_type = "single_label_classification" # "regression", "single_label_classification", or "multi_label_classification" - training_args.label_names = ['labels'] # https://github.com/huggingface/transformers/blob/c8b6ae858d61e5bc10e388d095aa74f7690d1021/src/transformers/trainer.py#L629-L630 - - # =============================================== PiPPy change start =============================================== - # Setting model to training mode so that PiPPy would automatically look for "loss" and generate backward pass - model.train() - - logger.info("[PiPPy] Splitting model ...") - num_ranks=len(pp_ranks) - bert.split(model, num_ranks) - - # Prepare concrete args in case FX tracing needs them - inputs = ['input_ids', 'token_type_ids', 'labels', 'attention_mask'] - concrete_args = pippy.create_default_args(model, - except_keys=inputs) - - # Compile into pipeline parallel, distributed model - pipe_mod = pippy.compile( - model, - num_ranks, - num_chunks=training_args.chunks or num_ranks, - ranks=pp_ranks, - tracer=PiPPyHFTracer(), - checkpoint=bool(training_args.checkpoint), - concrete_args=concrete_args, - ) - pipe_mod.init_data_parallel(dp_group_size=training_args.dp_group_size) - pipe_mod.config = model.config - model = pipe_mod - - # ================================================ PiPPy change end ================================================ - - # Preprocessing the raw_datasets - if data_args.task_name is not None: - sentence1_key, sentence2_key = task_to_keys[data_args.task_name] - else: - # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. - non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] - if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: - sentence1_key, sentence2_key = "sentence1", "sentence2" - else: - if len(non_label_column_names) >= 2: - sentence1_key, sentence2_key = non_label_column_names[:2] - else: - sentence1_key, sentence2_key = non_label_column_names[0], None - - # Padding strategy - if data_args.pad_to_max_length: - padding = "max_length" - else: - # We will pad later, dynamically at batch creation, to the max sequence length in each batch - padding = False - - # Some models have set the order of the labels to use, so let's make sure we do use it. - label_to_id = None - if ( - model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id - and data_args.task_name is not None - and not is_regression - ): - # Some have all caps in their config, some don't. - label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} - if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): - label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} - else: - logger.warning( - "Your model seems to have been trained with labels, but they don't match the dataset: ", - f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." - "\nIgnoring the model labels as a result.", - ) - elif data_args.task_name is None and not is_regression: - label_to_id = {v: i for i, v in enumerate(label_list)} - - if label_to_id is not None: - model.config.label2id = label_to_id - model.config.id2label = {id: label for label, id in config.label2id.items()} - elif data_args.task_name is not None and not is_regression: - model.config.label2id = {l: i for i, l in enumerate(label_list)} - model.config.id2label = {id: label for label, id in config.label2id.items()} - - if data_args.max_seq_length > tokenizer.model_max_length: - logger.warning( - f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - - def preprocess_function(examples): - # Tokenize the texts - args = ( - (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) - ) - result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) - - # Map labels to IDs (not necessary for GLUE tasks) - if label_to_id is not None and "label" in examples: - result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] - return result - - with training_args.main_process_first(desc="dataset map pre-processing"): - raw_datasets = raw_datasets.map( - preprocess_function, - batched=True, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on dataset", - ) - if training_args.do_train: - if "train" not in raw_datasets: - raise ValueError("--do_train requires a train dataset") - train_dataset = raw_datasets["train"] - if data_args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) - - if training_args.do_eval: - if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) - - if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: - if "test" not in raw_datasets and "test_matched" not in raw_datasets: - raise ValueError("--do_predict requires a test dataset") - predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"] - if data_args.max_predict_samples is not None: - max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) - predict_dataset = predict_dataset.select(range(max_predict_samples)) - - # Log a few random samples from the training set: - if training_args.do_train: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # Get the metric function - if data_args.task_name is not None: - metric = evaluate.load("glue", data_args.task_name) - else: - metric = evaluate.load("accuracy") - - # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a - # predictions and label_ids field) and has to return a dictionary string to float. - def compute_metrics(p: EvalPrediction): - preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions - preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) - if data_args.task_name is not None: - result = metric.compute(predictions=preds, references=p.label_ids) - if len(result) > 1: - result["combined_score"] = np.mean(list(result.values())).item() - return result - elif is_regression: - return {"mse": ((preds - p.label_ids) ** 2).mean().item()} - else: - return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} - - # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if - # we already did the padding. - if data_args.pad_to_max_length: - data_collator = default_data_collator - elif training_args.fp16: - data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) - else: - data_collator = None - - # Initialize our Trainer - # =============================================== PiPPy change start =============================================== - trainer = PiPPyTrainer( - # ================================================ PiPPy change end ================================================ - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - compute_metrics=compute_metrics, - tokenizer=tokenizer, - data_collator=data_collator, - ) - # =============================================== PiPPy change start =============================================== - trainer._signature_columns = inputs - # TODO(pbelevich): investigate! - trainer._signature_columns.remove('labels') - trainer._signature_columns.append('label') - # ================================================ PiPPy change end ================================================ - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - metrics = train_result.metrics - max_train_samples = ( - data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) - ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - - # TODO: overwrite save_model method so that it does not pickle process group - #trainer.save_model() # Saves the tokenizer too for easy upload - - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - # Evaluation - if training_args.do_eval: - logger.info("*** Evaluate ***") - - # Loop to handle MNLI double evaluation (matched, mis-matched) - tasks = [data_args.task_name] - eval_datasets = [eval_dataset] - if data_args.task_name == "mnli": - tasks.append("mnli-mm") - eval_datasets.append(raw_datasets["validation_mismatched"]) - combined = {} - - for eval_dataset, task in zip(eval_datasets, tasks): - metrics = trainer.evaluate(eval_dataset=eval_dataset) - - max_eval_samples = ( - data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - ) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - - if task == "mnli-mm": - metrics = {k + "_mm": v for k, v in metrics.items()} - if task is not None and "mnli" in task: - combined.update(metrics) - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics) - - if training_args.do_predict: - logger.info("*** Predict ***") - - # Loop to handle MNLI double evaluation (matched, mis-matched) - tasks = [data_args.task_name] - predict_datasets = [predict_dataset] - if data_args.task_name == "mnli": - tasks.append("mnli-mm") - predict_datasets.append(raw_datasets["test_mismatched"]) - - for predict_dataset, task in zip(predict_datasets, tasks): - # Removing the `label` columns because it contains -1 and Trainer won't like that. - predict_dataset = predict_dataset.remove_columns("label") - predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions - predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) - - output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt") - if trainer.is_world_process_zero(): - with open(output_predict_file, "w") as writer: - logger.info(f"***** Predict results {task} *****") - writer.write("index\tprediction\n") - for index, item in enumerate(predictions): - if is_regression: - writer.write(f"{index}\t{item:3.3f}\n") - else: - item = label_list[item] - writer.write(f"{index}\t{item}\n") - - kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} - if data_args.task_name is not None: - kwargs["language"] = "en" - kwargs["dataset_tags"] = "glue" - kwargs["dataset_args"] = data_args.task_name - kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" - - if training_args.push_to_hub: - trainer.push_to_hub(**kwargs) - else: - trainer.create_model_card(**kwargs) - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/examples/hf/text-classification/run_glue_no_trainer.py b/examples/hf/text-classification/run_glue_no_trainer.py deleted file mode 100644 index f74e55206..000000000 --- a/examples/hf/text-classification/run_glue_no_trainer.py +++ /dev/null @@ -1,635 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Finetuning a 🤗 Transformers model for sequence classification on GLUE.""" -import argparse -import json -import logging -import math -import os -import random -from pathlib import Path - -import datasets -import torch -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm - -import evaluate -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import set_seed -from huggingface_hub import Repository -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - DataCollatorWithPadding, - PretrainedConfig, - SchedulerType, - default_data_collator, - get_scheduler, -) -from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry -from transformers.utils.versions import require_version - - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -logger = get_logger(__name__) - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") - -task_to_keys = { - "cola": ("sentence", None), - "mnli": ("premise", "hypothesis"), - "mrpc": ("sentence1", "sentence2"), - "qnli": ("question", "sentence"), - "qqp": ("question1", "question2"), - "rte": ("sentence1", "sentence2"), - "sst2": ("sentence", None), - "stsb": ("sentence1", "sentence2"), - "wnli": ("sentence1", "sentence2"), -} - - -def parse_args(): - parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") - parser.add_argument( - "--task_name", - type=str, - default=None, - help="The name of the glue task to train on.", - choices=list(task_to_keys.keys()), - ) - parser.add_argument( - "--train_file", type=str, default=None, help="A csv or a json file containing the training data." - ) - parser.add_argument( - "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." - ) - parser.add_argument( - "--max_length", - type=int, - default=128, - help=( - "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," - " sequences shorter will be padded if `--pad_to_max_lengh` is passed." - ), - ) - parser.add_argument( - "--pad_to_max_length", - action="store_true", - help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", - ) - parser.add_argument( - "--model_name_or_path", - type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=True, - ) - parser.add_argument( - "--use_slow_tokenizer", - action="store_true", - help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", - ) - parser.add_argument( - "--per_device_train_batch_size", - type=int, - default=8, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--per_device_eval_batch_size", - type=int, - default=8, - help="Batch size (per device) for the evaluation dataloader.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--lr_scheduler_type", - type=SchedulerType, - default="linear", - help="The scheduler type to use.", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], - ) - parser.add_argument( - "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument( - "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." - ) - parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--checkpointing_steps", - type=str, - default=None, - help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help="If the training should continue from a checkpoint folder.", - ) - parser.add_argument( - "--with_tracking", - action="store_true", - help="Whether to enable experiment trackers for logging.", - ) - parser.add_argument( - "--report_to", - type=str, - default="all", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' - "Only applicable when `--with_tracking` is passed." - ), - ) - parser.add_argument( - "--ignore_mismatched_sizes", - action="store_true", - help="Whether or not to enable to load a pretrained model whose head dimensions are different.", - ) - args = parser.parse_args() - - # Sanity checks - if args.task_name is None and args.train_file is None and args.validation_file is None: - raise ValueError("Need either a task name or a training/validation file.") - else: - if args.train_file is not None: - extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." - if args.validation_file is not None: - extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - - if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." - - return args - - -def main(): - args = parse_args() - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_glue_no_trainer", args) - - # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. - # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers - # in the environment - accelerator = ( - Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() - ) - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - accelerator.wait_for_everyone() - - # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) - # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). - - # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the - # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named - # label if at least two columns are provided. - - # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this - # single column. You can easily tweak this behavior (see below) - - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if args.task_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset("glue", args.task_name) - else: - # Loading the dataset from local csv or json file. - data_files = {} - if args.train_file is not None: - data_files["train"] = args.train_file - if args.validation_file is not None: - data_files["validation"] = args.validation_file - extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1] - raw_datasets = load_dataset(extension, data_files=data_files) - # See more about loading any type of standard or custom dataset at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Labels - if args.task_name is not None: - is_regression = args.task_name == "stsb" - if not is_regression: - label_list = raw_datasets["train"].features["label"].names - num_labels = len(label_list) - else: - num_labels = 1 - else: - # Trying to have good defaults here, don't hesitate to tweak to your needs. - is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] - if is_regression: - num_labels = 1 - else: - # A useful fast method: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique - label_list = raw_datasets["train"].unique("label") - label_list.sort() # Let's sort it for determinism - num_labels = len(label_list) - - # Load pretrained model and tokenizer - # - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) - model = AutoModelForSequenceClassification.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - ignore_mismatched_sizes=args.ignore_mismatched_sizes, - ) - - # Preprocessing the datasets - if args.task_name is not None: - sentence1_key, sentence2_key = task_to_keys[args.task_name] - else: - # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. - non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] - if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: - sentence1_key, sentence2_key = "sentence1", "sentence2" - else: - if len(non_label_column_names) >= 2: - sentence1_key, sentence2_key = non_label_column_names[:2] - else: - sentence1_key, sentence2_key = non_label_column_names[0], None - - # Some models have set the order of the labels to use, so let's make sure we do use it. - label_to_id = None - if ( - model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id - and args.task_name is not None - and not is_regression - ): - # Some have all caps in their config, some don't. - label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} - if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): - logger.info( - f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " - "Using it!" - ) - label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} - else: - logger.warning( - "Your model seems to have been trained with labels, but they don't match the dataset: ", - f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." - "\nIgnoring the model labels as a result.", - ) - elif args.task_name is None and not is_regression: - label_to_id = {v: i for i, v in enumerate(label_list)} - - if label_to_id is not None: - model.config.label2id = label_to_id - model.config.id2label = {id: label for label, id in config.label2id.items()} - elif args.task_name is not None and not is_regression: - model.config.label2id = {l: i for i, l in enumerate(label_list)} - model.config.id2label = {id: label for label, id in config.label2id.items()} - - padding = "max_length" if args.pad_to_max_length else False - - def preprocess_function(examples): - # Tokenize the texts - texts = ( - (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) - ) - result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True) - - if "label" in examples: - if label_to_id is not None: - # Map labels to IDs (not necessary for GLUE tasks) - result["labels"] = [label_to_id[l] for l in examples["label"]] - else: - # In all cases, rename the column to labels because the model will expect that. - result["labels"] = examples["label"] - return result - - with accelerator.main_process_first(): - processed_datasets = raw_datasets.map( - preprocess_function, - batched=True, - remove_columns=raw_datasets["train"].column_names, - desc="Running tokenizer on dataset", - ) - - train_dataset = processed_datasets["train"] - eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] - - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # DataLoaders creation: - if args.pad_to_max_length: - # If padding was already done ot max length, we use the default data collator that will just convert everything - # to tensors. - data_collator = default_data_collator - else: - # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of - # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple - # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). - data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) - - train_dataloader = DataLoader( - train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size - ) - eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.max_train_steps, - ) - - # Prepare everything with our `accelerator`. - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # Figure out how many steps we should save the Accelerator states - if hasattr(args.checkpointing_steps, "isdigit"): - checkpointing_steps = args.checkpointing_steps - if args.checkpointing_steps.isdigit(): - checkpointing_steps = int(args.checkpointing_steps) - else: - checkpointing_steps = None - - # We need to initialize the trackers we use, and also store our configuration. - # We initialize the trackers only on main process because `accelerator.log` - # only logs on main process and we don't want empty logs/runs on other processes. - if args.with_tracking: - if accelerator.is_main_process: - experiment_config = vars(args) - # TensorBoard cannot log Enums, need the raw value - experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value - accelerator.init_trackers("glue_no_trainer", experiment_config) - - # Get the metric function - if args.task_name is not None: - metric = evaluate.load("glue", args.task_name) - else: - metric = evaluate.load("accuracy") - - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - completed_steps = 0 - starting_epoch = 0 - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - accelerator.load_state(args.resume_from_checkpoint) - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] - dirs.sort(key=os.path.getctime) - path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last - # Extract `epoch_{i}` or `step_{i}` - training_difference = os.path.splitext(path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - else: - resume_step = int(training_difference.replace("step_", "")) - starting_epoch = resume_step // len(train_dataloader) - resume_step -= starting_epoch * len(train_dataloader) - - for epoch in range(starting_epoch, args.num_train_epochs): - model.train() - if args.with_tracking: - total_loss = 0 - for step, batch in enumerate(train_dataloader): - # We need to skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == starting_epoch: - if resume_step is not None and step < resume_step: - completed_steps += 1 - continue - outputs = model(**batch) - loss = outputs.loss - # We keep track of the loss at each epoch - if args.with_tracking: - total_loss += loss.detach().float() - loss = loss / args.gradient_accumulation_steps - accelerator.backward(loss) - if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - progress_bar.update(1) - completed_steps += 1 - - if isinstance(checkpointing_steps, int): - if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps }" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - - if completed_steps >= args.max_train_steps: - break - - model.eval() - samples_seen = 0 - for step, batch in enumerate(eval_dataloader): - with torch.no_grad(): - outputs = model(**batch) - predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() - predictions, references = accelerator.gather((predictions, batch["labels"])) - # If we are in a multiprocess environment, the last batch has duplicates - if accelerator.num_processes > 1: - if step == len(eval_dataloader) - 1: - predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] - references = references[: len(eval_dataloader.dataset) - samples_seen] - else: - samples_seen += references.shape[0] - metric.add_batch( - predictions=predictions, - references=references, - ) - - eval_metric = metric.compute() - logger.info(f"epoch {epoch}: {eval_metric}") - - if args.with_tracking: - accelerator.log( - { - "accuracy" if args.task_name is not None else "glue": eval_metric, - "train_loss": total_loss.item() / len(train_dataloader), - "epoch": epoch, - "step": completed_steps, - }, - step=completed_steps, - ) - - if args.push_to_hub and epoch < args.num_train_epochs - 1: - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - repo.push_to_hub( - commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True - ) - - if args.checkpointing_steps == "epoch": - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - - if args.output_dir is not None: - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) - - if args.task_name == "mnli": - # Final evaluation on mismatched validation set - eval_dataset = processed_datasets["validation_mismatched"] - eval_dataloader = DataLoader( - eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size - ) - eval_dataloader = accelerator.prepare(eval_dataloader) - - model.eval() - for step, batch in enumerate(eval_dataloader): - outputs = model(**batch) - predictions = outputs.logits.argmax(dim=-1) - metric.add_batch( - predictions=accelerator.gather(predictions), - references=accelerator.gather(batch["labels"]), - ) - - eval_metric = metric.compute() - logger.info(f"mnli-mm: {eval_metric}") - - if args.output_dir is not None: - with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: - json.dump({"eval_accuracy": eval_metric["accuracy"]}, f) - - -if __name__ == "__main__": - main() diff --git a/examples/hf/text-classification/run_xnli.py b/examples/hf/text-classification/run_xnli.py deleted file mode 100755 index e140d41ff..000000000 --- a/examples/hf/text-classification/run_xnli.py +++ /dev/null @@ -1,437 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Finetuning multi-lingual models on XNLI (e.g. Bert, DistilBERT, XLM). - Adapted from `examples/text-classification/run_glue.py`""" - -import logging -import os -import random -import sys -from dataclasses import dataclass, field -from typing import Optional - -import datasets -import numpy as np -from datasets import load_dataset - -import evaluate -import transformers -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - DataCollatorWithPadding, - EvalPrediction, - HfArgumentParser, - Trainer, - TrainingArguments, - default_data_collator, - set_seed, -) -from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import check_min_version, send_example_telemetry -from transformers.utils.versions import require_version - - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") - -logger = logging.getLogger(__name__) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - - Using `HfArgumentParser` we can turn this class - into argparse arguments to be able to specify them on - the command line. - """ - - max_seq_length: Optional[int] = field( - default=128, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - ) - }, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} - ) - pad_to_max_length: bool = field( - default=True, - metadata={ - "help": ( - "Whether to pad all samples to `max_seq_length`. " - "If False, will pad the samples dynamically when batching to the maximum length in the batch." - ) - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ) - }, - ) - max_predict_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of prediction examples to this " - "value if set." - ) - }, - ) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field( - default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} - ) - language: str = field( - default=None, metadata={"help": "Evaluation language. Also train language if `train_language` is set to None."} - ) - train_language: Optional[str] = field( - default=None, metadata={"help": "Train language if it is different from the evaluation language."} - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, - ) - do_lower_case: Optional[bool] = field( - default=False, - metadata={"help": "arg to indicate if tokenizer should do lower case in AutoTokenizer.from_pretrained()"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - use_auth_token: bool = field( - default=False, - metadata={ - "help": ( - "Will use the token generated when running `transformers-cli login` (necessary to use this script " - "with private models)." - ) - }, - ) - ignore_mismatched_sizes: bool = field( - default=False, - metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, - ) - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_xnli", model_args) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Training/evaluation parameters {training_args}") - - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) - - # Set seed before initializing model. - set_seed(training_args.seed) - - # In distributed training, the load_dataset function guarantees that only one local process can concurrently - # download the dataset. - # Downloading and loading xnli dataset from the hub. - if training_args.do_train: - if model_args.train_language is None: - train_dataset = load_dataset( - "xnli", - model_args.language, - split="train", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - train_dataset = load_dataset( - "xnli", - model_args.train_language, - split="train", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - label_list = train_dataset.features["label"].names - - if training_args.do_eval: - eval_dataset = load_dataset( - "xnli", - model_args.language, - split="validation", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - label_list = eval_dataset.features["label"].names - - if training_args.do_predict: - predict_dataset = load_dataset( - "xnli", - model_args.language, - split="test", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - label_list = predict_dataset.features["label"].names - - # Labels - num_labels = len(label_list) - - # Load pretrained model and tokenizer - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task="xnli", - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - do_lower_case=model_args.do_lower_case, - cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast_tokenizer, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - model = AutoModelForSequenceClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, - ) - - # Preprocessing the datasets - # Padding strategy - if data_args.pad_to_max_length: - padding = "max_length" - else: - # We will pad later, dynamically at batch creation, to the max sequence length in each batch - padding = False - - def preprocess_function(examples): - # Tokenize the texts - return tokenizer( - examples["premise"], - examples["hypothesis"], - padding=padding, - max_length=data_args.max_seq_length, - truncation=True, - ) - - if training_args.do_train: - if data_args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) - with training_args.main_process_first(desc="train dataset map pre-processing"): - train_dataset = train_dataset.map( - preprocess_function, - batched=True, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on train dataset", - ) - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - if training_args.do_eval: - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) - with training_args.main_process_first(desc="validation dataset map pre-processing"): - eval_dataset = eval_dataset.map( - preprocess_function, - batched=True, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on validation dataset", - ) - - if training_args.do_predict: - if data_args.max_predict_samples is not None: - max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) - predict_dataset = predict_dataset.select(range(max_predict_samples)) - with training_args.main_process_first(desc="prediction dataset map pre-processing"): - predict_dataset = predict_dataset.map( - preprocess_function, - batched=True, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on prediction dataset", - ) - - # Get the metric function - metric = evaluate.load("xnli") - - # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a - # predictions and label_ids field) and has to return a dictionary string to float. - def compute_metrics(p: EvalPrediction): - preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions - preds = np.argmax(preds, axis=1) - return metric.compute(predictions=preds, references=p.label_ids) - - # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. - if data_args.pad_to_max_length: - data_collator = default_data_collator - elif training_args.fp16: - data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) - else: - data_collator = None - - # Initialize our Trainer - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - compute_metrics=compute_metrics, - tokenizer=tokenizer, - data_collator=data_collator, - ) - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - metrics = train_result.metrics - max_train_samples = ( - data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) - ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - - # TODO: overwrite save_model method so that it does not pickle process group - #trainer.save_model() # Saves the tokenizer too for easy upload - - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - # Evaluation - if training_args.do_eval: - logger.info("*** Evaluate ***") - metrics = trainer.evaluate(eval_dataset=eval_dataset) - - max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - # Prediction - if training_args.do_predict: - logger.info("*** Predict ***") - predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") - - max_predict_samples = ( - data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) - ) - metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) - - trainer.log_metrics("predict", metrics) - trainer.save_metrics("predict", metrics) - - predictions = np.argmax(predictions, axis=1) - output_predict_file = os.path.join(training_args.output_dir, "predictions.txt") - if trainer.is_world_process_zero(): - with open(output_predict_file, "w") as writer: - writer.write("index\tprediction\n") - for index, item in enumerate(predictions): - item = label_list[item] - writer.write(f"{index}\t{item}\n") - - -if __name__ == "__main__": - main() diff --git a/examples/hf/translation/README.md b/examples/hf/translation/README.md deleted file mode 100644 index 4bd66ea0a..000000000 --- a/examples/hf/translation/README.md +++ /dev/null @@ -1,211 +0,0 @@ - - -## Translation - -This directory contains examples for finetuning and evaluating transformers on translation tasks. -Please tag @patil-suraj with any issues/unexpected behaviors, or send a PR! -For deprecated `bertabs` instructions, see [`bertabs/README.md`](https://github.com/huggingface/transformers/blob/main/examples/research_projects/bertabs/README.md). -For the old `finetune_trainer.py` and related utils, see [`examples/legacy/seq2seq`](https://github.com/huggingface/transformers/blob/main/examples/legacy/seq2seq). - -### Supported Architectures - -- `BartForConditionalGeneration` -- `FSMTForConditionalGeneration` (translation only) -- `MBartForConditionalGeneration` -- `MarianMTModel` -- `PegasusForConditionalGeneration` -- `T5ForConditionalGeneration` -- `MT5ForConditionalGeneration` - -`run_translation.py` is a lightweight examples of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it. - -For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files -and you also will find examples of these below. - - -## With Trainer - -Here is an example of a translation fine-tuning with a MarianMT model: - -```bash -python examples/pytorch/translation/run_translation.py \ - --model_name_or_path Helsinki-NLP/opus-mt-en-ro \ - --do_train \ - --do_eval \ - --source_lang en \ - --target_lang ro \ - --dataset_name wmt16 \ - --dataset_config_name ro-en \ - --output_dir /tmp/tst-translation \ - --per_device_train_batch_size=4 \ - --per_device_eval_batch_size=4 \ - --overwrite_output_dir \ - --predict_with_generate -``` - -MBart and some T5 models require special handling. - -T5 models `t5-small`, `t5-base`, `t5-large`, `t5-3b` and `t5-11b` must use an additional argument: `--source_prefix "translate {source_lang} to {target_lang}"`. For example: - -```bash -python examples/pytorch/translation/run_translation.py \ - --model_name_or_path t5-small \ - --do_train \ - --do_eval \ - --source_lang en \ - --target_lang ro \ - --source_prefix "translate English to Romanian: " \ - --dataset_name wmt16 \ - --dataset_config_name ro-en \ - --output_dir /tmp/tst-translation \ - --per_device_train_batch_size=4 \ - --per_device_eval_batch_size=4 \ - --overwrite_output_dir \ - --predict_with_generate -``` - -If you get a terrible BLEU score, make sure that you didn't forget to use the `--source_prefix` argument. - -For the aforementioned group of T5 models it's important to remember that if you switch to a different language pair, make sure to adjust the source and target values in all 3 language-specific command line argument: `--source_lang`, `--target_lang` and `--source_prefix`. - -MBart models require a different format for `--source_lang` and `--target_lang` values, e.g. instead of `en` it expects `en_XX`, for `ro` it expects `ro_RO`. The full MBart specification for language codes can be found [here](https://huggingface.co/facebook/mbart-large-cc25). For example: - -```bash -python examples/pytorch/translation/run_translation.py \ - --model_name_or_path facebook/mbart-large-en-ro \ - --do_train \ - --do_eval \ - --dataset_name wmt16 \ - --dataset_config_name ro-en \ - --source_lang en_XX \ - --target_lang ro_RO \ - --output_dir /tmp/tst-translation \ - --per_device_train_batch_size=4 \ - --per_device_eval_batch_size=4 \ - --overwrite_output_dir \ - --predict_with_generate - ``` - -And here is how you would use the translation finetuning on your own files, after adjusting the -values for the arguments `--train_file`, `--validation_file` to match your setup: - -```bash -python examples/pytorch/translation/run_translation.py \ - --model_name_or_path t5-small \ - --do_train \ - --do_eval \ - --source_lang en \ - --target_lang ro \ - --source_prefix "translate English to Romanian: " \ - --dataset_name wmt16 \ - --dataset_config_name ro-en \ - --train_file path_to_jsonlines_file \ - --validation_file path_to_jsonlines_file \ - --output_dir /tmp/tst-translation \ - --per_device_train_batch_size=4 \ - --per_device_eval_batch_size=4 \ - --overwrite_output_dir \ - --predict_with_generate -``` - -The task of translation supports only custom JSONLINES files, with each line being a dictionary with a key `"translation"` and its value another dictionary whose keys is the language pair. For example: - -```json -{ "translation": { "en": "Others have dismissed him as a joke.", "ro": "Alții l-au numit o glumă." } } -{ "translation": { "en": "And some are holding out for an implosion.", "ro": "Iar alții așteaptă implozia." } } -``` -Here the languages are Romanian (`ro`) and English (`en`). - -If you want to use a pre-processed dataset that leads to high BLEU scores, but for the `en-de` language pair, you can use `--dataset_name stas/wmt14-en-de-pre-processed`, as following: - -```bash -python examples/pytorch/translation/run_translation.py \ - --model_name_or_path t5-small \ - --do_train \ - --do_eval \ - --source_lang en \ - --target_lang de \ - --source_prefix "translate English to German: " \ - --dataset_name stas/wmt14-en-de-pre-processed \ - --output_dir /tmp/tst-translation \ - --per_device_train_batch_size=4 \ - --per_device_eval_batch_size=4 \ - --overwrite_output_dir \ - --predict_with_generate - ``` - -## With Accelerate - -Based on the script [`run_translation_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/translation/run_translationn_no_trainer.py). - -Like `run_translation.py`, this script allows you to fine-tune any of the models supported on a -translation task, the main difference is that this -script exposes the bare training loop, to allow you to quickly experiment and add any customization you would like. - -It offers less options than the script with `Trainer` (for instance you can easily change the options for the optimizer -or the dataloaders directly in the script) but still run in a distributed setup, on TPU and supports mixed precision by -the mean of the [🤗 `Accelerate`](https://github.com/huggingface/accelerate) library. You can use the script normally -after installing it: - -```bash -pip install git+https://github.com/huggingface/accelerate -``` - -then - -```bash -python run_translation_no_trainer.py \ - --model_name_or_path Helsinki-NLP/opus-mt-en-ro \ - --source_lang en \ - --target_lang ro \ - --dataset_name wmt16 \ - --dataset_config_name ro-en \ - --output_dir ~/tmp/tst-translation -``` - -You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run - -```bash -accelerate config -``` - -and reply to the questions asked. Then - -```bash -accelerate test -``` - -that will check everything is ready for training. Finally, you can launch training with - -```bash -accelerate launch run_translation_no_trainer.py \ - --model_name_or_path Helsinki-NLP/opus-mt-en-ro \ - --source_lang en \ - --target_lang ro \ - --dataset_name wmt16 \ - --dataset_config_name ro-en \ - --output_dir ~/tmp/tst-translation -``` - -This command is the same and will work for: - -- a CPU-only setup -- a setup with one GPU -- a distributed training with several GPUs (single or multi node) -- a training on TPUs - -Note that this library is in alpha release so your feedback is more than welcome if you encounter any problem using it. diff --git a/examples/hf/translation/requirements.txt b/examples/hf/translation/requirements.txt deleted file mode 100644 index c34795fff..000000000 --- a/examples/hf/translation/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -accelerate -datasets >= 1.8.0 -sentencepiece != 0.1.92 -protobuf -sacrebleu >= 1.4.12 -py7zr -torch >= 1.3 diff --git a/examples/hf/translation/run_translation.py b/examples/hf/translation/run_translation.py deleted file mode 100755 index bb6598ff7..000000000 --- a/examples/hf/translation/run_translation.py +++ /dev/null @@ -1,749 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning the library models for sequence to sequence. -""" -# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. - -import logging -import os -import sys -import types -from dataclasses import dataclass, field -from typing import Optional - -# Run before `import torch` to pin CUDA devices first -from pippy import run_pippy - -import datasets -import numpy as np -import pippy -import torch -from datasets import load_dataset - -import evaluate -import transformers -from transformers import ( - AutoConfig, - AutoModelForSeq2SeqLM, - AutoTokenizer, - DataCollatorForSeq2Seq, - HfArgumentParser, - M2M100Tokenizer, - MBart50Tokenizer, - MBart50TokenizerFast, - MBartTokenizer, - MBartTokenizerFast, - default_data_collator, - set_seed, -) -from transformers.trainer_utils import get_last_checkpoint -from transformers.utils import check_min_version, send_example_telemetry -from transformers.utils.versions import require_version - -from pippy.hf import PiPPySeq2SeqTrainingArguments, PiPPySeq2SeqTrainer, bart, t5, PiPPyHFTracer -from pippy.microbatch import TensorChunkSpec, sum_reducer - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") - -logger = logging.getLogger(__name__) - -# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. -MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field( - metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} - ) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - use_auth_token: bool = field( - default=False, - metadata={ - "help": ( - "Will use the token generated when running `huggingface-cli login` (necessary to use this script " - "with private models)." - ) - }, - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) - target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) - - dataset_name: Optional[str] = field( - default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} - ) - dataset_config_name: Optional[str] = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} - ) - train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) - validation_file: Optional[str] = field( - default=None, - metadata={ - "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file." - }, - ) - test_file: Optional[str] = field( - default=None, - metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."}, - ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) - preprocessing_num_workers: Optional[int] = field( - default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, - ) - max_source_length: Optional[int] = field( - default=1024, - metadata={ - "help": ( - "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - ) - }, - ) - max_target_length: Optional[int] = field( - default=128, - metadata={ - "help": ( - "The maximum total sequence length for target text after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - ) - }, - ) - val_max_target_length: Optional[int] = field( - default=None, - metadata={ - "help": ( - "The maximum total sequence length for validation target text after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." - "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " - "during ``evaluate`` and ``predict``." - ) - }, - ) - pad_to_max_length: bool = field( - default=False, - metadata={ - "help": ( - "Whether to pad all samples to model maximum sentence length. " - "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " - "efficient on GPU but very bad for TPU." - ) - }, - ) - max_train_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ) - }, - ) - max_eval_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ) - }, - ) - max_predict_samples: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For debugging purposes or quicker training, truncate the number of prediction examples to this " - "value if set." - ) - }, - ) - num_beams: Optional[int] = field( - default=None, - metadata={ - "help": ( - "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " - "which is used during ``evaluate`` and ``predict``." - ) - }, - ) - ignore_pad_token_for_loss: bool = field( - default=True, - metadata={ - "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." - }, - ) - source_prefix: Optional[str] = field( - default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} - ) - forced_bos_token: Optional[str] = field( - default=None, - metadata={ - "help": ( - "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for" - " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to" - " be the target language token.(Usually it is the target language token)" - ) - }, - ) - - def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - elif self.source_lang is None or self.target_lang is None: - raise ValueError("Need to specify the source language and the target language.") - - # accepting both json and jsonl file extensions, as - # many jsonlines files actually have a .json extension - valid_extensions = ["json", "jsonl"] - - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in valid_extensions, "`train_file` should be a jsonlines file." - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - assert extension in valid_extensions, "`validation_file` should be a jsonlines file." - if self.val_max_target_length is None: - self.val_max_target_length = self.max_target_length - - -def main(): - # See all possible arguments in src/transformers/training_args.py - # or by passing the --help flag to this script. - # We now keep distinct sets of args, for a cleaner separation of concerns. - - # =============================================== PiPPy change start =============================================== - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PiPPySeq2SeqTrainingArguments)) - # ================================================ PiPPy change end ================================================ - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() - - # =============================================== PiPPy change start =============================================== - run_pippy(run_master, training_args, model_args, data_args) - - -def run_master(pp_ranks, training_args, model_args, data_args): - # ================================================ PiPPy change end ================================================ - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_translation", model_args, data_args) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - - log_level = training_args.get_process_log_level() - logger.setLevel(log_level) - datasets.utils.logging.set_verbosity(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) - logger.info(f"Training/evaluation parameters {training_args}") - - if data_args.source_prefix is None and model_args.model_name_or_path in [ - "t5-small", - "t5-base", - "t5-large", - "t5-3b", - "t5-11b", - ]: - logger.warning( - "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " - "`--source_prefix 'translate English to German: ' `" - ) - - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) - - # Set seed before initializing model. - set_seed(training_args.seed) - - # Get the datasets: you can either provide your own JSON training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub). - # - # For translation, only JSON files are supported, with one field named "translation" containing two keys for the - # source and target languages (unless you adapt what follows). - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - data_files = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - extension = data_args.train_file.split(".")[-1] - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = data_args.validation_file.split(".")[-1] - if data_args.test_file is not None: - data_files["test"] = data_args.test_file - extension = data_args.test_file.split(".")[-1] - raw_datasets = load_dataset( - extension, - data_files=data_files, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model and tokenizer - # - # Distributed training: - # The .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast_tokenizer, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - model = AutoModelForSeq2SeqLM.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - ) - - model.resize_token_embeddings(len(tokenizer)) - - model.to(training_args.device) - - # =============================================== PiPPy change start =============================================== - logger.info("[PiPPy] Splitting model ...") - num_ranks=len(pp_ranks) - model_name : str = model.__class__.__name__ - if model_name.startswith("T5"): - t5.split(model, num_ranks) - elif model_name.startswith("Bart"): - bart.split(model, num_ranks) - else: - raise ValueError(f"Split method does not exist for model {model_name}") - - # Prepare concrete args in case FX tracing needs them - inputs = ['input_ids', 'decoder_input_ids', 'labels', 'attention_mask'] - concrete_args = pippy.create_default_args(model, - except_keys=inputs) - output_chunk_spec = {'loss': sum_reducer, - 'logits': TensorChunkSpec(0), - 'encoder_last_hidden_state': TensorChunkSpec(0), - } - if model_name.startswith("T5") and model.config.use_cache: - # past_key_values, optional, returned when use_cache=True is passed or when config.use_cache=True. - output_chunk_spec['past_key_values'] = [ - [TensorChunkSpec(0) for _ in range(model.config.num_decoder_layers)] for _ in range(4) - ] - - # Compile into pipeline parallel, distributed model - pipe_mod = pippy.compile( - model, - num_ranks, - num_chunks=training_args.chunks or num_ranks, - ranks=pp_ranks, - tracer=PiPPyHFTracer(), - output_chunk_spec=output_chunk_spec, - checkpoint=bool(training_args.checkpoint), - concrete_args=concrete_args, - ) - pipe_mod.init_data_parallel(dp_group_size=training_args.dp_group_size) - pipe_mod.config = model.config - model = pipe_mod - - # ================================================ PiPPy change end ================================================ - - # Set decoder_start_token_id - if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): - if isinstance(tokenizer, MBartTokenizer): - model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] - else: - model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) - - if model.config.decoder_start_token_id is None: - raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") - - prefix = data_args.source_prefix if data_args.source_prefix is not None else "" - - # Preprocessing the datasets. - # We need to tokenize inputs and targets. - if training_args.do_train: - column_names = raw_datasets["train"].column_names - elif training_args.do_eval: - column_names = raw_datasets["validation"].column_names - elif training_args.do_predict: - column_names = raw_datasets["test"].column_names - else: - logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") - return - - # For translation we set the codes of our source and target languages (only useful for mBART, the others will - # ignore those attributes). - if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): - assert data_args.target_lang is not None and data_args.source_lang is not None, ( - f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and " - "--target_lang arguments." - ) - - tokenizer.src_lang = data_args.source_lang - tokenizer.tgt_lang = data_args.target_lang - - # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token - # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. - forced_bos_token_id = ( - tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None - ) - model.config.forced_bos_token_id = forced_bos_token_id - - # Get the language codes for input/target. - source_lang = data_args.source_lang.split("_")[0] - target_lang = data_args.target_lang.split("_")[0] - - # Temporarily set max_target_length for training. - max_target_length = data_args.max_target_length - padding = "max_length" if data_args.pad_to_max_length else False - - # =============================================== PiPPy change start =============================================== - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return self._shift_right(labels) - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id - - assert ( - decoder_start_token_id is not None - ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" - - # shift inputs to the right - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id - - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" - - return shifted_input_ids - - model.prepare_decoder_input_ids_from_labels = types.MethodType(prepare_decoder_input_ids_from_labels, model) - model._shift_right = types.MethodType(_shift_right, model) - # ================================================ PiPPy change end ================================================ - - if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): - logger.warning( - "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" - f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" - ) - - def preprocess_function(examples): - inputs = [ex[source_lang] for ex in examples["translation"]] - targets = [ex[target_lang] for ex in examples["translation"]] - inputs = [prefix + inp for inp in inputs] - model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) - - # Tokenize targets with the `text_target` keyword argument - labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) - - # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore - # padding in the loss. - if padding == "max_length" and data_args.ignore_pad_token_for_loss: - labels["input_ids"] = [ - [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] - ] - - model_inputs["labels"] = labels["input_ids"] - return model_inputs - - if training_args.do_train: - if "train" not in raw_datasets: - raise ValueError("--do_train requires a train dataset") - train_dataset = raw_datasets["train"] - if data_args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) - with training_args.main_process_first(desc="train dataset map pre-processing"): - train_dataset = train_dataset.map( - preprocess_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on train dataset", - ) - - if training_args.do_eval: - max_target_length = data_args.val_max_target_length - if "validation" not in raw_datasets: - raise ValueError("--do_eval requires a validation dataset") - eval_dataset = raw_datasets["validation"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) - with training_args.main_process_first(desc="validation dataset map pre-processing"): - eval_dataset = eval_dataset.map( - preprocess_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on validation dataset", - ) - - if training_args.do_predict: - max_target_length = data_args.val_max_target_length - if "test" not in raw_datasets: - raise ValueError("--do_predict requires a test dataset") - predict_dataset = raw_datasets["test"] - if data_args.max_predict_samples is not None: - max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) - predict_dataset = predict_dataset.select(range(max_predict_samples)) - with training_args.main_process_first(desc="prediction dataset map pre-processing"): - predict_dataset = predict_dataset.map( - preprocess_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on prediction dataset", - ) - - # Data collator - label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id - if data_args.pad_to_max_length: - data_collator = default_data_collator - else: - data_collator = DataCollatorForSeq2Seq( - tokenizer, - model=model, - label_pad_token_id=label_pad_token_id, - pad_to_multiple_of=8 if training_args.fp16 else None, - ) - - # Metric - metric = evaluate.load("sacrebleu") - - def postprocess_text(preds, labels): - preds = [pred.strip() for pred in preds] - labels = [[label.strip()] for label in labels] - - return preds, labels - - def compute_metrics(eval_preds): - preds, labels = eval_preds - if isinstance(preds, tuple): - preds = preds[0] - decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) - if data_args.ignore_pad_token_for_loss: - # Replace -100 in the labels as we can't decode them. - labels = np.where(labels != -100, labels, tokenizer.pad_token_id) - decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - - # Some simple post-processing - decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) - - result = metric.compute(predictions=decoded_preds, references=decoded_labels) - result = {"bleu": result["score"]} - - prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] - result["gen_len"] = np.mean(prediction_lens) - result = {k: round(v, 4) for k, v in result.items()} - return result - - # Initialize our Trainer - # =============================================== PiPPy change start =============================================== - trainer = PiPPySeq2SeqTrainer( - # ================================================ PiPPy change end ================================================ - model=model, - args=training_args, - train_dataset=train_dataset if training_args.do_train else None, - eval_dataset=eval_dataset if training_args.do_eval else None, - tokenizer=tokenizer, - data_collator=data_collator, - compute_metrics=compute_metrics if training_args.predict_with_generate else None, - ) - # =============================================== PiPPy change start =============================================== - trainer._signature_columns = inputs - # ================================================ PiPPy change end ================================================ - - # Training - if training_args.do_train: - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - # TODO: overwrite save_model method so that it does not pickle process group - #trainer.save_model() # Saves the tokenizer too for easy upload - - metrics = train_result.metrics - max_train_samples = ( - data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) - ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) - - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - # Evaluation - results = {} - max_length = ( - training_args.generation_max_length - if training_args.generation_max_length is not None - else data_args.val_max_target_length - ) - num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams - if training_args.do_eval: - logger.info("*** Evaluate ***") - - metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") - max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - if training_args.do_predict: - logger.info("*** Predict ***") - - predict_results = trainer.predict( - predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams - ) - metrics = predict_results.metrics - max_predict_samples = ( - data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) - ) - metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) - - trainer.log_metrics("predict", metrics) - trainer.save_metrics("predict", metrics) - - if trainer.is_world_process_zero(): - if training_args.predict_with_generate: - predictions = tokenizer.batch_decode( - predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - predictions = [pred.strip() for pred in predictions] - output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") - with open(output_prediction_file, "w", encoding="utf-8") as writer: - writer.write("\n".join(predictions)) - - kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"} - if data_args.dataset_name is not None: - kwargs["dataset_tags"] = data_args.dataset_name - if data_args.dataset_config_name is not None: - kwargs["dataset_args"] = data_args.dataset_config_name - kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" - else: - kwargs["dataset"] = data_args.dataset_name - - languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] - if len(languages) > 0: - kwargs["language"] = languages - - if training_args.push_to_hub: - trainer.push_to_hub(**kwargs) - else: - trainer.create_model_card(**kwargs) - - return results - - -def _mp_fn(index): - # For xla_spawn (TPUs) - main() - - -if __name__ == "__main__": - main() diff --git a/examples/hf/translation/run_translation_no_trainer.py b/examples/hf/translation/run_translation_no_trainer.py deleted file mode 100644 index a6b0988f6..000000000 --- a/examples/hf/translation/run_translation_no_trainer.py +++ /dev/null @@ -1,743 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning a 🤗 Transformers model on text translation. -""" -# You can also adapt this script on your own text translation task. Pointers for this are left as comments. - -import argparse -import json -import logging -import math -import os -import random -from pathlib import Path - -import datasets -import numpy as np -import torch -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm - -import evaluate -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import set_seed -from huggingface_hub import Repository -from transformers import ( - CONFIG_MAPPING, - MODEL_MAPPING, - AutoConfig, - AutoModelForSeq2SeqLM, - AutoTokenizer, - DataCollatorForSeq2Seq, - MBartTokenizer, - MBartTokenizerFast, - SchedulerType, - default_data_collator, - get_scheduler, -) -from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry -from transformers.utils.versions import require_version - - -# Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.22.0.dev0") - -logger = get_logger(__name__) -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") - -# You should update this to your particular problem to have better documentation of `model_type` -MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -# Parsing input arguments -def parse_args(): - - parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help="The name of the dataset to use (via the datasets library).", - ) - - parser.add_argument( - "--predict_with_generate", - type=bool, - default=True, - help="", - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The configuration name of the dataset to use (via the datasets library).", - ) - parser.add_argument( - "--train_file", type=str, default=None, help="A csv or a json file containing the training data." - ) - - parser.add_argument( - "--num_beams", - type=int, - default=None, - help=( - "Number of beams to use for evaluation. This argument will be " - "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``." - ), - ) - - parser.add_argument( - "--max_source_length", - type=int, - default=1024, - help=( - "The maximum total input sequence length after " - "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded." - ), - ) - parser.add_argument( - "--max_target_length", - type=int, - default=128, - help=( - "The maximum total sequence length for target text after " - "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." - "during ``evaluate`` and ``predict``." - ), - ) - parser.add_argument( - "--val_max_target_length", - type=int, - default=None, - help=( - "The maximum total sequence length for validation " - "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be " - "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` " - "param of ``model.generate``, which is used during ``evaluate`` and ``predict``." - ), - ) - parser.add_argument( - "--pad_to_max_length", - type=bool, - default=False, - help=( - "Whether to pad all samples to model maximum sentence " - "length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More" - "efficient on GPU but very bad for TPU." - ), - ) - parser.add_argument( - "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." - ) - parser.add_argument( - "--ignore_pad_token_for_loss", - type=bool, - default=True, - help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.", - ) - parser.add_argument("--source_lang", type=str, default=None, help="Source language id for translation.") - parser.add_argument("--target_lang", type=str, default=None, help="Target language id for translation.") - parser.add_argument( - "--source_prefix", - type=str, - default=None, - help="A prefix to add before every source text (useful for T5 models).", - ) - parser.add_argument( - "--preprocessing_num_workers", - type=int, - default=None, - help="The number of processes to use for the preprocessing.", - ) - parser.add_argument( - "--overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets" - ) - parser.add_argument( - "--max_length", - type=int, - default=128, - help=( - "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," - " sequences shorter will be padded if `--pad_to_max_lengh` is passed." - ), - ) - parser.add_argument( - "--model_name_or_path", - type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=False, - ) - parser.add_argument( - "--config_name", - type=str, - default=None, - help="Pretrained config name or path if not the same as model_name", - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--use_slow_tokenizer", - action="store_true", - help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", - ) - parser.add_argument( - "--per_device_train_batch_size", - type=int, - default=8, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--per_device_eval_batch_size", - type=int, - default=8, - help="Batch size (per device) for the evaluation dataloader.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--lr_scheduler_type", - type=SchedulerType, - default="linear", - help="The scheduler type to use.", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], - ) - parser.add_argument( - "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--model_type", - type=str, - default=None, - help="Model type to use if training from scratch.", - choices=MODEL_TYPES, - ) - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument( - "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." - ) - parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--checkpointing_steps", - type=str, - default=None, - help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help="If the training should continue from a checkpoint folder.", - ) - parser.add_argument( - "--with_tracking", - action="store_true", - help="Whether to enable experiment trackers for logging.", - ) - parser.add_argument( - "--report_to", - type=str, - default="all", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' - "Only applicable when `--with_tracking` is passed." - ), - ) - args = parser.parse_args() - - # Sanity checks - - if args.dataset_name is None and args.train_file is None and args.validation_file is None: - raise ValueError("Need either a task name or a training/validation file.") - - if args.train_file is not None: - extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." - if args.validation_file is not None: - extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - - if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." - - return args - - -def main(): - # Parse the arguments - args = parse_args() - - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - send_example_telemetry("run_translation_no_trainer", args) - - # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. - # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers - # in the environment - accelerator = ( - Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() - ) - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - accelerator.wait_for_everyone() - - # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called - # 'text' is found. You can easily tweak this behavior (see below). - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) - else: - data_files = {} - if args.train_file is not None: - data_files["train"] = args.train_file - if args.validation_file is not None: - data_files["validation"] = args.validation_file - extension = args.train_file.split(".")[-1] - raw_datasets = load_dataset(extension, data_files=data_files) - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model and tokenizer - # - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - if args.config_name: - config = AutoConfig.from_pretrained(args.config_name) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained(args.model_name_or_path) - else: - config = CONFIG_MAPPING[args.model_type]() - logger.warning("You are instantiating a new config instance from scratch.") - - if args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) - elif args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) - - if args.model_name_or_path: - model = AutoModelForSeq2SeqLM.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - ) - else: - logger.info("Training new model from scratch") - model = AutoModelForSeq2SeqLM.from_config(config) - - model.resize_token_embeddings(len(tokenizer)) - - # Set decoder_start_token_id - if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): - assert ( - args.target_lang is not None and args.source_lang is not None - ), "mBart requires --target_lang and --source_lang" - if isinstance(tokenizer, MBartTokenizer): - model.config.decoder_start_token_id = tokenizer.lang_code_to_id[args.target_lang] - else: - model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(args.target_lang) - - if model.config.decoder_start_token_id is None: - raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") - - prefix = args.source_prefix if args.source_prefix is not None else "" - - # Preprocessing the datasets. - # First we tokenize all the texts. - column_names = raw_datasets["train"].column_names - - # For translation we set the codes of our source and target languages (only useful for mBART, the others will - # ignore those attributes). - if isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): - if args.source_lang is not None: - tokenizer.src_lang = args.source_lang - if args.target_lang is not None: - tokenizer.tgt_lang = args.target_lang - - # Get the language codes for input/target. - source_lang = args.source_lang.split("_")[0] - target_lang = args.target_lang.split("_")[0] - - padding = "max_length" if args.pad_to_max_length else False - - # Temporarily set max_target_length for training. - max_target_length = args.max_target_length - padding = "max_length" if args.pad_to_max_length else False - - def preprocess_function(examples): - inputs = [ex[source_lang] for ex in examples["translation"]] - targets = [ex[target_lang] for ex in examples["translation"]] - inputs = [prefix + inp for inp in inputs] - model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) - - # Tokenize targets with the `text_target` keyword argument - labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) - - # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore - # padding in the loss. - if padding == "max_length" and args.ignore_pad_token_for_loss: - labels["input_ids"] = [ - [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] - ] - - model_inputs["labels"] = labels["input_ids"] - return model_inputs - - with accelerator.main_process_first(): - processed_datasets = raw_datasets.map( - preprocess_function, - batched=True, - num_proc=args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not args.overwrite_cache, - desc="Running tokenizer on dataset", - ) - - train_dataset = processed_datasets["train"] - eval_dataset = processed_datasets["validation"] - - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # DataLoaders creation: - label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id - if args.pad_to_max_length: - # If padding was already done ot max length, we use the default data collator that will just convert everything - # to tensors. - data_collator = default_data_collator - else: - # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of - # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple - # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). - data_collator = DataCollatorForSeq2Seq( - tokenizer, - model=model, - label_pad_token_id=label_pad_token_id, - pad_to_multiple_of=8 if accelerator.use_fp16 else None, - ) - - train_dataloader = DataLoader( - train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size - ) - eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.max_train_steps, - ) - - # Prepare everything with our `accelerator`. - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - # Figure out how many steps we should save the Accelerator states - if hasattr(args.checkpointing_steps, "isdigit"): - checkpointing_steps = args.checkpointing_steps - if args.checkpointing_steps.isdigit(): - checkpointing_steps = int(args.checkpointing_steps) - else: - checkpointing_steps = None - - # We need to initialize the trackers we use, and also store our configuration. - # We initialize the trackers only on main process because `accelerator.log` - # only logs on main process and we don't want empty logs/runs on other processes. - if args.with_tracking: - if accelerator.is_main_process: - experiment_config = vars(args) - # TensorBoard cannot log Enums, need the raw value - experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value - accelerator.init_trackers("translation_no_trainer", experiment_config) - - metric = evaluate.load("sacrebleu") - - def postprocess_text(preds, labels): - preds = [pred.strip() for pred in preds] - labels = [[label.strip()] for label in labels] - - return preds, labels - - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - completed_steps = 0 - starting_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") - accelerator.load_state(args.resume_from_checkpoint) - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] - dirs.sort(key=os.path.getctime) - path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last - # Extract `epoch_{i}` or `step_{i}` - training_difference = os.path.splitext(path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - else: - resume_step = int(training_difference.replace("step_", "")) - starting_epoch = resume_step // len(train_dataloader) - resume_step -= starting_epoch * len(train_dataloader) - - for epoch in range(starting_epoch, args.num_train_epochs): - model.train() - if args.with_tracking: - total_loss = 0 - for step, batch in enumerate(train_dataloader): - # We need to skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == starting_epoch: - if resume_step is not None and step < resume_step: - completed_steps += 1 - continue - outputs = model(**batch) - loss = outputs.loss - # We keep track of the loss at each epoch - if args.with_tracking: - total_loss += loss.detach().float() - loss = loss / args.gradient_accumulation_steps - accelerator.backward(loss) - if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - progress_bar.update(1) - completed_steps += 1 - - if isinstance(checkpointing_steps, int): - if completed_steps % checkpointing_steps == 0: - output_dir = f"step_{completed_steps }" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - - if completed_steps >= args.max_train_steps: - break - - model.eval() - - if args.val_max_target_length is None: - args.val_max_target_length = args.max_target_length - - gen_kwargs = { - "max_length": args.val_max_target_length if args is not None else config.max_length, - "num_beams": args.num_beams, - } - samples_seen = 0 - for step, batch in enumerate(eval_dataloader): - with torch.no_grad(): - generated_tokens = accelerator.unwrap_model(model).generate( - batch["input_ids"], - attention_mask=batch["attention_mask"], - **gen_kwargs, - ) - - generated_tokens = accelerator.pad_across_processes( - generated_tokens, dim=1, pad_index=tokenizer.pad_token_id - ) - labels = batch["labels"] - if not args.pad_to_max_length: - # If we did not pad to max length, we need to pad the labels too - labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id) - - generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() - labels = accelerator.gather(labels).cpu().numpy() - - if args.ignore_pad_token_for_loss: - # Replace -100 in the labels as we can't decode them. - labels = np.where(labels != -100, labels, tokenizer.pad_token_id) - - decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) - decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - - decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) - - # If we are in a multiprocess environment, the last batch has duplicates - if accelerator.num_processes > 1: - if step == len(eval_dataloader) - 1: - decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen] - decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen] - else: - samples_seen += len(decoded_labels) - - metric.add_batch(predictions=decoded_preds, references=decoded_labels) - eval_metric = metric.compute() - logger.info({"bleu": eval_metric["score"]}) - - if args.with_tracking: - accelerator.log( - { - "bleu": eval_metric["score"], - "train_loss": total_loss.item() / len(train_dataloader), - "epoch": epoch, - "step": completed_steps, - }, - step=completed_steps, - ) - - if args.push_to_hub and epoch < args.num_train_epochs - 1: - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - repo.push_to_hub( - commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True - ) - - if args.checkpointing_steps == "epoch": - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerator.save_state(output_dir) - - if args.output_dir is not None: - accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(args.output_dir) - if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) - with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: - json.dump({"eval_bleu": eval_metric["score"]}, f) - - -if __name__ == "__main__": - - main() diff --git a/examples/hf/translation/t5-3b_config.json b/examples/hf/translation/t5-3b_config.json deleted file mode 100644 index ff18a0695..000000000 --- a/examples/hf/translation/t5-3b_config.json +++ /dev/null @@ -1,58 +0,0 @@ -{ - "architectures": [ - "T5WithLMHeadModel" - ], - "d_ff": 16384, - "d_kv": 128, - "d_model": 1024, - "decoder_start_token_id": 0, - "dense_act_fn": "relu", - "dropout_rate": 0.1, - "eos_token_id": 1, - "feed_forward_proj": "relu", - "initializer_factor": 1.0, - "is_encoder_decoder": true, - "is_gated_act": false, - "layer_norm_epsilon": 1e-06, - "model_type": "t5", - "n_positions": 512, - "num_decoder_layers": 24, - "num_heads": 32, - "num_layers": 24, - "output_past": true, - "pad_token_id": 0, - "relative_attention_max_distance": 128, - "relative_attention_num_buckets": 32, - "task_specific_params": { - "summarization": { - "early_stopping": true, - "length_penalty": 2.0, - "max_length": 200, - "min_length": 30, - "no_repeat_ngram_size": 3, - "num_beams": 4, - "prefix": "summarize: " - }, - "translation_en_to_de": { - "early_stopping": true, - "max_length": 300, - "num_beams": 4, - "prefix": "translate English to German: " - }, - "translation_en_to_fr": { - "early_stopping": true, - "max_length": 300, - "num_beams": 4, - "prefix": "translate English to French: " - }, - "translation_en_to_ro": { - "early_stopping": true, - "max_length": 300, - "num_beams": 4, - "prefix": "translate English to Romanian: " - } - }, - "transformers_version": "4.22.0.dev0", - "use_cache": false, - "vocab_size": 32128 -} diff --git a/examples/resnet/.gitignore b/examples/resnet/.gitignore deleted file mode 100644 index 1269488f7..000000000 --- a/examples/resnet/.gitignore +++ /dev/null @@ -1 +0,0 @@ -data diff --git a/examples/resnet/local_resnet.py b/examples/resnet/local_resnet.py deleted file mode 100644 index c3c103f1f..000000000 --- a/examples/resnet/local_resnet.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os - -import torch -import torch.nn as nn -import torch.optim as optim -from torchvision import datasets, transforms # type: ignore -from resnet import ResNet50 -from tqdm import tqdm # type: ignore - -USE_TQDM = bool(int(os.getenv('USE_TQDM', '1'))) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--max_epochs', type=int, default=10) - parser.add_argument('--batch_size', type=int, default=25) - args = parser.parse_args() - - chunks = 4 - batch_size = args.batch_size * chunks - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print("Using device:", device) - - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - - train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform) - valid_data = datasets.CIFAR10('./data', train=False, transform=transform) - - train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size) - valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) - - model = ResNet50().to(device) - - criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters()) - - loaders = {"train": train_dataloader, "valid": valid_dataloader} - - for epoch in range(args.max_epochs): - print(f"Epoch: {epoch + 1}") - for k, dataloader in loaders.items(): - epoch_correct = 0 - epoch_all = 0 - for i, (x_batch, y_batch) in enumerate(tqdm(dataloader) if USE_TQDM else dataloader): - x_batch = x_batch.to(device) - y_batch = y_batch.to(device) - if k == "train": - model.train() - optimizer.zero_grad() - outp = model(x_batch) - else: - model.eval() - with torch.no_grad(): - outp = model(x_batch) - preds = outp.argmax(-1) - correct = (preds == y_batch).sum() - all = len(y_batch) - epoch_correct += correct.item() - epoch_all += all - if k == "train": - loss = criterion(outp, y_batch) - loss.backward() - optimizer.step() - print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}") diff --git a/examples/resnet/pippy_resnet.py b/examples/resnet/pippy_resnet.py deleted file mode 100644 index 597912880..000000000 --- a/examples/resnet/pippy_resnet.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -from functools import reduce - -import torch -from torch import optim -from torch.nn.functional import cross_entropy -from torchvision import datasets, transforms # type: ignore -from tqdm import tqdm # type: ignore - -import pippy.fx -from pippy import run_pippy -from pippy.IR import MultiUseParameterConfig, Pipe, LossWrapper, PipeSplitWrapper, annotate_split_points -from pippy.PipelineDriver import PipelineDriverFillDrain, PipelineDriver1F1B, PipelineDriverInterleaved1F1B, \ - PipelineDriverBase -from pippy.events import EventsContext -from pippy.microbatch import sum_reducer, TensorChunkSpec -from pippy.visualizer import events_to_json -from resnet import ResNet50 - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - 'FillDrain': PipelineDriverFillDrain, - '1F1B': PipelineDriver1F1B, - 'Interleaved1F1B': PipelineDriverInterleaved1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - -USE_TQDM = bool(int(os.getenv('USE_TQDM', '1'))) - - -def run_master(_, args): - MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE if args.replicate else MultiUseParameterConfig.TRANSMIT - print(f'REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}') - print("Using schedule:", args.schedule) - print("Using device:", args.device) - - number_of_workers = 4 - all_worker_ranks = list(range(1, 1 + number_of_workers)) # exclude master rank = 0 - chunks = len(all_worker_ranks) - batch_size = args.batch_size * chunks - - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - - train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform) - valid_data = datasets.CIFAR10('./data', train=False, transform=transform) - - train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size) - valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size) - - class OutputLossWrapper(LossWrapper): - def __init__(self, module, loss_fn): - super().__init__(module, loss_fn) - - def forward(self, input, target): - output = self.module(input) - loss = self.loss_fn(output, target) - # Here we use a dict with the "loss" keyword so that PiPPy can automatically find the loss field when - # generating the backward pass - return {"output": output, "loss": loss} - - model = ResNet50() - - annotate_split_points(model, { - 'layer1': PipeSplitWrapper.SplitPoint.END, - 'layer2': PipeSplitWrapper.SplitPoint.END, - 'layer3': PipeSplitWrapper.SplitPoint.END, - }) - - wrapper = OutputLossWrapper(model, cross_entropy) - - pipe = Pipe.from_tracing(wrapper, MULTI_USE_PARAM_CONFIG) - pipe.to(args.device) - - output_chunk_spec = (TensorChunkSpec(0), sum_reducer) - pipe_driver: PipelineDriverBase = schedules[args.schedule](pipe, chunks, - len(all_worker_ranks), - all_ranks=all_worker_ranks, - output_chunk_spec=output_chunk_spec, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint)) - - optimizer = pipe_driver.instantiate_optimizer(optim.Adam, lr=1e-3, betas=(0.9, 0.999), eps=1e-8) - - loaders = { - "train": train_dataloader, - "valid": valid_dataloader - } - - this_file_name = os.path.splitext(os.path.basename(__file__))[0] - pipe_visualized_filename = f"{this_file_name}_visualized_{args.rank}.json" - batches_events_contexts = [] - - for epoch in range(args.max_epochs): - print(f"Epoch: {epoch + 1}") - for k, dataloader in loaders.items(): - epoch_correct = 0 - epoch_all = 0 - for i, (x_batch, y_batch) in enumerate(tqdm(dataloader) if USE_TQDM else dataloader): - x_batch = x_batch.to(args.device) - y_batch = y_batch.to(args.device) - if k == "train": - pipe_driver.train() - optimizer.zero_grad() - outp, _ = pipe_driver(x_batch, y_batch) - preds = outp.argmax(-1) - correct = (preds == y_batch).sum() - all = len(y_batch) - epoch_correct += correct.item() - epoch_all += all - optimizer.step() - else: - pipe_driver.eval() - with torch.no_grad(): - outp, _ = pipe_driver(x_batch, y_batch) - preds = outp.argmax(-1) - correct = (preds == y_batch).sum() - all = len(y_batch) - epoch_correct += correct.item() - epoch_all += all - - if args.visualize: - batches_events_contexts.append(pipe_driver.retrieve_events()) - print(f"Loader: {k}. Accuracy: {epoch_correct / epoch_all}") - - if args.visualize: - all_events_contexts: EventsContext = reduce(lambda c1, c2: EventsContext().update(c1).update(c2), - batches_events_contexts, EventsContext()) - with open(pipe_visualized_filename, "w") as f: - f.write(events_to_json(all_events_contexts)) - print(f"Saved {pipe_visualized_filename}") - print('Finished') - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 5))) - parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) - parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) - - parser.add_argument('--max_epochs', type=int, default=10) - parser.add_argument('--batch_size', type=int, default=10) - - parser.add_argument('-s', '--schedule', type=str, default=list(schedules.keys())[0], choices=schedules.keys()) - parser.add_argument('--replicate', type=int, default=int(os.getenv("REPLICATE", '0'))) - parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) - parser.add_argument('--visualize', type=int, default=0, choices=[0, 1]) - parser.add_argument('--record_mem_dumps', type=int, default=0, choices=[0, 1]) - parser.add_argument('--checkpoint', type=int, default=0, choices=[0, 1]) - args = parser.parse_args() - args.world_size = 5 # "This program requires exactly 4 workers + 1 master" - - run_pippy(run_master, args) diff --git a/examples/resnet/pippy_sbatch.sh b/examples/resnet/pippy_sbatch.sh deleted file mode 100755 index ad1a7b769..000000000 --- a/examples/resnet/pippy_sbatch.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -#SBATCH --job-name=mnist_pippy - -#SBATCH --open-mode=append - -#SBATCH --partition=train - -#SBATCH --nodes=1 - -#SBATCH --ntasks-per-node=8 - -#SBATCH --cpus-per-task=12 - -#SBATCH --gpus-per-node=8 - -#SBATCH --time=1:00:00 - -srun --label pippy_wrapper.sh diff --git a/examples/resnet/pippy_wrapper.sh b/examples/resnet/pippy_wrapper.sh deleted file mode 100755 index c1be5c43a..000000000 --- a/examples/resnet/pippy_wrapper.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -export MASTER_PORT=29500 -export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) -export LOCAL_RANK=${SLURM_LOCALID} -export CUDA_VISIBLE_DEVICES=${SLURM_LOCALID} -export WORLD_SIZE=${SLURM_NTASKS} -export RANK=${SLURM_PROCID} - -export USE_TQDM=0 - -python -u pippy_resnet.py --record_mem_dumps=0 --checkpoint=0 diff --git a/examples/resnet/resnet.py b/examples/resnet/resnet.py deleted file mode 100644 index d19d88e41..000000000 --- a/examples/resnet/resnet.py +++ /dev/null @@ -1,150 +0,0 @@ -# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py - -# MIT License - -# Copyright (c) 2017 liukuang - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -'''ResNet in PyTorch. - -For Pre-activation ResNet, see 'preact_resnet.py'. - -Reference: -[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun - Deep Residual Learning for Image Recognition. arXiv:1512.03385 -''' -import torch.nn as nn -import torch.nn.functional as F - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, self.expansion * - planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion*planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion*planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = F.relu(self.bn2(self.conv2(out))) - out = self.bn3(self.conv3(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class ResNet(nn.Module): - def __init__(self, block, num_blocks, num_classes=10): - super(ResNet, self).__init__() - self.in_planes = 64 - - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.relu = nn.ReLU(inplace=True) - self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) - self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.flatten = nn.Flatten() - self.linear = nn.Linear(512*block.expansion, num_classes) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1]*(num_blocks-1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - out = self.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) - out = self.avgpool(out) - out = self.flatten(out) # out.view(out.size(0), -1) - out = self.linear(out) - return out - - -def ResNet18(): - return ResNet(BasicBlock, [2, 2, 2, 2]) - - -def ResNet34(): - return ResNet(BasicBlock, [3, 4, 6, 3]) - - -def ResNet50(): - return ResNet(Bottleneck, [3, 4, 6, 3]) - - -def ResNet101(): - return ResNet(Bottleneck, [3, 4, 23, 3]) - - -def ResNet152(): - return ResNet(Bottleneck, [3, 8, 36, 3]) From 4f023e0e1310bd3a5bbc127bce3fe66a0ac47b8c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:49:11 -0500 Subject: [PATCH 51/96] Migrate T5 example (#881) ## Description This example requires PyTorch PR https://github.com/pytorch/pytorch/pull/114982 to work, because stage 0 and stage 2 seem to be transmitting non-contiguous tensors. ## Test https://gist.github.com/kwen2501/33fa5723496992691f8b1cc7daaadd89 --- examples/hf/pippy_t5.py | 129 ++++++++ .../hf/t5/pippy_sbatch_16_gpus_per_node.sh | 20 -- .../hf/t5/pippy_sbatch_8_gpus_per_node.sh | 20 -- examples/hf/t5/pippy_t5.py | 282 ------------------ examples/hf/t5/pippy_wrapper.sh | 13 - examples/hf/t5/t5_200m_config.json | 25 -- examples/hf/t5/t5_3b_config.json | 29 -- 7 files changed, 129 insertions(+), 389 deletions(-) create mode 100644 examples/hf/pippy_t5.py delete mode 100755 examples/hf/t5/pippy_sbatch_16_gpus_per_node.sh delete mode 100755 examples/hf/t5/pippy_sbatch_8_gpus_per_node.sh delete mode 100644 examples/hf/t5/pippy_t5.py delete mode 100755 examples/hf/t5/pippy_wrapper.sh delete mode 100644 examples/hf/t5/t5_200m_config.json delete mode 100644 examples/hf/t5/t5_3b_config.json diff --git a/examples/hf/pippy_t5.py b/examples/hf/pippy_t5.py new file mode 100644 index 000000000..9fb6ad148 --- /dev/null +++ b/examples/hf/pippy_t5.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_t5.py + + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import T5ForConditionalGeneration, T5Config + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(t5, nranks): + # Number of encoder layers: t5.config.num_layers + # Number of decoder layers: t5.config.num_decoder_layers + # 6 encoder layers, 6 decoder layers, 12 layers in total + total_layers = t5.config.num_layers + t5.config.num_decoder_layers + layers_per_rank = (total_layers + nranks - 1) // nranks + print(f"Layers per rank = {layers_per_rank}") + nstages = 1 + # Split encoder + for i in range(1, t5.config.num_layers // layers_per_rank): + annotate_split_points( + t5, {f'encoder.block.{i * layers_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}) + nstages += 1 + # Split at the boundary of encoder and decoder + annotate_split_points( + t5, {f'decoder.embed_tokens': PipeSplitWrapper.SplitPoint.BEGINNING}) + nstages += 1 + # Split decoder + for i in range(1, t5.config.num_decoder_layers // layers_per_rank): + annotate_split_points( + t5, {f'decoder.block.{i * layers_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}) + nstages += 1 + assert nstages == nranks, f"nstages = {nstages} nranks = {nranks}" + + +def run(args): + # Model configs + config = T5Config() + print("Using device:", args.device) + + # Create model + model_class = T5ForConditionalGeneration + model_name = "T5ForConditionalGeneration" + t5 = model_class(config) + t5.to(args.device) + t5.eval() + if args.rank == 0: + print(t5.config) + print(f"Total number of params = {get_number_of_params(t5) // 10 ** 6}M") + print(t5) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, t5, model_name, args.batch_size, args.device) + + # Annotate split points + add_split_points(t5, args.world_size) + + # Create pipeline + t5_pipe = Pipe.from_tracing( + t5, + num_chunks=args.chunks, + example_args=(), + example_kwargs=example_inputs, + ) + assert len(list(t5_pipe.split_gm.children())) == args.world_size + if args.rank == 0: + for i, sm in enumerate(t5_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + t5_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(example_inputs["input_ids"]) + elif args.rank == 1: + stage(example_inputs["decoder_input_ids"]) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) diff --git a/examples/hf/t5/pippy_sbatch_16_gpus_per_node.sh b/examples/hf/t5/pippy_sbatch_16_gpus_per_node.sh deleted file mode 100755 index baf3e8714..000000000 --- a/examples/hf/t5/pippy_sbatch_16_gpus_per_node.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -#SBATCH --job-name=t5_pippy - -#SBATCH --open-mode=append - -#SBATCH --partition=train - -#SBATCH --nodes=1 - -#SBATCH --ntasks-per-node=16 - -#SBATCH --cpus-per-task=6 - -#SBATCH --gpus-per-node=16 - -#SBATCH --time=1:00:00 - -srun --label pippy_wrapper.sh diff --git a/examples/hf/t5/pippy_sbatch_8_gpus_per_node.sh b/examples/hf/t5/pippy_sbatch_8_gpus_per_node.sh deleted file mode 100755 index 986d35d21..000000000 --- a/examples/hf/t5/pippy_sbatch_8_gpus_per_node.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -#SBATCH --job-name=t5_pippy - -#SBATCH --open-mode=append - -#SBATCH --partition=train - -#SBATCH --nodes=1 - -#SBATCH --ntasks-per-node=8 - -#SBATCH --cpus-per-task=12 - -#SBATCH --gpus-per-node=8 - -#SBATCH --time=1:00:00 - -srun --label pippy_wrapper.sh diff --git a/examples/hf/t5/pippy_t5.py b/examples/hf/t5/pippy_t5.py deleted file mode 100644 index 81215b095..000000000 --- a/examples/hf/t5/pippy_t5.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -from functools import reduce - -import torch -from transformers import T5ForConditionalGeneration, T5Config - -import pippy -import pippy.fx -import pippy.ModelSplit -from pippy import run_pippy -from pippy.IR import ( - PipeSplitWrapper, - annotate_split_points, -) -from pippy import split_on_size_threshold, split_into_equal_size -from pippy.events import EventsContext -from pippy.hf import PiPPyHFTracer -from pippy.visualizer import events_to_json - - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def print_submod_sizes(model_pipe): - total_params = 0 - for i, sm in enumerate(model_pipe.split_gm.children()): - params = get_number_of_params(sm) - print(f"submod_{i} {params // 10 ** 6}M params") - total_params += params - print(f"total {total_params // 10 ** 6}M params") - - -def get_number_of_params(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def add_split_points(t5, num_submodules): - if num_submodules == 1: - pass - elif num_submodules == 3: - # assert num_submodules == _add_split_points(t5, [16, 30]) - assert num_submodules == _add_split_points(t5, [17, 31]) - elif num_submodules == 4: - assert num_submodules == _add_split_points(t5, [13, 24, 35]) - elif num_submodules == 7: - # assert num_submodules == _add_split_points(t5, [8, 14, 20, 26, 32, 38]) - assert num_submodules == _add_split_points(t5, [9, 15, 21, 27, 33, 39]) - elif num_submodules == 8: - # assert num_submodules == _add_split_points(t5, [7, 13, 19, 25, 31, 37, 43]) - assert num_submodules == _add_split_points( - t5, [9, 14, 19, 24, 29, 34, 39] - ) - elif num_submodules == 15: - # assert num_submodules == _add_split_points(t5, [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42]) - assert num_submodules == _add_split_points( - t5, [1, 5, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42] - ) - elif num_submodules == 16: - # assert num_submodules == _add_split_points(t5, [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 44]) - # assert num_submodules == _add_split_points(t5, [1, 4, 7, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41]) - assert num_submodules == _add_split_points( - t5, [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43] - ) - else: - raise ValueError(f"Unsupported num_submodules = {num_submodules}") - - -def _add_split_points(t5, split_indices): - enc_emb = 1 - num_enc = t5.config.num_layers - dec_emb = 1 - num_dec = t5.config.num_decoder_layers - lm_head = 1 - count = 0 - for index in split_indices: - if index < enc_emb: - # index = 0: do nothing - pass - elif index < enc_emb + num_enc: - if index == enc_emb: - # index = 1: insert a split point after `encoder.embed_tokens` before the first encoder - # to put encoder's dropout with the first encoder and not with encoders' embeddings - annotate_split_points( - t5, - {f"encoder.embed_tokens": PipeSplitWrapper.SplitPoint.END}, - ) - else: - # 1 < index < 1 + num_enc: insert a split point before the `index - enc_emb`-th encoder - annotate_split_points( - t5, - { - f"encoder.block.{index - enc_emb}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - count += 1 - elif index < enc_emb + num_enc + dec_emb + num_dec: - # 1 + num_enc <= index < 1 + num_enc + 1 + num_dec - if index == enc_emb + num_enc: - # index = 1 + num_enc: insert a split point before `decoder.embed_tokens` - annotate_split_points( - t5, - { - f"decoder.embed_tokens": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - elif index == enc_emb + num_enc + dec_emb: - # index = 1 + num_enc + 1: insert a split point after `decoder.embed_tokens` before the first decoder - # to put decoder's dropout with the first decoder and not with decoders' embeddings - annotate_split_points( - t5, - {f"decoder.embed_tokens": PipeSplitWrapper.SplitPoint.END}, - ) - else: - # 1 + num_enc + 1 < index < 1 + num_enc + 1 + num_dec: - # insert a split point before the `index - (enc_emb + num_enc + dec_emb)`-th encoder - annotate_split_points( - t5, - { - f"decoder.block.{index - (enc_emb + num_enc + dec_emb)}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - count += 1 - elif index < enc_emb + num_enc + dec_emb + num_dec + lm_head: - # index = 1 + num_enc + 1 + num_dec: insert a split point before the `lm_head` - annotate_split_points( - t5, {f"lm_head": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - count += 1 - return count + 1 - - -def resolve_pg_per_stage(pp_rank): - assert pippy.utils.dp_pg_per_pp_rank - return pippy.utils.dp_pg_per_pp_rank[pp_rank] - - -def run_master(pp_ranks, args): - torch.manual_seed(42) - print("Using schedule:", args.schedule) - device = args.device - - t5_config = T5Config.from_pretrained(args.model_config) - t5_config.num_layers = args.num_encoder_layers or t5_config.num_layers - t5_config.num_decoder_layers = ( - args.num_decoder_layers or t5_config.num_decoder_layers - ) - t5_config.use_cache = False # don't output `past_key_values` - t5 = T5ForConditionalGeneration(t5_config) - t5.to(device) - print(t5.config) - print(f"T5 total number of params = {get_number_of_params(t5) // 10 ** 6}M") - - num_ranks = len(pp_ranks) - print(f"number_of_workers = {num_ranks}") - - # Specify auto_split policy for use by `pippy.compile` call later - if args.auto_split == "threshold": - split_policy = split_on_size_threshold(490 * 1e6) - elif args.auto_split == "equal_size": - split_policy = split_into_equal_size(num_ranks) - else: - # Manually insert split points before `pippy.compile` call - add_split_points(t5, num_ranks) - split_policy = None - - chunks = args.chunks or num_ranks - bs = args.batch_size * chunks - seq_length = args.seq_length - - torch.manual_seed(args.rank) - inp = torch.empty(bs, seq_length, dtype=torch.long, device=device).random_( - t5.config.vocab_size - ) - - if args.train: - t5_input_dict = { - "input_ids": inp, - "decoder_input_ids": inp, - "labels": torch.empty( - bs, seq_length, dtype=torch.long, device=device - ).random_(t5.config.vocab_size - 1), - } - else: - t5.eval() - t5_input_dict = {"input_ids": inp, "decoder_input_ids": inp} - - concrete_args = pippy.create_default_args(t5, - except_keys=t5_input_dict.keys()) - - print("Instantiating T5 Pipeline") - pipe_driver = pippy.compile( - t5, - num_ranks=num_ranks, - num_chunks=chunks, - schedule=args.schedule, - split_policy=split_policy, - ranks=pp_ranks, - tracer=PiPPyHFTracer(), - checkpoint=bool(args.checkpoint) if args.train else False, - concrete_args=concrete_args, - ) - print_submod_sizes(pipe_driver.pipe) - - print("Running T5 pipeline.") - this_file_name = os.path.splitext(os.path.basename(__file__))[0] - pipe_visualized_filename = f"{this_file_name}_visualized_{args.rank}.json" - batches_events_contexts = [] - for i in range(args.batches): - pipe_driver(**t5_input_dict) - if args.visualize: - batches_events_contexts.append(pipe_driver.retrieve_events()) - - if args.visualize: - all_events_contexts: EventsContext = reduce( - lambda c1, c2: EventsContext().update(c1).update(c2), - batches_events_contexts, - EventsContext(), - ) - with open(pipe_visualized_filename, "w") as f: - f.write(events_to_json(all_events_contexts)) - print(f"Saved {pipe_visualized_filename}") - print("Finished") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 8)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - - model_config = ( - os.path.dirname(os.path.realpath(__file__)) - + "/" - + "t5_200m_config.json" - ) - parser.add_argument("--model_config", type=str, default=model_config) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--batches", type=int, default=1) - parser.add_argument("--chunks", type=int, default=None) - parser.add_argument("--seq_length", type=int, default=16) - - parser.add_argument("--num_encoder_layers", type=int, default=None) - parser.add_argument("--num_decoder_layers", type=int, default=None) - - parser.add_argument( - "-s", - "--schedule", - type=str, - default="FillDrain", - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument("--visualize", type=int, default=1, choices=[0, 1]) - parser.add_argument("--checkpoint", type=int, default=1, choices=[0, 1]) - parser.add_argument("--pp_group_size", type=int, default=8) - parser.add_argument("--auto_split", type=str, default=None) - parser.add_argument( - "--train", - type=int, - default=1, - choices=[0, 1], - help="Choose the mode to run, 1: train, 0: inference", - ) - args = parser.parse_args() - - if (args.pp_group_size > args.world_size): - args.pp_group_size = args.world_size - assert args.world_size % args.pp_group_size == 0 - - args.dp_group_size = args.world_size // args.pp_group_size - - run_pippy(run_master, args) diff --git a/examples/hf/t5/pippy_wrapper.sh b/examples/hf/t5/pippy_wrapper.sh deleted file mode 100755 index a1667fb9b..000000000 --- a/examples/hf/t5/pippy_wrapper.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -export MASTER_PORT=29500 -export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1) -export LOCAL_RANK=${SLURM_LOCALID} -export CUDA_VISIBLE_DEVICES=${SLURM_LOCALID} -export WORLD_SIZE=${SLURM_NTASKS} -export RANK=${SLURM_PROCID} - -python -u pippy_t5.py \ - --model_config=t5_3b_config.json \ - --checkpoint=1 diff --git a/examples/hf/t5/t5_200m_config.json b/examples/hf/t5/t5_200m_config.json deleted file mode 100644 index d5a6daef1..000000000 --- a/examples/hf/t5/t5_200m_config.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "architectures": [ - "T5WithLMHeadModel" - ], - "pad_token_id": 0, - "decoder_start_token_id": 0, - "bos_token_id": 1, - "eos_token_id": 2, - - "num_layers": 6, - "num_decoder_layers": 36, - - "feed_forward_proj": "relu", - "initializer_factor": 1.0, - "is_encoder_decoder": true, - "layer_norm_epsilon": 1e-06, - "model_type": "t5", - "n_positions": 768, - "dropout_rate": 0.1, - "output_past": true, - "relative_attention_num_buckets": 32, - "transformers_version": "4.16.2", - "use_cache": true, - "vocab_size": 32100 -} diff --git a/examples/hf/t5/t5_3b_config.json b/examples/hf/t5/t5_3b_config.json deleted file mode 100644 index 308d7659e..000000000 --- a/examples/hf/t5/t5_3b_config.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "architectures": [ - "T5WithLMHeadModel" - ], - "pad_token_id": 0, - "decoder_start_token_id": 0, - "bos_token_id": 1, - "eos_token_id": 2, - - "d_ff": 12288, - "d_kv": 128, - "d_model": 2048, - "num_layers": 6, - "num_decoder_layers": 36, - "num_heads": 16, - - "feed_forward_proj": "relu", - "initializer_factor": 1.0, - "is_encoder_decoder": true, - "layer_norm_epsilon": 1e-06, - "model_type": "t5", - "n_positions": 768, - "dropout_rate": 0.1, - "output_past": true, - "relative_attention_num_buckets": 32, - "transformers_version": "4.16.2", - "use_cache": true, - "vocab_size": 32100 -} From d41ebbee93252ce5c243f9cf27b985a14f795d16 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:50:09 -0500 Subject: [PATCH 52/96] Add CamemBert example (#884) ``` CamembertForMaskedLM( (roberta): CamembertModel( (embeddings): CamembertEmbeddings( (word_embeddings): Embedding(30522, 768, padding_idx=1) (position_embeddings): Embedding(512, 768, padding_idx=1) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): CamembertEncoder( (layer): ModuleList( (0-11): 12 x CamembertLayer( (attention): CamembertAttention( (self): CamembertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): CamembertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): CamembertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): CamembertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) ) (lm_head): CamembertLMHead( (dense): Linear(in_features=768, out_features=768, bias=True) (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (decoder): Linear(in_features=768, out_features=30522, bias=True) ) ) ``` --- examples/hf/pippy_camemBert.py | 109 +++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_camemBert.py diff --git a/examples/hf/pippy_camemBert.py b/examples/hf/pippy_camemBert.py new file mode 100644 index 000000000..019d64faa --- /dev/null +++ b/examples/hf/pippy_camemBert.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_camemBert.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import CamembertForMaskedLM, CamembertConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(camembert, nranks): + layers_per_rank = camembert.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + camembert, {f"roberta.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = CamembertConfig() + print("Using device:", args.device) + + # Create model + model_class = CamembertForMaskedLM + model_name = "CamembertForMaskedLM" + camembert = model_class(config) + camembert.to(args.device) + camembert.eval() + if args.rank == 0: + print(camembert.config) + print(f"Total number of params = {get_number_of_params(camembert) // 10 ** 6}M") + print(camembert) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, camembert, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(camembert, args.world_size) + + # Create pipeline + camembert_pipe = Pipe.from_tracing( + camembert, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(camembert_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(camembert_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + camembert_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 01018054b8eece4ed97c39d1a05ba064b4f57b6e Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:50:26 -0500 Subject: [PATCH 53/96] Add Deberta example (#885) ``` DebertaForMaskedLM( (deberta): DebertaModel( (embeddings): DebertaEmbeddings( (word_embeddings): Embedding(50265, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): DebertaLayerNorm() (dropout): StableDropout() ) (encoder): DebertaEncoder( (layer): ModuleList( (0-11): 12 x DebertaLayer( (attention): DebertaAttention( (self): DisentangledSelfAttention( (in_proj): Linear(in_features=768, out_features=2304, bias=False) (dropout): StableDropout() ) (output): DebertaSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): DebertaLayerNorm() (dropout): StableDropout() ) ) (intermediate): DebertaIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): DebertaOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): DebertaLayerNorm() (dropout): StableDropout() ) ) ) ) ) (cls): DebertaOnlyMLMHead( (predictions): DebertaLMPredictionHead( (transform): DebertaPredictionHeadTransform( (dense): Linear(in_features=768, out_features=768, bias=True) (transform_act_fn): GELUActivation() (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) ) (decoder): Linear(in_features=768, out_features=50265, bias=True) ) ) ) ``` --- examples/hf/pippy_deberta.py | 109 +++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_deberta.py diff --git a/examples/hf/pippy_deberta.py b/examples/hf/pippy_deberta.py new file mode 100644 index 000000000..e7bca39b9 --- /dev/null +++ b/examples/hf/pippy_deberta.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_deberta.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import DebertaForMaskedLM, DebertaConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(deberta, nranks): + layers_per_rank = deberta.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + deberta, {f"deberta.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = DebertaConfig() + print("Using device:", args.device) + + # Create model + model_class = DebertaForMaskedLM + model_name = "DebertaForMaskedLM" + deberta = model_class(config) + deberta.to(args.device) + deberta.eval() + if args.rank == 0: + print(deberta.config) + print(f"Total number of params = {get_number_of_params(deberta) // 10 ** 6}M") + print(deberta) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, deberta, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(deberta, args.world_size) + + # Create pipeline + deberta_pipe = Pipe.from_tracing( + deberta, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(deberta_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(deberta_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + deberta_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 4ea77e4a57577b079b46a5561abeeeca870aea51 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:50:40 -0500 Subject: [PATCH 54/96] Add DebertaV2 example (#886) ``` DebertaV2ForMaskedLM( (deberta): DebertaV2Model( (embeddings): DebertaV2Embeddings( (word_embeddings): Embedding(50265, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) (dropout): StableDropout() ) (encoder): DebertaV2Encoder( (layer): ModuleList( (0-11): 12 x DebertaV2Layer( (attention): DebertaV2Attention( (self): DisentangledSelfAttention( (query_proj): Linear(in_features=768, out_features=768, bias=True) (key_proj): Linear(in_features=768, out_features=768, bias=True) (value_proj): Linear(in_features=768, out_features=768, bias=True) (dropout): StableDropout() ) (output): DebertaV2SelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) (dropout): StableDropout() ) ) (intermediate): DebertaV2Intermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): DebertaV2Output( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) (dropout): StableDropout() ) ) ) ) ) (cls): DebertaV2OnlyMLMHead( (predictions): DebertaV2LMPredictionHead( (transform): DebertaV2PredictionHeadTransform( (dense): Linear(in_features=768, out_features=768, bias=True) (transform_act_fn): GELUActivation() (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) ) (decoder): Linear(in_features=768, out_features=50265, bias=True) ) ) ) ``` --- examples/hf/pippy_debertaV2.py | 109 +++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_debertaV2.py diff --git a/examples/hf/pippy_debertaV2.py b/examples/hf/pippy_debertaV2.py new file mode 100644 index 000000000..1c8aff883 --- /dev/null +++ b/examples/hf/pippy_debertaV2.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_deberta.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import DebertaV2ForMaskedLM, DebertaConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(deberta, nranks): + layers_per_rank = deberta.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + deberta, {f"deberta.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = DebertaConfig() + print("Using device:", args.device) + + # Create model + model_class = DebertaV2ForMaskedLM + model_name = "DebertaV2ForMaskedLM" + deberta = model_class(config) + deberta.to(args.device) + deberta.eval() + if args.rank == 0: + print(deberta.config) + print(f"Total number of params = {get_number_of_params(deberta) // 10 ** 6}M") + print(deberta) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, deberta, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(deberta, args.world_size) + + # Create pipeline + deberta_pipe = Pipe.from_tracing( + deberta, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(deberta_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(deberta_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + deberta_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 17ba274019b2de57f9910330b6d4a99333e5f7a3 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:51:03 -0500 Subject: [PATCH 55/96] Add DistilBert example (#887) ``` DistilBertForMaskedLM( (activation): GELUActivation() (distilbert): DistilBertModel( (embeddings): Embeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (transformer): Transformer( (layer): ModuleList( (0-5): 6 x TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) ) ) (vocab_transform): Linear(in_features=768, out_features=768, bias=True) (vocab_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (vocab_projector): Linear(in_features=768, out_features=30522, bias=True) (mlm_loss_fct): CrossEntropyLoss() ) ``` --- examples/hf/pippy_distilBert.py | 113 ++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 examples/hf/pippy_distilBert.py diff --git a/examples/hf/pippy_distilBert.py b/examples/hf/pippy_distilBert.py new file mode 100644 index 000000000..32f8455b4 --- /dev/null +++ b/examples/hf/pippy_distilBert.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_distilBert.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import DistilBertForMaskedLM, DistilBertConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(distilbert, nranks): + # The first rank carries the embedding layer + annotate_split_points( + distilbert, {f"distilbert.embeddings": PipeSplitWrapper.SplitPoint.END}) + # 6 Transformer layers divided over the rest 3 ranks + layers_per_rank = distilbert.config.num_hidden_layers // (nranks - 1) + for i in range(1, nranks - 1): + annotate_split_points( + distilbert, {f"distilbert.transformer.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = DistilBertConfig() + print("Using device:", args.device) + + # Create model + model_class = DistilBertForMaskedLM + model_name = "DistilBertForMaskedLM" + distilbert = model_class(config) + distilbert.to(args.device) + distilbert.eval() + if args.rank == 0: + print(distilbert.config) + print(f"Total number of params = {get_number_of_params(distilbert) // 10 ** 6}M") + print(distilbert) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, distilbert, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(distilbert, args.world_size) + + # Create pipeline + distilbert_pipe = Pipe.from_tracing( + distilbert, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(distilbert_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(distilbert_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + distilbert_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From b1fd30c3e9317ee76d56e8a7d1628530c1cc2e1b Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:51:19 -0500 Subject: [PATCH 56/96] Add Electra example (#888) ``` ElectraForCausalLM( (electra): ElectraModel( (embeddings): ElectraEmbeddings( (word_embeddings): Embedding(30522, 128, padding_idx=0) (position_embeddings): Embedding(512, 128) (token_type_embeddings): Embedding(2, 128) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (embeddings_project): Linear(in_features=128, out_features=256, bias=True) (encoder): ElectraEncoder( (layer): ModuleList( (0-11): 12 x ElectraLayer( (attention): ElectraAttention( (self): ElectraSelfAttention( (query): Linear(in_features=256, out_features=256, bias=True) (key): Linear(in_features=256, out_features=256, bias=True) (value): Linear(in_features=256, out_features=256, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): ElectraSelfOutput( (dense): Linear(in_features=256, out_features=256, bias=True) (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): ElectraIntermediate( (dense): Linear(in_features=256, out_features=1024, bias=True) (intermediate_act_fn): GELUActivation() ) (output): ElectraOutput( (dense): Linear(in_features=1024, out_features=256, bias=True) (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) ) (generator_predictions): ElectraGeneratorPredictions( (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dense): Linear(in_features=256, out_features=128, bias=True) ) (generator_lm_head): Linear(in_features=128, out_features=30522, bias=True) ) ``` --- examples/hf/pippy_electra.py | 109 +++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_electra.py diff --git a/examples/hf/pippy_electra.py b/examples/hf/pippy_electra.py new file mode 100644 index 000000000..5c6a3e7a9 --- /dev/null +++ b/examples/hf/pippy_electra.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_electra.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import ElectraForCausalLM, ElectraConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(electra, nranks): + layers_per_rank = electra.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + electra, {f"electra.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = ElectraConfig() + print("Using device:", args.device) + + # Create model + model_class = ElectraForCausalLM + model_name = "ElectraForCausalLM" + electra = model_class(config) + electra.to(args.device) + electra.eval() + if args.rank == 0: + print(electra.config) + print(f"Total number of params = {get_number_of_params(electra) // 10 ** 6}M") + print(electra) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, electra, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(electra, args.world_size) + + # Create pipeline + electra_pipe = Pipe.from_tracing( + electra, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(electra_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(electra_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + electra_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 329caf7992a50194079fc4bed15a0880257a2a26 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:51:31 -0500 Subject: [PATCH 57/96] Add GPTNeo example (#889) ``` GPTNeoForCausalLM( (transformer): GPTNeoModel( (wte): Embedding(50257, 2048) (wpe): Embedding(2048, 2048) (drop): Dropout(p=0.0, inplace=False) (h): ModuleList( (0-23): 24 x GPTNeoBlock( (ln_1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True) (attn): GPTNeoAttention( (attention): GPTNeoSelfAttention( (attn_dropout): Dropout(p=0.0, inplace=False) (resid_dropout): Dropout(p=0.0, inplace=False) (k_proj): Linear(in_features=2048, out_features=2048, bias=False) (v_proj): Linear(in_features=2048, out_features=2048, bias=False) (q_proj): Linear(in_features=2048, out_features=2048, bias=False) (out_proj): Linear(in_features=2048, out_features=2048, bias=True) ) ) (ln_2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True) (mlp): GPTNeoMLP( (c_fc): Linear(in_features=2048, out_features=8192, bias=True) (c_proj): Linear(in_features=8192, out_features=2048, bias=True) (act): NewGELUActivation() (dropout): Dropout(p=0.0, inplace=False) ) ) ) (ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True) ) (lm_head): Linear(in_features=2048, out_features=50257, bias=False) ) ``` --- examples/hf/pippy_gptNeo.py | 109 ++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_gptNeo.py diff --git a/examples/hf/pippy_gptNeo.py b/examples/hf/pippy_gptNeo.py new file mode 100644 index 000000000..795abff5a --- /dev/null +++ b/examples/hf/pippy_gptNeo.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_gptNeo.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import GPTNeoForCausalLM, GPTNeoConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(gptneo, nranks): + layers_per_rank = gptneo.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + gptneo, {f"transformer.h.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = GPTNeoConfig() + print("Using device:", args.device) + + # Create model + model_class = GPTNeoForCausalLM + model_name = "GPTNeoForCausalLM" + gptneo = model_class(config) + gptneo.to(args.device) + gptneo.eval() + if args.rank == 0: + print(gptneo.config) + print(f"Total number of params = {get_number_of_params(gptneo) // 10 ** 6}M") + print(gptneo) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, gptneo, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(gptneo, args.world_size) + + # Create pipeline + gptneo_pipe = Pipe.from_tracing( + gptneo, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(gptneo_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(gptneo_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + gptneo_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 6359968af23e7fd7d6ee3f8c872e6de9b1d69d3d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:51:43 -0500 Subject: [PATCH 58/96] Add FNet example (#890) ``` FNetForMaskedLM( (fnet): FNetModel( (embeddings): FNetEmbeddings( (word_embeddings): Embedding(32000, 768, padding_idx=3) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(4, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (projection): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): FNetEncoder( (layer): ModuleList( (0-11): 12 x FNetLayer( (fourier): FNetFourierTransform( (self): FNetBasicFourierTransform() (output): FNetBasicOutput( (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) (intermediate): FNetIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): NewGELUActivation() ) (output): FNetOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): FNetPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) ) (cls): FNetOnlyMLMHead( (predictions): FNetLMPredictionHead( (transform): FNetPredictionHeadTransform( (dense): Linear(in_features=768, out_features=768, bias=True) (transform_act_fn): NewGELUActivation() (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (decoder): Linear(in_features=768, out_features=32000, bias=True) ) ) ) ``` --- examples/hf/pippy_fnet.py | 109 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_fnet.py diff --git a/examples/hf/pippy_fnet.py b/examples/hf/pippy_fnet.py new file mode 100644 index 000000000..7be90cf70 --- /dev/null +++ b/examples/hf/pippy_fnet.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_fnet.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import FNetForMaskedLM, FNetConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(fnet, nranks): + layers_per_rank = fnet.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + fnet, {f"fnet.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = FNetConfig() + print("Using device:", args.device) + + # Create model + model_class = FNetForMaskedLM + model_name = "FNetForMaskedLM" + fnet = model_class(config) + fnet.to(args.device) + fnet.eval() + if args.rank == 0: + print(fnet.config) + print(f"Total number of params = {get_number_of_params(fnet) // 10 ** 6}M") + print(fnet) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, fnet, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(fnet, args.world_size) + + # Create pipeline + fnet_pipe = Pipe.from_tracing( + fnet, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(fnet_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(fnet_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + fnet_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 8ed449077aef86793e438386ec21e159ad5d7d6c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:51:56 -0500 Subject: [PATCH 59/96] Add LayoutLM example (#891) ``` LayoutLMForMaskedLM( (layoutlm): LayoutLMModel( (embeddings): LayoutLMEmbeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (x_position_embeddings): Embedding(1024, 768) (y_position_embeddings): Embedding(1024, 768) (h_position_embeddings): Embedding(1024, 768) (w_position_embeddings): Embedding(1024, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): LayoutLMEncoder( (layer): ModuleList( (0-11): 12 x LayoutLMLayer( (attention): LayoutLMAttention( (self): LayoutLMSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): LayoutLMSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): LayoutLMIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): LayoutLMOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): LayoutLMPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) ) (cls): LayoutLMOnlyMLMHead( (predictions): LayoutLMLMPredictionHead( (transform): LayoutLMPredictionHeadTransform( (dense): Linear(in_features=768, out_features=768, bias=True) (transform_act_fn): GELUActivation() (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (decoder): Linear(in_features=768, out_features=30522, bias=True) ) ) ) ``` --- examples/hf/pippy_layoutLM.py | 116 ++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 examples/hf/pippy_layoutLM.py diff --git a/examples/hf/pippy_layoutLM.py b/examples/hf/pippy_layoutLM.py new file mode 100644 index 000000000..efc1a2af9 --- /dev/null +++ b/examples/hf/pippy_layoutLM.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_layoutLM.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import LayoutLMForMaskedLM, LayoutLMConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(layoutlm, nranks): + # First stage carries the embedding layer + annotate_split_points( + layoutlm, {"layoutlm.embeddings": PipeSplitWrapper.SplitPoint.END}) + # Last stage carries the LM head + annotate_split_points( + layoutlm, {"cls": PipeSplitWrapper.SplitPoint.BEGINNING}) + # 12 Transformer layers divided over the rest 2 ranks + layers_per_rank = layoutlm.config.num_hidden_layers // (nranks - 2) + for i in range(1, nranks - 2): + annotate_split_points( + layoutlm, {f"layoutlm.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = LayoutLMConfig() + print("Using device:", args.device) + + # Create model + model_class = LayoutLMForMaskedLM + model_name = "LayoutLMForMaskedLM" + layoutlm = model_class(config) + layoutlm.to(args.device) + layoutlm.eval() + if args.rank == 0: + print(layoutlm.config) + print(f"Total number of params = {get_number_of_params(layoutlm) // 10 ** 6}M") + print(layoutlm) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, layoutlm, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(layoutlm, args.world_size) + + # Create pipeline + layoutlm_pipe = Pipe.from_tracing( + layoutlm, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(layoutlm_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(layoutlm_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + layoutlm_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From e76d6c196ed3fe2577ac87c3481c3f2ada9d57dd Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 17:55:46 -0500 Subject: [PATCH 60/96] Add MobileBert example (#894) ``` MobileBertForMaskedLM( (mobilebert): MobileBertModel( (embeddings): MobileBertEmbeddings( (word_embeddings): Embedding(30522, 128, padding_idx=0) (position_embeddings): Embedding(512, 512) (token_type_embeddings): Embedding(2, 512) (embedding_transformation): Linear(in_features=384, out_features=512, bias=True) (LayerNorm): NoNorm() (dropout): Dropout(p=0.0, inplace=False) ) (encoder): MobileBertEncoder( (layer): ModuleList( (0-23): 24 x MobileBertLayer( (attention): MobileBertAttention( (self): MobileBertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=512, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): MobileBertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): NoNorm() ) ) (intermediate): MobileBertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): ReLU() ) (output): MobileBertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): NoNorm() (bottleneck): OutputBottleneck( (dense): Linear(in_features=128, out_features=512, bias=True) (LayerNorm): NoNorm() (dropout): Dropout(p=0.0, inplace=False) ) ) (bottleneck): Bottleneck( (input): BottleneckLayer( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): NoNorm() ) (attention): BottleneckLayer( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): NoNorm() ) ) (ffn): ModuleList( (0-2): 3 x FFNLayer( (intermediate): MobileBertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): ReLU() ) (output): FFNOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): NoNorm() ) ) ) ) ) ) ) (cls): MobileBertOnlyMLMHead( (predictions): MobileBertLMPredictionHead( (transform): MobileBertPredictionHeadTransform( (dense): Linear(in_features=512, out_features=512, bias=True) (transform_act_fn): ReLU() (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True) ) (dense): Linear(in_features=30522, out_features=384, bias=False) (decoder): Linear(in_features=128, out_features=30522, bias=True) ) ) ) ``` --- examples/hf/pippy_mobileBert.py | 113 ++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 examples/hf/pippy_mobileBert.py diff --git a/examples/hf/pippy_mobileBert.py b/examples/hf/pippy_mobileBert.py new file mode 100644 index 000000000..81bcc0e6b --- /dev/null +++ b/examples/hf/pippy_mobileBert.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_mobileBert.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import MobileBertForMaskedLM, MobileBertConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(mobilebert, nranks): + # The last rank carries LM head + annotate_split_points( + mobilebert, {"cls": PipeSplitWrapper.SplitPoint.BEGINNING}) + # The rest ranks divide the 24 layers + layers_per_rank = mobilebert.config.num_hidden_layers // (nranks - 1) + for i in range(1, nranks - 1): + annotate_split_points( + mobilebert, {f"mobilebert.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = MobileBertConfig() + print("Using device:", args.device) + + # Create model + model_class = MobileBertForMaskedLM + model_name = "MobileBertForMaskedLM" + mobilebert = model_class(config) + mobilebert.to(args.device) + mobilebert.eval() + if args.rank == 0: + print(mobilebert.config) + print(f"Total number of params = {get_number_of_params(mobilebert) // 10 ** 6}M") + print(mobilebert) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, mobilebert, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(mobilebert, args.world_size) + + # Create pipeline + mobilebert_pipe = Pipe.from_tracing( + mobilebert, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(mobilebert_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(mobilebert_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + mobilebert_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 0f4d3e75e4d6499b8545001a5e7bde1f8146ae40 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 18:20:00 -0500 Subject: [PATCH 61/96] Add MT5 example (#892) ``` MT5ForConditionalGeneration( (shared): Embedding(250112, 512) (encoder): MT5Stack( (embed_tokens): Embedding(250112, 512) (block): ModuleList( (0): MT5Block( (layer): ModuleList( (0): MT5LayerSelfAttention( (SelfAttention): MT5Attention( (q): Linear(in_features=512, out_features=384, bias=False) (k): Linear(in_features=512, out_features=384, bias=False) (v): Linear(in_features=512, out_features=384, bias=False) (o): Linear(in_features=384, out_features=512, bias=False) (relative_attention_bias): Embedding(32, 6) ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) (1): MT5LayerFF( (DenseReluDense): MT5DenseGatedActDense( (wi_0): Linear(in_features=512, out_features=1024, bias=False) (wi_1): Linear(in_features=512, out_features=1024, bias=False) (wo): Linear(in_features=1024, out_features=512, bias=False) (dropout): Dropout(p=0.1, inplace=False) (act): NewGELUActivation() ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) ) ) (1-7): 7 x MT5Block( (layer): ModuleList( (0): MT5LayerSelfAttention( (SelfAttention): MT5Attention( (q): Linear(in_features=512, out_features=384, bias=False) (k): Linear(in_features=512, out_features=384, bias=False) (v): Linear(in_features=512, out_features=384, bias=False) (o): Linear(in_features=384, out_features=512, bias=False) ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) (1): MT5LayerFF( (DenseReluDense): MT5DenseGatedActDense( (wi_0): Linear(in_features=512, out_features=1024, bias=False) (wi_1): Linear(in_features=512, out_features=1024, bias=False) (wo): Linear(in_features=1024, out_features=512, bias=False) (dropout): Dropout(p=0.1, inplace=False) (act): NewGELUActivation() ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (final_layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) (decoder): MT5Stack( (embed_tokens): Embedding(250112, 512) (block): ModuleList( (0): MT5Block( (layer): ModuleList( (0): MT5LayerSelfAttention( (SelfAttention): MT5Attention( (q): Linear(in_features=512, out_features=384, bias=False) (k): Linear(in_features=512, out_features=384, bias=False) (v): Linear(in_features=512, out_features=384, bias=False) (o): Linear(in_features=384, out_features=512, bias=False) (relative_attention_bias): Embedding(32, 6) ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) (1): MT5LayerCrossAttention( (EncDecAttention): MT5Attention( (q): Linear(in_features=512, out_features=384, bias=False) (k): Linear(in_features=512, out_features=384, bias=False) (v): Linear(in_features=512, out_features=384, bias=False) (o): Linear(in_features=384, out_features=512, bias=False) ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) (2): MT5LayerFF( (DenseReluDense): MT5DenseGatedActDense( (wi_0): Linear(in_features=512, out_features=1024, bias=False) (wi_1): Linear(in_features=512, out_features=1024, bias=False) (wo): Linear(in_features=1024, out_features=512, bias=False) (dropout): Dropout(p=0.1, inplace=False) (act): NewGELUActivation() ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) ) ) (1-7): 7 x MT5Block( (layer): ModuleList( (0): MT5LayerSelfAttention( (SelfAttention): MT5Attention( (q): Linear(in_features=512, out_features=384, bias=False) (k): Linear(in_features=512, out_features=384, bias=False) (v): Linear(in_features=512, out_features=384, bias=False) (o): Linear(in_features=384, out_features=512, bias=False) ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) (1): MT5LayerCrossAttention( (EncDecAttention): MT5Attention( (q): Linear(in_features=512, out_features=384, bias=False) (k): Linear(in_features=512, out_features=384, bias=False) (v): Linear(in_features=512, out_features=384, bias=False) (o): Linear(in_features=384, out_features=512, bias=False) ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) (2): MT5LayerFF( (DenseReluDense): MT5DenseGatedActDense( (wi_0): Linear(in_features=512, out_features=1024, bias=False) (wi_1): Linear(in_features=512, out_features=1024, bias=False) (wo): Linear(in_features=1024, out_features=512, bias=False) (dropout): Dropout(p=0.1, inplace=False) (act): NewGELUActivation() ) (layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (final_layer_norm): MT5LayerNorm() (dropout): Dropout(p=0.1, inplace=False) ) (lm_head): Linear(in_features=512, out_features=250112, bias=False) ) ``` --- examples/hf/pippy_mt5.py | 129 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 examples/hf/pippy_mt5.py diff --git a/examples/hf/pippy_mt5.py b/examples/hf/pippy_mt5.py new file mode 100644 index 000000000..152901528 --- /dev/null +++ b/examples/hf/pippy_mt5.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_mt5.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import MT5ForConditionalGeneration, MT5Config + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(mt5, nranks): + # Number of encoder layers (mt5.config.num_layers): 8 + # Number of decoder layers (mt5.config.num_decoder_layers): 8 + # 16 layers in total + total_layers = mt5.config.num_layers + mt5.config.num_decoder_layers + layers_per_rank = (total_layers + nranks - 1) // nranks + print(f"Layers per rank = {layers_per_rank}") + nstages = 1 + # Split encoder + for i in range(1, mt5.config.num_layers // layers_per_rank): + annotate_split_points( + mt5, {f'encoder.block.{i * layers_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}) + nstages += 1 + # Split at the boundary of encoder and decoder + annotate_split_points( + mt5, {f'decoder.embed_tokens': PipeSplitWrapper.SplitPoint.BEGINNING}) + nstages += 1 + # Split decoder + for i in range(1, mt5.config.num_decoder_layers // layers_per_rank): + annotate_split_points( + mt5, {f'decoder.block.{i * layers_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}) + nstages += 1 + assert nstages == nranks, f"nstages = {nstages} nranks = {nranks}" + + +def run(args): + # Model configs + config = MT5Config() + print("Using device:", args.device) + + # Create model + model_class = MT5ForConditionalGeneration + model_name = "MT5ForConditionalGeneration" + mt5 = model_class(config) + mt5.to(args.device) + mt5.eval() + if args.rank == 0: + print(mt5.config) + print(f"Total number of params = {get_number_of_params(mt5) // 10 ** 6}M") + print(mt5) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, mt5, model_name, args.batch_size, args.device) + + # Annotate split points + add_split_points(mt5, args.world_size) + + # Create pipeline + mt5_pipe = Pipe.from_tracing( + mt5, + num_chunks=args.chunks, + example_args=(), + example_kwargs=example_inputs, + ) + nstages = len(list(mt5_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(mt5_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + mt5_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(example_inputs["input_ids"]) + elif args.rank == 1: + stage(example_inputs["decoder_input_ids"]) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From ef209f248e75a3ceaae4037fe657adc33c23ba5b Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 4 Dec 2023 18:20:43 -0500 Subject: [PATCH 62/96] Add MegatronBert example (#893) ``` MegatronBertForCausalLM( (bert): MegatronBertModel( (embeddings): MegatronBertEmbeddings( (word_embeddings): Embedding(29056, 1024, padding_idx=0) (position_embeddings): Embedding(512, 1024) (token_type_embeddings): Embedding(2, 1024) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): MegatronBertEncoder( (layer): ModuleList( (0-23): 24 x MegatronBertLayer( (attention): MegatronBertAttention( (ln): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) (self): MegatronBertSelfAttention( (query): Linear(in_features=1024, out_features=1024, bias=True) (key): Linear(in_features=1024, out_features=1024, bias=True) (value): Linear(in_features=1024, out_features=1024, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): MegatronBertSelfOutput( (dense): Linear(in_features=1024, out_features=1024, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (ln): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) (intermediate): MegatronBertIntermediate( (dense): Linear(in_features=1024, out_features=4096, bias=True) (intermediate_act_fn): GELUActivation() ) (output): MegatronBertOutput( (dense): Linear(in_features=4096, out_features=1024, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) (ln): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) ) ) (cls): MegatronBertOnlyMLMHead( (predictions): MegatronBertLMPredictionHead( (transform): MegatronBertPredictionHeadTransform( (dense): Linear(in_features=1024, out_features=1024, bias=True) (transform_act_fn): GELUActivation() (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) ) (decoder): Linear(in_features=1024, out_features=29056, bias=True) ) ) ) ``` --- examples/hf/pippy_megatronBert.py | 109 ++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_megatronBert.py diff --git a/examples/hf/pippy_megatronBert.py b/examples/hf/pippy_megatronBert.py new file mode 100644 index 000000000..31ca3f58d --- /dev/null +++ b/examples/hf/pippy_megatronBert.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_megatronBert.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import MegatronBertForCausalLM, MegatronBertConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(bert, nranks): + layers_per_rank = bert.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + bert, {f"bert.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = MegatronBertConfig() + print("Using device:", args.device) + + # Create model + model_class = MegatronBertForCausalLM + model_name = "MegatronBertForCausalLM" + bert = model_class(config) + bert.to(args.device) + bert.eval() + if args.rank == 0: + print(bert.config) + print(f"Total number of params = {get_number_of_params(bert) // 10 ** 6}M") + print(bert) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, bert, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(bert, args.world_size) + + # Create pipeline + bert_pipe = Pipe.from_tracing( + bert, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(bert_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(bert_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + bert_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 2a698b6e6035d790b59807bdcacd00031e19b835 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Mon, 4 Dec 2023 17:00:53 -0800 Subject: [PATCH 63/96] add multinode_trainer.slurm for easy multinode reference (#878) ## Description Adding a default multinode reference file for slurm, with appropriate updates to ensure Pippy works nicely with AWS EFA. testing: Have used and verified this script with AWS A100's and EFA (multinode). Note that for EFA, having export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 is vital (otherwise, AWS servers will complain about unable to register device), so part of reason to add this to repo for easy reference. --- test/multinode_trainer.slurm | 51 ++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 test/multinode_trainer.slurm diff --git a/test/multinode_trainer.slurm b/test/multinode_trainer.slurm new file mode 100644 index 000000000..e6b2d4acf --- /dev/null +++ b/test/multinode_trainer.slurm @@ -0,0 +1,51 @@ +#!/bin/bash + +#SBATCH --job-name=looped-bfs-trainer + +#SBATCH --ntasks=2 + +#SBATCH --nodes=2 + +#SBATCH --gpus-per-task=8 + +#SBATCH --cpus-per-task=96 + +#SBATCH --partition=train + + +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +echo Node IP: $head_node_ip +export LOGLEVEL=INFO +# Enable for A100 +export FI_PROVIDER="efa" +# Ensure that P2P is available +# export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 + +# debugging flags (optional) +export NCCL_DEBUG=WARN +export PYTHONFAULTHANDLER=1 +# optional debug settings +# export NCCL_DEBUG=INFO +# NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV + +export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH +export CUDA_LAUNCH_BLOCKING=0 + +# on your cluster you might need these: +# set the network interface +export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" +export NCCL_BUFFSIZE=2097152 +#export TORCH_DIST_INIT_BARRIER=1 +export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + +dcgmi profile --pause +# adjust sbatch --ntasks and sbatch --nodes above and --nnodes below +# to your specific node count, and update target launch file. +srun torchrun --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./multinode_bfs.py +dcgmi profile --resume From cdcd1b682791850937e9865165af14bba654b6b7 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 5 Dec 2023 16:08:32 -0500 Subject: [PATCH 64/96] Add MBart example (#900) ``` MBartForCausalLM( (model): MBartDecoderWrapper( (decoder): MBartDecoder( (embed_tokens): Embedding(50265, 1024, padding_idx=1) (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024) (layers): ModuleList( (0-11): 12 x MBartDecoderLayer( (self_attn): MBartAttention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (activation_fn): GELUActivation() (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (encoder_attn): MBartAttention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=1024, out_features=4096, bias=True) (fc2): Linear(in_features=4096, out_features=1024, bias=True) (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (lm_head): Linear(in_features=1024, out_features=50265, bias=False) ) ``` --- examples/hf/pippy_mbart.py | 109 +++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_mbart.py diff --git a/examples/hf/pippy_mbart.py b/examples/hf/pippy_mbart.py new file mode 100644 index 000000000..fdaadddfd --- /dev/null +++ b/examples/hf/pippy_mbart.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_mbart.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import MBartForCausalLM, MBartConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(mbart, nranks): + layers_per_rank = mbart.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + mbart, {f"model.decoder.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = MBartConfig() + print("Using device:", args.device) + + # Create model + model_class = MBartForCausalLM + model_name = "MBartForCausalLM" + mbart = model_class(config) + mbart.to(args.device) + mbart.eval() + if args.rank == 0: + print(mbart.config) + print(f"Total number of params = {get_number_of_params(mbart) // 10 ** 6}M") + print(mbart) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, mbart, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(mbart, args.world_size) + + # Create pipeline + mbart_pipe = Pipe.from_tracing( + mbart, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(mbart_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(mbart_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + mbart_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 13723f556223be14d82304f674d979ee199e60b2 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 5 Dec 2023 16:08:50 -0500 Subject: [PATCH 65/96] Add OPT example (#901) ``` OPTForCausalLM( (model): OPTModel( (decoder): OPTDecoder( (embed_tokens): Embedding(50272, 768, padding_idx=1) (embed_positions): OPTLearnedPositionalEmbedding(2050, 768) (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (layers): ModuleList( (0-11): 12 x OPTDecoderLayer( (self_attn): OPTAttention( (k_proj): Linear(in_features=768, out_features=768, bias=True) (v_proj): Linear(in_features=768, out_features=768, bias=True) (q_proj): Linear(in_features=768, out_features=768, bias=True) (out_proj): Linear(in_features=768, out_features=768, bias=True) ) (activation_fn): ReLU() (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=768, out_features=3072, bias=True) (fc2): Linear(in_features=3072, out_features=768, bias=True) (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) ) ) ) (lm_head): Linear(in_features=768, out_features=50272, bias=False) ) ``` --- examples/hf/pippy_opt.py | 109 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_opt.py diff --git a/examples/hf/pippy_opt.py b/examples/hf/pippy_opt.py new file mode 100644 index 000000000..a301e2ed3 --- /dev/null +++ b/examples/hf/pippy_opt.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_opt.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import OPTForCausalLM, OPTConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(opt, nranks): + layers_per_rank = opt.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + opt, {f"model.decoder.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = OPTConfig() + print("Using device:", args.device) + + # Create model + model_class = OPTForCausalLM + model_name = "OPTForCausalLM" + opt = model_class(config) + opt.to(args.device) + opt.eval() + if args.rank == 0: + print(opt.config) + print(f"Total number of params = {get_number_of_params(opt) // 10 ** 6}M") + print(opt) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, opt, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(opt, args.world_size) + + # Create pipeline + opt_pipe = Pipe.from_tracing( + opt, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(opt_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(opt_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + opt_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 3b49a370bfb27c60842fd3eca8e091009ca76341 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 5 Dec 2023 16:59:03 -0500 Subject: [PATCH 66/96] Add Bart example (#902) ``` BartForCausalLM( (model): BartDecoderWrapper( (decoder): BartDecoder( (embed_tokens): Embedding(50265, 1024, padding_idx=1) (embed_positions): BartLearnedPositionalEmbedding(1026, 1024) (layers): ModuleList( (0-11): 12 x BartDecoderLayer( (self_attn): BartAttention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (activation_fn): GELUActivation() (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (encoder_attn): BartAttention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=1024, out_features=4096, bias=True) (fc2): Linear(in_features=4096, out_features=1024, bias=True) (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (lm_head): Linear(in_features=1024, out_features=50265, bias=False) ) ``` --- examples/hf/pippy_bart.py | 109 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_bart.py diff --git a/examples/hf/pippy_bart.py b/examples/hf/pippy_bart.py new file mode 100644 index 000000000..0a84b97f0 --- /dev/null +++ b/examples/hf/pippy_bart.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_bart.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import BartForCausalLM, BartConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(bart, nranks): + layers_per_rank = bart.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + bart, {f"model.decoder.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = BartConfig() + print("Using device:", args.device) + + # Create model + model_class = BartForCausalLM + model_name = "BartForCausalLM" + bart = model_class(config) + bart.to(args.device) + bart.eval() + if args.rank == 0: + print(bart.config) + print(f"Total number of params = {get_number_of_params(bart) // 10 ** 6}M") + print(bart) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, bart, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(bart, args.world_size) + + # Create pipeline + bart_pipe = Pipe.from_tracing( + bart, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(bart_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(bart_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + bart_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 73bbf70ce249e93c7fb6da6d964425e79c8978a8 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 6 Dec 2023 15:14:34 -0500 Subject: [PATCH 67/96] Mute an assert to coup with pytorch upstream change (#905) After certain torch 2.2.0.dev version, submod_0, submod_1, submod_2 ... are named as submod_0, submod_2, submod_4 ... Muting this assert to let program pass. --- pippy/IR.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pippy/IR.py b/pippy/IR.py index bb98d2827..f484f5430 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -666,7 +666,13 @@ def _number_and_count_forward_stages(gm: fx.GraphModule): num_stages += 1 # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule - assert all(i in found_idxs for i in range(num_stages)) + # Update: the following assert may fail against some torch versions >= + # 2.2.0, as: + # submod_0, submod_1, submod_2, ... + # may be named as + # submod_0, submod_2, submod_4, ... + # TODO: investigate + # assert all(i in found_idxs for i in range(num_stages)) return num_stages From 7f26cca9fb593b42cd61cf12b0f97925aeff4294 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 6 Dec 2023 16:38:07 -0500 Subject: [PATCH 68/96] Add Pegasus example (#904) ``` PegasusForCausalLM( (model): PegasusDecoderWrapper( (decoder): PegasusDecoder( (embed_tokens): Embedding(50265, 1024, padding_idx=0) (embed_positions): PegasusSinusoidalPositionalEmbedding(1024, 1024) (layers): ModuleList( (0-11): 12 x PegasusDecoderLayer( (self_attn): PegasusAttention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (activation_fn): GELUActivation() (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (encoder_attn): PegasusAttention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=1024, out_features=4096, bias=True) (fc2): Linear(in_features=4096, out_features=1024, bias=True) (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (lm_head): Linear(in_features=1024, out_features=50265, bias=False) ) ``` --- examples/hf/pippy_pegasus.py | 109 +++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_pegasus.py diff --git a/examples/hf/pippy_pegasus.py b/examples/hf/pippy_pegasus.py new file mode 100644 index 000000000..fdbdd0704 --- /dev/null +++ b/examples/hf/pippy_pegasus.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_pegasus.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import PegasusForCausalLM, PegasusConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(pegasus, nranks): + layers_per_rank = pegasus.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + pegasus, {f"model.decoder.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = PegasusConfig() + print("Using device:", args.device) + + # Create model + model_class = PegasusForCausalLM + model_name = "PegasusForCausalLM" + pegasus = model_class(config) + pegasus.to(args.device) + pegasus.eval() + if args.rank == 0: + print(pegasus.config) + print(f"Total number of params = {get_number_of_params(pegasus) // 10 ** 6}M") + print(pegasus) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, pegasus, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(pegasus, args.world_size) + + # Create pipeline + pegasus_pipe = Pipe.from_tracing( + pegasus, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(pegasus_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(pegasus_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + pegasus_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 122bdad38fcec545a51008781177ef2669b74d52 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 6 Dec 2023 16:41:37 -0500 Subject: [PATCH 69/96] Add M2M100 example (#898) Note: the way the splitting plan is written is for a 3-rank run. If you would like to run it on more ranks, please modify the splitting plan. ``` M2M100ForConditionalGeneration( (model): M2M100Model( (shared): Embedding(128112, 1024, padding_idx=1) (encoder): M2M100Encoder( (embed_tokens): Embedding(128112, 1024, padding_idx=1) (embed_positions): M2M100SinusoidalPositionalEmbedding() (layers): ModuleList( (0-11): 12 x M2M100EncoderLayer( (self_attn): M2M100Attention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (activation_fn): ReLU() (fc1): Linear(in_features=1024, out_features=4096, bias=True) (fc2): Linear(in_features=4096, out_features=1024, bias=True) (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) (decoder): M2M100Decoder( (embed_tokens): Embedding(128112, 1024, padding_idx=1) (embed_positions): M2M100SinusoidalPositionalEmbedding() (layers): ModuleList( (0-11): 12 x M2M100DecoderLayer( (self_attn): M2M100Attention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (activation_fn): ReLU() (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (encoder_attn): M2M100Attention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=1024, out_features=4096, bias=True) (fc2): Linear(in_features=4096, out_features=1024, bias=True) (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) (lm_head): Linear(in_features=1024, out_features=128112, bias=False) ) ``` --- examples/hf/pippy_m2m100.py | 114 ++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 examples/hf/pippy_m2m100.py diff --git a/examples/hf/pippy_m2m100.py b/examples/hf/pippy_m2m100.py new file mode 100644 index 000000000..3e6ab5063 --- /dev/null +++ b/examples/hf/pippy_m2m100.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 3 pippy_m2m100.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import M2M100ForConditionalGeneration, M2M100Config + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(m2m100, nranks): + # First rank takes encoder + annotate_split_points( + m2m100, {"model.encoder": PipeSplitWrapper.SplitPoint.END}) + # Second rank takes decoder + annotate_split_points( + m2m100, {"model.decoder": PipeSplitWrapper.SplitPoint.END}) + # Last rank takes LM head + + +def run(args): + # Model configs + config = M2M100Config() + print("Using device:", args.device) + + # Create model + model_class = M2M100ForConditionalGeneration + model_name = "M2M100ForConditionalGeneration" + m2m100 = model_class(config) + m2m100.to(args.device) + m2m100.eval() + if args.rank == 0: + print(m2m100.config) + print(f"Total number of params = {get_number_of_params(m2m100) // 10 ** 6}M") + print(m2m100) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, m2m100, model_name, args.batch_size, args.device) + + # Annotate split points + add_split_points(m2m100, args.world_size) + + # Create pipeline + m2m100_pipe = Pipe.from_tracing( + m2m100, + num_chunks=args.chunks, + example_args=(), + example_kwargs=example_inputs, + ) + nstages = len(list(m2m100_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(m2m100_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + m2m100_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(example_inputs["input_ids"]) + elif args.rank == 1: + out = stage(example_inputs["decoder_input_ids"]) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 3))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 44ab9e978e6737632d122355f74286e73b63405b Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 6 Dec 2023 16:57:09 -0500 Subject: [PATCH 70/96] add abstract classes for PipelineStage and PipelineSchedule (#903) Add abstract class definitions --- pippy/PipelineSchedule.py | 125 +++++++++++++++++++++++---------- test/test_pipeline_schedule.py | 6 +- 2 files changed, 90 insertions(+), 41 deletions(-) diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index 094bfdcbb..67af9b3b6 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -1,6 +1,7 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. import logging +from abc import ABC, abstractmethod from collections import deque from typing import Deque, List, Optional, Tuple @@ -16,7 +17,65 @@ logger.addHandler(handler) -class PipelineStage(nn.Module): +class PipelineStage(ABC, nn.Module): + @abstractmethod + def forward(self, microbatch): + """ + TODO: this will be updated to support multiple arguments + + Perform forward pass on the module. + This should only be called once per microbatch. + + Args: + microbatch: The input to the module + """ + raise NotImplementedError + + @abstractmethod + def get_fwd_recv_ops(self) -> List[dist.P2POp]: + """ + Get the list of P2P operations that need to be performed before calling forward() + """ + raise NotImplementedError + + @abstractmethod + def get_fwd_send_ops(self) -> List[dist.P2POp]: + """ + Get the list of P2P operations that need to be performed after calling forward() + """ + raise NotImplementedError + + @abstractmethod + def backward(self): + """ + Perform backward pass on the module. + This should only be called once per microbatch. + """ + raise NotImplementedError + + @abstractmethod + def get_bwd_recv_ops(self) -> List[dist.P2POp]: + """ + Get the list of P2P operations that need to be performed before calling backward() + """ + raise NotImplementedError + + @abstractmethod + def get_bwd_send_ops(self) -> List[dist.P2POp]: + """ + Get the list of P2P operations that need to be performed after calling backward() + """ + raise NotImplementedError + + @abstractmethod + def compute_loss(self): + """ + Compute loss from the outputs of the last stage + """ + raise NotImplementedError + + +class PipelineStageV2Impl(PipelineStage): def __init__( self, module: nn.Module, @@ -89,12 +148,10 @@ def get_fwd_send_ops(self) -> List[dist.P2POp]: return [] return [dist.P2POp(dist.isend, self.fwd_output, self.next_stage)] - def forward(self, input_data, is_first_mb, is_last_mb): - logger.info( - f"[{self.rank} FORWARD {self.stage_id}] is_first_mb {is_first_mb} is_last_mb {is_last_mb}" - ) + def forward(self, microbatch: torch.Tensor): + logger.info(f"[{self.rank} FORWARD {self.stage_id}") if self.is_first_stage: - self.fwd_input = input_data + self.fwd_input = microbatch # this is needed when we access the gradients for this in backward() self.fwd_input.requires_grad = True @@ -138,10 +195,8 @@ def _wait_backward_inputs(self): self.bwd_recv_queue = None return self.fwd_output_grads - def backward(self, is_first_mb, is_last_mb): - logger.info( - f"[{self.rank} BACKWARD {self.stage_id}] is_first_mb {is_first_mb} is_last_mb {is_last_mb}" - ) + def backward(self): + logger.info(f"[{self.rank} BACKWARD {self.stage_id}]") if self.is_last_stage: fwd_inputs, loss = self.fwd_outputs_for_backward.popleft() @@ -165,22 +220,31 @@ def compute_loss(self): return self.fwd_output.mean() -class PipelineScheduleGPipe: +class PipelineSchedule(ABC): + @abstractmethod + def step(self, microbatches: List[torch.Tensor]) -> None: + """ + Run one iteration of the pipeline schedule. Will go through all the microbatches + according to the schedule implementation. + + Args: + microbatches: list of microbatch tensors + """ + raise NotImplementedError + + +class PipelineScheduleGPipe(PipelineSchedule): def __init__(self, stage: PipelineStage): self._stage = stage def step(self, microbatches): for i, mb in enumerate(microbatches): with record_function(f"Forward {i}"): - is_last_mb = i == len(microbatches) - 1 - ops = self._stage.get_fwd_recv_ops() if ops: dist.batch_isend_irecv(ops).pop().wait() - self._stage.forward( - mb, is_first_mb=i == 0, is_last_mb=is_last_mb - ) + self._stage.forward(mb) ops = self._stage.get_fwd_send_ops() if ops: @@ -196,10 +260,7 @@ def step(self, microbatches): if ops: dist.batch_isend_irecv(ops).pop().wait() - self._stage.backward( - is_first_mb=i == 0, - is_last_mb=i == len(microbatches) - 1, - ) + self._stage.backward() ops = self._stage.get_bwd_send_ops() if ops: @@ -208,7 +269,7 @@ def step(self, microbatches): logger.info(f"{self._stage.stage_id} backward {i} finished") -class PipelineScheduleLoopedBFS: +class PipelineScheduleLoopedBFS(PipelineSchedule): def __init__(self, stages: List[PipelineStage]): self._stages = stages @@ -216,13 +277,11 @@ def step(self, microbatches): for s, stage in enumerate(self._stages): for i, mb in enumerate(microbatches): with record_function(f"Stage {s} Forward"): - is_last_mb = i == len(microbatches) - 1 - ops = stage.get_fwd_recv_ops() if ops: dist.batch_isend_irecv(ops).pop().wait() - stage.forward(mb, is_first_mb=i == 0, is_last_mb=is_last_mb) + stage.forward(mb) ops = stage.get_fwd_send_ops() if ops: @@ -235,17 +294,14 @@ def step(self, microbatches): if ops: dist.batch_isend_irecv(ops).pop().wait() - stage.backward( - is_first_mb=i == 0, - is_last_mb=i == len(microbatches) - 1, - ) + stage.backward() ops = stage.get_bwd_send_ops() if ops: dist.batch_isend_irecv(ops) -class PipelineScheduleLoopedDFS: +class PipelineScheduleLoopedDFS(PipelineSchedule): def __init__(self, stages: List[PipelineStage], n_microbatch, pp_id, n_pp): assert ( n_microbatch % n_pp == 0 @@ -385,11 +441,7 @@ def stage_index(step): logger.info( f"pp_id {self.pp_id} step {step} forward_stage {forward_stage.stage_id} mb_id {mb_id_fwd}" ) - forward_stage.forward( - microbatches[mb_id_fwd], - is_first_mb=mb_id_fwd == 0, - is_last_mb=mb_id_fwd == len(microbatches) - 1, - ) + forward_stage.forward(microbatches[mb_id_fwd]) requests: List[dist.P2POp] = [] @@ -432,10 +484,7 @@ def stage_index(step): logger.info( f"pp_id {self.pp_id} step {step}/{self.total_steps} backward_step {backward_step} backward_stage_id {backward_stage.stage_id} mb_id {mb_id_bwd}" ) - backward_stage.backward( - is_first_mb=mb_id_bwd == 0, - is_last_mb=mb_id_bwd == len(microbatches) - 1, - ) + backward_stage.backward() requests = [] diff --git a/test/test_pipeline_schedule.py b/test/test_pipeline_schedule.py index 9b18c0c07..7b5cb4853 100644 --- a/test/test_pipeline_schedule.py +++ b/test/test_pipeline_schedule.py @@ -32,7 +32,7 @@ PipelineScheduleGPipe, PipelineScheduleLoopedBFS, PipelineScheduleLoopedDFS, - PipelineStage, + PipelineStageV2Impl, ) logger = logging.getLogger(__name__) @@ -107,13 +107,13 @@ def main(**kwargs): x = torch.randn([microbatch_size, input_dim]).to("meta") - stage_model = PipelineStage( + stage_model = PipelineStageV2Impl( module_list[rank], rank, world_size, rank, world_size, x, device ) stage_model.init_p2p_neighbors() stage_model_looped = [ - PipelineStage( + PipelineStageV2Impl( module_list[rank], stage_id=(world_size * i) + rank, num_stages=world_size * world_size, From 90a06ee7140705f8dcac15c5f0e86ac37001fb3a Mon Sep 17 00:00:00 2001 From: Less Wright Date: Wed, 6 Dec 2023 14:05:14 -0800 Subject: [PATCH 71/96] Add torch profiling based on context manager for pipeline scheduler (#899) ## Description This PR adds a context based profiling manager (maybe_run_profiler) to the pipeline scheduler training loop. The idea of using a context manager is to ensure no profiler overhead if not profiling, while at the same time not requiring two different training loops (where one has a profiler, the other does not, resulting in duplicate code). Profiling is controlled via the --profiler True flag on launch. Traces are saved to the --trace_dir which now defaults to ./pipeline_traces Updates to how the profiler is run, can be done in the context manager code block (i.e. adding warmup, etc). profiler_custom_handler Side update is the multinode.slurm file is updated to launch the new name, test_pipeline_schedules.py. Fixes #(issue) ## Type of change Please delete options that are not relevant. - [ X] New feature (non-breaking change which adds functionality) ## Feature/Issue validation/testing Tested with and without running profiler on single node. Tested without profiler on multinode. --- test/multinode_trainer.slurm | 4 +- test/run_pipeline_scheduler.sh | 3 ++ test/test_pipeline_schedule.py | 75 +++++++++++++++++++++++++++------- 3 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 test/run_pipeline_scheduler.sh diff --git a/test/multinode_trainer.slurm b/test/multinode_trainer.slurm index e6b2d4acf..bc69812e5 100644 --- a/test/multinode_trainer.slurm +++ b/test/multinode_trainer.slurm @@ -1,6 +1,6 @@ #!/bin/bash -#SBATCH --job-name=looped-bfs-trainer +#SBATCH --job-name=test_pipeline_schedules #SBATCH --ntasks=2 @@ -47,5 +47,5 @@ export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 dcgmi profile --pause # adjust sbatch --ntasks and sbatch --nodes above and --nnodes below # to your specific node count, and update target launch file. -srun torchrun --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./multinode_bfs.py +srun torchrun --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./test_pipeline_schedule.py --schedules gpipe looped_bfs dcgmi profile --resume diff --git a/test/run_pipeline_scheduler.sh b/test/run_pipeline_scheduler.sh new file mode 100644 index 000000000..24fee401a --- /dev/null +++ b/test/run_pipeline_scheduler.sh @@ -0,0 +1,3 @@ +# To run samples: +# launcher for testing pipeline schedules +torchrun --nnodes=1 --nproc_per_node 8 --rdzv_endpoint="localhost:59124" test_pipeline_schedule.py \ No newline at end of file diff --git a/test/test_pipeline_schedule.py b/test/test_pipeline_schedule.py index 7b5cb4853..5e3125daa 100644 --- a/test/test_pipeline_schedule.py +++ b/test/test_pipeline_schedule.py @@ -21,6 +21,7 @@ import argparse import logging import os +from contextlib import contextmanager, nullcontext from datetime import timedelta @@ -34,9 +35,42 @@ PipelineScheduleLoopedDFS, PipelineStageV2Impl, ) +from torch.profiler import record_function logger = logging.getLogger(__name__) +_null_context = nullcontext() + + +# profiling context manager +@contextmanager +def maybe_run_profiler( + use_profiler, trace_dir, schedule, rank, *args, **kwargs +): + def trace_handler(prof): + if rank == 0: + (f"about to EXPORT traces for {schedule} to {trace_dir}") + prof.export_chrome_trace( + f"{trace_dir}/{schedule}_rank{rank}_trace.json" + ) + + if use_profiler: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + # schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1), + on_trace_ready=trace_handler, + profile_memory=True, + with_stack=False, + record_shapes=True, + ) as torch_profiler: + yield torch_profiler + else: + torch_profiler = nullcontext() + yield None + class MLP(nn.Module): def __init__( @@ -79,7 +113,7 @@ def setup(local_rank, world_size): def main(**kwargs): torch.manual_seed(42) - print(f"MY KWARGS ARE {kwargs}") + rank = kwargs["rank"] local_rank = kwargs["local_rank"] world_size = kwargs["world_size"] @@ -90,6 +124,12 @@ def main(**kwargs): f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size}, Device {device} main ======" ) + def rank_print(msg): + if rank == 0: + print(f"{msg}") + + rank_print(f"My KWARGS are {kwargs}") + input_dim = 4000 hidden_dim = 8000 output_dim = 4000 @@ -129,6 +169,16 @@ def main(**kwargs): torch.randn_like(x_cuda_empty) for _ in range(n_microbatches) ] + # profiling setup (enable with --profiler True) + _run_profiler = kwargs["profiler"] + _torch_profiler = None + _trace_dir = kwargs["trace_dir"] + + if _run_profiler: + if not os.path.exists(_trace_dir): + os.mkdir(_trace_dir) + rank_print(f"Profiling active -- saving traces to {_trace_dir}") + for schedule in kwargs["schedules"]: logger.info(f"====== Rank {rank} running schedule {schedule} ======") if schedule == "gpipe": @@ -143,19 +193,15 @@ def main(**kwargs): n_pp=n_pp, ) - logger.info(f"====== Rank {rank} profile ======") + if _run_profiler: + logger.info(f"====== Rank {rank} profile ======") - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: - # with record_function(schedule): - pipeline.step(microbatches) + with maybe_run_profiler( + _run_profiler, _trace_dir, schedule, rank + ) as _torch_profiler: + with record_function(schedule): + pipeline.step(microbatches) - # TODO - default should be no profiling. - """if not kwargs["no_trace"]: - trace_dir = kwargs["trace_dir"] - if not os.path.exists(trace_dir): - os.mkdir(trace_dir) - prof.export_chrome_trace(f"{trace_dir}/{schedule}_rank{rank}_trace.json") - """ logger.info(f"====== Rank {rank} finished {schedule} ======") @@ -196,8 +242,8 @@ def set_up_logging(rank, log_level=logging.INFO): master_port = os.environ.get("MASTER_PORT", None) parser = argparse.ArgumentParser(description="Pipeline Stages Runner") - parser.add_argument("--no_trace", action="store_true") - parser.add_argument("--trace_dir", type=str, default="./traces") + parser.add_argument("--profiler", type=bool, default=False) + parser.add_argument("--trace_dir", type=str, default="./pipeline_traces") parser.add_argument( "--schedules", type=str, @@ -208,7 +254,6 @@ def set_up_logging(rank, log_level=logging.INFO): parser.add_argument("--device", type=str, default="cuda") args = parser.parse_args() kwargs = vars(args) - print(kwargs) if ( rank is None From fa24505beedd8e237c949dbba0d3aae1bbe08383 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 6 Dec 2023 17:59:22 -0500 Subject: [PATCH 72/96] Add Blenderbot example (#897) Needs https://github.com/pytorch/pytorch/pull/114982 to work. ``` BlenderbotForCausalLM( (model): BlenderbotDecoderWrapper( (decoder): BlenderbotDecoder( (embed_tokens): Embedding(8008, 2560, padding_idx=0) (embed_positions): BlenderbotLearnedPositionalEmbedding(128, 2560) (layers): ModuleList( (0-23): 24 x BlenderbotDecoderLayer( (self_attn): BlenderbotAttention( (k_proj): Linear(in_features=2560, out_features=2560, bias=True) (v_proj): Linear(in_features=2560, out_features=2560, bias=True) (q_proj): Linear(in_features=2560, out_features=2560, bias=True) (out_proj): Linear(in_features=2560, out_features=2560, bias=True) ) (activation_fn): GELUActivation() (self_attn_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) (encoder_attn): BlenderbotAttention( (k_proj): Linear(in_features=2560, out_features=2560, bias=True) (v_proj): Linear(in_features=2560, out_features=2560, bias=True) (q_proj): Linear(in_features=2560, out_features=2560, bias=True) (out_proj): Linear(in_features=2560, out_features=2560, bias=True) ) (encoder_attn_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=2560, out_features=10240, bias=True) (fc2): Linear(in_features=10240, out_features=2560, bias=True) (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) ) ) (layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) ) ) (lm_head): Linear(in_features=2560, out_features=8008, bias=False) ) ``` --- examples/hf/pippy_blenderbot.py | 109 ++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_blenderbot.py diff --git a/examples/hf/pippy_blenderbot.py b/examples/hf/pippy_blenderbot.py new file mode 100644 index 000000000..25073a822 --- /dev/null +++ b/examples/hf/pippy_blenderbot.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_blenderbot.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import BlenderbotForCausalLM, BlenderbotConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(blenderbot, nranks): + layers_per_rank = blenderbot.config.decoder_layers // nranks + for i in range(1, nranks): + annotate_split_points( + blenderbot, {f"model.decoder.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = BlenderbotConfig() + print("Using device:", args.device) + + # Create model + model_class = BlenderbotForCausalLM + model_name = "BlenderbotForCausalLM" + blenderbot = model_class(config) + blenderbot.to(args.device) + blenderbot.eval() + if args.rank == 0: + print(blenderbot.config) + print(f"Total number of params = {get_number_of_params(blenderbot) // 10 ** 6}M") + print(blenderbot) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, blenderbot, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(blenderbot, args.world_size) + + # Create pipeline + blenderbot_pipe = Pipe.from_tracing( + blenderbot, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(blenderbot_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(blenderbot_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + blenderbot_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From d9d3aad03a205ed220786b33f571b771e06a86f0 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 6 Dec 2023 17:59:37 -0500 Subject: [PATCH 73/96] Add PLBart example (#895) Needs https://github.com/pytorch/pytorch/pull/114982 to work. ``` PLBartForCausalLM( (model): PLBartDecoderWrapper( (decoder): PLBartDecoder( (embed_tokens): Embedding(50005, 768, padding_idx=1) (embed_positions): PLBartLearnedPositionalEmbedding(1026, 768) (layers): ModuleList( (0-5): 6 x PLBartDecoderLayer( (self_attn): PLBartAttention( (k_proj): Linear(in_features=768, out_features=768, bias=True) (v_proj): Linear(in_features=768, out_features=768, bias=True) (q_proj): Linear(in_features=768, out_features=768, bias=True) (out_proj): Linear(in_features=768, out_features=768, bias=True) ) (activation_fn): GELUActivation() (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (encoder_attn): PLBartAttention( (k_proj): Linear(in_features=768, out_features=768, bias=True) (v_proj): Linear(in_features=768, out_features=768, bias=True) (q_proj): Linear(in_features=768, out_features=768, bias=True) (out_proj): Linear(in_features=768, out_features=768, bias=True) ) (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=768, out_features=3072, bias=True) (fc2): Linear(in_features=3072, out_features=768, bias=True) (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) ) (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) ) (lm_head): Linear(in_features=768, out_features=50005, bias=False) ) ``` --- examples/hf/pippy_plBart.py | 116 ++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 examples/hf/pippy_plBart.py diff --git a/examples/hf/pippy_plBart.py b/examples/hf/pippy_plBart.py new file mode 100644 index 000000000..c01ac99ba --- /dev/null +++ b/examples/hf/pippy_plBart.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_plBart.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import PLBartForCausalLM, PLBartConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(plbart, nranks): + # The first rank carries embedding + annotate_split_points( + plbart, {f"model.decoder.embed_positions": PipeSplitWrapper.SplitPoint.END}) + # The last rank carries LM head + annotate_split_points( + plbart, {"lm_head": PipeSplitWrapper.SplitPoint.BEGINNING}) + # The rest ranks divide 6 layers + layers_per_rank = plbart.config.num_hidden_layers // (nranks - 2) + for i in range(1, nranks - 2): + annotate_split_points( + plbart, {f"model.decoder.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = PLBartConfig() + print("Using device:", args.device) + + # Create model + model_class = PLBartForCausalLM + model_name = "PLBartForCausalLM" + plbart = model_class(config) + plbart.to(args.device) + plbart.eval() + if args.rank == 0: + print(plbart.config) + print(f"Total number of params = {get_number_of_params(plbart) // 10 ** 6}M") + print(plbart) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, plbart, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(plbart, args.world_size) + + # Create pipeline + plbart_pipe = Pipe.from_tracing( + plbart, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(plbart_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(plbart_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + plbart_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From a3b35db312e1f53d41877a6b667c0fc481fdd668 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 7 Dec 2023 13:10:19 -0500 Subject: [PATCH 74/96] Add XLNet example (#908) ``` XLNetLMHeadModel( (transformer): XLNetModel( (word_embedding): Embedding(32000, 1024) (layer): ModuleList( (0-23): 24 x XLNetLayer( (rel_attn): XLNetRelativeAttention( (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (ff): XLNetFeedForward( (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True) (layer_1): Linear(in_features=1024, out_features=4096, bias=True) (layer_2): Linear(in_features=4096, out_features=1024, bias=True) (dropout): Dropout(p=0.1, inplace=False) (activation_function): GELUActivation() ) (dropout): Dropout(p=0.1, inplace=False) ) ) (dropout): Dropout(p=0.1, inplace=False) ) (lm_loss): Linear(in_features=1024, out_features=32000, bias=True) ) ``` --- examples/hf/pippy_xlnet.py | 109 +++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_xlnet.py diff --git a/examples/hf/pippy_xlnet.py b/examples/hf/pippy_xlnet.py new file mode 100644 index 000000000..a92679927 --- /dev/null +++ b/examples/hf/pippy_xlnet.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_xlnet.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import XLNetLMHeadModel, XLNetConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(xlnet, nranks): + layers_per_rank = xlnet.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + xlnet, {f"transformer.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = XLNetConfig() + print("Using device:", args.device) + + # Create model + model_class = XLNetLMHeadModel + model_name = "XLNetLMHeadModel" + xlnet = model_class(config) + xlnet.to(args.device) + xlnet.eval() + if args.rank == 0: + print(xlnet.config) + print(f"Total number of params = {get_number_of_params(xlnet) // 10 ** 6}M") + print(xlnet) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, xlnet, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(xlnet, args.world_size) + + # Create pipeline + xlnet_pipe = Pipe.from_tracing( + xlnet, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(xlnet_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(xlnet_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + xlnet_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 7d93a41d149d5c535d349e0e41175badb4826ca9 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 7 Dec 2023 13:11:15 -0500 Subject: [PATCH 75/96] Add ConvBert example (#909) ``` ConvBertForMaskedLM( (convbert): ConvBertModel( (embeddings): ConvBertEmbeddings( (word_embeddings): Embedding(30522, 768, padding_idx=1) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): ConvBertEncoder( (layer): ModuleList( (0-11): 12 x ConvBertLayer( (attention): ConvBertAttention( (self): ConvBertSelfAttention( (query): Linear(in_features=768, out_features=384, bias=True) (key): Linear(in_features=768, out_features=384, bias=True) (value): Linear(in_features=768, out_features=384, bias=True) (key_conv_attn_layer): SeparableConv1D( (depthwise): Conv1d(768, 768, kernel_size=(9,), stride=(1,), padding=(4,), groups=768, bias=False) (pointwise): Conv1d(768, 384, kernel_size=(1,), stride=(1,), bias=False) ) (conv_kernel_layer): Linear(in_features=384, out_features=54, bias=True) (conv_out_layer): Linear(in_features=768, out_features=384, bias=True) (unfold): Unfold(kernel_size=[9, 1], dilation=1, padding=[4, 0], stride=1) (dropout): Dropout(p=0.1, inplace=False) ) (output): ConvBertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): ConvBertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): ConvBertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) ) (generator_predictions): ConvBertGeneratorPredictions( (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dense): Linear(in_features=768, out_features=768, bias=True) ) (generator_lm_head): Linear(in_features=768, out_features=30522, bias=True) ) ``` --- examples/hf/pippy_convBert.py | 114 ++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 examples/hf/pippy_convBert.py diff --git a/examples/hf/pippy_convBert.py b/examples/hf/pippy_convBert.py new file mode 100644 index 000000000..9baa31d29 --- /dev/null +++ b/examples/hf/pippy_convBert.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_convBert.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import ConvBertForMaskedLM, ConvBertConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(convbert, nranks): + # The first rank takes embedding + annotate_split_points(convbert, {"convbert.embeddings": PipeSplitWrapper.SplitPoint.END}) + # The last rank takes generation + annotate_split_points(convbert, {"generator_predictions": PipeSplitWrapper.SplitPoint.BEGINNING}) + # The rest ranks divide encoder layers + layers_per_rank = convbert.config.num_hidden_layers // (nranks - 2) + for i in range(1, nranks - 2): + annotate_split_points( + convbert, {f"convbert.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = ConvBertConfig() + print("Using device:", args.device) + + # Create model + model_class = ConvBertForMaskedLM + model_name = "ConvBertForMaskedLM" + convbert = model_class(config) + convbert.to(args.device) + convbert.eval() + if args.rank == 0: + print(convbert.config) + print(f"Total number of params = {get_number_of_params(convbert) // 10 ** 6}M") + print(convbert) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, convbert, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(convbert, args.world_size) + + # Create pipeline + convbert_pipe = Pipe.from_tracing( + convbert, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(convbert_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(convbert_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + convbert_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 545bded557e4b18bb4ccad29fdc6eb8f7acab95b Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 8 Dec 2023 11:40:47 -0500 Subject: [PATCH 76/96] Add TrOCR example (#907) Requires https://github.com/pytorch/pytorch/pull/114982 to work. ``` TrOCRForCausalLM( (model): TrOCRDecoderWrapper( (decoder): TrOCRDecoder( (embed_tokens): Embedding(50265, 1024, padding_idx=1) (embed_positions): TrOCRLearnedPositionalEmbedding(514, 1024) (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (layers): ModuleList( (0-11): 12 x TrOCRDecoderLayer( (self_attn): TrOCRAttention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (activation_fn): GELUActivation() (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (encoder_attn): TrOCRAttention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=1024, out_features=4096, bias=True) (fc2): Linear(in_features=4096, out_features=1024, bias=True) (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) ) ) (output_projection): Linear(in_features=1024, out_features=50265, bias=False) ) ``` --- examples/hf/pippy_trOCR.py | 109 +++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 examples/hf/pippy_trOCR.py diff --git a/examples/hf/pippy_trOCR.py b/examples/hf/pippy_trOCR.py new file mode 100644 index 000000000..d2a016057 --- /dev/null +++ b/examples/hf/pippy_trOCR.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 pippy_trOCR.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import TrOCRForCausalLM, TrOCRConfig + +from hf_utils import generate_inputs_for_model, get_number_of_params + + +def add_split_points(trocr, nranks): + layers_per_rank = trocr.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + trocr, {f"model.decoder.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = TrOCRConfig() + print("Using device:", args.device) + + # Create model + model_class = TrOCRForCausalLM + model_name = "TrOCRForCausalLM" + trocr = model_class(config) + trocr.to(args.device) + trocr.eval() + if args.rank == 0: + print(trocr.config) + print(f"Total number of params = {get_number_of_params(trocr) // 10 ** 6}M") + print(trocr) + + # Input configs + example_inputs = generate_inputs_for_model( + model_class, trocr, model_name, args.batch_size, args.device) + input_ids = example_inputs["input_ids"] + + # Annotate split points + add_split_points(trocr, args.world_size) + + # Create pipeline + trocr_pipe = Pipe.from_tracing( + trocr, + num_chunks=args.chunks, + example_args=(input_ids, ), + ) + nstages = len(list(trocr_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + if args.rank == 0: + for i, sm in enumerate(trocr_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + trocr_pipe, + args.rank, + device=args.device, + ) + + # Run + if args.rank == 0: + stage(input_ids) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 196436f8fb357f17cce354be59bc0737ae534d2b Mon Sep 17 00:00:00 2001 From: Less Wright Date: Tue, 12 Dec 2023 22:15:58 -0800 Subject: [PATCH 77/96] Upgrade pipeline scheduler test to use device_mesh (#914) ## Description Moves the pipeline scheduler testing to using device_mesh, instead of using init_process_group and torch.cuda.set_device. ## Type of change - [x] New feature (non-breaking change which adds functionality) ## Feature/Issue validation/testing Verified on single node and multi-node. --- test/test_pipeline_schedule.py | 48 ++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/test/test_pipeline_schedule.py b/test/test_pipeline_schedule.py index 5e3125daa..604a3ec4a 100644 --- a/test/test_pipeline_schedule.py +++ b/test/test_pipeline_schedule.py @@ -23,10 +23,7 @@ import os from contextlib import contextmanager, nullcontext -from datetime import timedelta - import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn from pippy.PipelineSchedule import ( @@ -35,8 +32,10 @@ PipelineScheduleLoopedDFS, PipelineStageV2Impl, ) +from torch.distributed._tensor.device_mesh import init_device_mesh from torch.profiler import record_function + logger = logging.getLogger(__name__) _null_context = nullcontext() @@ -97,37 +96,54 @@ def forward(self, x): return c -def setup(local_rank, world_size): +def setup(local_rank, init_process_group=False): # If this is a child process (i.e., its PID is not the same as the PID of the process that started this script) if os.getppid() != os.getpid(): set_up_logging(local_rank) - # initialize the process group - logger.info(f"init for rank {local_rank}") - dist.init_process_group("nccl", timeout=timedelta(seconds=20)) - if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) + # initialize the process group (not needed if using device_mesh) + if init_process_group: + logger.info(f"init for rank {local_rank}") + if torch.distributed.is_initialized(): + torch.cuda.set_device(local_rank) logger.info(f"finish init for rank {local_rank}") def main(**kwargs): torch.manual_seed(42) + device = torch.device(kwargs["device"]) - rank = kwargs["rank"] local_rank = kwargs["local_rank"] world_size = kwargs["world_size"] - device = torch.device(kwargs["device"]) - - setup(local_rank, world_size) - logger.info( - f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size}, Device {device} main ======" - ) def rank_print(msg): if rank == 0: print(f"{msg}") + # use device mesh for cuda - create a device mesh based on the given world_size. + device_mesh = None + rank = None + + if device.type == "cuda": + device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),) + ) + + rank = device_mesh.get_rank() + rank_print(f"using {device_mesh=}") + + if not device_mesh: + rank = kwargs["rank"] + + init_process_group = bool(device_mesh is not None) + + setup(local_rank, init_process_group=init_process_group) + + logger.info( + f"====== World Rank {rank}, Local Rank {local_rank}, World Size {world_size}, Device {device} main ======" + ) + rank_print(f"My KWARGS are {kwargs}") input_dim = 4000 From c3e1ced6a26261ea5acb2b21cb9e23bd3a21a496 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 13 Dec 2023 13:45:26 -0500 Subject: [PATCH 78/96] Fix qualname mapping after tracer migration (#915) ## Description Dynamo adds `L__self___` prefix while FX tracer does not. Also, export methods (including `_export_to_torch_ir`) flatten modules and hence replace "." with "_". The above changes breaks the `remap_qualname` functionality of PiPPy. ## Solution As inspired by https://github.com/pytorch/pytorch/pull/115462, we prepare a mapping of parameters (and buffers) before and after tracing. This allows us to perform a second remap, which remaps qualname from after tracing to the one before. The previous remap, still serves as a first remap pass, which maps qualname from after splitting to the one before splitting. ## Test Test 1: `pytest test/test_ir.py -k test_remap_qualname` Test 2: `python test/test_pipe.py` Output: ``` Testing ExampleCode equivalence test passed 96870448.0 ref 96870448.0 split_gm.submod_1.L__self___lin.bias -> lin.bias split_gm.submod_0.moved_L__self___mm_param -> mm_param split_gm.submod_3.L__self___lin.weight -> lin.weight split_gm.submod_3.L__self___lin.bias -> lin.bias split_gm.submod_1.moved_L__self___mm_param -> mm_param split_gm.submod_1.L__self___lin.weight -> lin.weight split_gm.submod_2.moved_L__self___mm_param2 -> mm_param2 Qualname check passed Testing MultiMLP equivalence test passed -456.07415771484375 ref -456.07415771484375 split_gm.submod_3.L__self___mlp3_net2.bias -> mlp3.net2.bias split_gm.submod_2.L__self___mlp2_net1.weight -> mlp2.net1.weight split_gm.submod_2.L__self___mlp2_net1.bias -> mlp2.net1.bias split_gm.submod_2.L__self___mlp2_net2.bias -> mlp2.net2.bias split_gm.submod_1.L__self___mlp1_net1.weight -> mlp1.net1.weight split_gm.submod_0.L__self___mlp0_net2.weight -> mlp0.net2.weight split_gm.submod_1.L__self___mlp1_net2.weight -> mlp1.net2.weight split_gm.submod_0.L__self___mlp0_net1.weight -> mlp0.net1.weight split_gm.submod_3.L__self___mlp3_net1.weight -> mlp3.net1.weight split_gm.submod_3.L__self___mlp3_net1.bias -> mlp3.net1.bias split_gm.submod_1.L__self___mlp1_net2.bias -> mlp1.net2.bias split_gm.submod_3.L__self___mlp3_net2.weight -> mlp3.net2.weight split_gm.submod_2.L__self___mlp2_net2.weight -> mlp2.net2.weight split_gm.submod_0.L__self___mlp0_net1.bias -> mlp0.net1.bias split_gm.submod_0.L__self___mlp0_net2.bias -> mlp0.net2.bias split_gm.submod_1.L__self___mlp1_net1.bias -> mlp1.net1.bias Qualname check passed ``` Fixes #912 --- pippy/IR.py | 101 ++++++++++++++++++++++++++++++++++++++-------- test/test_ir.py | 70 +++----------------------------- test/test_pipe.py | 11 +++++ 3 files changed, 101 insertions(+), 81 deletions(-) diff --git a/pippy/IR.py b/pippy/IR.py index f484f5430..3b41f3526 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -492,20 +492,25 @@ class QualnameMapMixin: def __init__( self, - qualname_mapping: Dict[str, str] = None, + splitter_qualname_map: Dict[str, str] = None, + tracer_qualname_map: Dict[str, str] = None, ): self.new_to_old_qualname_mapping: Dict[str, str] = ( - qualname_mapping or {} + splitter_qualname_map or {} ) + self.tracer_qualname_map = tracer_qualname_map def remap_qualname(self, qualname: str): # TODO: annoying if qualname.startswith("split_gm."): qualname = qualname[len("split_gm.") :] - # The qualname map does not store recursive items, thus, - # when passed a qualname with leaves, we need to perform longest prefix match - if qualname not in self.new_to_old_qualname_mapping: + name_before_split = None + if qualname in self.new_to_old_qualname_mapping: + name_before_split = self.new_to_old_qualname_mapping[qualname] + else: + # The qualname map does not store recursive items, thus, + # when passed a qualname with leaves, we need to perform longest prefix match # Split from the right, one each time split_names = qualname.rsplit(".", 1) leaf = split_names[-1] @@ -513,25 +518,34 @@ def remap_qualname(self, qualname: str): prefix = split_names[0] if prefix in self.new_to_old_qualname_mapping: old_prefix = self.new_to_old_qualname_mapping[prefix] - return ".".join([old_prefix, leaf]) + name_before_split = ".".join([old_prefix, leaf]) + break split_names = prefix.rsplit(".", 1) leaf = ".".join([split_names[-1], leaf]) - # Either full name match, or key not found - return self.new_to_old_qualname_mapping[qualname] + if name_before_split is None: + raise RuntimeError(f"Could not find mapping for {qualname}") + + if self.tracer_qualname_map is not None: + return self.tracer_qualname_map[name_before_split] + else: + return name_before_split class Pipe(QualnameMapMixin, torch.nn.Module): def __init__( self, split_gm: fx.GraphModule, - qualname_mapping: Dict[str, str], + splitter_qualname_map: Dict[str, str], num_stages: int, has_loss_and_backward: bool, loss_spec, + tracer_qualname_map: Optional[Dict[str, str]] = None, ): # TODO: is there a way not to hard wire init? - QualnameMapMixin.__init__(self, qualname_mapping) + QualnameMapMixin.__init__( + self, splitter_qualname_map, tracer_qualname_map + ) torch.nn.Module.__init__(self) self.split_gm: fx.GraphModule = split_gm self.executor: DetachExecutor = DetachExecutor(self.split_gm) @@ -728,10 +742,10 @@ def split_callback(n: fx.Node): return part_idx # Ask split_module to return mapping from new qualname to old qualname - qualname_map: Dict[str, str] = {} + splitter_qualname_map: Dict[str, str] = {} # TODO: what does split do with module invocations? does it move the modules # into the submodules? - split = split_module(traced, mod, split_callback, qualname_map) + split = split_module(traced, mod, split_callback, splitter_qualname_map) # a (custom) tracer can produce dead code like orphan get_attr nodes split.graph.eliminate_dead_code() @@ -783,13 +797,15 @@ def move_param_to_callee( # Update qualname mapping # New qualname will have submodule prefix new_qualname = f"{callee_name}.{new_param_name}" - if node.target in qualname_map: - # Just in case the target name is already in the qualname_map + if node.target in splitter_qualname_map: + # Just in case the target name is already in the splitter_qualname_map # returned by split_module() -- we update the mapping using the # new name as a new key - qualname_map[new_qualname] = qualname_map.pop(node.target) + splitter_qualname_map[new_qualname] = splitter_qualname_map.pop( + node.target + ) else: - qualname_map[new_qualname] = node.target + splitter_qualname_map[new_qualname] = node.target ph_counter = 0 for sn in callee.graph.nodes: @@ -1005,12 +1021,17 @@ def move_param_to_callee( "Pipeline is in inference mode, backward pass not generated" ) + # Tracer may modify qualname, get the qualname mapping before and after tracing. + # This qualname mapping is different from the mapping before and after splitting. + tracer_qualname_map = Pipe._get_param_buffer_mapping(mod, traced) + return Pipe( split, - qualname_map, + splitter_qualname_map, num_stages, has_loss_and_backward, generated_loss_spec, + tracer_qualname_map, ) @staticmethod @@ -1117,6 +1138,52 @@ def __str__(self): def __repr__(self): return self.split_gm.__repr__() + # TODO: this util comes from pytorch/pytorch#115462, delete it from PiPPy + # when PyTorch 2.3 comes with support, or when PiPPy migrates from + # `_export_to_torch_ir` to export + unflattener. + @staticmethod + def _get_param_buffer_mapping( + original_module: torch.nn.Module, + traced_module: torch.nn.Module, + ) -> Dict[str, str]: + """ + Returns a mapping of parameter/buffer names from the new module to the + original model. This is to help with restoring the FQN for parameter/buffers + of a traced module to what the original module contains. + """ + + param_lookup: Dict[int, List[str]] = {} + buffer_lookup: Dict[int, List[str]] = {} + for name, param in original_module.named_parameters( + remove_duplicate=False + ): + param_lookup.setdefault(id(param), []).append(name) + for name, buffer in original_module.named_buffers( + remove_duplicate=False + ): + buffer_lookup.setdefault(id(buffer), []).append(name) + + param_buffer_table: Dict[str, str] = {} + for dynamo_name, dynamo_param in traced_module.named_parameters( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_param) in param_lookup: + param_buffer_table[dynamo_name] = param_lookup[ + id(dynamo_param) + ].pop() + + for dynamo_name, dynamo_buffer in traced_module.named_buffers( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_buffer) in buffer_lookup: + param_buffer_table[dynamo_name] = buffer_lookup[ + id(dynamo_buffer) + ].pop() + + return param_buffer_table + class PipeSplitWrapper(torch.nn.Module): class SplitPoint(Enum): diff --git a/test/test_ir.py b/test/test_ir.py index 768dec27f..6e1fcae2c 100644 --- a/test/test_ir.py +++ b/test/test_ir.py @@ -5,7 +5,6 @@ import unittest from typing import NamedTuple -import pippy.fx import torch from pippy.IR import ( @@ -26,11 +25,6 @@ ) -@pippy.fx.wrap -def arange_wrapper(*args, **kwargs): - return torch.arange(*args, **kwargs) - - class ExampleCode(torch.nn.Module): def __init__(self): super().__init__() @@ -74,6 +68,7 @@ def setUp(self): mods += [mods[0]] self.seq = torch.nn.Sequential(*mods) self.ec = ExampleCode() + self.example_inputs = (torch.randn(50, 512),) def test_sequential(self): pipe_seq = PipeSequential.from_sequential(self.seq) @@ -337,39 +332,6 @@ def forward(self, x, target): v_ref = ref_grads[k_ref] torch.testing.assert_close(v_test, v_ref) - def test_custom_tracer_serialization(self): - class CustomTracer(pippy.fx.Tracer): - def trace(self, root, concrete_args=None): - rv = super().trace(root, concrete_args) - for node in rv.nodes: - if node.target == arange_wrapper: - node.target = torch.arange - node.meta.clear() - return rv - - class FooMod(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(1)) - - def forward(self, x): - return x + arange_wrapper(x.shape[-1]) + 1 + torch.zeros(1) - - fm = FooMod() - - tracer = CustomTracer() - pipe = Pipe.from_tracing(fm, tracer=tracer) - - with tempfile.TemporaryDirectory() as d: - with open(d + "tmp.pkl", "wb") as f: - pickle.dump(pipe.split_gm.submod_0, f) - - with open(d + "tmp.pkl", "rb") as f: - loaded = pickle.load(f) - - x = torch.randn(5, 3) - torch.testing.assert_close(pipe.split_gm.submod_0(x), loaded(x)) - def test_direct_serialization_recursion_depth(self): class LomgBoi(torch.nn.Module): def forward(self, x): @@ -814,31 +776,8 @@ def forward( chunks_merged_masked["multiplied"], ref_out["multiplied"] ) - def test_remap_qualname_transmit(self): - ec_pipe = Pipe.from_tracing(self.ec, MultiUseParameterConfig.TRANSMIT) - - # Get the first field of all tuples, i.e. names - old_named_params = zip(*list(self.ec.named_parameters())) - old_names = list(old_named_params)[0] - - # Check qualname mapping for pipe - for new_name, _ in ec_pipe.named_parameters(): - old_name = ec_pipe.remap_qualname(new_name) - # print(f"{new_name} -> {old_name}") - assert ( - old_name in old_names - ), f"Remapped parameter {old_name} not found in {old_names}" - - # Check qualname mapping for submodule - for _, stage_mod in ec_pipe.split_gm.named_children(): - for new_name, _ in stage_mod.named_parameters(): - old_name = stage_mod.remap_qualname(new_name) - assert ( - old_name in old_names - ), f"Remapped parameter {old_name} not found in {old_names}" - - def test_remap_qualname_replicate(self): - ec_pipe = Pipe.from_tracing(self.ec, MultiUseParameterConfig.REPLICATE) + def test_remap_qualname(self): + ec_pipe = Pipe.from_tracing(self.ec, 1, self.example_inputs) # Get the first field of all tuples, i.e. names old_named_params = zip(*list(self.ec.named_parameters())) @@ -853,12 +792,15 @@ def test_remap_qualname_replicate(self): ), f"Remapped parameter {old_name} not found in {old_names}" # Check qualname mapping for submodule + # Not supported at the moment + """ for _, stage_mod in ec_pipe.split_gm.named_children(): for new_name, _ in stage_mod.named_parameters(): old_name = stage_mod.remap_qualname(new_name) assert ( old_name in old_names ), f"Remapped parameter {old_name} not found in {old_names}" + """ if __name__ == "__main__": diff --git a/test/test_pipe.py b/test/test_pipe.py index 6259953e7..1400c2906 100644 --- a/test/test_pipe.py +++ b/test/test_pipe.py @@ -89,6 +89,17 @@ def run_worker(args, model_class): torch.testing.assert_close(out, ref_out) print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}") + # Check qualname + old_names = set(mod.state_dict().keys()) + new_names = set(pipe.state_dict().keys()) + for new_name in new_names: + old_name = pipe.remap_qualname(new_name) + assert ( + old_name in old_names + ), f"Remapped parameter {old_name} not found in {old_names}" + print(f"{new_name} -> {old_name}") + print("Qualname check passed") + def main(args=None): parser = argparse.ArgumentParser() From d53405b601929258401c9ce93e8a77c1079cdc7f Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 14 Dec 2023 10:11:12 -0500 Subject: [PATCH 79/96] Limit torch version in CI (#916) ## Description https://github.com/pytorch/pytorch/pull/115557 moved `_export_to_torch_ir` to `torch.export._trace`. Limit our torch version in our CI to avoid a break. Since https://github.com/pytorch/pytorch/pull/115557 came after PyTorch 2.2 release cut, PyTorch won't have it till 2.3. So once 2.2 is released, we can switch to release version instead of nightly. Meanwhile, we will be migrating from `_export_to_torch_ir` to `export`. So this version limit is a temporary workaround while we migrate. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4d73a1d52..256160137 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch >= 2.2.0.dev +torch >= 2.2.0.dev, <2.2.0.dev20231212 packaging >= 21.3 From 12c0b6ba9ad0e61041c3583b19b596b72c2e35eb Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 18 Dec 2023 10:47:40 -0500 Subject: [PATCH 80/96] Support multiple arguments in PipelineStage forward (#913) Currently the `PipelineStage` API wraps around an `nn.module` but does not support `nn.module`s which have multiple arguments in their `forward()`. This updates `PipelineStage` so that the forward may take an arbitrary number of arguments (tensors only) and corresponding p2p ops (`get_fwd_recv_ops`, `get_bwd_send_ops`, etc.) which are used in the pipeline schedules are still valid. Tested on all the current schedules --- pippy/PipelineSchedule.py | 153 ++++++++++++++++++++++----------- test/test_pipeline_schedule.py | 37 ++++---- 2 files changed, 120 insertions(+), 70 deletions(-) diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index 67af9b3b6..ed5ca147f 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod from collections import deque -from typing import Deque, List, Optional, Tuple +from typing import Deque, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -19,10 +19,8 @@ class PipelineStage(ABC, nn.Module): @abstractmethod - def forward(self, microbatch): + def forward(self, args: List[torch.tensor]) -> torch.tensor: """ - TODO: this will be updated to support multiple arguments - Perform forward pass on the module. This should only be called once per microbatch. @@ -75,6 +73,23 @@ def compute_loss(self): raise NotImplementedError +def create_buffers( + input: Union[torch.Tensor, List[torch.tensor]], device: torch.device +) -> List[torch.Tensor]: + """ + Creates buffers for a given input on a specified device. + This function takes as input a tensor or a list of tensors and returns a tensor or a list of tensors (respectively) + of the same shape, but located on the specified device and uninitialized (i.e., filled with arbitrary data). + """ + if isinstance(input, torch.Tensor): + return [torch.empty_like(input, device=device)] + elif isinstance(input, (list, tuple)): + return [torch.empty_like(inp, device=device) for inp in input] + raise ValueError( + f"Unsupported input type {type(input)} cannot create buffers" + ) + + class PipelineStageV2Impl(PipelineStage): def __init__( self, @@ -83,7 +98,7 @@ def __init__( num_stages: int, rank: int, world_size: int, - meta_input: torch.Tensor, + input_args: List[torch.tensor], device: torch.device, ): super().__init__() @@ -94,11 +109,13 @@ def __init__( self.num_stages = num_stages # When we materialize the model partition on cuda, we call reset_parameters() if it is available self.module = module.to(device) - - meta_output = self.module(meta_input) - self.fwd_input = torch.empty_like(meta_input, device=device) - self.fwd_output = None - self.fwd_output_grads = torch.empty_like(meta_output, device=device) + logger.info(f"input args {input_args=}") + meta_output = self.module(*input_args) + self.fwd_inputs: List[torch.tensor] = create_buffers(input_args, device) + self.fwd_outputs = None + self.fwd_output_grads: List[torch.tensor] = create_buffers( + meta_output, device + ) self.fwd_outputs_for_backward: Deque[ Tuple[torch.tensor, torch.tensor] ] = deque() @@ -110,8 +127,13 @@ def __init__( self.bwd_recv_queue = None self.requests: List[dist.P2POp] = [] - logger.info( - f"finished pipeline stage init, {self.stage_id=}, {self.is_first_stage=}, {self.is_last_stage=}, {self.num_stages=}, {self.fwd_input.shape=}, {self.fwd_output_grads.shape=}" + logger.debug( + f""" + finished pipeline stage init, {self.stage_id=}, {self.is_first_stage=}, + {self.is_last_stage=}, {self.num_stages=}, + {[fwd_input.shape for fwd_input in self.fwd_inputs]}, + {[fwd_output_grad.shape for fwd_output_grad in self.fwd_output_grads]} + """ ) def init_p2p_neighbors(self): @@ -141,83 +163,110 @@ def init_p2p_neighbors(self): def get_fwd_recv_ops(self) -> List[dist.P2POp]: if self.is_first_stage: return [] - return [dist.P2POp(dist.irecv, self.fwd_input, self.prev_stage)] + return [ + dist.P2POp(dist.irecv, fwd_input, self.prev_stage) + for fwd_input in self.fwd_inputs + ] def get_fwd_send_ops(self) -> List[dist.P2POp]: + assert ( + self.fwd_outputs is not None + ), "forward() must be called before get_fwd_send_ops" if self.is_last_stage: return [] - return [dist.P2POp(dist.isend, self.fwd_output, self.next_stage)] + return [ + dist.P2POp(dist.isend, fwd_output, self.next_stage) + for fwd_output in self.fwd_outputs + ] - def forward(self, microbatch: torch.Tensor): + def forward(self, args: List[torch.tensor]) -> torch.tensor: logger.info(f"[{self.rank} FORWARD {self.stage_id}") if self.is_first_stage: - self.fwd_input = microbatch + self.fwd_inputs = args # this is needed when we access the gradients for this in backward() - self.fwd_input.requires_grad = True - self.fwd_input.retain_grad() + if not self.is_first_stage: + for tensor in self.fwd_inputs: + tensor.requires_grad = True + tensor.retain_grad() # perform forward pass on module - self.fwd_output = self.module(self.fwd_input) + self.fwd_outputs = self.module(*self.fwd_inputs) - output_for_backward = ( - self.compute_loss() if self.is_last_stage else self.fwd_output + fwd_outputs_for_backward = ( + self.compute_loss() if self.is_last_stage else self.fwd_outputs ) # we store a ref to the input/output pair for this forward to be later used by the corresponding backward self.fwd_outputs_for_backward.append( - (self.fwd_input, output_for_backward) + (self.fwd_inputs, fwd_outputs_for_backward) ) - return self.fwd_output - - def get_bwd_send_ops(self) -> List[dist.P2POp]: - if self.is_first_stage: - return [] - assert self.fwd_input.grad is not None, "grad must be valid" - return [dist.P2POp(dist.isend, self.fwd_input.grad, self.prev_stage)] + return self.fwd_outputs def get_bwd_recv_ops(self) -> List[dist.P2POp]: if self.is_last_stage: return [] - return [dist.P2POp(dist.irecv, self.fwd_output_grads, self.next_stage)] - - def sync_recv_backward_inputs(self) -> None: - ops = self.get_bwd_recv_ops() - if ops: - dist.batch_isend_irecv(ops).pop().wait() + return [ + dist.P2POp(dist.irecv, output_grad, self.next_stage) + for output_grad in self.fwd_output_grads + ] - def _wait_backward_inputs(self): - assert ( - self.bwd_recv_queue is not None - ), "Waiting for backward input without enqueueing one" - self.bwd_recv_queue.wait() - self.bwd_recv_queue = None - return self.fwd_output_grads + def get_bwd_send_ops(self) -> List[dist.P2POp]: + if self.is_first_stage: + return [] + for fwd_input in self.fwd_inputs: + logger.info(f"{fwd_input.grad=}") + assert fwd_input.grad is not None, "grad must be valid" + return [ + dist.P2POp(dist.isend, fwd_input.grad, self.prev_stage) + for fwd_input in self.fwd_inputs + ] def backward(self): logger.info(f"[{self.rank} BACKWARD {self.stage_id}]") if self.is_last_stage: - fwd_inputs, loss = self.fwd_outputs_for_backward.popleft() + self.fwd_inputs, loss = self.fwd_outputs_for_backward.popleft() else: - fwd_inputs, fwd_outputs = self.fwd_outputs_for_backward.popleft() + ( + self.fwd_inputs, + fwd_outputs, + ) = self.fwd_outputs_for_backward.popleft() # Compute gradients + # TODO: HACK materialize_grads=True sets gradients to 0s on backward pass, + # we need set all the gradients for the inputs that need it, but should not send 0s + # due to extra communication if self.is_last_stage: - torch.autograd.backward(loss, retain_graph=True) + gradients = torch.autograd.grad( + outputs=loss, + inputs=self.fwd_inputs, + retain_graph=True, + allow_unused=True, + materialize_grads=True, + ) else: - torch.autograd.backward( - fwd_outputs, self.fwd_output_grads, retain_graph=True + gradients = torch.autograd.grad( + outputs=fwd_outputs, + inputs=self.fwd_inputs, + grad_outputs=self.fwd_output_grads, + retain_graph=True, + allow_unused=True, + materialize_grads=True, ) - return fwd_inputs + # Set the gradients for each tensor in self.fwd_inputs + for i in range(len(self.fwd_inputs)): + self.fwd_inputs[i].grad = gradients[i] + + return self.fwd_inputs def compute_loss(self): - if self.fwd_output is None: + if self.fwd_outputs is None: raise RuntimeError("forward() must be called before compute_loss()") # TODO: use a real loss function passed in - return self.fwd_output.mean() + return self.fwd_outputs[0].mean() class PipelineSchedule(ABC): @@ -251,7 +300,7 @@ def step(self, microbatches): dist.batch_isend_irecv(ops) logger.info( - f"{self._stage.stage_id} forward {i} finished, microbatch: {mb.shape}" + f"{self._stage.stage_id} forward mb {i} finished, microbatch: {[inp.shape for inp in mb]}" ) for i, _ in enumerate(microbatches): @@ -266,7 +315,7 @@ def step(self, microbatches): if ops: dist.batch_isend_irecv(ops) - logger.info(f"{self._stage.stage_id} backward {i} finished") + logger.info(f"{self._stage.stage_id} backward mb {i} finished") class PipelineScheduleLoopedBFS(PipelineSchedule): diff --git a/test/test_pipeline_schedule.py b/test/test_pipeline_schedule.py index 604a3ec4a..92721bd73 100644 --- a/test/test_pipeline_schedule.py +++ b/test/test_pipeline_schedule.py @@ -86,14 +86,14 @@ def __init__( self.wo = nn.Linear(hidden_dim, out_dim, bias=False) self.gelu_act = nn.GELU(approximate="tanh") - def forward(self, x): + def forward(self, x, arg1): a = self.wi(x) a = self.wh1(a) a = self.wh2(a) a = self.wh3(a) b = self.gelu_act(a) c = self.wo(b) - return c + return c, arg1 def setup(local_rank, init_process_group=False): @@ -106,7 +106,6 @@ def setup(local_rank, init_process_group=False): logger.info(f"init for rank {local_rank}") if torch.distributed.is_initialized(): torch.cuda.set_device(local_rank) - logger.info(f"finish init for rank {local_rank}") @@ -162,9 +161,17 @@ def rank_print(msg): n_pp = world_size x = torch.randn([microbatch_size, input_dim]).to("meta") + unused = torch.ones((1, 1), device="meta") + input_args = (x, unused) stage_model = PipelineStageV2Impl( - module_list[rank], rank, world_size, rank, world_size, x, device + module_list[rank], + rank, + world_size, + rank, + world_size, + input_args, + device, ) stage_model.init_p2p_neighbors() @@ -175,26 +182,21 @@ def rank_print(msg): num_stages=world_size * world_size, rank=rank, world_size=world_size, - meta_input=x, + input_args=input_args, device=device, ) for i in range(world_size) ] x_cuda_empty = torch.empty_like(x, device="cuda") - microbatches = [ - torch.randn_like(x_cuda_empty) for _ in range(n_microbatches) - ] - # profiling setup (enable with --profiler True) + microbatches = [] + for i in range(n_microbatches): + microbatches.append( + (torch.randn_like(x_cuda_empty), torch.ones((1, 1), device="cuda")) + ) + _run_profiler = kwargs["profiler"] - _torch_profiler = None _trace_dir = kwargs["trace_dir"] - - if _run_profiler: - if not os.path.exists(_trace_dir): - os.mkdir(_trace_dir) - rank_print(f"Profiling active -- saving traces to {_trace_dir}") - for schedule in kwargs["schedules"]: logger.info(f"====== Rank {rank} running schedule {schedule} ======") if schedule == "gpipe": @@ -217,7 +219,6 @@ def rank_print(msg): ) as _torch_profiler: with record_function(schedule): pipeline.step(microbatches) - logger.info(f"====== Rank {rank} finished {schedule} ======") @@ -282,7 +283,7 @@ def set_up_logging(rank, log_level=logging.INFO): master_addr = "localhost" os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port - n_gpus = 4 + n_gpus = 2 world_size = n_gpus os.environ["WORLD_SIZE"] = str(world_size) print( From cca94f7794938663903afacbf832024198c54704 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 18 Dec 2023 13:13:43 -0500 Subject: [PATCH 81/96] Add LLaMA example (#917) ## Description Apply PiPPy to LLaMA 2 with equal layer cutting. We are getting the llama model from HuggingFace Model Hub: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf You may need to request access to this model before running the example. The prompt generation and decoding is are learned from @SunMarc --- examples/llama/pippy_llama.py | 111 ++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 examples/llama/pippy_llama.py diff --git a/examples/llama/pippy_llama.py b/examples/llama/pippy_llama.py new file mode 100644 index 000000000..33dfc43c3 --- /dev/null +++ b/examples/llama/pippy_llama.py @@ -0,0 +1,111 @@ +# Minimum effort to run this example: +# $ pip install transformers +# $ torchrun --nproc-per-node 2 pippy_llama.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + + +def add_split_points(llama, nranks): + # Cut model by equal number of layers per rank + layers_per_rank = (llama.config.num_hidden_layers + nranks - 1) // nranks + print(f"layers_per_rank = {layers_per_rank}") + for i in range(1, nranks): + annotate_split_points( + llama, + {f'model.layers.{i * layers_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}, + ) + + +def get_number_of_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def run(args): + # Create a blank model + llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True) + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + prompts = ( + "How do you", "I like to", "Can I help", "You have to", + "The weather is", "I have a", "What is your", "You are a", + ) # bs = 8 + tokenizer.pad_token = tokenizer.eos_token + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(args.device) + + # Move model to `device` and set to evaluation + llama.to(args.device) + llama.eval() + print(llama) + + # Annotate split points + add_split_points(llama, args.world_size) + + # Create a pipeline stage from the model + llama_pipe = Pipe.from_tracing( + llama, + num_chunks=args.world_size, + example_args=(inputs['input_ids'],), + ) + + assert len(list(llama_pipe.split_gm.children())) == args.world_size + if args.rank == 0: + for i, sm in enumerate(llama_pipe.split_gm.children()): + print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") + + # Create schedule runtime + stage = PipelineStage( + llama_pipe, + args.rank, + device=args.device, + ) + + # Run + output = None + if args.rank == 0: + stage(inputs['input_ids']) + elif args.rank == args.world_size - 1: + output = stage() + else: + stage() + + if output is not None: + next_token_logits = output[0][:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1) + print(tokenizer.batch_decode(next_token)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) From 4ab9266bb65c1520a32ee3bdc6cb012c019bd32b Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 19 Dec 2023 11:46:40 -0500 Subject: [PATCH 82/96] Update example.py (#918) ## Description to use stage based API instead of driver based API and cover it with CI --- .github/workflows/pippy_tests.yaml | 2 + example.py | 133 +++++++++++------------------ 2 files changed, 53 insertions(+), 82 deletions(-) diff --git a/.github/workflows/pippy_tests.yaml b/.github/workflows/pippy_tests.yaml index b79122078..561a94eaa 100644 --- a/.github/workflows/pippy_tests.yaml +++ b/.github/workflows/pippy_tests.yaml @@ -98,6 +98,8 @@ jobs: run: torchrun --nproc-per-node 4 test/test_fwd.py - name: Run forward-loss-backward integration test run: torchrun --nproc-per-node 4 test/test_bwd.py --schedule ${{ matrix.schedule }} + - name: Run example + run: torchrun --nproc-per-node 3 example.py # - name: Run null_coalesce_accumulate integration test # run: python test/local_test_null_coalesce_accumulate.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} # - name: Run PP + DDP test diff --git a/example.py b/example.py index 0793c5ac0..b59232811 100644 --- a/example.py +++ b/example.py @@ -1,6 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +# Minimal effort to run this code: +# $ torchrun --nproc-per-node 3 example.py + +import os import torch -from typing import Any +from pippy.IR import annotate_split_points, Pipe, PipeSplitWrapper +from pippy.PipelineStage import PipelineStage class MyNetworkBlock(torch.nn.Module): @@ -34,16 +39,30 @@ def forward(self, x): return self.output_proj(x) -mn = MyNetwork(512, [512, 1024, 256]) +# To run a distributed training job, we must launch the script in multiple +# different processes. We are using `torchrun` to do so in this example. +# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`, +# which represent the index of this process within the set of processes and +# the total number of processes, respectively. +# +# To learn more about `torchrun`, see +# https://pytorch.org/docs/stable/elastic/run.html -from pippy.IR import Pipe +torch.manual_seed(0) +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) -pipe = Pipe.from_tracing(mn) -print(pipe) -print(pipe.split_gm.submod_0) +# Figure out device to use +if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") +else: + device = torch.device("cpu") -from pippy.IR import annotate_split_points, PipeSplitWrapper +# Create the model +in_dim = 512 +layer_dims = [512, 1024, 256] +mn = MyNetwork(in_dim, layer_dims).to(device) annotate_split_points( mn, @@ -53,7 +72,12 @@ def forward(self, x): }, ) -pipe = Pipe.from_tracing(mn) +batch_size = 32 +example_input = torch.randn(batch_size, in_dim, device=device) +chunks = 4 + +pipe = Pipe.from_tracing(mn, chunks, example_args=(example_input,)) + print(" pipe ".center(80, "*")) print(pipe) print(" submod0 ".center(80, "*")) @@ -64,85 +88,30 @@ def forward(self, x): print(pipe.split_gm.submod_2) -# To run a distributed training job, we must launch the script in multiple -# different processes. We are using `torchrun` to do so in this example. -# `torchrun` defines two environment variables: `LOCAL_RANK` and `WORLD_SIZE`, -# which represent the index of this process within the set of processes and -# the total number of processes, respectively. -# -# To learn more about `torchrun`, see -# https://pytorch.org/docs/stable/elastic/run.html -import os +# Initialize distributed environment +import torch.distributed as dist -local_rank = int(os.environ["LOCAL_RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) +dist.init_process_group(rank=rank, world_size=world_size) -# PiPPy uses the PyTorch RPC interface. To use RPC, we must call `init_rpc` -# and inform the RPC framework of this process's rank and the total world -# size. We can directly pass values `torchrun` provided.` -# -# To learn more about the PyTorch RPC framework, see -# https://pytorch.org/docs/stable/rpc.html -import torch.distributed.rpc as rpc - -rpc.init_rpc(f"worker{local_rank}", rank=local_rank, world_size=world_size) - -# PiPPy relies on the concept of a "driver" process. The driver process -# should be a single process within the RPC group that instantiates the -# PipelineDriver and issues commands on that object. The other processes -# in the RPC group will receive commands from this process and execute -# the pipeline stages -if local_rank == 0: - # We are going to use the PipelineDriverFillDrain class. This class - # provides an interface for executing the `Pipe` in a style similar - # to the GPipe fill-drain schedule. To learn more about GPipe and - # the fill-drain schedule, see https://arxiv.org/abs/1811.06965 - from pippy.PipelineDriver import PipelineDriverFillDrain - from pippy.microbatch import TensorChunkSpec - - # Pipelining relies on _micro-batching_--that is--the process of - # dividing the program's input data into smaller chunks and - # feeding those chunks through the pipeline sequentially. Doing - # this requires that the data and operations be _separable_, i.e. - # there should be at least one dimension along which data can be - # split such that the program does not have interactions across - # this dimension. PiPPy provides `chunk_spec` arguments for this - # purpose, to specify the batch dimension for tensors in each of - # the args, kwargs, and outputs. The structure of the `chunk_spec`s - # should mirror that of the data type. Here, the program has a - # single tensor input and single tensor output, so we specify - # a single `TensorChunkSpec` instance indicating dimension 0 - # for args[0] and the output value. - args_chunk_spec: Any = (TensorChunkSpec(0),) - kwargs_chunk_spec: Any = {} - output_chunk_spec: Any = TensorChunkSpec(0) - - # Finally, we instantiate the PipelineDriver. We pass in the pipe, - # chunk specs, and world size, and the constructor will distribute - # our code to the processes in the RPC group. `driver` is an object - # we can invoke to run the pipeline. - driver = PipelineDriverFillDrain( - pipe, - 64, - world_size=world_size, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - ) - - x = torch.randn(512, 512) - - # Run the pipeline with input `x`. Divide the batch into 64 micro-batches - # and run them in parallel on the pipeline - output = driver(x) +# Pipeline stage is our main pipeline runtime. It takes in the pipe object, +# the rank of this process, and the device. +stage = PipelineStage(pipe, rank, device) + +# Input data +x = torch.randn(batch_size, in_dim, device=device) +# Run the pipeline with input `x`. Divide the batch into 4 micro-batches +# and run them in parallel on the pipeline +if rank == 0: + stage(x) +elif rank == world_size - 1: + output = stage() +else: + stage() + +if rank == world_size - 1: # Run the original code and get the output for comparison reference_output = mn(x) - # Compare numerics of pipeline and original model torch.testing.assert_close(output, reference_output) - print(" Pipeline parallel model ran successfully! ".center(80, "*")) - - -rpc.shutdown() From c81a65da9fb3c3bc8ff5f621159f1b2e7204a018 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 19 Dec 2023 11:46:51 -0500 Subject: [PATCH 83/96] Remove PipelineDriver from README (#919) Partially done --- README.md | 259 ++++++++++++------------------------------------------ 1 file changed, 57 insertions(+), 202 deletions(-) diff --git a/README.md b/README.md index e212dbe97..c6628952d 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ # Why PiPPy? -One of the most important techniques for advancing the state of the art in deep learning is scaling. Common techniques for scaling neural networks include _data parallelism_, _tensor/model parallelism_, and _pipeline parallelism_. In many cases, pipeline parallelism in particular can be an effective technique for scaling, however it is often difficult to implement, requiring intrusive code changes to model code and difficult-to-implement runtime orchestration code. PiPPy aims to provide a toolkit that does said things automatically to allow high-productivity scaling of models. +One of the most important techniques for advancing the state of the art in deep learning is scaling. Common techniques for scaling neural networks include _data parallelism_, _tensor/operation parallelism_, and _pipeline parallelism_. In many cases, pipeline parallelism in particular can be an effective technique for scaling, however it is often difficult to implement, requiring intrusive code changes to model code and difficult-to-implement runtime orchestration code. PiPPy aims to provide a toolkit that does said things automatically to allow high-productivity scaling of models. # What is PiPPy? @@ -21,17 +21,17 @@ The PiPPy project consists of a compiler and runtime stack for automated paralle PiPPy provides the following features that make pipeline parallelism easier: -* Automatic splitting of model code via `torch.fx`. The goal is for the user to provide model code as-is to the system for parallelization, without having to make heavyweight modifications to make parallelism work. +* Automatic splitting of model code via PyTorch tracer. The goal is for the user to provide model code as-is to the system for parallelization, without having to make heavyweight modifications to make parallelism work. * Related to the last point, PiPPy supports non-trivial topologies, including skip connections and tied weights/layers. PiPPy provides configurable behavior for tied weights, allowing for transmission across pipeline stages or replication and gradient synchronization. * First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects). This is currently missing from the torchgpipe-based `torch.distributed.pipeline.sync.Pipe`. * Composability with other parallelism schemes such as data parallelism or tensor splitting model parallelism (overall, known as "3d parallelism"). Currently, pipelining and data parallelism can be composed. Other compositions will be available in the future. -* Support for pipeline scheduling paradigms, including static schedules like fill-drain (GPipe), 1f1b, interleaved 1f1b and dynamic schedules like lookahead or registers/back-pressure. +* Support for pipeline scheduling paradigms, including schedules like fill-drain (GPipe), 1F1B and interleaved 1F1B. More schedules will be added too. For in-depth technical architecture, see [ARCHITECTURE.md](ARCHITECTURE.md). # Install -PiPPy requires PyTorch version newer than 1.12 to work. To quickly install, for example, PyTorch nightly, run the following command from the same directory as this README: +PiPPy requires PyTorch version newer than 2.2.0.dev to work. To quickly install, for example, PyTorch nightly, run the following command from the same directory as this README: ``` pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html @@ -40,7 +40,7 @@ pip install -r requirements.txt --find-links https://download.pytorch.org/whl/ni You can also select the CUDA build of PyTorch if your system has NVIDIA GPUs, for example: ``` -pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu116/torch_nightly.html +pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu118/torch_nightly.html ``` To install PiPPy from source, run the following command in the same directory as this README: @@ -57,7 +57,7 @@ python setup.py develop # PiPPy Quickstart -PiPPy consists of two parts: a _compiler_ and a _runtime_. The compiler takes your model code, splits it up, and transforms it into a `Pipe`, which is a wrapper that describes how to execute the model in pipeline parallelism. The runtime executes the `Pipe` in parallel, handling things like micro-batch splitting and gradient propagation/syncing. We will cover the APIs for these concepts in this section. +PiPPy consists of two parts: a _compiler_ and a _runtime_. The compiler takes your model code, splits it up, and transforms it into a `Pipe`, which is a wrapper that describes the model at each pipeline stage and their data-flow relationship. The runtime executes the `Pipe` in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. We will cover the APIs for these concepts in this section. ## Splitting a Model with Pipe @@ -97,7 +97,9 @@ class MyNetwork(torch.nn.Module): return self.output_proj(x) -mn = MyNetwork(512, [512, 1024, 256]) +in_dim = 512 +layer_dims = [512, 1024, 256] +mn = MyNetwork(in_dim, layer_dims).to(device) ``` This network is written as free-form Python code; it has not been modified for any specific parallelism technique. @@ -105,241 +107,94 @@ This network is written as free-form Python code; it has not been modified for a Let us see our first usage of the `pippy.IR.Pipe` interface: ```python -from pippy.IR import annotate_split_points, PipeSplitWrapper +from pippy.IR import annotate_split_points, Pipe, PipeSplitWrapper annotate_split_points(mn, {'layer0': PipeSplitWrapper.SplitPoint.END, 'layer1': PipeSplitWrapper.SplitPoint.END}) -pipe = Pipe.from_tracing(mn) +batch_size = 32 +example_input = torch.randn(batch_size, in_dim, device=device) +chunks = 4 + +pipe = Pipe.from_tracing(mn, chunks, example_args=(example_input,)) print(pipe) """ ************************************* pipe ************************************* GraphModule( - (submod_0): GraphModule( - (layer0_mod_lin): Linear(in_features=512, out_features=512, bias=True) + (submod_0): PipeStageModule( + (L__self___layer0_mod_lin): Linear(in_features=512, out_features=512, bias=True) ) - (submod_1): GraphModule( - (layer1_mod_lin): Linear(in_features=512, out_features=1024, bias=True) + (submod_1): PipeStageModule( + (L__self___layer1_mod_lin): Linear(in_features=512, out_features=1024, bias=True) ) - (submod_2): GraphModule( - (layer2_lin): Linear(in_features=1024, out_features=256, bias=True) - (output_proj): Linear(in_features=256, out_features=10, bias=True) + (submod_2): PipeStageModule( + (L__self___layer2_lin): Linear(in_features=1024, out_features=256, bias=True) + (L__self___output_proj): Linear(in_features=256, out_features=10, bias=True) ) ) -def forward(self, x): - submod_0 = self.submod_0(x); x = None +def forward(self, arg0): + submod_0 = self.submod_0(arg0); arg0 = None submod_1 = self.submod_1(submod_0); submod_0 = None submod_2 = self.submod_2(submod_1); submod_1 = None - return submod_2 -""" - -print(pipe.split_gm.submod_0) - -""" -*********************************** submod0 ************************************ -GraphModule( - (layer0_mod_lin): Linear(in_features=512, out_features=512, bias=True) -) - -def forward(self, x): - layer0_mod_lin = self.layer0_mod_lin(x); x = None - relu = torch.relu(layer0_mod_lin); layer0_mod_lin = None - return relu -""" - -print(pipe.split_gm.submod_1) - -""" -*********************************** submod1 ************************************ -GraphModule( - (layer1_mod_lin): Linear(in_features=512, out_features=1024, bias=True) -) - -def forward(self, relu): - layer1_mod_lin = self.layer1_mod_lin(relu); relu = None - relu_1 = torch.relu(layer1_mod_lin); layer1_mod_lin = None - return relu_1 -""" - -print(pipe.split_gm.submod_2) - -""" -*********************************** submod2 ************************************ -GraphModule( - (layer2_lin): Linear(in_features=1024, out_features=256, bias=True) - (output_proj): Linear(in_features=256, out_features=10, bias=True) -) - -def forward(self, relu_1): - layer2_lin = self.layer2_lin(relu_1); relu_1 = None - relu = torch.relu(layer2_lin); layer2_lin = None - output_proj = self.output_proj(relu); relu = None - return output_proj + return [submod_2] """ ``` -So what's going on here? First, `Pipe.from_tracing` uses `torch.fx` symbolic tracing to turn our model into a directed acyclic graph (DAG) representation. Then, it groups together the operations and parameters into _pipeline stages_. Stages are represented as `submod_N` submodules, where `N` is a natural number. +So what's going on here? First, `Pipe.from_tracing` uses a PyTorch tracer to turn our model into a directed acyclic graph (DAG) representation. Then, it groups together the operations and parameters into _pipeline stages_. Stages are represented as `submod_N` submodules, where `N` is a natural number. -Our code has now been split into _three_ pipeline stages. We used `annotate_split_points` to specify that the code should be split and the end of `layer0` and `layer1`. - -In addition to custom splitting policy, PiPPy also provides automatic splitting policies. For example: -* `split_on_size_threshold(numel)`: create a new pipeline stage upon reaching a given number of parameters; -* `split_into_equal_size(num_stages)`: split the model in to specified number of equal-size stages. - -We can pass the splitting policy to the `from_tracing` API: - -```python -from pippy import split_into_equal_size +We used `annotate_split_points` to specify that the code should be split and the end of `layer0` and `layer1`. Our code has thus been split into _three_ pipeline stages. PiPPy also provides `SplitPoint.BEGINNING` if a user wants to split before certain annotation point. -split_policy = split_into_equal_size(world_size) - -pipe = Pipe.from_tracing(mn, split_policy=split_policy) -``` +While the `annotate_split_points` API gives users a way to specify the split points without modifying the model, PiPPy also provides an API for in-model annotation: `pipe_split()`. For details, you can read [this example](https://github.com/pytorch/PiPPy/blob/main/test/test_pipe.py). This covers the basic usage of the `Pipe` API. For more information, see the documentation. -## Using PipelineDriver for Pipelined Execution +## Using PipelineStage for Pipelined Execution -Given the above `Pipe` object, we can use one of the `PipelineDriver` classes to execute our model in a pipelined fashion. First off, let us instantiate a `PipelineDriverFillDrain` instance: +Given the above `Pipe` object, we can use one of the `PipelineStage` classes to execute our model in a pipelined fashion. First off, let us instantiate a `PipelineStage` instance: ```python -# To run a distributed training job, we must launch the script in multiple -# different processes. We are using `torchrun` to do so in this example. -# `torchrun` defines two environment variables: `LOCAL_RANK` and `WORLD_SIZE`, -# which represent the index of this process within the set of processes and -# the total number of processes, respectively. -# -# To learn more about `torchrun`, see -# https://pytorch.org/docs/stable/elastic/run.html -import os -local_rank = int(os.environ["LOCAL_RANK"]) -world_size = int(os.environ['WORLD_SIZE']) - -# PiPPy uses the PyTorch RPC interface. To use RPC, we must call `init_rpc` -# and inform the RPC framework of this process's rank and the total world -# size. We can directly pass values `torchrun` provided.` -# -# To learn more about the PyTorch RPC framework, see -# https://pytorch.org/docs/stable/rpc.html -import torch.distributed.rpc as rpc -rpc.init_rpc(f'worker{local_rank}', rank=local_rank, world_size=world_size) - -# PiPPy relies on the concept of a "driver" process. The driver process -# should be a single process within the RPC group that instantiates the -# PipelineDriver and issues commands on that object. The other processes -# in the RPC group will receive commands from this process and execute -# the pipeline stages -if local_rank == 0: - # We are going to use the PipelineDriverFillDrain class. This class - # provides an interface for executing the `Pipe` in a style similar - # to the GPipe fill-drain schedule. To learn more about GPipe and - # the fill-drain schedule, see https://arxiv.org/abs/1811.06965 - from pippy.PipelineDriver import PipelineDriverFillDrain - from pippy.microbatch import TensorChunkSpec - - # Pipelining relies on _micro-batching_--that is--the process of - # dividing the program's input data into smaller chunks and - # feeding those chunks through the pipeline sequentially. Doing - # this requires that the data and operations be _separable_, i.e. - # there should be at least one dimension along which data can be - # split such that the program does not have interactions across - # this dimension. PiPPy provides `chunk_spec` arguments for this - # purpose, to specify the batch dimension for tensors in each of - # the args, kwargs, and outputs. The structure of the `chunk_spec`s - # should mirror that of the data type. Here, the program has a - # single tensor input and single tensor output, so we specify - # a single `TensorChunkSpec` instance indicating dimension 0 - # for args[0] and the output value. - args_chunk_spec = (TensorChunkSpec(0),) - kwargs_chunk_spec = {} - output_chunk_spec = TensorChunkSpec(0) - - # Finally, we instantiate the PipelineDriver. We pass in the pipe, - # chunk specs, and world size, and the constructor will distribute - # our code to the processes in the RPC group. `driver` is an object - # we can invoke to run the pipeline. - driver = PipelineDriverFillDrain( - pipe, args_chunk_spec=args_chunk_spec, kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, world_size=world_size) - - # - -rpc.shutdown() +# We are using `torchrun` to run this example with multiple processes. +# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) + +# Initialize distributed environment +import torch.distributed as dist +dist.init_process_group(rank=rank, world_size=world_size) + +# Pipeline stage is our main pipeline runtime. It takes in the pipe object, +# the rank of this process, and the device. +from pippy.PipelineStage import PipelineStage +stage = PipelineStage(pipe, rank, device) ``` -Note that our script must now be replicated across multiple workers. For this example, we will use `torchrun` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named [example.py](example.py) and then run it with `torchrun` like so: - -``` -torchrun --nproc_per_node=3 example.py -``` - -Note that we have launched 3 processes, as we have 3 pipeline stages. - -We can now run the pipeline by passing input to the `PipelineDriver` because `PipelineDriver` is also a `nn.module`: +We can now run the pipeline by passing input to the first `PipelineStage`: ```python - # Instantiate a random input for testing purposes. - x = torch.randn(512, 512) - - # Run the pipeline with input `x`. Divide the batch into 64 micro-batches - # and run them in parallel on the pipeline - driver.chunks = 64 - output = driver(x) - - # Run the original code and get the output for comparison - reference_output = mn(x) - - # Compare numerics of pipeline and original model - torch.testing.assert_close(output, reference_output) - - print(' Pipeline parallel model ran successfully! '.center(80, '*')) -``` - -We can see that we can now execute our model in a pipelined fashion and get the same numeric outputs. - -## pippy.compile and pippy.all_compile - -Most users do not need to use the `pipe` object generated by `Pipe.from_tracing`. For convenience, PiPPy provides a `compile` API that directly generates a `PipelineDriver` from user's model. - -```python -import pippy +# Input data +x = torch.randn(batch_size, in_dim, device=device) +# Run the pipeline with input `x`. Divide the batch into 4 micro-batches +# and run them in parallel on the pipeline if rank == 0: - # Create pipeline driver - driver = pippy.compile( - mn, - num_ranks=world_size, - num_chunks=world_size, - schedule="FillDrain", - split_policy=split_poicy, - ) - - output = driver(x) + stage(x) +elif rank == world_size - 1: + output = stage() +else: + stage() ``` -All examples above assume that the driver process has enough memory to materialize the model (before splitting). In case that the model is so large that the driver process cannot materialize it on a single device, it would be necessary to first split the model and then let each process materialize its pipeline stage on its own device. `pippy.all_compile` provides such functionality. Different from `pippy.compile`, `pippy.all_compile` requires all ranks to call into it so that they all know which part of the model they should materialize. For example: - -```python -import pippy - -# All ranks call into it -driver, stage_mod = pippy.all_compile( - mn, - num_ranks=world_size, - num_chunks=world_size, - schedule="FillDrain", - split_policy=split_poicy, -) +Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use `torchrun` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named [example.py](example.py) and then run it with `torchrun` like so: -if rank == 0: - output = driver(x) +``` +torchrun --nproc_per_node=3 example.py ``` -Only rank 0 will have the pipeline driver returned, but all ranks will be returned a handle to their local stage module (`stage_mod`). +## Note: the following sections need to be updated. ## ## Forward vs. Forward-loss-backward From ea7d1d65f28b508ee9f5db65603d88a83608e46c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 20 Dec 2023 14:10:02 -0500 Subject: [PATCH 84/96] Simplify llama example (#921) Add README. Remove unnecessary logics. --- examples/llama/README.md | 17 ++++ examples/llama/pippy_llama.py | 154 ++++++++++------------------------ pippy/__init__.py | 2 + 3 files changed, 65 insertions(+), 108 deletions(-) create mode 100644 examples/llama/README.md diff --git a/examples/llama/README.md b/examples/llama/README.md new file mode 100644 index 000000000..e7346f052 --- /dev/null +++ b/examples/llama/README.md @@ -0,0 +1,17 @@ +``` +$ torchrun --nproc-per-node 2 pippy_llama.py +``` +``` +$ torchrun --nproc-per-node 4 pippy_llama.py +``` +``` +$ torchrun --nproc-per-node 8 pippy_llama.py +``` +``` +prompts = ( + "How do you", "I like to", "Can I help", "You need to", + "The weather is", "I found a", "What is your", "You are so", +) +Outputs: +['make', 'think', 'you', 'be', 'getting', 'great', 'favorite', 'right'] +``` diff --git a/examples/llama/pippy_llama.py b/examples/llama/pippy_llama.py index 33dfc43c3..5ac314ca7 100644 --- a/examples/llama/pippy_llama.py +++ b/examples/llama/pippy_llama.py @@ -1,111 +1,49 @@ -# Minimum effort to run this example: -# $ pip install transformers -# $ torchrun --nproc-per-node 2 pippy_llama.py - -import argparse +# $ torchrun --nproc-per-node 4 pippy_llama.py import os - import torch -import torch.distributed as dist - from transformers import AutoModelForCausalLM, AutoTokenizer - -from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points -from pippy.PipelineStage import PipelineStage - - -def add_split_points(llama, nranks): - # Cut model by equal number of layers per rank - layers_per_rank = (llama.config.num_hidden_layers + nranks - 1) // nranks - print(f"layers_per_rank = {layers_per_rank}") - for i in range(1, nranks): - annotate_split_points( - llama, - {f'model.layers.{i * layers_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}, - ) - - -def get_number_of_params(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def run(args): - # Create a blank model - llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True) - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - - prompts = ( - "How do you", "I like to", "Can I help", "You have to", - "The weather is", "I have a", "What is your", "You are a", - ) # bs = 8 - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(args.device) - - # Move model to `device` and set to evaluation - llama.to(args.device) - llama.eval() - print(llama) - - # Annotate split points - add_split_points(llama, args.world_size) - - # Create a pipeline stage from the model - llama_pipe = Pipe.from_tracing( - llama, - num_chunks=args.world_size, - example_args=(inputs['input_ids'],), - ) - - assert len(list(llama_pipe.split_gm.children())) == args.world_size - if args.rank == 0: - for i, sm in enumerate(llama_pipe.split_gm.children()): - print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") - - # Create schedule runtime - stage = PipelineStage( - llama_pipe, - args.rank, - device=args.device, - ) - - # Run - output = None - if args.rank == 0: - stage(inputs['input_ids']) - elif args.rank == args.world_size - 1: - output = stage() - else: - stage() - - if output is not None: - next_token_logits = output[0][:, -1, :] - next_token = torch.argmax(next_token_logits, dim=-1) - print(tokenizer.batch_decode(next_token)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) - parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) - parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) - parser.add_argument('--schedule', type=str, default="FillDrain") - parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) - - args = parser.parse_args() - - if args.cuda: - dev_id = args.rank % torch.cuda.device_count() - args.device = torch.device(f"cuda:{dev_id}") - else: - args.device = torch.device("cpu") - - # Init process group - backend = "nccl" if args.cuda else "gloo" - dist.init_process_group( - backend=backend, - rank=args.rank, - world_size=args.world_size, - ) - - run(args) +from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage + +# Grab the model +llama = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True +) +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + +prompts = ( + "How do you", "I like to", "Can I help", "You need to", + "The weather is", "I found a", "What is your", "You are so", +) # bs = 8 +tokenizer.pad_token = tokenizer.eos_token + +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") +llama.to(device).eval() +inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) + +# Cut model by equal number of layers per rank +layers_per_rank = llama.config.num_hidden_layers // world_size +for i in range(1, world_size): + annotate_split_points(llama, + {f"model.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + +# Create a pipeline representation from the model +llama_pipe = Pipe.from_tracing(llama, world_size, example_args=(inputs["input_ids"],)) + +# Create pipeline stage for each rank +torch.distributed.init_process_group(rank=rank, world_size=world_size) +stage = PipelineStage(llama_pipe, rank, device=device) + +# Run +if rank == 0: + args = inputs["input_ids"] +else: + args = None +output = stage(args) + +# Decode +if output is not None: + next_token_logits = output[0][:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1) + print(tokenizer.batch_decode(next_token)) diff --git a/pippy/__init__.py b/pippy/__init__.py index 75de5f7bd..a68c01a3e 100644 --- a/pippy/__init__.py +++ b/pippy/__init__.py @@ -9,6 +9,7 @@ TrivialLossWrapper, ) from pippy.ModelSplit import split_into_equal_size, split_on_size_threshold +from pippy.PipelineStage import PipelineStage __all__ = [ @@ -16,6 +17,7 @@ "LossWrapper", "TrivialLossWrapper", "Pipe", + "PipelineStage", "pipe_split", "PipeSplitWrapper", "annotate_split_points", From a4cc35fd5b11bac50b21d25f063db5dd09f224eb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 28 Dec 2023 11:09:20 -0500 Subject: [PATCH 85/96] Support device dispatching during stage creation (#923) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR adds support to a case where the user creates model and trace model on CPU, then creates pipeline stage on GPU. PiPPy would move only the stage module to the corresponding GPU. ## Test ``` torchrun --nproc-per-node 2 test_cpu_init.py ``` ## Update: Sometimes, the `forward` function of user code may create constant tensors based on input device: ``` device = input_ids.device attention_mask = torch.ones(…, device=device) ``` As of now, PT2 tracer does not treat `input_ids.device` as a symbolic device. As a result, `device="cpu"` got burned in the generated code: ``` ones = torch.ones(…, device = device(type='cpu')) ``` To workaround this, this PR added call in `PipelineStage` creation: ``` def _move_ops_to_device(new_device) ``` After this call, the `device=` kwarg of `torch.ones` will be modified to the `new_device`. This call is hidden from user, thus when symbolic device support is added, we can silently remove this and not involve user code change. We also checked native_functions.yaml, all APIs involving the "device" kwarg are generator ops, which are safe to change the device value. (And we should). ## Real Example ``` cd examples/cpu_init torchrun --nproc-per-node 4 bert_cpu_init.py ``` Cc: @muellerzr @SunMarc --- examples/cpu_init/bert_cpu_init.py | 116 +++++++++++++++++++++++++ pippy/IR.py | 11 ++- pippy/PipelineStage.py | 29 ++++++- pippy/utils.py | 23 +++++ test/test_cpu_init.py | 135 +++++++++++++++++++++++++++++ 5 files changed, 310 insertions(+), 4 deletions(-) create mode 100644 examples/cpu_init/bert_cpu_init.py create mode 100644 test/test_cpu_init.py diff --git a/examples/cpu_init/bert_cpu_init.py b/examples/cpu_init/bert_cpu_init.py new file mode 100644 index 000000000..bb1ecf5de --- /dev/null +++ b/examples/cpu_init/bert_cpu_init.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# Minimum effort to run this example: +# $ torchrun --nproc-per-node 4 bert_cpu_init.py + +import argparse +import os + +import torch +import torch.distributed as dist + +from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points +from pippy.PipelineStage import PipelineStage + +from transformers import BertForMaskedLM, BertConfig + + +def add_split_points(bert, nranks): + layers_per_rank = bert.config.num_hidden_layers // nranks + for i in range(1, nranks): + annotate_split_points( + bert, {f"bert.encoder.layer.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) + + +def run(args): + # Model configs + config = BertConfig() + + # Create model on CPU + model_class = BertForMaskedLM + model_name = "BertForMaskedLM" + bert = model_class(config) + bert.eval() + if args.rank == 0: + print(bert.config) + print(bert) + + # Example input on CPU + example_input = torch.randint( + low=0, + high=config.vocab_size, + size=(args.batch_size, 512), # bs x seq_len + device="cpu", + dtype=torch.int64, + requires_grad=False, + ) + + # Annotate split points + add_split_points(bert, args.world_size) + + # Create pipeline + bert_pipe = Pipe.from_tracing( + bert, + num_chunks=args.chunks, + example_args=(example_input,), + ) + nstages = len(list(bert_pipe.split_gm.children())) + assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" + + # Create schedule runtime + stage = PipelineStage( + bert_pipe, + args.rank, + device=args.device, + ) + + # Real input on GPU + real_input = torch.randint( + low=0, + high=config.vocab_size, + size=(args.batch_size, 512), # bs x seq_len + device=args.device, + dtype=torch.int64, + requires_grad=False, + ) + + # Run + if args.rank == 0: + stage(real_input) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + print(f"Rank {args.rank} completes") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) + parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) + parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) + parser.add_argument('--schedule', type=str, default="FillDrain") + parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) + parser.add_argument("--chunks", type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--batches', type=int, default=1) + + args = parser.parse_args() + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run(args) diff --git a/pippy/IR.py b/pippy/IR.py index 3b41f3526..7161f4fb8 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -611,9 +611,14 @@ def __init__( new_key = k[len(mod_prefix) :] mod_qualname_mapping.setdefault(new_key, v) # Add a remap mixin to submodule instance - mod.__class__ = type( - "PipeStageModule", (QualnameMapMixin, mod.__class__), {} - ) + # TODO: this class change is commented out because it breaks + # recompilation if we want to recompile mod after. For example, we + # may recompile mod to modify the "device" kwarg of a `torch.ones` + # node (trace on cpu/meta, run on cuda). + # See: https://github.com/pytorch/vision/issues/5826 + # mod.__class__ = type( + # "PipeStageModule", (QualnameMapMixin, mod.__class__), {} + # ) setattr(mod, "new_to_old_qualname_mapping", mod_qualname_mapping) def throw(self, *args, **kwargs): diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index a3b1903e9..fe39a15d6 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -13,7 +13,7 @@ from pippy.debug import map_debug_info from pippy.IR import Pipe from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks -from pippy.utils import flatten_args +from pippy.utils import flatten_args, modify_graph_op_device logger = logging.getLogger(__name__) @@ -135,6 +135,33 @@ def __init__( # Prepare send/recv infrastructure self._prepare_send_recv_infra() + # Cast submodule to device + self._move_submod_to_device() + # Move ops argument to device + self._move_ops_to_device() + + def _move_submod_to_device(self): + # Move submodule to indicated device if possible + # Note: we cannot move meta module to real devices because meta tensors + # do not support to() method. One needs to do an in-place tensor swap in + # that case. + has_meta_param = any( + isinstance(p, FakeTensor) or p.is_meta + for p in self.submod.parameters() + ) + if has_meta_param: + logger.debug(f"[{self.group_rank}] Found meta parameters!") + else: + logger.debug(f"[{self.group_rank}] No meta parameters found!") + self.submod.to(self.device) + + def _move_ops_to_device(self): + # Today PT2 tracer does not treat `x.device` as a symbolic device; + # instead, the device of tracing time got burned into the generated + # code. Here we provide a workaround for users to manually modify the + # "device" kwarg of operations. Such operation may include: + # `torch.ones`, `torch.zeros`, `torch.rand`, etc. + modify_graph_op_device(self.submod, self.device) def is_first(self): return self.stage_index == 0 diff --git a/pippy/utils.py b/pippy/utils.py index 69a9b7ae1..b0417248c 100644 --- a/pippy/utils.py +++ b/pippy/utils.py @@ -1,9 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +import logging + import torch import torch.distributed as dist from torch import fx +logger = logging.getLogger(__name__) + + def flatten_args_detach(args): flat_detached_args = [] @@ -69,3 +74,21 @@ def _get_binary_filename(cur_idx: int, is_optim: bool = False) -> str: # type: state_type = "optim" if is_optim else "model" return f"pytorch_{state_type}-{idx}-of-{world_size}.bin" + + +def modify_graph_op_device( + gm: torch.fx.GraphModule, + new_device: torch.device, +): + modified = False + for node in gm.graph.nodes: + if node.op == "call_function": + if "device" in node.kwargs and node.kwargs["device"] != new_device: + logger.debug( + f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" + ) + node.update_kwarg("device", new_device) + modified = True + + if modified: + gm.recompile() diff --git a/test/test_cpu_init.py b/test/test_cpu_init.py new file mode 100644 index 000000000..5b15338d2 --- /dev/null +++ b/test/test_cpu_init.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import os +import unittest + +import pippy + +import torch +import torch.distributed as dist +from pippy.IR import Pipe, pipe_split +from pippy.PipelineStage import PipelineStage + + +pippy.microbatch._debug_mask_minibatches = True + +d_hid = 512 +batch_size = 256 + +torch.manual_seed(0) + + +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + # Test change of tensor creation device after tracing + a = torch.ones(batch_size, d_hid, device=x.device) + x = x + a + x = torch.mm(x, self.mm_param) + x = torch.relu(x) + pipe_split() + x = self.lin(x) + x = torch.relu(x) + return x + + +def run_worker(args): + # Create module and trace model in CPU + mod = ExampleCode() + + xe = torch.randn(batch_size, d_hid) + + pipe = Pipe.from_tracing( + mod, + args.chunks, + example_args=(xe,), + ) + + # Create pipeline stages and move stage to GPU + stage = PipelineStage( + pipe, + args.rank, + device=args.device, + ) + + # Create real input on real device + x = torch.randn(batch_size, d_hid, device=args.device) + + # Run + if args.rank == 0: + stage(x) + elif args.rank == args.world_size - 1: + out = stage() + else: + stage() + + dist.barrier() + print(f"Rank {args.rank} completes") + + # Last rank checks result + if args.rank == args.world_size - 1: + mod.to(args.device) + ref_out = mod(x) + torch.testing.assert_close(out, ref_out) + print( + f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" + ) + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 2)) + ) + parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) + parser.add_argument( + "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") + ) + parser.add_argument( + "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") + ) + parser.add_argument( + "--cuda", type=int, default=int(torch.cuda.is_available()) + ) + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + if args.cuda: + dev_id = args.rank % torch.cuda.device_count() + args.device = torch.device(f"cuda:{dev_id}") + else: + args.device = torch.device("cpu") + + # Init process group + backend = "nccl" if args.cuda else "gloo" + dist.init_process_group( + backend=backend, + rank=args.rank, + world_size=args.world_size, + ) + + run_worker(args) + + +if __name__ == "__main__": + main() + + +class TestFwd(unittest.TestCase): + def test_fwd(self): + import random + + port = random.randint(29500, 30000) + args = [ + "--master_port", + str(port), + ] + main(args) From 02903a4305c260aa83d6ae57244e15cbb827d600 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 2 Jan 2024 16:03:37 -0500 Subject: [PATCH 86/96] [copy] update import for _export_to_torch_ir (#926) ## Description This is a copy of #924 with an addition to support both new and old import paths. --------- Co-authored-by: lessw2020 --- pippy/IR.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pippy/IR.py b/pippy/IR.py index 7161f4fb8..5a4a4757c 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -12,6 +12,19 @@ from torch.fx.interpreter import Interpreter from torch.fx.passes.split_module import split_module +try: + # New import path + from torch.export._trace import _export_to_torch_ir +except ImportError: + try: + # Old import path + from torch._export import _export_to_torch_ir + except ImportError: + print( + "Could not import _export_to_torch_ir. Please make sure your PyTorch " + "version is newer than 2.2.0." + ) + from pippy.backward import _null_coalesce_accumulate, stage_backward from pippy.debug import PIPPY_VERBOSITY from pippy.microbatch import LossReducer, split_args_kwargs_into_chunks @@ -1052,7 +1065,7 @@ def _trace_with_export( logger.info("Tracing model ...") try: torch._dynamo.allow_in_graph(pipe_split) - traced: torch.fx.GraphModule = torch._export._export_to_torch_ir( + traced: torch.fx.GraphModule = _export_to_torch_ir( mod, example_args, example_kwargs, From 169892c43297f2ec7f89d36f79c94ef555f1703a Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 2 Jan 2024 16:10:26 -0500 Subject: [PATCH 87/96] Remove cap on torch version (#927) ## Description No more cap needed after #926 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 256160137..4d73a1d52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch >= 2.2.0.dev, <2.2.0.dev20231212 +torch >= 2.2.0.dev packaging >= 21.3 From bb90773e3b3d4e9737c1aeedb010233aca86573f Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 2 Jan 2024 16:11:24 -0500 Subject: [PATCH 88/96] Add example for unrolling iterative blocks (#920) ## Description Demonstrate PiPPy's functionality in unrolling iterative blocks. For details, please see [README](https://github.com/pytorch/PiPPy/tree/unroll_example/examples/unrolling). Many thanks to @mortzur 's inspiration! --- examples/unrolling/README.md | 93 ++++++++++++++++++++++++++++ examples/unrolling/pippy_unroll.py | 98 ++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 examples/unrolling/README.md create mode 100644 examples/unrolling/pippy_unroll.py diff --git a/examples/unrolling/README.md b/examples/unrolling/README.md new file mode 100644 index 000000000..17565f53d --- /dev/null +++ b/examples/unrolling/README.md @@ -0,0 +1,93 @@ +## What does this example do? + +This is a synthetic example used to demonstrate PiPPy's functionality in unrolling iterative blocks in a model. + +We create a model that runs an iteration block in a for loop: +```python +class IterationBlock(torch.nn.Module): + def __init__(self, d_hid): + super().__init__() + self.lin = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + x = self.lin(x) + x = torch.relu(x) + return x + + +class IterativeNetwork(torch.nn.Module): + def __init__(self, d_hid, num_iters): + super().__init__() + self.num_iters = num_iters + self.iter_block = IterationBlock(d_hid) + # 10 output classes + self.output_proj = torch.nn.Linear(d_hid, 10) + + def forward(self, x): + for i in range(self.num_iters): + x = self.iter_block(x) + return self.output_proj(x) +``` + +If we annotate the model as follows, we will create a pipeline stage per +iteration block: + +```python +# Add a split point after each iter_block +annotate_split_points( + model, + {"iter_block": PipeSplitWrapper.SplitPoint.END}, +) +``` + +That is, PiPPy would create a split point every time it sees "self.iter_block". + +Run it with 4 ranks: +``` +$ torchrun --nproc-per-node 4 pippy_unroll.py +``` + +Print-out of the pipe: +``` +************************************* pipe ************************************* +GraphModule( + (submod_0): PipeStageModule( + (L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) + ) + (submod_1): PipeStageModule( + (L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) + ) + (submod_2): PipeStageModule( + (L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) + ) + (submod_3): PipeStageModule( + (L__self___output_proj): Linear(in_features=512, out_features=10, bias=True) + ) +) + +def forward(self, arg0): + submod_0 = self.submod_0(arg0); arg0 = None + submod_1 = self.submod_1(submod_0); submod_0 = None + submod_2 = self.submod_2(submod_1); submod_1 = None + submod_3 = self.submod_3(submod_2); submod_2 = None + return [submod_3] +``` +We can see 4 stages as expected (3 iterations plus 1 output projection). + +If we print one of the stages, we can see that it contains the code of one iteration: +``` +*********************************** submod0 ************************************ +PipeStageModule( + (L__self___iter_block_mod_lin): Linear(in_features=512, out_features=512, bias=True) +) + +def forward(self, l_x_): + l__self___iter_block_mod_lin = self.L__self___iter_block_mod_lin(l_x_); l_x_ = None + relu = torch.relu(l__self___iter_block_mod_lin); l__self___iter_block_mod_lin = None + return relu +``` + +## How can this functionality help? +Increase throughput of your model. + +Imagine your for loop needs to iterate on the data for `n` times, and it takes time `t` to process 1 sample (yielding a throughput of `1/t`). If we were to unroll the for loop onto `n` devices, then we can push `n` microbatches into the pipeline, each microbatch containing 1 sample. Then at any timeslot, the pipeline is processing `n` samples, yielding a throughput of `n/t`. diff --git a/examples/unrolling/pippy_unroll.py b/examples/unrolling/pippy_unroll.py new file mode 100644 index 000000000..c4d236a69 --- /dev/null +++ b/examples/unrolling/pippy_unroll.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Minimal effort to run this code: +# $ torchrun --nproc-per-node 4 pippy_unroll.py + +import os +import torch +import torch.distributed as dist + +from pippy.IR import annotate_split_points, Pipe, PipeSplitWrapper +from pippy.PipelineStage import PipelineStage + + +class IterationBlock(torch.nn.Module): + def __init__(self, d_hid): + super().__init__() + self.lin = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + x = self.lin(x) + x = torch.relu(x) + return x + + +class IterativeNetwork(torch.nn.Module): + def __init__(self, d_hid, num_iters): + super().__init__() + self.num_iters = num_iters + self.iter_block = IterationBlock(d_hid) + # 10 output classes + self.output_proj = torch.nn.Linear(d_hid, 10) + + def forward(self, x): + for i in range(self.num_iters): + x = self.iter_block(x) + return self.output_proj(x) + + +# We are using `torchrun` to run this example with multiple processes. +# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. +torch.manual_seed(0) +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) + +# Figure out device to use +if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") +else: + device = torch.device("cpu") + +# Create the model +d_hid = 512 +# (n-1) iterations + 1 output projection +num_iters = world_size - 1 +model = IterativeNetwork(d_hid, num_iters).to(device) + +# Add a split point after each iter_block +annotate_split_points( + model, + {"iter_block": PipeSplitWrapper.SplitPoint.END}, +) + +batch_size = 32 +example_input = torch.randn(batch_size, d_hid, device=device) +chunks = world_size + +pipe = Pipe.from_tracing(model, chunks, example_args=(example_input,)) + +if rank == 0: + print(" pipe ".center(80, "*")) + print(pipe) + print(" submod0 ".center(80, "*")) + print(pipe.split_gm.submod_0) + +# Initialize distributed environment +dist.init_process_group(rank=rank, world_size=world_size) + +# Pipeline stage is our main pipeline runtime. It takes in the pipe object, +# the rank of this process, and the device. +stage = PipelineStage(pipe, rank, device) + +# Input data +x = torch.randn(batch_size, d_hid, device=device) + +# Run the pipeline with input `x`. Divide the batch into n micro-batches +# and run them in parallel on the pipeline +if rank == 0: + stage(x) +elif rank == world_size - 1: + output = stage() +else: + stage() + +if rank == world_size - 1: + # Run the original code and get the output for comparison + reference_output = model(x) + # Compare numerics of pipeline and original model + torch.testing.assert_close(output, reference_output) + print(" Pipeline parallel model ran successfully! ".center(80, "*")) From e9e2d5f0164a2e5d952a1424a3926da543365801 Mon Sep 17 00:00:00 2001 From: Mao <62143443+Mao-Siang@users.noreply.github.com> Date: Wed, 3 Jan 2024 05:20:24 +0800 Subject: [PATCH 89/96] Correct spelling errors in README.md for inference examples (#925) ## Description I am new to the PiPPy community and this is my first pr, so feel free to correct me if I am wrong. I have corrected spelling errors and markdown format in the `README.md` file for inference examples. For instance, the `threshold` option for `split_policy`. I hope that I understand it right. ## Type of change Documentation ## Feature/Issue validation/testing None --- examples/inference/README.md | 68 +++++++++++++++++------------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/examples/inference/README.md b/examples/inference/README.md index f4fbeea38..9a85d1fc6 100644 --- a/examples/inference/README.md +++ b/examples/inference/README.md @@ -1,35 +1,36 @@ -# PiPPy (Pipline Parallelism for PyTorch) Distributed Inference for Large Models +# PiPPy (Pipeline Parallelism for PyTorch) Distributed Inference for Large Models -PiPPy helps to run very large models for inference by splitting the model into mutliple stages running on multiple GPUs. -PiPPy make this easier by providing an auto split API that automates this process for user. +PiPPy helps to run very large models for inference by splitting the model into multiple stages running on multiple GPUs. +PiPPy makes this easier by providing an auto split API that automates this process for user. ## How It Works PiPPy splits your model into multiple stages, each stage loaded on one gpu then the input batch will be further divided into micro-batches and run through the splits from -rank0..rankN. Results are returned to rank0 as rank 0 is running the PipelineDriver. Please read more on pipleines [here](https://github.com/pytorch/tau/blob/main/README.md) +rank0..rankN. Results are returned to rank0 as rank 0 is running the PipelineDriver. Please read more on pipelines [here](https://github.com/pytorch/tau/blob/main/README.md) The flowchart below helps to visualize the process in high level as well. drawing -## PiPPy Supports Arbitary Model Partitioning +## PiPPy Supports Arbitrary Model Partitioning -Unlike most of the available solutions that need to know the model architecture beforehand, PiPPy supports arbitary PyTorch models. -* PiPPy supports both manual splitting and auto split. -* Auto split uses `split_policy` and support both `equal_size` and `threshod` policies, the name are self-explanatory. -* PiPPy use FX to trace and split the model. +Unlike most of the available solutions that need to know the model architecture beforehand, PiPPy supports arbitrary PyTorch models. + +- PiPPy supports both manual splitting and auto split. +- Auto split uses `split_policy` and supports both `equal_size` and `threshold` policies, the name is self-explanatory. +- PiPPy uses FX to trace and split the model. ## Settings To Care About -* **world_size** specifies your availble number of gpus for partitioning your model +- **world_size** specifies your available number of gpus for partitioning your model -* **split_policy** it can be either `equal_size`, `split_into_equal_size(number_of_workers)` or `threshod`, `split_on_size_threshold(#some number)` +- **split_policy** it can be either `equal_size`, `split_into_equal_size(number_of_workers)` or `threshold`, `split_on_size_threshold(#some number)` -* **schedule** for the pipline, we use `PipelineDriverFillDrain` for inference, please learn more about it [here](https://github.com/pytorch/tau/blob/main/README.md#advanced-pipeline-schedules). +- **schedule** for the pipeline, we use `PipelineDriverFillDrain` for inference, please learn more about it [here](https://github.com/pytorch/tau/blob/main/README.md#advanced-pipeline-schedules). -* **chunks** it detemines the size of microbatches, microbatch = batch size/ chuncks +- **chunks** it detemines the size of microbatches, microbatch = batch size / chunks -* **FX Tracers** use PiPPyHFTracer() is dealing with a HuggingFace model otherwise set to `None` +- **FX Tracers** use PiPPyHFTracer() is dealing with a HuggingFace model otherwise set to `None` ## Get the Pipeline Driver @@ -46,27 +47,24 @@ pipe_driver, stage_mode = pippy.all_compile( concrete_args=concrete_args, ) ``` -**Note** As PiPPY leverage FX tracing for partitioning, as a result for HuggingFace models that have `generate` method will need to call `inject_pipeline_forward(model, pipe_driver)` to make `model.generate` available. This works for the decoder only models so far, encoder-decoder models such as `T5` is in progress. - -**Main difference between Pippy for training and inference is we dont need to call the init_data_parallel API in the inference. The reason is DDP init is only needed if we need backward pass which is not the case for inference.** +**Note** As PiPPY leverage FX tracing for partitioning, as a result for HuggingFace models that have `generate` method will need to call `inject_pipeline_forward(model, pipe_driver)` to make `model.generate` available. This works for the decoder only models so far, encoder-decoder models such as `T5` is in progress. +**Main difference between Pippy for training and inference is we don't need to call the init_data_parallel API in the inference. The reason is DDP init is only needed if we need backward pass which is not the case for inference.** ## HuggingFace Example **Define a function such as run_all() and add the followings to it.** -We use a HuggingFace `OPT` model as the running example here. The `hf_generate.py` also support other models for text generation such as `Bloom`, `gpt2` and `codegen` family of the models as well. Make sure to specifiy the model name as follows ` python hf_generate.py --model_name "facebook/opt-2.7b" `. This is not limited to LLMs it also works for models such [RegNet 10B](https://huggingface.co/facebook/regnet-y-10b-seer). +We use a HuggingFace `OPT` model as the running example here. The `hf_generate.py` also supports other models for text generation such as `Bloom`, `gpt2` and `codegen` family of the models as well. Make sure to specify the model name as follows `python hf_generate.py --model_name "facebook/opt-2.7b"`. This is not limited to LLMs it also works for models such [RegNet 10B](https://huggingface.co/facebook/regnet-y-10b-seer). - -* Load your model normally on CPU +- Load your model normally on CPU example: -` model = AutoModelForCausalLM.from_pretrained('facebook/opt-6.7b', use_cache=False) ` - +`model = AutoModelForCausalLM.from_pretrained('facebook/opt-6.7b', use_cache=False)` -* Setup the model split policy +- Setup the model split policy ```python from pippy import split_on_size_threshold, split_into_equal_size @@ -76,7 +74,8 @@ if args.auto_split == "threshold": elif args.auto_split == "equal_size": split_policy = split_into_equal_size(number_of_workers) ``` -* Make the concerete args (optional), If the model has inside an if-else condition, the concrete args can help FX determine which path to trace. For now control flow is not supported in FX tracing, we are working on integrating Torch Dynamo to make this more flexible. + +- Make the concrete args (optional), If the model has inside an if-else condition, the concrete args can help FX determine which path to trace. For now control flow is not supported in FX tracing, we are working on integrating Torch Dynamo to make this more flexible. ```python inputs = tokenizer(prompt, return_tensors="pt") @@ -85,20 +84,18 @@ sig = inspect.signature(model.forward) concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} ``` -* Get the pipline driver and model stages with `pippy.all_compile()`. See the section above. +- Get the pipeline driver and model stages with `pippy.all_compile()`. See the section above. -This under the hood, splits the model into a pipline, `Pipe.from_tracing` uses `torch.fx` symbolic tracing to turn our model into a directed acyclic graph (DAG) representation. Then, it groups together the operations and parameters into _pipeline stages_. Stages are represented as `submod_N` submodules, where `N` is a natural number. Note: here we use HF FX_tracer for tracing. +This under the hood, splits the model into a pipeline, `Pipe.from_tracing` uses `torch.fx` symbolic tracing to turn our model into a directed acyclic graph (DAG) representation. Then, it groups together the operations and parameters into _pipeline stages_. Stages are represented as `submod_N` submodules, where `N` is a natural number. Note: here we use HF FX_tracer for tracing. Loads to device directly using `defer_stage_init`, which basically let each rank trace the model and split the model and only materialize its own shard. -Finally, we get a `PipelineDriver` that runs the pipeline. It implements the runtime scheduling and communcation between stages. - +Finally, we get a `PipelineDriver` that runs the pipeline. It implements the runtime scheduling and communication between stages. -* Run the inference by passing input data to the `PipelineDriver`. +- Run the inference by passing input data to the `PipelineDriver`. `pipe_driver(**inputs)` - **we Now pass the run_all() function to the run_pippy() along with args to run the program** ```python @@ -114,26 +111,27 @@ if __name__ == "__main__": To run the full example, simply run your Python inference script: -` python hf_generate.py --model_name 'facebook/opt-6.7b'' ` +`python hf_generate.py --model_name 'facebook/opt-6.7b''` or -` torchrun --nproc_per_node=8 hf_generate.py --model_name 'facebook/opt-6.7b' ` +`torchrun --nproc_per_node=8 hf_generate.py --model_name 'facebook/opt-6.7b'` ### Run OPT model example This has been tested for [OPT 2.7 and 30B](https://huggingface.co/facebook/opt-30b) on 8 V100 GPUs. -` python hf_generate.py --model_name 'facebook/opt-30b' ` +`python hf_generate.py --model_name 'facebook/opt-30b'` ### Run Bloom model example This has been tested for [Bloom 3b](https://huggingface.co/docs/transformers/model_doc/bloom) on 8 V100 GPUs. -` python hf_generate.py --model_name 'bigscience/bloom-3b' ` +`python hf_generate.py --model_name 'bigscience/bloom-3b'` ### More models to try -- "gpt2" + +- "gpt2" - "bigscience/bloom-3b" - EleutherAI/gpt-neo-2.7B - Salesforce/codegen-2B-multi From 50250636c13fbea663a9c17454160ecae29a3f2c Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 19 Jan 2024 16:12:35 -0500 Subject: [PATCH 90/96] Re-support kwargs at run time (#929) ## Description Implements #928 Users want the first pipeline stage to accept kwargs if the original program does. This is controlled by the `_codegen` field of the graph as @angelayi suggests, so we make a copy from the traced program to submod0. ## Feature/Issue validation/testing Added kwargs in test_fwd.py. Also changed a few HF examples to directly kwargs. --- examples/hf/pippy_albert.py | 6 ++-- examples/hf/pippy_bart.py | 6 ++-- examples/hf/pippy_bert.py | 6 ++-- examples/hf/pippy_camemBert.py | 6 ++-- examples/hf/pippy_gpt2.py | 6 ++-- examples/hf/pippy_gptNeo.py | 6 ++-- examples/hf/pippy_opt.py | 6 ++-- pippy/IR.py | 32 +++++++++++++++++-- pippy/PipelineStage.py | 56 +++++++++++++++++++++------------- test/test_fwd.py | 9 +++--- 10 files changed, 91 insertions(+), 48 deletions(-) diff --git a/examples/hf/pippy_albert.py b/examples/hf/pippy_albert.py index 816ece388..15da03319 100644 --- a/examples/hf/pippy_albert.py +++ b/examples/hf/pippy_albert.py @@ -46,7 +46,6 @@ def run(args): # Input configs example_inputs = generate_inputs_for_model( model_class, albert, model_name, args.batch_size, args.device) - input_ids = example_inputs["input_ids"] # Annotate split points add_split_points(albert, args.world_size) @@ -55,7 +54,8 @@ def run(args): albert_pipe = Pipe.from_tracing( albert, num_chunks=args.chunks, - example_args=(input_ids, ), + example_args=(), + example_kwargs=example_inputs, ) nstages = len(list(albert_pipe.split_gm.children())) assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" @@ -72,7 +72,7 @@ def run(args): # Run if args.rank == 0: - stage(input_ids) + stage(**example_inputs) elif args.rank == args.world_size - 1: out = stage() else: diff --git a/examples/hf/pippy_bart.py b/examples/hf/pippy_bart.py index 0a84b97f0..6015334ec 100644 --- a/examples/hf/pippy_bart.py +++ b/examples/hf/pippy_bart.py @@ -43,7 +43,6 @@ def run(args): # Input configs example_inputs = generate_inputs_for_model( model_class, bart, model_name, args.batch_size, args.device) - input_ids = example_inputs["input_ids"] # Annotate split points add_split_points(bart, args.world_size) @@ -52,7 +51,8 @@ def run(args): bart_pipe = Pipe.from_tracing( bart, num_chunks=args.chunks, - example_args=(input_ids, ), + example_args=(), + example_kwargs=example_inputs, ) nstages = len(list(bart_pipe.split_gm.children())) assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" @@ -69,7 +69,7 @@ def run(args): # Run if args.rank == 0: - stage(input_ids) + stage(**example_inputs) elif args.rank == args.world_size - 1: out = stage() else: diff --git a/examples/hf/pippy_bert.py b/examples/hf/pippy_bert.py index 7d67d6077..e104f1967 100644 --- a/examples/hf/pippy_bert.py +++ b/examples/hf/pippy_bert.py @@ -43,7 +43,6 @@ def run(args): # Input configs example_inputs = generate_inputs_for_model( model_class, bert, model_name, args.batch_size, args.device) - input_ids = example_inputs["input_ids"] # Annotate split points add_split_points(bert, args.world_size) @@ -52,7 +51,8 @@ def run(args): bert_pipe = Pipe.from_tracing( bert, num_chunks=args.chunks, - example_args=(input_ids, ), + example_args=(), + example_kwargs=example_inputs, ) nstages = len(list(bert_pipe.split_gm.children())) assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" @@ -69,7 +69,7 @@ def run(args): # Run if args.rank == 0: - stage(input_ids) + stage(**example_inputs) elif args.rank == args.world_size - 1: out = stage() else: diff --git a/examples/hf/pippy_camemBert.py b/examples/hf/pippy_camemBert.py index 019d64faa..b56caaf38 100644 --- a/examples/hf/pippy_camemBert.py +++ b/examples/hf/pippy_camemBert.py @@ -43,7 +43,6 @@ def run(args): # Input configs example_inputs = generate_inputs_for_model( model_class, camembert, model_name, args.batch_size, args.device) - input_ids = example_inputs["input_ids"] # Annotate split points add_split_points(camembert, args.world_size) @@ -52,7 +51,8 @@ def run(args): camembert_pipe = Pipe.from_tracing( camembert, num_chunks=args.chunks, - example_args=(input_ids, ), + example_args=(), + example_kwargs=example_inputs, ) nstages = len(list(camembert_pipe.split_gm.children())) assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" @@ -69,7 +69,7 @@ def run(args): # Run if args.rank == 0: - stage(input_ids) + stage(**example_inputs) elif args.rank == args.world_size - 1: out = stage() else: diff --git a/examples/hf/pippy_gpt2.py b/examples/hf/pippy_gpt2.py index 060212277..2dae0d536 100644 --- a/examples/hf/pippy_gpt2.py +++ b/examples/hf/pippy_gpt2.py @@ -52,7 +52,6 @@ def run(args): # Input configs example_inputs = generate_inputs_for_model( model_class, gpt2, model_name, args.batch_size, args.device) - input_ids = example_inputs["input_ids"] # Annotate split points add_split_points(gpt2, args.world_size) @@ -61,7 +60,8 @@ def run(args): gpt2_pipe = Pipe.from_tracing( gpt2, num_chunks=args.chunks, - example_args=(input_ids, ), + example_args=(), + example_kwargs=example_inputs, ) assert len(list(gpt2_pipe.split_gm.children())) == args.world_size if args.rank == 0: @@ -77,7 +77,7 @@ def run(args): # Run if args.rank == 0: - stage(input_ids) + stage(**example_inputs) elif args.rank == args.world_size - 1: out = stage() else: diff --git a/examples/hf/pippy_gptNeo.py b/examples/hf/pippy_gptNeo.py index 795abff5a..983e2c7fa 100644 --- a/examples/hf/pippy_gptNeo.py +++ b/examples/hf/pippy_gptNeo.py @@ -43,7 +43,6 @@ def run(args): # Input configs example_inputs = generate_inputs_for_model( model_class, gptneo, model_name, args.batch_size, args.device) - input_ids = example_inputs["input_ids"] # Annotate split points add_split_points(gptneo, args.world_size) @@ -52,7 +51,8 @@ def run(args): gptneo_pipe = Pipe.from_tracing( gptneo, num_chunks=args.chunks, - example_args=(input_ids, ), + example_args=(), + example_kwargs=example_inputs, ) nstages = len(list(gptneo_pipe.split_gm.children())) assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" @@ -69,7 +69,7 @@ def run(args): # Run if args.rank == 0: - stage(input_ids) + stage(**example_inputs) elif args.rank == args.world_size - 1: out = stage() else: diff --git a/examples/hf/pippy_opt.py b/examples/hf/pippy_opt.py index a301e2ed3..4bbf2cb1a 100644 --- a/examples/hf/pippy_opt.py +++ b/examples/hf/pippy_opt.py @@ -43,7 +43,6 @@ def run(args): # Input configs example_inputs = generate_inputs_for_model( model_class, opt, model_name, args.batch_size, args.device) - input_ids = example_inputs["input_ids"] # Annotate split points add_split_points(opt, args.world_size) @@ -52,7 +51,8 @@ def run(args): opt_pipe = Pipe.from_tracing( opt, num_chunks=args.chunks, - example_args=(input_ids, ), + example_args=(), + example_kwargs=example_inputs, ) nstages = len(list(opt_pipe.split_gm.children())) assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}" @@ -69,7 +69,7 @@ def run(args): # Run if args.rank == 0: - stage(input_ids) + stage(**example_inputs) elif args.rank == args.world_size - 1: out = stage() else: diff --git a/pippy/IR.py b/pippy/IR.py index 5a4a4757c..48f32fcac 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -3,6 +3,7 @@ import logging import operator from enum import Enum +from inspect import Parameter, signature, Signature from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -655,8 +656,6 @@ def throw(self, *args, **kwargs): def forward(self, *args, **kwargs): executor_args = args if len(kwargs) > 0: - from inspect import Parameter, Signature - parameters = [] for node in self.split_gm.graph.nodes: if node.op == "placeholder": @@ -1005,6 +1004,34 @@ def move_param_to_callee( split.delete_all_unused_submodules() + # Users want the first pipeline stage to accept kwargs if the original + # program does. This is controlled by the `_codegen` field of the graph, + # so we make a copy here. Note: we only want the input spec and not the + # output spec, because the output spec is for the last stage. Maybe a + # TODO? Not sure yet. + submod0 = list(split.children())[0] + model_sign = signature(traced.forward) + model_num_args = len(model_sign.parameters) + submod0_sign = signature(submod0.forward) + submod0_num_args = len(submod0_sign.parameters) + if model_num_args != submod0_num_args: + # We don't change the signature of the first stage if it takes + # different number of args than original model + logger.info( + f"Original model takes {model_num_args} args but the first pipeline stage takes {submod0_num_args}. " + "Please provide args to respective pipeline stages." + ) + else: + # Support kwargs for the first stage + submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) + # `_replace` is actually not "private" or internal. based on this doc: + # To prevent conflicts with field names, the method and attribute names + # start with an underscore + submod0.graph._codegen.pytree_info = ( + submod0.graph._codegen.pytree_info._replace(out_spec=None) + ) + submod0.recompile() + split.graph.lint() split.recompile() @@ -1071,6 +1098,7 @@ def _trace_with_export( example_kwargs, constraints, ) + logger.debug(f"Traced model: {traced}") if split_policy is not None: traced = split_policy(traced) finally: diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index fe39a15d6..9367f6cfb 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -7,6 +7,7 @@ import torch.distributed as dist import torch.fx as fx from torch._subclasses.fake_tensor import FakeTensor +from torch.fx.node import map_aggregate, map_arg from torch.nn.parallel import DistributedDataParallel from pippy.backward import stage_backward @@ -47,6 +48,10 @@ class StageArgPlaceholder: pass +class StageKwargPlaceholder: + pass + + class PipelineStage(torch.nn.Module): def __init__( self, @@ -269,14 +274,15 @@ def create_recv_tensor( # `args` is a Tuple, hence we will have: # Tuple[RecvInfo] - args_recv_info = fx.node.map_arg(self.node.args, create_recv_tensor) + args_recv_info = map_arg(self.node.args, create_recv_tensor) # `kwargs` is a Dict, hence we will have: # Dict[keyword, RecvInfo] - kwargs_recv_info = fx.node.map_arg(self.node.kwargs, create_recv_tensor) + kwargs_recv_info = map_arg(self.node.kwargs, create_recv_tensor) logger.info( - f"[{self.group_rank}] " f"Activation recv info: {args_recv_info}" + f"[{self.group_rank}] " + f"Activation recv / args info: {args_recv_info}" ) return args_recv_info, kwargs_recv_info @@ -370,9 +376,9 @@ def map_recv_to_send(a): grad_send_info.append(None) return None - fx.node.map_aggregate(args_recv_info, map_recv_to_send) + map_aggregate(args_recv_info, map_recv_to_send) - fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send) + map_aggregate(kwargs_recv_info, map_recv_to_send) logger.info(f"[{self.group_rank}] " f"Grad send info: {grad_send_info}") return grad_send_info @@ -422,35 +428,43 @@ def _recv_and_fill_inputs( act_recv = self.recv_tensor_fn(recv_reqs) + chunk_args_list: List = [] if self.args_split: chunk_args = self.args_split[chunk] chunk_args_list = list(chunk_args) def recv_args(info): if isinstance(info, RecvInfo): + # This is an activation to receive return act_recv(info) else: - return chunk_args_list.pop(0) # type: ignore[has-type] + # This is a pass-in argument + if len(chunk_args_list): + return chunk_args_list.pop(0) # type: ignore[has-type] + else: + # kwargs were treated as args in graph phase. That's why + # there are extra placeholders here. We mark them and filter + # them out later. + return StageKwargPlaceholder() - composite_args = fx.node.map_aggregate( + composite_args = map_aggregate( self.args_recv_info[chunk], recv_args, ) + # Filter out kwarg placeholders + composite_args = tuple( + x + for x in composite_args + if not isinstance(x, StageKwargPlaceholder) + ) + # Middle stages won't have incoming activations in kwargs form. So if + # kwargs_split is not empty, it must be model inputs for stage 0. We + # hence pass it as is to the interal submodule, without performing + # `recv_args` on it. + composite_kwargs: Dict = {} if self.kwargs_split: - chunk_kwargs = self.kwargs_split[chunk] - - def recv_kwargs(info): - if isinstance(info, RecvInfo): - return act_recv(info) - else: - k = next(iter(chunk_kwargs)) # type: ignore[has-type] - return chunk_kwargs.pop(k) # type: ignore[has-type] - - composite_kwargs = fx.node.map_aggregate( - self.kwargs_recv_info[chunk], - recv_kwargs, - ) + composite_kwargs = self.kwargs_split[chunk] # Wait for all recvs to finish for work in recv_reqs: @@ -496,7 +510,7 @@ def _recv_grads( recv_grad = self.recv_tensor_fn(grad_recv_reqs) # Receive gradients - grads = fx.node.map_aggregate( + grads = map_aggregate( self.grad_recv_info[bwd_chunk], recv_grad, ) diff --git a/test/test_fwd.py b/test/test_fwd.py index 675844170..9feedfd4c 100644 --- a/test/test_fwd.py +++ b/test/test_fwd.py @@ -26,7 +26,7 @@ def __init__(self): self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin = torch.nn.Linear(d_hid, d_hid) - def forward(self, x, y): + def forward(self, x, y=torch.zeros(batch_size, d_hid)): x = torch.mm(x, self.mm_param) skip_connection = x x = x + y @@ -54,7 +54,8 @@ def run_worker(args): pipe = Pipe.from_tracing( mod, args.chunks, - example_args=(x, y), + example_args=(x,), + example_kwargs={"y": y}, ) stage = PipelineStage( @@ -65,7 +66,7 @@ def run_worker(args): # Run if args.rank == 0: - stage(x, y) + stage(x, y=y) elif args.rank == args.world_size - 1: out = stage() else: @@ -76,7 +77,7 @@ def run_worker(args): # Last rank checks result if args.rank == args.world_size - 1: - ref_out = mod(x, y) + ref_out = mod(x, y=y) torch.testing.assert_close(out, ref_out) print( f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" From 3d128641c59c8601d99485da275e8ea57838d74f Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 22 Jan 2024 10:54:11 -0500 Subject: [PATCH 91/96] Remove hf utils used by old FX tracer and pipeline driver (#930) --- pippy/hf/__init__.py | 18 --- pippy/hf/bart.py | 47 ------- pippy/hf/bert.py | 25 ---- pippy/hf/gpt2.py | 25 ---- pippy/hf/roberta.py | 25 ---- pippy/hf/t5.py | 45 ------ pippy/hf/utils.py | 326 ------------------------------------------- 7 files changed, 511 deletions(-) delete mode 100644 pippy/hf/__init__.py delete mode 100644 pippy/hf/bart.py delete mode 100644 pippy/hf/bert.py delete mode 100644 pippy/hf/gpt2.py delete mode 100644 pippy/hf/roberta.py delete mode 100644 pippy/hf/t5.py delete mode 100644 pippy/hf/utils.py diff --git a/pippy/hf/__init__.py b/pippy/hf/__init__.py deleted file mode 100644 index 9c93e2faf..000000000 --- a/pippy/hf/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.hf.utils import ( - inject_pipeline_forward, - PiPPyHFTracer, - PiPPySeq2SeqTrainer, - PiPPySeq2SeqTrainingArguments, - PiPPyTrainer, - PiPPyTrainingArguments, -) - -__all__ = [ - "PiPPyHFTracer", - "PiPPyTrainingArguments", - "PiPPySeq2SeqTrainingArguments", - "PiPPyTrainer", - "PiPPySeq2SeqTrainer", - "inject_pipeline_forward", -] diff --git a/pippy/hf/bart.py b/pippy/hf/bart.py deleted file mode 100644 index db801d245..000000000 --- a/pippy/hf/bart.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy import annotate_split_points, PipeSplitWrapper - - -def add_split_points(bart, xxcoders_per_rank): - assert bart.config.encoder_layers == bart.config.decoder_layers - assert ( - bart.config.encoder_layers + bart.config.decoder_layers - ) % xxcoders_per_rank == 0 - encoders_per_rank = xxcoders_per_rank - for i in range( - 0, - (bart.config.encoder_layers + encoders_per_rank - 1) - // encoders_per_rank, - ): - annotate_split_points( - bart, - { - f"model.encoder.layers.{i * encoders_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - decoders_per_rank = xxcoders_per_rank - for i in range( - 0, - (bart.config.decoder_layers + decoders_per_rank - 1) - // decoders_per_rank, - ): - annotate_split_points( - bart, - { - f"model.decoder.layers.{i * decoders_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - annotate_split_points( - bart, {"lm_head": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - - -def split(model, num_ranks): - emb_head = 2 # encoder embeddings + decoder embeddings - num_of_ranks_for_xxcoders = num_ranks - emb_head - xxcoders = model.config.encoder_layers + model.config.decoder_layers - xxcoders_per_rank = ( - xxcoders + num_of_ranks_for_xxcoders - 1 - ) // num_of_ranks_for_xxcoders - # print(f"xxcoders_per_rank = {xxcoders_per_rank}") - add_split_points(model, xxcoders_per_rank) diff --git a/pippy/hf/bert.py b/pippy/hf/bert.py deleted file mode 100644 index 537275422..000000000 --- a/pippy/hf/bert.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy import annotate_split_points, PipeSplitWrapper - - -def add_split_points(bert, encoders_per_rank): - for i in range(0, bert.config.num_hidden_layers // encoders_per_rank): - annotate_split_points( - bert, - { - f"bert.encoder.layer.{i * encoders_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - annotate_split_points( - bert, {"classifier": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - - -def split(model, num_ranks): - emb_head = 2 # embeddings + head - num_of_ranks_for_encoders = num_ranks - emb_head - encoders_per_rank = ( - model.config.num_hidden_layers + num_of_ranks_for_encoders - 1 - ) // num_of_ranks_for_encoders # a divider of bert.config.num_hidden_layers: [1, 2, 3, 4, 6, 12] - # print(f"encoders_per_rank = {encoders_per_rank}") - add_split_points(model, encoders_per_rank) diff --git a/pippy/hf/gpt2.py b/pippy/hf/gpt2.py deleted file mode 100644 index 1e5dab8ad..000000000 --- a/pippy/hf/gpt2.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy import annotate_split_points, PipeSplitWrapper - - -def add_split_points(gpt2, decoders_per_rank): - for i in range(0, gpt2.config.n_layer // decoders_per_rank): - annotate_split_points( - gpt2, - { - f"transformer.h.{i * decoders_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - annotate_split_points( - gpt2, {"transformer.ln_f": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - - -def split(model, num_ranks): - emb_head = 2 # embeddings + head - num_of_ranks_for_decoders = num_ranks - emb_head - decoders_per_rank = ( - model.config.n_layer + num_of_ranks_for_decoders - 1 - ) // num_of_ranks_for_decoders # a divider of model.config.n_layer: [1, 2, 3, 4, 6, 12] - # print(f"encoders_per_rank = {decoders_per_rank}") - add_split_points(model, decoders_per_rank) diff --git a/pippy/hf/roberta.py b/pippy/hf/roberta.py deleted file mode 100644 index 6458af536..000000000 --- a/pippy/hf/roberta.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy import annotate_split_points, PipeSplitWrapper - - -def add_split_points(roberta, encoders_per_rank): - for i in range(0, roberta.config.num_hidden_layers // encoders_per_rank): - annotate_split_points( - roberta, - { - f"roberta.encoder.layer.{i}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - annotate_split_points( - roberta, {"lm_head": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - - -def split(model, num_ranks): - emb_head = 2 # embeddings + head - num_of_ranks_for_encoders = num_ranks - emb_head - encoders_per_rank = ( - model.config.num_hidden_layers + num_of_ranks_for_encoders - 1 - ) // num_of_ranks_for_encoders # a divider of roberta.config.num_hidden_layers: [1, 2, 3, 4, 6, 12] - # print(f"encoders_per_rank = {encoders_per_rank}") - add_split_points(model, encoders_per_rank) diff --git a/pippy/hf/t5.py b/pippy/hf/t5.py deleted file mode 100644 index 33da268d6..000000000 --- a/pippy/hf/t5.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy import annotate_split_points, PipeSplitWrapper - - -def add_split_points(t5, xxcoders_per_rank): - assert t5.config.num_layers == t5.config.num_decoder_layers - assert ( - t5.config.num_layers + t5.config.num_decoder_layers - ) % xxcoders_per_rank == 0 - encoders_per_rank = xxcoders_per_rank - for i in range( - 0, (t5.config.num_layers + encoders_per_rank - 1) // encoders_per_rank - ): - annotate_split_points( - t5, - { - f"encoder.block.{i * encoders_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - decoders_per_rank = xxcoders_per_rank - for i in range( - 0, - (t5.config.num_decoder_layers + decoders_per_rank - 1) - // decoders_per_rank, - ): - annotate_split_points( - t5, - { - f"decoder.block.{i * decoders_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING - }, - ) - annotate_split_points( - t5, {"lm_head": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - - -def split(model, num_ranks): - emb_head = 2 # encoder embeddings + decoder embeddings - num_of_ranks_for_xxcoders = num_ranks - emb_head - xxcoders = model.config.num_layers + model.config.num_decoder_layers - xxcoders_per_rank = ( - xxcoders + num_of_ranks_for_xxcoders - 1 - ) // num_of_ranks_for_xxcoders - # print(f"xxcoders_per_rank = {xxcoders_per_rank}") - add_split_points(model, xxcoders_per_rank) diff --git a/pippy/hf/utils.py b/pippy/hf/utils.py deleted file mode 100644 index 7a8053f78..000000000 --- a/pippy/hf/utils.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import contextlib -import logging -import os -import types -from dataclasses import dataclass, field -from typing import Optional - -import torch -import torch.distributed -import transformers -import transformers.utils.fx as fx -from transformers import ( - Seq2SeqTrainer, - Seq2SeqTrainingArguments, - Trainer, - TrainingArguments, -) -from transformers.modeling_utils import ModuleUtilsMixin -from transformers.utils import cached_property, is_torch_available - -from pippy.PipelineDriver import PipelineDriverBase - - -logger = logging.getLogger(__name__) - - -@dataclass -class PiPPyTrainingArguments(TrainingArguments): - dp_group_size: int = field(default=-1, metadata={"help": "DP group size."}) - - pp_group_size: int = field( - default=-1, metadata={"help": "Pipeline group size."} - ) - - rank: int = field( - default=int(os.getenv("RANK", -1)), metadata={"help": "Rank."} - ) - - driver_index: int = field( - default=-1, - metadata={ - "help": "Index of current pipeline driver in all pipeline drivers." - }, - ) - - local_driver_index: int = field( - default=-1, - metadata={ - "help": "Index of current pipeline driver in local pipeline drivers." - }, - ) - - master_addr: str = field( - default=os.getenv("MASTER_ADDR", "localhost"), - metadata={"help": "Master address."}, - ) - - master_port: str = field( - default=os.getenv("MASTER_PORT", "29500"), - metadata={"help": "Master port."}, - ) - - exclude_master: int = field( - default=0, - metadata={"help": "Exclude master.", "choices": [0, 1]}, - ) - - # TODO: use `no_cuda` instead? - cuda: int = field( - default=int(torch.cuda.is_available()), - metadata={"help": "Exclude master.", "choices": [0, 1]}, - ) - - chunks: Optional[int] = field( - default=None, metadata={"help": "Number of Chunks."} - ) - - record_mem_dumps: int = field( - default=0, metadata={"help": "Record memory dumps flag."} - ) - - checkpoint: int = field(default=1, metadata={"help": "Checkpoint flag."}) - - _device: Optional[torch.device] = None - - @property - def device(self): - if self.rank == -1: - if self.cuda and torch.cuda.is_available(): - return torch.device("cuda") - else: - return torch.device("cpu") - else: - return super().device - - @device.setter - def device(self, value): - self._device = value - - # Process Group including all drivers - _driver_group = None - - @property - def driver_group(self): - return self._driver_group - - @driver_group.setter - def driver_group(self, value): - self._driver_group = value - - @cached_property - def _setup_devices(self) -> "torch.device": - if self.cuda: - n_devs = torch.cuda.device_count() - if n_devs > 0: - dev_id = self.rank % n_devs - self._device = torch.device(f"cuda:{dev_id}") - else: - self.cuda = 0 - self._device = torch.device("cpu") - else: - self._device = torch.device("cpu") - return self._device - - # Overriding property `world_size` in TrainingArguments - # Here it means number of pipelines - @property - def world_size(self): - return self.dp_group_size - - # Overriding property `process_index` in TrainingArguments - # Here it means the index of current pipeline driver in all pipeline drivers - @property - def process_index(self): - return self.driver_index - - # Overriding property `local_process_index` in TrainingArguments - # Here it means the index of current pipeline driver in local pipeline drivers - @property - def local_process_index(self): - return self.local_driver_index - - def __post_init__(self): - super().__post_init__() - self.local_rank = ( - -1 - ) # must be -1 to disable automatic DDP in the HF trainer - - @contextlib.contextmanager - def main_process_first(self, local=True, desc="work"): - if is_torch_available() and self.world_size > 1: - main_process_desc = "main process" - if local: - is_main_process = self.local_process_index == 0 - main_process_desc = "main local process" - # elif is_sagemaker_mp_enabled(): - # is_main_process = smp.rank() == 0 - else: - is_main_process = self.process_index == 0 - - try: - if not is_main_process: - # tell all replicas to wait - logger.debug( - f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}" - ) - # if is_torch_tpu_available(): - # xm.rendezvous(desc) - # elif is_sagemaker_dp_enabled(): - # dist.barrier() - # else: - torch.distributed.barrier(group=self.driver_group) - yield - finally: - if is_main_process: - # the wait is over - # logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") - # if is_torch_tpu_available(): - # pass # TODO xm.rendezvous(desc) - # elif is_sagemaker_dp_enabled(): - # pass # TODO dist.barrier() - # else: - torch.distributed.barrier(group=self.driver_group) - else: - yield - - -@dataclass -class PiPPySeq2SeqTrainingArguments( - PiPPyTrainingArguments, Seq2SeqTrainingArguments -): - pass - - -def _backward( - self, gradient=None, retain_graph=None, create_graph=False, inputs=None -): - # No-op backward for pipe mode, because otherwise HF Trainer will call loss.backward second time and will crash - pass - - -class PiPPyTrainer(Trainer): - def create_optimizer(self): - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args - ) - self.optimizer = self.model.instantiate_optimizer( # type: ignore[operator] - optimizer_cls, **optimizer_kwargs - ) - return self.optimizer - - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - self.lr_scheduler = self.model.instantiate_lr_scheduler( # type: ignore[operator] - transformers.optimization.TYPE_TO_SCHEDULER_FUNCTION[ - self.args.lr_scheduler_type - ], - num_warmup_steps=self.args.get_warmup_steps(self.args.max_steps), - num_training_steps=self.args.max_steps, - ) - return self.lr_scheduler - - def compute_loss(self, model, inputs, return_outputs=False): - if return_outputs: - loss, outputs = Trainer.compute_loss( - self, model, inputs, return_outputs - ) - loss.backward = types.MethodType(_backward, loss) - return loss, outputs - else: - loss = Trainer.compute_loss(self, model, inputs, return_outputs) - loss.backward = types.MethodType(_backward, loss) - return loss - - -class PiPPySeq2SeqTrainer(PiPPyTrainer, Seq2SeqTrainer): - pass - - -def torch_ones_wrapper(*args, **kwargs): - return torch.ones(*args, **kwargs) - - -def torch_arange_wrapper(*args, **kwargs): - return torch.arange(*args, **kwargs) - - -def torch_full_like_wrapper(*args, **kwargs): - return torch.full_like(*args, **kwargs) - - -def torch_create_extended_attention_mask_for_decoder_wrapper(*args, **kwargs): - return ModuleUtilsMixin.create_extended_attention_mask_for_decoder( - *args, **kwargs - ) - - -def torch_zeros_wrapper(*args, **kwargs): - return torch.zeros(*args, **kwargs) - - -class PiPPyHFTracer(fx.HFTracer): - def trace(self, *args, **kwargs): - graph = super().trace(*args, **kwargs) - for node in graph.nodes: - if node.op == "call_function": - if getattr(node.target, "_orig", None) == torch.ones: - node.target = torch_ones_wrapper - elif getattr(node.target, "_orig", None) == torch.arange: - node.target = torch_arange_wrapper - elif getattr(node.target, "_orig", None) == torch.full_like: - node.target = torch_full_like_wrapper - elif ( - getattr(node.target, "_orig", None) - == ModuleUtilsMixin.create_extended_attention_mask_for_decoder - ): - node.target = ( - torch_create_extended_attention_mask_for_decoder_wrapper - ) - elif getattr(node.target, "_orig", None) == torch.zeros: - node.target = torch_zeros_wrapper - return graph - - -# The `DotDict` class adds dot notation access to dictionary attributes. -class DotDict(dict): - def __getattr__(self, attr): - return self.get(attr) - - def __setattr__(self, key, value): - self.__setitem__(key, value) - - def __delattr__(self, item): - self.__delitem__(item) - - -# This is an experimental utility function that replaces the original model's forward method with PiPPy's PipelineDriver -# forward method. It is used to support HuggingFace's `generate()` method, which is defined in a `GenerationMixin` -# class that `PreTrainedModel` inherits from. We choose this replacement path instead of writing our own `generate()` -# method because the `generate()` method would call into many `GenerationMixin` APIs that may be implemented differently -# by each model. -def inject_pipeline_forward( - model: torch.nn.Module, - pipe_driver: PipelineDriverBase, -): - logger.info( - f"Inserting PiPPy pipeline forward into model {model._get_name()}" - ) - # Inject pipeline driver as a member object of original model - setattr(model, "pippy_pipeline_driver", pipe_driver) - - # Define a new forward method that uses PiPPy's pipeline driver - def pippy_forward(self, *args, **kwargs): - output = self.pippy_pipeline_driver(*args, **kwargs) - if isinstance(output, dict): - # Add dot access if output is a dictionary. The output of a traced HF model is a traditional dict which has - # only [`key`] access. The wrapping is needed for Transformer versons >= 4.28 which access attributes of - # output via dot notation, such as `output.logits`. See for example the `generate()` method and - # `modeling_output.py`. - output = DotDict(output) - return output - - # Replace the forward method in original model - setattr(model, "forward", types.MethodType(pippy_forward, model)) From 099f140965dc892f6886f2a774ed21bfda409098 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 22 Jan 2024 11:44:03 -0500 Subject: [PATCH 92/96] Remove stale tests (#931) Tests written in pipeline driver style --- test/local_test_ddp.py | 249 ---------------- test/local_test_forward_auto_parallel.py | 173 ------------ test/local_test_forward_backward.py | 346 ----------------------- test/local_test_forward_hf_bert.py | 157 ---------- test/local_test_forward_hf_gpt2.py | 153 ---------- 5 files changed, 1078 deletions(-) delete mode 100644 test/local_test_ddp.py delete mode 100644 test/local_test_forward_auto_parallel.py delete mode 100644 test/local_test_forward_backward.py delete mode 100644 test/local_test_forward_hf_bert.py delete mode 100644 test/local_test_forward_hf_gpt2.py diff --git a/test/local_test_ddp.py b/test/local_test_ddp.py deleted file mode 100644 index 1b8e60b08..000000000 --- a/test/local_test_ddp.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import copy -import os -import unittest - -import pippy.fx - -import torch -import torch.distributed.rpc as rpc -from pippy import run_pippy -from pippy.IR import ( - MultiUseParameterConfig, - Pipe, - pipe_split, - TrivialLossWrapper, -) -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) - -# TODOs for implementing forward/backward/loss with schedules: -# * ability to switch between full-batch loss vs. per-microbatch loss. shen mentioned -# this might change numerics. So we should have the ability to compute loss over -# the whole minibatch rather than doing it for each micro-batch - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - - -def get_grad_from_executor(executor, qualname): - mod = executor.local_value().mod - if isinstance(mod, torch.nn.parallel.DistributedDataParallel): - return mod.module.get_parameter(qualname).grad - else: - return mod.get_parameter(qualname).grad - - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(pp_ranks, args): - torch.manual_seed(42) - - d_hid = 50 - bs = 503 - CHUNKS = 5 - DEBUG_MASK_MINIBATCHES = True - check_numeric = True if args.cuda == 0 else False # TODO - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - def rand_zeros_or_ones(shape): - return torch.randint(0, 2, shape).float() - - class ZeroOneLinear(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.w = torch.nn.Parameter(rand_zeros_or_ones((in_dim, out_dim))) - - def forward(self, x): - return x @ self.w - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter( - rand_zeros_or_ones((d_hid, d_hid)) - ) - self.mm_param2 = torch.nn.Parameter( - rand_zeros_or_ones((d_hid, d_hid)) - ) - self.lin = ZeroOneLinear(d_hid, d_hid) - self.register_buffer( - "buffer", 0.00001 * rand_zeros_or_ones((bs + 100, d_hid)) - ) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - pipe_split() - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - return x - - ec = ExampleCode() - ec.to(args.device) - ec(torch.randn(bs, d_hid, device=args.device)) - ec.train() - - # TODO: works with sum, need to define semantics for e.g. mean - mse_loss = torch.nn.MSELoss(reduction="sum") - wrapper = TrivialLossWrapper(ec, mse_loss) - ec_pipe = Pipe.from_tracing(wrapper, MULTI_USE_PARAM_CONFIG) - if args.rank == 0: - print(ec_pipe.split_gm) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - CHUNKS, - args.pp_group_size, - all_ranks=pp_ranks, - _debug_mask_minibatches=DEBUG_MASK_MINIBATCHES, - checkpoint=bool(args.checkpoint), - ) - print(f"Rank {args.rank} Instantiated pipe with ranks {pp_ranks}") - - pipe_driver.init_data_parallel(dp_group_size=args.dp_group_size) - - torch.manual_seed(args.rank) - input = torch.randn(bs, d_hid, device=args.device) - target = torch.randn(bs, d_hid, device=args.device) - - # TODO: distributed optimizer - out = pipe_driver(input, target) - - print(f"Rank {args.rank} got loss value {out}") - - if not check_numeric: - print("DDP + PP API test passed") - return - - all_grad_qualnames = {k: None for k, v in ec_pipe.named_parameters()} - - pipe_grads = {} - - for name in all_grad_qualnames: - assert "split_gm." in name - _, module_name, param_qualname = name.split(".", maxsplit=2) - - assert module_name in pipe_driver.remote_stage_executor_rrefs - rank, module_rref = pipe_driver.remote_stage_executor_rrefs[module_name] - grad_value = rpc.rpc_sync( - module_rref.owner(), - get_grad_from_executor, - (module_rref, param_qualname), - ) - pipe_grads[name] = copy.deepcopy(grad_value) - - # User driver group as the DDP reference group - wrapper_ddp = torch.nn.parallel.DistributedDataParallel( - wrapper, process_group=args.driver_group - ) - - wrapper_out = wrapper_ddp(input, target) - wrapper_out.backward() - - not_close_grads = [] - ref_grads = {} - - for name in all_grad_qualnames: - remapped_qualname = ec_pipe.remap_qualname(name) - param = wrapper_ddp.module.get_parameter(remapped_qualname) - assert ( - name in pipe_grads - ), f"{name} not in pipe_grads keys {pipe_grads.keys()}" - ref_grads[name] = copy.deepcopy(param.grad) - if not torch.allclose(pipe_grads[name], ref_grads[name]): - not_close_grads.append(name) - - if len(not_close_grads): - raise AssertionError(f"Gradients not close: {not_close_grads}") - - print("Gradient equivalence test passed") - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - # in row-major - # DP ranks are contiguous rows of size `args.dp_group_size` - # PP ranks are non-contiguous columns of size `args.pp_group_size` - # - # if dp_group_size = 4 and pp_group_size = 3 - # - # 0 1 2 3 - # 4 5 6 7 - # 8 9 10 11 - # - # DP ranks are [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11] - # PP ranks are [0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11] - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # ExampleCode has two stages - args.pp_group_size = 2 - assert args.world_size % args.pp_group_size == 0 - - # Use world size to determine DDP size - args.dp_group_size = args.world_size // args.pp_group_size - print(f"Using data parallel group size: {args.dp_group_size}") - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestDDP(unittest.TestCase): - def test_ddp(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_forward_auto_parallel.py b/test/local_test_forward_auto_parallel.py deleted file mode 100644 index 19f0edd01..000000000 --- a/test/local_test_forward_auto_parallel.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import unittest - -import pippy.fx - -import torch -import torch.autograd.profiler_legacy -from pippy import run_pippy -from pippy.auto_parallelization import AutoParallelConfig, dp_auto_parallel -from pippy.IR import MultiUseParameterConfig, Pipe -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - d_hid = 512 - bs = 503 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - x = torch.relu(x) - return {"out": x} - - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - ec(ec_input) - - auto_parallel_ctx = AutoParallelConfig( - n_compute_nodes=args.world_size, n_devices_per_node=1, n_microbatches=5 - ) - ec_pipe = Pipe.from_tracing( - ec, - MULTI_USE_PARAM_CONFIG, - split_policy=dp_auto_parallel(auto_parallel_ctx), - ) - print(ec_pipe.split_gm) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - 5, - args.world_size, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - # # Warm up and correctness runs - out = pipe_driver(ec_input) - ref_out = ec_pipe(ec_input) - - # run with different chunk size to exercise microbatch and scheduling components - pipe_driver.chunks = 1 - pipe_driver(ec_input) - pipe_driver.chunks = 100 - pipe_driver(ec_input) - - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close(out["out"], ref_out["out"]) - print( - f'equivalence test passed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - - # # Profiling runs - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_driver.chunks = 5 - out = pipe_driver(ec_input) - ref_out = ec_pipe(ec_input) - print( - f'profiling run completed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestForwardAutoParallelTest(unittest.TestCase): - def test_forward_auto_parallel(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_forward_backward.py b/test/local_test_forward_backward.py deleted file mode 100644 index 43c177a9b..000000000 --- a/test/local_test_forward_backward.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import copy -import os -import unittest - -import pippy.fx - -import torch -import torch.distributed.rpc as rpc -from pippy import run_pippy -from pippy.IR import ( - MultiUseParameterConfig, - Pipe, - pipe_split, - TrivialLossWrapper, -) -from pippy.microbatch import split_args_kwargs_into_chunks -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) - -# TODOs for implementing forward/backward/loss with schedules: -# * ability to switch between full-batch loss vs. per-microbatch loss. shen mentioned -# this might change numerics. So we should have the ability to compute loss over -# the whole minibatch rather than doing it for each micro-batch - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - - -# import ctypes -# libc = ctypes.cdll.LoadLibrary("libc.so.6") -# libc.prctl.argtypes = [ -# ctypes.c_int, -# ctypes.c_ulong, -# ctypes.c_ulong, -# ctypes.c_ulong, -# ctypes.c_ulong, -# ] -# libc.prctl.restype = ctypes.c_int -# libc.prctl(0x59616D61, -1, 0, 0, 0) - - -def get_grad_from_executor(executor, qualname): - return executor.local_value().mod.get_parameter(qualname).grad - - -def set_grad_in_executor(executor, qualname, value): - param = executor.local_value().mod.get_parameter(qualname) - param.grad = value - - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - torch.manual_seed(42) - - d_hid = 50 - bs = 503 - CHUNKS = 5 - DEBUG_MASK_MINIBATCHES = True - REF_USE_MICROBATCHES = True - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - def rand_zeros_or_ones(shape): - return torch.randint(0, 2, shape).float() - - class ZeroOneLinear(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.w = torch.nn.Parameter(rand_zeros_or_ones((in_dim, out_dim))) - - def forward(self, x): - return x @ self.w - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter( - rand_zeros_or_ones((d_hid, d_hid)) - ) - self.mm_param2 = torch.nn.Parameter( - rand_zeros_or_ones((d_hid, d_hid)) - ) - self.lin = ZeroOneLinear(d_hid, d_hid) - self.register_buffer( - "buffer", 0.00001 * rand_zeros_or_ones((bs + 100, d_hid)) - ) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - size = x.size() # for https://github.com/pytorch/PiPPy/issues/256 - pipe_split() - x.reshape( - size[0], size[1] - ) # for https://github.com/pytorch/PiPPy/issues/256 - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - size = x.size() # for https://github.com/pytorch/PiPPy/issues/256 - pipe_split() - x.reshape( - size[0], size[1] - ) # for https://github.com/pytorch/PiPPy/issues/256 - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - size = x.size() # for https://github.com/pytorch/PiPPy/issues/256 - pipe_split() - x.reshape( - size[0], size[1] - ) # for https://github.com/pytorch/PiPPy/issues/256 - x = torch.relu(x) - return x - - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - ec(ec_input) - ec.train() - - # TODO: works with sum, need to define semantics for e.g. mean - mse_loss = torch.nn.MSELoss(reduction="sum") - wrapper = TrivialLossWrapper(ec, mse_loss) - ec_pipe = Pipe.from_tracing(wrapper, MULTI_USE_PARAM_CONFIG) - print(ec_pipe.split_gm) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - CHUNKS, - args.world_size, - _debug_mask_minibatches=DEBUG_MASK_MINIBATCHES, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - target = torch.randn(bs, d_hid, device=args.device) - - # TODO: distributed optimizer - out = pipe_driver(ec_input, target) - - all_grad_qualnames = {k: None for k, v in ec_pipe.named_parameters()} - - pipe_grads = {} - - for name in all_grad_qualnames: - assert "split_gm." in name - _, module_name, param_qualname = name.split(".", maxsplit=2) - - assert module_name in pipe_driver.remote_stage_executor_rrefs - stage_id, module_rref = pipe_driver.remote_stage_executor_rrefs[ - module_name - ] - grad_value = rpc.rpc_sync( - module_rref.owner(), - get_grad_from_executor, - (module_rref, param_qualname), - ) - pipe_grads[name] = copy.deepcopy(grad_value) - - optim = torch.optim.SGD(ec_pipe.split_gm.parameters(), lr=0.05) - optim.zero_grad() - if REF_USE_MICROBATCHES: - args_split, kwargs_split = split_args_kwargs_into_chunks( - (ec_input, target), - {}, - CHUNKS, - args_chunk_spec=None, - kwargs_chunk_spec=None, - _debug_mask_minibatches=DEBUG_MASK_MINIBATCHES, - ) - ref_outs = [] - for chunk in range(CHUNKS): - ref_outs.append(ec_pipe(*args_split[chunk])) - ref_out = torch.sum(torch.stack(ref_outs)) - else: - ref_out = ec_pipe(ec_input, target) - - # Shared parameter sync for reference. TODO: move this to actual runtime - for param_set in ec_pipe.replicated_params: - grad_values = [] - for module_name, param_qualname in param_set.items(): - grad_values.append( - ec_pipe.get_parameter( - f"split_gm.{module_name}.{param_qualname}" - ).grad - ) - - synced_value = torch.sum(torch.stack(grad_values), dim=0) - - for module_name, param_qualname in param_set.items(): - ec_pipe.get_parameter( - f"split_gm.{module_name}.{param_qualname}" - ).grad = synced_value - - # TODO: scale output - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close(out, ref_out) - print( - f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" - ) - - not_close_grads = [] - ref_grads = {} - for name in all_grad_qualnames: - param = ec_pipe.get_parameter(name) - assert ( - name in pipe_grads - ), f"{name} not in pipe_grads keys {pipe_grads.keys()}" - ref_grads[name] = param.grad - if not torch.allclose(pipe_grads[name], param.grad): - not_close_grads.append(name) - - for name in not_close_grads: - pipe_grad = pipe_grads[name] - ref_grad = ref_grads[name] - - relative_delta = torch.abs(pipe_grad - ref_grad) / ref_grad - assert False, ( - f"Gradient for parameter {name} is not numerically close! Relative diff mean " - f"{torch.mean(relative_delta)} std {torch.std(relative_delta)} max {torch.max(relative_delta)}" - ) - - print("Gradient equivalence test passed") - - # Test equivalence with initial code as well - orig_optim = torch.optim.SGD(ec.parameters(), lr=0.05) - orig_optim.zero_grad() - orig_loss = mse_loss(ec(ec_input), target) - orig_loss.backward() - torch.testing.assert_close(out, orig_loss) - - not_close_orig_grads = [] - not_found_mappings = [] - - for name in all_grad_qualnames: - try: - remapped_qualname = ec_pipe.remap_qualname(name) - except KeyError: - not_found_mappings.append(name) - else: - orig_grad = wrapper.get_parameter(remapped_qualname).grad - pipe_grad = pipe_grads[name] - if not torch.allclose(pipe_grad, orig_grad): - not_close_orig_grads.append(name) - print(name, torch.abs(pipe_grad - orig_grad) / orig_grad) - print( - name, - torch.max(torch.abs(pipe_grad - orig_grad) / orig_grad), - ) - - assert len(not_found_mappings) == 0, ( - f"No qualname mapping found between pipelined and original " - f"model: {not_found_mappings}" - ) - - assert len(not_close_orig_grads) == 0, ( - f"Grads not close between pipelined and original " - f"model: {not_close_orig_grads}" - ) - - print("correctness checks with original module passed") - - # # # Profiling runs - # with torch.autograd.profiler_legacy.profile(enabled=PROFILING_ENABLED) as prof: - # pipe_driver._debug_mask_minibatches = False - # pipe_driver.chunks = CHUNKS - # out = pipe_driver(ec_input, target) - # ref_out = ec_pipe.split_gm(ec_input, target) - # print(f'profiling run completed {torch.sum(ref_out)} ref {torch.sum(ref_out)}') - # if PROFILING_ENABLED: - # prof.export_chrome_trace(f'{os.path.splitext(os.path.basename(__file__))[0]}.json') - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestForwardAutoParallelTest(unittest.TestCase): - def test_forward_backward(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_forward_hf_bert.py b/test/local_test_forward_hf_bert.py deleted file mode 100644 index 3b5d99563..000000000 --- a/test/local_test_forward_hf_bert.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import inspect -import os - -import pippy.fx - -import torch -import torch.autograd.profiler_legacy -from pippy import run_pippy -from pippy.hf import PiPPyHFTracer -from pippy.IR import ( - annotate_split_points, - MultiUseParameterConfig, - Pipe, - PipeSplitWrapper, -) -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, -) -from transformers import BertConfig, BertModel - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - bs = 20 - seq_length = 32 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - bert = BertModel(BertConfig()) - bert.to(args.device) - bert.eval() - bert_input = torch.zeros( - bs, seq_length, dtype=torch.long, device=args.device - ).random_(bert.config.vocab_size) - bert(bert_input) - - for i in range(bert.config.num_hidden_layers): - annotate_split_points( - bert, {f"encoder.layer.{i}": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - annotate_split_points( - bert, {"pooler": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - - input_names = bert.dummy_inputs.keys() - sig = inspect.signature(bert.forward) - concrete_args = { - p.name: p.default - for p in sig.parameters.values() - if p.name not in input_names - } - - print("Instantiating BERT Pipeline") - bert_pipe = Pipe.from_tracing( - bert, - MULTI_USE_PARAM_CONFIG, - tracer=PiPPyHFTracer(), - concrete_args=concrete_args, - ) - - assert bert.config.num_hidden_layers + 2 == len( - list(bert_pipe.split_gm.children()) - ) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - bert_pipe, - 5, - args.world_size, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - # # Warm up and correctness runs - out = pipe_driver(bert_input) - ref_out = bert_pipe(bert_input) - - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close( - out["last_hidden_state"], ref_out["last_hidden_state"] - ) - torch.testing.assert_close( - out["pooler_output"], ref_out["pooler_output"] - ) - print( - f'equivalence test passed {torch.sum(out["last_hidden_state"])} ref {torch.sum(ref_out["last_hidden_state"])}' - ) - - # # Profiling runs - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_driver._debug_mask_minibatches = False - pipe_driver.chunks = 5 - out = pipe_driver(bert_input) - ref_out = bert_pipe(bert_input) - print( - f'profiling run completed {torch.sum(out["last_hidden_state"])} ref {torch.sum(ref_out["last_hidden_state"])}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 14)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args() - - run_pippy(run_master, args) diff --git a/test/local_test_forward_hf_gpt2.py b/test/local_test_forward_hf_gpt2.py deleted file mode 100644 index a8beb677e..000000000 --- a/test/local_test_forward_hf_gpt2.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import inspect -import os - -import pippy.fx - -import torch -import torch.autograd.profiler_legacy -from pippy import run_pippy -from pippy.hf import PiPPyHFTracer -from pippy.IR import ( - annotate_split_points, - MultiUseParameterConfig, - Pipe, - PipeSplitWrapper, -) -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, -) -from transformers import GPT2Config, GPT2Model - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - bs = 20 - seq_length = 32 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - gpt2 = GPT2Model(GPT2Config(use_cache=False)) - gpt2.to(args.device) - gpt2.eval() - gpt2_input = torch.zeros( - bs, seq_length, dtype=torch.long, device=args.device - ).random_(gpt2.config.vocab_size) - - for i in range(gpt2.config.n_layer): - annotate_split_points( - gpt2, {f"h.{i}": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - annotate_split_points(gpt2, {"ln_f": PipeSplitWrapper.SplitPoint.BEGINNING}) - - input_names = gpt2.dummy_inputs.keys() - sig = inspect.signature(gpt2.forward) - concrete_args = { - p.name: p.default - for p in sig.parameters.values() - if p.name not in input_names - } - - print("Instantiating GPT2 Pipeline") - gpt2_pipe = Pipe.from_tracing( - gpt2, - MULTI_USE_PARAM_CONFIG, - tracer=PiPPyHFTracer(), - concrete_args=concrete_args, - ) - - assert gpt2.config.n_layer + 2 == len(list(gpt2_pipe.split_gm.children())) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - gpt2_pipe, - 5, - args.world_size, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - # # Warm up and correctness runs - print( - "Running GPT2 pipeline. NB: if this is too slow, set OMP_NUM_THREADS to a higher value" - ) - out = pipe_driver(gpt2_input) - print("Running reference pipeline") - ref_out = gpt2_pipe(gpt2_input) - - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close( - out["last_hidden_state"], ref_out["last_hidden_state"] - ) - print( - f'equivalence test passed {torch.sum(out["last_hidden_state"])} ref {torch.sum(ref_out["last_hidden_state"])}' - ) - - # # Profiling runs - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_driver._debug_mask_minibatches = False - pipe_driver.chunks = 5 - out = pipe_driver(gpt2_input) - ref_out = gpt2_pipe(gpt2_input) - print( - f'profiling run completed {torch.sum(out["last_hidden_state"])} ref {torch.sum(ref_out["last_hidden_state"])}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 14)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args() - - run_pippy(run_master, args) From 3632106dacaf66c179f00fec49df8f78d2f8b53d Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 26 Jan 2024 11:53:39 -0500 Subject: [PATCH 93/96] Fill value for non-tensor inputs during shape prop (#933) Fixes #932 --- pippy/IR.py | 61 ++++++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/pippy/IR.py b/pippy/IR.py index 48f32fcac..23c0db526 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -1004,34 +1004,6 @@ def move_param_to_callee( split.delete_all_unused_submodules() - # Users want the first pipeline stage to accept kwargs if the original - # program does. This is controlled by the `_codegen` field of the graph, - # so we make a copy here. Note: we only want the input spec and not the - # output spec, because the output spec is for the last stage. Maybe a - # TODO? Not sure yet. - submod0 = list(split.children())[0] - model_sign = signature(traced.forward) - model_num_args = len(model_sign.parameters) - submod0_sign = signature(submod0.forward) - submod0_num_args = len(submod0_sign.parameters) - if model_num_args != submod0_num_args: - # We don't change the signature of the first stage if it takes - # different number of args than original model - logger.info( - f"Original model takes {model_num_args} args but the first pipeline stage takes {submod0_num_args}. " - "Please provide args to respective pipeline stages." - ) - else: - # Support kwargs for the first stage - submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) - # `_replace` is actually not "private" or internal. based on this doc: - # To prevent conflicts with field names, the method and attribute names - # start with an underscore - submod0.graph._codegen.pytree_info = ( - submod0.graph._codegen.pytree_info._replace(out_spec=None) - ) - submod0.recompile() - split.graph.lint() split.recompile() @@ -1176,6 +1148,33 @@ def from_tracing( f"{node.meta['example_value'] if 'example_value' in node.meta else 'None'}", ) + # Users want the first pipeline stage to accept kwargs if the original + # program does. This is controlled by the `_codegen` field of the graph, + # so we make a copy here. Note: we only want the input spec and not the + # output spec, because the output spec is for the last stage. Maybe a + # TODO? Not sure yet. + split = pipe.split_gm + submod0 = list(split.children())[0] + submod0_sign = signature(submod0.forward) + model_sign = signature(traced.forward) + if len(model_sign.parameters) != len(submod0_sign.parameters): + # We don't change the signature of the first stage if it takes + # different number of args than original model + logger.info( + f"Original model takes {len(model_sign.parameters)} args but the first pipeline stage takes {len(submod0_sign.parameters)}. " + "Please provide args to respective pipeline stages." + ) + else: + # Support kwargs for the first stage + submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) + # `_replace` is actually not "private" or internal. based on this doc: + # To prevent conflicts with field names, the method and attribute names + # start with an underscore + submod0.graph._codegen.pytree_info = ( + submod0.graph._codegen.pytree_info._replace(out_spec=None) + ) + submod0.recompile() + return pipe def __str__(self): @@ -1284,8 +1283,12 @@ def __init__( self.stop_prop = False def run(self): + # Prepare input from node.meta, which will be filled during tracing if + # input is a tensor. For non-tensor inputs, e.g. constants, its value + # would have been burned into the program, so we use an arbitrary value + # here (None). inp = tuple( - node.meta["val"] + node.meta["val"] if "val" in node.meta else None for node in self.module.graph.nodes if node.op == "placeholder" ) From e51c8b9bc51afb99319ef097904c7e8d941c992a Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 26 Jan 2024 13:26:53 -0500 Subject: [PATCH 94/96] Add back support for stage.remap_qualname() (#934) ## Description `stage.remap_qualname(key)` maps a stage's parameter name (`key`) back to the original model's parameter name. This now works: ``` # Stage module's state dict sd = stage.submod.state_dict() remapped_keys = [stage.remap_qualname(k) for k in sd.keys()] # Original model's state dict old_keys = mod.state_dict().keys() # Confirm they match assert all(rk in old_keys for rk in remapped_keys) ``` --- pippy/IR.py | 64 ++++-------------------------------------- pippy/PipelineStage.py | 11 ++++++-- pippy/utils.py | 49 ++++++++++++++++++++++++++++++++ test/test_fwd.py | 10 +++++++ test/test_ir.py | 11 -------- 5 files changed, 74 insertions(+), 71 deletions(-) diff --git a/pippy/IR.py b/pippy/IR.py index 23c0db526..18791b988 100644 --- a/pippy/IR.py +++ b/pippy/IR.py @@ -29,6 +29,7 @@ from pippy.backward import _null_coalesce_accumulate, stage_backward from pippy.debug import PIPPY_VERBOSITY from pippy.microbatch import LossReducer, split_args_kwargs_into_chunks +from pippy.utils import QualnameMapMixin logger = logging.getLogger(__name__) @@ -498,54 +499,6 @@ def _direct_serialization_reduce(self): ) -class QualnameMapMixin: - """ - A mixin class to provide qualname remap functionality for both Pipe object - and submodules - """ - - def __init__( - self, - splitter_qualname_map: Dict[str, str] = None, - tracer_qualname_map: Dict[str, str] = None, - ): - self.new_to_old_qualname_mapping: Dict[str, str] = ( - splitter_qualname_map or {} - ) - self.tracer_qualname_map = tracer_qualname_map - - def remap_qualname(self, qualname: str): - # TODO: annoying - if qualname.startswith("split_gm."): - qualname = qualname[len("split_gm.") :] - - name_before_split = None - if qualname in self.new_to_old_qualname_mapping: - name_before_split = self.new_to_old_qualname_mapping[qualname] - else: - # The qualname map does not store recursive items, thus, - # when passed a qualname with leaves, we need to perform longest prefix match - # Split from the right, one each time - split_names = qualname.rsplit(".", 1) - leaf = split_names[-1] - while len(split_names) > 1: - prefix = split_names[0] - if prefix in self.new_to_old_qualname_mapping: - old_prefix = self.new_to_old_qualname_mapping[prefix] - name_before_split = ".".join([old_prefix, leaf]) - break - split_names = prefix.rsplit(".", 1) - leaf = ".".join([split_names[-1], leaf]) - - if name_before_split is None: - raise RuntimeError(f"Could not find mapping for {qualname}") - - if self.tracer_qualname_map is not None: - return self.tracer_qualname_map[name_before_split] - else: - return name_before_split - - class Pipe(QualnameMapMixin, torch.nn.Module): def __init__( self, @@ -615,6 +568,10 @@ def __init__( ) # Create qualname mapping for each submodule + # Dict looks like this: + # {submod_name : Dict{old_qualname : new_qualname}} + # We save this information here for use during pipeline stage creation. + self.submod_qualname_mappings: Dict[str, Dict[str, str]] = {} for m_qualname, mod in self.split_gm.named_children(): # "submod_x." prefix mod_prefix = m_qualname + "." @@ -624,16 +581,7 @@ def __init__( # Remove prefix new_key = k[len(mod_prefix) :] mod_qualname_mapping.setdefault(new_key, v) - # Add a remap mixin to submodule instance - # TODO: this class change is commented out because it breaks - # recompilation if we want to recompile mod after. For example, we - # may recompile mod to modify the "device" kwarg of a `torch.ones` - # node (trace on cpu/meta, run on cuda). - # See: https://github.com/pytorch/vision/issues/5826 - # mod.__class__ = type( - # "PipeStageModule", (QualnameMapMixin, mod.__class__), {} - # ) - setattr(mod, "new_to_old_qualname_mapping", mod_qualname_mapping) + self.submod_qualname_mappings[m_qualname] = mod_qualname_mapping def throw(self, *args, **kwargs): raise RuntimeError( diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 9367f6cfb..1d5e6184b 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -14,7 +14,7 @@ from pippy.debug import map_debug_info from pippy.IR import Pipe from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks -from pippy.utils import flatten_args, modify_graph_op_device +from pippy.utils import flatten_args, modify_graph_op_device, QualnameMapMixin logger = logging.getLogger(__name__) @@ -52,7 +52,7 @@ class StageKwargPlaceholder: pass -class PipelineStage(torch.nn.Module): +class PipelineStage(torch.nn.Module, QualnameMapMixin): def __init__( self, pipe: Pipe, @@ -98,6 +98,13 @@ def __init__( f"{self.submod}" ) + # Enable `remap_qualname` method + QualnameMapMixin.__init__( + self, + pipe.submod_qualname_mappings[self.name], + pipe.tracer_qualname_map, + ) + # Find my forward node in graph found_node = False for node in self.split_gm.graph.nodes: diff --git a/pippy/utils.py b/pippy/utils.py index b0417248c..1e400b17f 100644 --- a/pippy/utils.py +++ b/pippy/utils.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging +from typing import Dict import torch import torch.distributed as dist @@ -92,3 +93,51 @@ def modify_graph_op_device( if modified: gm.recompile() + + +class QualnameMapMixin: + """ + A mixin class to provide qualname remap functionality for both Pipe object + and submodules + """ + + def __init__( + self, + splitter_qualname_map: Dict[str, str] = None, + tracer_qualname_map: Dict[str, str] = None, + ): + self.new_to_old_qualname_mapping: Dict[str, str] = ( + splitter_qualname_map or {} + ) + self.tracer_qualname_map = tracer_qualname_map + + def remap_qualname(self, qualname: str): + # TODO: annoying + if qualname.startswith("split_gm."): + qualname = qualname[len("split_gm.") :] + + name_before_split = None + if qualname in self.new_to_old_qualname_mapping: + name_before_split = self.new_to_old_qualname_mapping[qualname] + else: + # The qualname map does not store recursive items, thus, + # when passed a qualname with leaves, we need to perform longest prefix match + # Split from the right, one each time + split_names = qualname.rsplit(".", 1) + leaf = split_names[-1] + while len(split_names) > 1: + prefix = split_names[0] + if prefix in self.new_to_old_qualname_mapping: + old_prefix = self.new_to_old_qualname_mapping[prefix] + name_before_split = ".".join([old_prefix, leaf]) + break + split_names = prefix.rsplit(".", 1) + leaf = ".".join([split_names[-1], leaf]) + + if name_before_split is None: + raise RuntimeError(f"Could not find mapping for {qualname}") + + if self.tracer_qualname_map is not None: + return self.tracer_qualname_map[name_before_split] + else: + return name_before_split diff --git a/test/test_fwd.py b/test/test_fwd.py index 9feedfd4c..74afc32e4 100644 --- a/test/test_fwd.py +++ b/test/test_fwd.py @@ -83,6 +83,16 @@ def run_worker(args): f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" ) + # Test qualname mapping + sd = stage.submod.state_dict() + print(f"Rank {args.rank} state dict keys: {sd.keys()}") + remapped_keys = [stage.remap_qualname(k) for k in sd.keys()] + print(f"Rank {args.rank} remapped keys: {remapped_keys}") + # Confirm remapped keys are consistent with original model + old_keys = mod.state_dict().keys() + assert all(rk in old_keys for rk in remapped_keys) + print(f"Qualname test passed") + def main(args=None): parser = argparse.ArgumentParser() diff --git a/test/test_ir.py b/test/test_ir.py index 6e1fcae2c..63e754c15 100644 --- a/test/test_ir.py +++ b/test/test_ir.py @@ -791,17 +791,6 @@ def test_remap_qualname(self): old_name in old_names ), f"Remapped parameter {old_name} not found in {old_names}" - # Check qualname mapping for submodule - # Not supported at the moment - """ - for _, stage_mod in ec_pipe.split_gm.named_children(): - for new_name, _ in stage_mod.named_parameters(): - old_name = stage_mod.remap_qualname(new_name) - assert ( - old_name in old_names - ), f"Remapped parameter {old_name} not found in {old_names}" - """ - if __name__ == "__main__": unittest.main() From 7f4c69d85b05d1b653123fd1e3571a5a66d5ed66 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 26 Jan 2024 14:09:09 -0500 Subject: [PATCH 95/96] Clean stale code, doc & examples (#935) Also added examples section in doc to refer to HuggingFace examples --- .github/workflows/pippy_tests.yaml | 2 +- .gitmodules | 3 - README.md | 301 +--------- example.py => examples/basic/example.py | 0 .../basic/example_train.py | 0 examples/{hf => huggingface}/hf_utils.py | 0 examples/{hf => huggingface}/pippy_albert.py | 0 examples/{hf => huggingface}/pippy_bart.py | 0 examples/{hf => huggingface}/pippy_bert.py | 0 .../{hf => huggingface}/pippy_blenderbot.py | 0 .../{hf => huggingface}/pippy_camemBert.py | 0 .../{hf => huggingface}/pippy_convBert.py | 0 examples/{hf => huggingface}/pippy_deberta.py | 0 .../{hf => huggingface}/pippy_debertaV2.py | 0 .../{hf => huggingface}/pippy_distilBert.py | 0 examples/{hf => huggingface}/pippy_electra.py | 0 examples/{hf => huggingface}/pippy_fnet.py | 0 examples/{hf => huggingface}/pippy_gpt2.py | 0 examples/{hf => huggingface}/pippy_gptNeo.py | 0 .../{hf => huggingface}/pippy_layoutLM.py | 0 examples/{hf => huggingface}/pippy_m2m100.py | 0 examples/{hf => huggingface}/pippy_mbart.py | 0 .../{hf => huggingface}/pippy_megatronBert.py | 0 .../{hf => huggingface}/pippy_mobileBert.py | 0 examples/{hf => huggingface}/pippy_mt5.py | 0 examples/{hf => huggingface}/pippy_opt.py | 0 examples/{hf => huggingface}/pippy_pegasus.py | 0 examples/{hf => huggingface}/pippy_plBart.py | 0 examples/{hf => huggingface}/pippy_t5.py | 0 examples/{hf => huggingface}/pippy_trOCR.py | 0 examples/{hf => huggingface}/pippy_xlnet.py | 0 examples/selective2d/2d_train.py | 497 ---------------- examples/selective2d/dim_solver | 1 - examples/selective2d/model.py | 540 ------------------ examples/selective2d/run.sh | 20 - pippy-ddp.pptx | Bin 39551 -> 0 bytes pippy-ddp.svg | 1 - pippy/auto_parallelization.py | 308 ---------- run_all_tests.sh | 15 - 39 files changed, 13 insertions(+), 1675 deletions(-) rename example.py => examples/basic/example.py (100%) rename example_train.py => examples/basic/example_train.py (100%) rename examples/{hf => huggingface}/hf_utils.py (100%) rename examples/{hf => huggingface}/pippy_albert.py (100%) rename examples/{hf => huggingface}/pippy_bart.py (100%) rename examples/{hf => huggingface}/pippy_bert.py (100%) rename examples/{hf => huggingface}/pippy_blenderbot.py (100%) rename examples/{hf => huggingface}/pippy_camemBert.py (100%) rename examples/{hf => huggingface}/pippy_convBert.py (100%) rename examples/{hf => huggingface}/pippy_deberta.py (100%) rename examples/{hf => huggingface}/pippy_debertaV2.py (100%) rename examples/{hf => huggingface}/pippy_distilBert.py (100%) rename examples/{hf => huggingface}/pippy_electra.py (100%) rename examples/{hf => huggingface}/pippy_fnet.py (100%) rename examples/{hf => huggingface}/pippy_gpt2.py (100%) rename examples/{hf => huggingface}/pippy_gptNeo.py (100%) rename examples/{hf => huggingface}/pippy_layoutLM.py (100%) rename examples/{hf => huggingface}/pippy_m2m100.py (100%) rename examples/{hf => huggingface}/pippy_mbart.py (100%) rename examples/{hf => huggingface}/pippy_megatronBert.py (100%) rename examples/{hf => huggingface}/pippy_mobileBert.py (100%) rename examples/{hf => huggingface}/pippy_mt5.py (100%) rename examples/{hf => huggingface}/pippy_opt.py (100%) rename examples/{hf => huggingface}/pippy_pegasus.py (100%) rename examples/{hf => huggingface}/pippy_plBart.py (100%) rename examples/{hf => huggingface}/pippy_t5.py (100%) rename examples/{hf => huggingface}/pippy_trOCR.py (100%) rename examples/{hf => huggingface}/pippy_xlnet.py (100%) delete mode 100644 examples/selective2d/2d_train.py delete mode 160000 examples/selective2d/dim_solver delete mode 100644 examples/selective2d/model.py delete mode 100644 examples/selective2d/run.sh delete mode 100644 pippy-ddp.pptx delete mode 100644 pippy-ddp.svg delete mode 100644 pippy/auto_parallelization.py delete mode 100755 run_all_tests.sh diff --git a/.github/workflows/pippy_tests.yaml b/.github/workflows/pippy_tests.yaml index 561a94eaa..ded9ab246 100644 --- a/.github/workflows/pippy_tests.yaml +++ b/.github/workflows/pippy_tests.yaml @@ -99,7 +99,7 @@ jobs: - name: Run forward-loss-backward integration test run: torchrun --nproc-per-node 4 test/test_bwd.py --schedule ${{ matrix.schedule }} - name: Run example - run: torchrun --nproc-per-node 3 example.py + run: torchrun --nproc-per-node 3 examples/basic/example.py # - name: Run null_coalesce_accumulate integration test # run: python test/local_test_null_coalesce_accumulate.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} # - name: Run PP + DDP test diff --git a/.gitmodules b/.gitmodules index 3161f1f29..9778ca341 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,6 +4,3 @@ [submodule "docs/src/pytorch-sphinx-theme"] path = docs/src/pytorch-sphinx-theme url = https://github.com/pytorch/pytorch_sphinx_theme.git -[submodule "examples/selective2d/dim_solver"] - path = examples/selective2d/dim_solver - url = https://github.com/moonbucks/dim_solver.git diff --git a/README.md b/README.md index c6628952d..6560e8c41 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,9 @@ [**Why PiPPy?**](#why-pippy) | [**Install guide**](#install) -| [**PiPPy Quickstart**](#pippy-quickstart) +| [**Examples**](#examples) +| [**PiPPy Explained**](#pippy-explained) | [**Future Work**](#future-work) -| [**References**](#references) -| [**License**](#license) -| [**Citing PiPPy**](#citing-pippy) # Why PiPPy? @@ -55,7 +53,11 @@ To expose PiPPy for development such that changes to this repo are reflected in python setup.py develop ``` -# PiPPy Quickstart +# Examples + +In this repo, we provide rich examples based on realistic models. In particular, we show how to apply PiPPy without any code change to the model. Please refer to the [HuggingFace examples directory](examples/huggingface/). Examples include: [BERT](examples/huggingface/pippy_bert.py), [GPT2](examples/huggingface/pippy_gpt2.py), [T5](examples/huggingface/pippy_t5.py), [LLaMA](examples/llama/), etc. + +# PiPPy Explained PiPPy consists of two parts: a _compiler_ and a _runtime_. The compiler takes your model code, splits it up, and transforms it into a `Pipe`, which is a wrapper that describes the model at each pipeline stage and their data-flow relationship. The runtime executes the `Pipe` in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. We will cover the APIs for these concepts in this section. @@ -188,302 +190,23 @@ else: stage() ``` -Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use `torchrun` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named [example.py](example.py) and then run it with `torchrun` like so: +Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use `torchrun` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named [example.py](examples/basic/example.py) and then run it with `torchrun` like so: ``` torchrun --nproc_per_node=3 example.py ``` -## Note: the following sections need to be updated. ## - -## Forward vs. Forward-loss-backward - -The above example demonstrated only pipelining the `forward()` computation, for example for the purposes of model evaluation. We can extend the example to include the loss and back-propagation computation for the purposes of model training. Let us replace the code under the `if local_rank == 0:` block in the example: - -```python -if local_rank == 0: - from pippy.PipelineDriver import PipelineDriverFillDrain - from pippy.microbatch import TensorChunkSpec - - # LossWrapper is a convenient base class you can use to compose your model - # with the desired loss function for the purpose of pipeline parallel training. - # Since the loss is executed as part of the pipeline, it cannot reside in the - # training loop, so you must embed it like this - from pippy.IR import LossWrapper - class ModelLossWrapper(LossWrapper): - def forward(self, x, target): - return self.loss_fn(self.module(x), target) - - # TODO: mean reduction - loss_wrapper = ModelLossWrapper(module=mn, loss_fn=torch.nn.MSELoss(reduction='sum')) - - # Instantiate the `Pipe` similarly to before, but with two differences: - # 1) We pass in the `loss_wrapper` module to include the loss in the - # computation - # 2) We specify `output_loss_value_spec`. This is a data structure - # that should mimic the structure of the output of LossWrapper - # and has a True value in the position where the loss value will - # be. Since LossWrapper returns just the loss, we just pass True - pipe = Pipe.from_tracing(loss_wrapper, output_loss_value_spec=True) - - # We now have two args: `x` and `target`, so specify batch dimension - # for both. - args_chunk_spec = (TensorChunkSpec(0), TensorChunkSpec(0)) - kwargs_chunk_spec = {} - # The output is now a `loss` value, which is a scalar tensor. - # PiPPy's default is to concatenate outputs, but that will not - # work with a scalar tensor. So we use a LossReducer instead - # to merge together the partial loss values. - from pippy.microbatch import LossReducer - output_chunk_spec = LossReducer(0.0, lambda a, b: a + b) - - # Instantiate the driver as usual. - driver = PipelineDriverFillDrain( - pipe, args_chunk_spec=args_chunk_spec, kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, world_size=world_size) -``` - -The comments describe the new components that have been added to enable training. We can print out the new `pipe` to see the loss and backward stages: - -```python - print(pipe) - - """ - def forward(self, x, target): - submod_0 = self.submod_0(x) - submod_1 = self.submod_1(submod_0) - submod_2 = self.submod_2(submod_1, target) - stage_backward = pippy_IR_stage_backward(stage_output = (submod_2,), output_grads = (None,), input_values = [submod_1, target], outputs_with_grads_idxs = [0], stage_info = 'stage_backward for stage %submod_2 : [#users=2] = call_module[target=submod_2](args = (%submod_1, %target), kwargs = {})'); target = None - getitem = stage_backward[0] - getitem_1 = stage_backward[1]; stage_backward = None - getitem_2 = getitem[0] - getitem_3 = getitem[1]; getitem = None - stage_backward_1 = pippy_IR_stage_backward(stage_output = (submod_1,), output_grads = (getitem_2,), input_values = [submod_0], outputs_with_grads_idxs = [0], stage_info = 'stage_backward_1 for stage %submod_1 : [#users=3] = call_module[target=submod_1](args = (%submod_0,), kwargs = {})'); submod_1 = getitem_2 = None - getitem_4 = stage_backward_1[0] - getitem_5 = stage_backward_1[1]; stage_backward_1 = None - getitem_6 = getitem_4[0]; getitem_4 = None - stage_backward_2 = pippy_IR_stage_backward(stage_output = (submod_0,), output_grads = (getitem_6,), input_values = [x], outputs_with_grads_idxs = [0], stage_info = 'stage_backward_2 for stage %submod_0 : [#users=3] = call_module[target=submod_0](args = (%x,), kwargs = {})'); submod_0 = getitem_6 = x = None - getitem_7 = stage_backward_2[0] - getitem_8 = stage_backward_2[1]; stage_backward_2 = None - getitem_9 = getitem_7[0]; getitem_7 = None - sync_barrier = pippy_IR_sync_barrier(submod_2, [getitem_1, getitem_5, getitem_8]); submod_2 = getitem_1 = getitem_5 = getitem_8 = None - return sync_barrier - """ -``` - -As before, we can now call the `driver` object to execute the pipeline; However this time, the forward, loss, and backward passes will all be executed: - -```python - x = torch.randn(512, 512) - target = torch.randn(512, 256) - - # note the additional `target` argument, as the module we're running is - # ModelLossWrapper - driver.chunks = 64 - output = driver(x, target) - - # NOTE: Backpropagation is run implicitly by `driver.forward()` when supplied with - # a Pipe with loss computation. You should not run `output.backward()`; PiPPy's - # runtime has already done that. This divergence from the PyTorch API exists - # because of the distributed nature of pipeline parallelism. - - reference_output = loss_wrapper(x, target) - - # Compare numerics of pipeline and original model - torch.testing.assert_close(output, reference_output) - - print(' Pipeline parallel model ran successfully! '.center(80, '*')) -``` - - - -The above code has computed the gradients for the parameters of the model, but has not applied updates to the parameters. We use an `Optimizer` to do this by using the `instantiate_optimizer()` method on the pipeline driver: - -```python - # Instantiate remote Adam optimizers. `instantiate_optimizer` takes the - # optimizer class as the first argument, then additional arguments to that - # optimizer. Note that the `parameters` argument is omitted; PiPPy will - # populate that value for each pipeline stage for you. - optimizer = driver.instantiate_optimizer(torch.optim.Adam) - # Also instantiate a learning rate scheduler. Note that the `optimizer` argument is - # omitted; PiPPy will populate that argument for each pipeline stage - lr_scheduler = driver.instantiate_lr_scheduler(torch.optim.lr_scheduler.LinearLR, total_iters=100) - - N_TRAINING_STEPS = 100 - - x = torch.randn(512, 512) - target = torch.randn(512, 10) - driver.chunks = 64 - for i in range(N_TRAINING_STEPS): - optimizer.zero_grad() - pipe_loss = driver(x, target) - optimizer.step() - lr_scheduler.step() - - log_info = f' Training step {i}, loss: {pipe_loss}, LR: {lr_scheduler.get_last_lr()} ' - print(log_info.center(80, '*')) - - - print(' Pipeline parallel model ran successfully! '.center(80, '*')) -``` - -Launching this file [example_train.py](example_train.py) with torchrun similarly as before: - -``` -torchrun --nproc_per_node=3 example_train.py -``` - -We see the model train, memorizing the 512 examples in our input batch: - -``` -***** Training step 0, loss: 5197.943359375, LR: [0.00033999999999999997] ****** -***** Training step 1, loss: 5104.2080078125, LR: [0.0003466666666666666] ****** -**** Training step 2, loss: 5025.17236328125, LR: [0.00035333333333333327] ***** -****** Training step 3, loss: 4940.39453125, LR: [0.0003599999999999999] ******* -***** Training step 4, loss: 4845.97998046875, LR: [0.0003666666666666666] ***** -**** Training step 5, loss: 4742.01220703125, LR: [0.00037333333333333326] ***** -<...> -**** Training step 94, loss: 16.55620765686035, LR: [0.0009666666666666657] **** -*** Training step 95, loss: 12.990996360778809, LR: [0.0009733333333333323] **** -**** Training step 96, loss: 8.712753295898438, LR: [0.000979999999999999] ***** -*** Training step 97, loss: 3.0083038806915283, LR: [0.0009866666666666659] **** -*** Training step 98, loss: 6.2024617195129395, LR: [0.0009933333333333326] **** -*** Training step 99, loss: 12.104667663574219, LR: [0.0009999999999999994] **** -****************** Pipeline parallel model ran successfully! ******************* -``` - -## PiPPy on CUDA - -When using PiPPy on CUDA devices, the model can be on a CUDA device before being passed to PiPPy, for example: - -```python -model = MyNetwork() -# `dev_id` is the GPU index -model.to(f'cuda:{dev_id}') -pipe = Pipe.from_tracing(model) -``` - -Note: in cases where a model's parameters do not fit into the memory of a single GPU, PiPPy also supports deferred -distributed initialization which only materializes a pipeline stage on its corresponding GPU worker. For details, -please see PiPPy's `Pipe.defer_stage_init` API. - -In addition, some backend options need to be passed to RPC initialization. RPC by default uses the TensorPipe backend -that supports point-to-point communication in an asynchronous manner. Configurations for TensorPipe can be specified -with a `TensorPipeRpcBackendOptions` object. Here is an example: - -```python -# Create a backend option with 256 threads in the thread-pool used by -# TensorPipeAgent to execute requests -options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256) - -# Number of GPUs per node -# (Assuming each node has the same number of GPUs) -n_devs = torch.cuda.device_count() -# Local GPU index of this worker within the node -dev_id = rank % n_devs -# Set device mappings from this worker to other RPC callees -for i in range(world_size): - options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) - -# Initialize RPC -rpc.init_rpc(f'worker{rank}', rank=rank, world_size=world_size, - rpc_backend_options=options) -``` - -The `set_device_map` call takes two arguments: the first one is the callee worker's name, the second one is a dictionary -that maps from this worker's device to the callee worker's device. For more details, please refer to the documentation -of `TensorPipeRpcBackendOptions`. - -## PiPPy + Data Parallelism - -![pippy-ddp](pippy-ddp.svg) - -PiPPy supports combination with Distributed Data Parallel (DDP) via the `init_data_parallel` API. Users can create -multiple pipelines each targeting a distinct subset of ranks. For the same stages of the different pipelines, data -parallel training is possible as these stages are replicas of the same model chunk. The created pipelines can -collectively call the `init_data_parallel` API. PiPPy will then create a DDP group for each stage, across the pipelines. -In the backward pass of the training, gradients will be exchanged among the same stages of the different pipelines via -the DDP groups. - -Here is an example of PiPPy + Distributed Data Parallel: - -```python -# Number of ranks in each pipeline -pp_group_size = 3 -# Number of pipelines that coordinate in DDP fashion -dp_group_size = 4 -# The total number of ranks -world_size = pp_group_size * dp_group_size - -# In this example: -# DP ranks are contiguous rows of size `dp_group_size` -# PP ranks are non-contiguous columns of size `pp_group_size` -# PP ranks -# | -# v -# DP ranks -> 0 1 2 3 -# 4 5 6 7 -# 8 9 10 11 - -# The driver of each pipeline creates and runs the pipeline -def run_driver(pp_ranks): - # Code to create the pipe object - # ... - - # Create a PipelineDriver using the pipeline group size and the pipeline - # ranks given to this driver, e.g. [0, 4, 8] for driver 0 - pipe_driver = PipelineDriverFillDrain(pipe, chunks, args_chunk_spec, - kwargs_chunk_spec, output_chunk_spec, - pp_group_size, pp_ranks) - - # Create DDP groups for same pipeline stages, across pipelines - # `dp_group_size` specifies the number of pipelines expected to collectively - # make this call - pipe_driver.init_data_parallel(dp_group_size) - - # Run training combining PiPPy and DDP - out = pipe_driver(input, target) - - -# Initialize the default distributed process group (involving all ranks) -# This is needed for DDP collectives -torch.distributed.init_process_group(backend=backend, rank=rank, - world_size=world_size) - -# Initialize RPC (involving all ranks) -# This is needed by PiPPy -rpc.init_rpc(f'worker{rank}', rank=rank, world_size=world_size) - -# Assuming each driver process is the first rank in its respective pipeline, -# then the driver processes are the first `dp_group_size` ranks of the world -if rank < dp_group_size: - # The list of ranks belonging to the pipeline of this driver - pp_ranks = [i * dp_group_size + rank for i in range(pp_group_size)] - run_driver(pp_ranks) - -rpc.shutdown() -``` - ## Advanced: Pipeline Schedules Pipeline parallel training of deep neural networks is _bidirectional_ since training requires running both forward- and back-propagation of the network. As a result, multiple items of work may be ready to run on a pipeline stage at a given time. The problem of selecting between these work items is known as _scheduling_, and a specific policy for selecting work-items is known as a _pipeline schedule_. PiPPy provides both off-the-shelf pipeline schedules as described in the research literature as well as a programmable interface for creating new schedules. The schedules include: -* Fill-Drain. Fill-drain is a schedule that executes all forward microbatches before executing any backward microbatches. This is the "standard" schedule used in GPipe (Huang, 2018). Fill-drain scheduling can be used in PiPPy via the `PipelineDriverFillDrain` driver class. A diagram illustrating the fill-drain schedule is below. - -GPipe Schedule -(Diagram from Huang, 2018) - -* 1F1B (one forward, one backward) is a schedule that provides good hardware utilization as well as limits the amount of memory needed on a stage. At steady-state, a pipeline stage will alternate between processing forward and backward micro-batches. 1F1B was introduced in its asynchronous form in (Harlap, 2018) and in its synchronous form in (Narayanan, 2021). 1F1B scheduling can be used in PiPPy via the `PipelineDriver1F1B` driver class. A diagram illustrating the 1F1B schedule is below. - -Synchronous 1F1B Schedule -(Diagram from Narayanan, 2021) +* Fill-Drain. Fill-drain is a schedule that executes all forward microbatches before executing any backward microbatches. This is the "standard" schedule used in GPipe (Huang, 2018). -* Interleaved 1F1B. Interleaved 1F1B is a variant of 1F1B that divides the program into smaller chunks and assigns multiple chunks per stage in a wrap-around fashion. Interleaving improves pipeline throughput with similar memory characteristics to 1F1B. Interleaved 1F1B was introduced by (Narayanan, 2021). Interleaved 1F1B can be using in PiPPy via the `PipelineDriverInterleaved1F1B` driver class. +* 1F1B (one forward, one backward) is a schedule that provides good hardware utilization as well as limits the amount of memory needed on a stage. At steady-state, a pipeline stage will alternate between processing forward and backward micro-batches. 1F1B was introduced in its asynchronous form in (Harlap, 2018) and in its synchronous form in (Narayanan, 2021). -Interleaved 1F1B Schedule -(Diagram from Narayanan, 2021) +* Interleaved 1F1B. Interleaved 1F1B is a variant of 1F1B that divides the program into smaller chunks and assigns multiple chunks per stage in a wrap-around fashion. Interleaving improves pipeline throughput with similar memory characteristics to 1F1B. Interleaved 1F1B was introduced by (Narayanan, 2021). # Future Work @@ -518,7 +241,7 @@ If you use PiPPy in your publication, please cite it by using the following BibT ```bibtex @Misc{pippy2022, - author = {James Reed, Pavel Belevich, Ke Wen}, + author = {James Reed, Pavel Belevich, Ke Wen, Howard Huang, Will Constable}, title = {PiPPy: Pipeline Parallelism for PyTorch}, howpublished = {\url{https://github.com/pytorch/PiPPy}}, year = {2022} diff --git a/example.py b/examples/basic/example.py similarity index 100% rename from example.py rename to examples/basic/example.py diff --git a/example_train.py b/examples/basic/example_train.py similarity index 100% rename from example_train.py rename to examples/basic/example_train.py diff --git a/examples/hf/hf_utils.py b/examples/huggingface/hf_utils.py similarity index 100% rename from examples/hf/hf_utils.py rename to examples/huggingface/hf_utils.py diff --git a/examples/hf/pippy_albert.py b/examples/huggingface/pippy_albert.py similarity index 100% rename from examples/hf/pippy_albert.py rename to examples/huggingface/pippy_albert.py diff --git a/examples/hf/pippy_bart.py b/examples/huggingface/pippy_bart.py similarity index 100% rename from examples/hf/pippy_bart.py rename to examples/huggingface/pippy_bart.py diff --git a/examples/hf/pippy_bert.py b/examples/huggingface/pippy_bert.py similarity index 100% rename from examples/hf/pippy_bert.py rename to examples/huggingface/pippy_bert.py diff --git a/examples/hf/pippy_blenderbot.py b/examples/huggingface/pippy_blenderbot.py similarity index 100% rename from examples/hf/pippy_blenderbot.py rename to examples/huggingface/pippy_blenderbot.py diff --git a/examples/hf/pippy_camemBert.py b/examples/huggingface/pippy_camemBert.py similarity index 100% rename from examples/hf/pippy_camemBert.py rename to examples/huggingface/pippy_camemBert.py diff --git a/examples/hf/pippy_convBert.py b/examples/huggingface/pippy_convBert.py similarity index 100% rename from examples/hf/pippy_convBert.py rename to examples/huggingface/pippy_convBert.py diff --git a/examples/hf/pippy_deberta.py b/examples/huggingface/pippy_deberta.py similarity index 100% rename from examples/hf/pippy_deberta.py rename to examples/huggingface/pippy_deberta.py diff --git a/examples/hf/pippy_debertaV2.py b/examples/huggingface/pippy_debertaV2.py similarity index 100% rename from examples/hf/pippy_debertaV2.py rename to examples/huggingface/pippy_debertaV2.py diff --git a/examples/hf/pippy_distilBert.py b/examples/huggingface/pippy_distilBert.py similarity index 100% rename from examples/hf/pippy_distilBert.py rename to examples/huggingface/pippy_distilBert.py diff --git a/examples/hf/pippy_electra.py b/examples/huggingface/pippy_electra.py similarity index 100% rename from examples/hf/pippy_electra.py rename to examples/huggingface/pippy_electra.py diff --git a/examples/hf/pippy_fnet.py b/examples/huggingface/pippy_fnet.py similarity index 100% rename from examples/hf/pippy_fnet.py rename to examples/huggingface/pippy_fnet.py diff --git a/examples/hf/pippy_gpt2.py b/examples/huggingface/pippy_gpt2.py similarity index 100% rename from examples/hf/pippy_gpt2.py rename to examples/huggingface/pippy_gpt2.py diff --git a/examples/hf/pippy_gptNeo.py b/examples/huggingface/pippy_gptNeo.py similarity index 100% rename from examples/hf/pippy_gptNeo.py rename to examples/huggingface/pippy_gptNeo.py diff --git a/examples/hf/pippy_layoutLM.py b/examples/huggingface/pippy_layoutLM.py similarity index 100% rename from examples/hf/pippy_layoutLM.py rename to examples/huggingface/pippy_layoutLM.py diff --git a/examples/hf/pippy_m2m100.py b/examples/huggingface/pippy_m2m100.py similarity index 100% rename from examples/hf/pippy_m2m100.py rename to examples/huggingface/pippy_m2m100.py diff --git a/examples/hf/pippy_mbart.py b/examples/huggingface/pippy_mbart.py similarity index 100% rename from examples/hf/pippy_mbart.py rename to examples/huggingface/pippy_mbart.py diff --git a/examples/hf/pippy_megatronBert.py b/examples/huggingface/pippy_megatronBert.py similarity index 100% rename from examples/hf/pippy_megatronBert.py rename to examples/huggingface/pippy_megatronBert.py diff --git a/examples/hf/pippy_mobileBert.py b/examples/huggingface/pippy_mobileBert.py similarity index 100% rename from examples/hf/pippy_mobileBert.py rename to examples/huggingface/pippy_mobileBert.py diff --git a/examples/hf/pippy_mt5.py b/examples/huggingface/pippy_mt5.py similarity index 100% rename from examples/hf/pippy_mt5.py rename to examples/huggingface/pippy_mt5.py diff --git a/examples/hf/pippy_opt.py b/examples/huggingface/pippy_opt.py similarity index 100% rename from examples/hf/pippy_opt.py rename to examples/huggingface/pippy_opt.py diff --git a/examples/hf/pippy_pegasus.py b/examples/huggingface/pippy_pegasus.py similarity index 100% rename from examples/hf/pippy_pegasus.py rename to examples/huggingface/pippy_pegasus.py diff --git a/examples/hf/pippy_plBart.py b/examples/huggingface/pippy_plBart.py similarity index 100% rename from examples/hf/pippy_plBart.py rename to examples/huggingface/pippy_plBart.py diff --git a/examples/hf/pippy_t5.py b/examples/huggingface/pippy_t5.py similarity index 100% rename from examples/hf/pippy_t5.py rename to examples/huggingface/pippy_t5.py diff --git a/examples/hf/pippy_trOCR.py b/examples/huggingface/pippy_trOCR.py similarity index 100% rename from examples/hf/pippy_trOCR.py rename to examples/huggingface/pippy_trOCR.py diff --git a/examples/hf/pippy_xlnet.py b/examples/huggingface/pippy_xlnet.py similarity index 100% rename from examples/hf/pippy_xlnet.py rename to examples/huggingface/pippy_xlnet.py diff --git a/examples/selective2d/2d_train.py b/examples/selective2d/2d_train.py deleted file mode 100644 index c20af5938..000000000 --- a/examples/selective2d/2d_train.py +++ /dev/null @@ -1,497 +0,0 @@ -""" -This training script updates NanoGPT to run with either TP, PP, or TP+PP (2D). -Usage: -gpurun4 torchrun --nproc-per-node 4 2d_train.py -""" - -import argparse -import os -import time - -import torch -import torch.distributed as dist - -from model import GPT, GPTConfig -from pippy.compile import compile_stage - -from pippy.IR import annotate_split_points, PipeSplitWrapper -from pippy.microbatch import sum_reducer, TensorChunkSpec - -from torch.distributed._tensor import DeviceMesh -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PairwiseParallel, - parallelize_module, - RowwiseParallel, -) - -from torch.profiler import profile, ProfilerActivity - - -def get_args(): - # default config values designed to train a gpt2 (124M) on OpenWebText - - def str_to_bool(v): - if isinstance(v, bool): - return v - if v.lower() in ("true", "t", "1"): - return True - elif v.lower() in ("false", "f", "0"): - return False - else: - raise ArgumentTypeError("Boolean expected.") - - # I/O - parser = argparse.ArgumentParser() - parser.add_argument("--out_dir", type=str, default="out") - parser.add_argument("--eval_interval", type=int, default=2000) - parser.add_argument("--log_interval", type=int, default=2) - parser.add_argument("--eval_iters", type=int, default=200) - parser.add_argument( - "--eval_only", type=str_to_bool, default=False - ) # if True, script exits right after the first eval - parser.add_argument( - "--always_save_checkpoint", type=str_to_bool, default=True - ) # if True, always save a checkpoint after each eval - parser.add_argument( - "--init_from", type=str, default="scratch" - ) # 'scratch', 'resume', 'gpt2*' - parser.add_argument("--train_iters", type=int, default=200000) - parser.add_argument("--seed", type=int, default=1337) - - # data - parser.add_argument( - "--dataset", type=str, default="shakespeare_char" - ) # "openwebtext" - parser.add_argument( - "--gradient_accumulation_steps", type=int, default=1 - ) # used to simulate larger batch sizes - parser.add_argument( - "--batch_size", type=int, default=12 - ) # if gradient_accumulation_steps > 1, this is the micro-batch size - parser.add_argument("--block_size", type=int, default=1024) - parser.add_argument("--vocab_size", type=int, default=50304) - - # model - parser.add_argument("--n_layer", type=int, default=12) - parser.add_argument("--n_head", type=int, default=12) - parser.add_argument("--n_embd", type=int, default=768) - parser.add_argument( - "--dropout", type=float, default=0.0 - ) # for pretraining 0 is good, for finetuning try 0.1+ - parser.add_argument("--bias", type=str_to_bool, default=False) - - # adamw optimizer - parser.add_argument( - "--learning_rate", type=float, default=4e-4 - ) # max learning rate - parser.add_argument( - "--max_iters", type=int, default=600000 - ) # total number of training iterations - parser.add_argument("--weight_decay", type=float, default=1e-2) - parser.add_argument("--beta1", type=float, default=0.9) - parser.add_argument("--beta2", type=float, default=0.95) - parser.add_argument( - "--grad_clip", type=float, default=1.0 - ) # clip gradients at this value, or disable if == 0.0 - parser.add_argument( - "--decay_lr", type=str_to_bool, default=True - ) # whether to decay the learning rate - parser.add_argument("--warmup_iters", type=int, default=2000) - parser.add_argument("--lr_decay_iters", type=int, default=600000) - parser.add_argument( - "--min_lr", type=float, default=6e-5 - ) # minimum learning rate - - # distributed - parser.add_argument( - "--backend", type=str, default="nccl" - ) # 'nccl', 'gloo', etc. - parser.add_argument( - "--compile", type=str_to_bool, default=False - ) # use PyTorch 2.0 to compile the model to be faster - parser.add_argument("--rank", type=int, default=int(os.environ["RANK"])) - parser.add_argument( - "--local_rank", type=int, default=int(os.environ["LOCAL_RANK"]) - ) - parser.add_argument( - "--world_size", type=int, default=int(os.environ["WORLD_SIZE"]) - ) - parser.add_argument( - "--device", type=str, default=f"cuda:{os.environ['LOCAL_RANK']}" - ) - parser.add_argument( - "--master_process", - type=str_to_bool, - default=bool(os.environ["RANK"] == 0), - ) - parser.add_argument("--tp_size", type=int, default=2) - parser.add_argument("--pp_size", type=int, default=2) - parser.add_argument("--i_stage", type=int, default=1) - parser.add_argument("--n_chunks", type=int, default=2) - - parser.add_argument("--debug", dest="debug", action="store_true") - - args = parser.parse_args() - - return args - - -def rank_print(x): - _rank = os.getenv("RANK") - if _rank == "0": - print(x) - - -def get_rand(args): - x = torch.randint( - 0, - args.vocab_size, - (args.batch_size, args.block_size), - device=args.device, - ) - y = torch.randint( - 0, - args.vocab_size, - (args.batch_size, args.block_size), - device=args.device, - ) - return x, y - - -def tp_attention(model, name, mesh, tp_dim=0, q="q", k="k", v="v", o="c_proj"): - layer = model.get_submodule(name) - parallelize_module( - layer, - mesh, - { - q: ColwiseParallel(), - k: ColwiseParallel(), - v: ColwiseParallel(), - o: RowwiseParallel(), - }, - tp_mesh_dim=tp_dim, - ) - - return model - - -def tp_mlp(model, name, mesh, tp_dim=0, mlp="mlp"): - layer = model.get_submodule(name) - parallelize_module( - layer, mesh, {mlp: PairwiseParallel()}, tp_mesh_dim=tp_dim - ) - - return model - - -def tp(model, n_layer, mesh, offset=0, tp_dim=0): - for i in range(n_layer): - block = model.get_submodule(f"transformer.h.{i + offset}") - parallelize_module( - block, - mesh, - { - "attn.q": ColwiseParallel(), - "attn.k": ColwiseParallel(), - "attn.v": ColwiseParallel(), - "attn.c_proj": RowwiseParallel(), - "mlp": PairwiseParallel(), - }, - tp_mesh_dim=tp_dim, - ) - - return model - - -def pp(model, pp_device_mesh, args): - pp_chunks = args.world_size - pp_groups = pp_device_mesh.get_dim_groups()[0] - - output_chunk_spec = (TensorChunkSpec(0), sum_reducer) - stage = compile_stage( - model, - args.rank, - args.world_size, - pp_chunks, - pp_device_mesh, - pp_groups, - example_inputs=[X, Y], - output_chunk_spec=output_chunk_spec, - ) - - print(f"[Rank{_rank}] {stage.submod.print_readable()}") - return model, stage - - -def pp_and_tp(model, mesh, args): - """ - Apply TP and PP to all layers in a model - This function assumes the model is already cut manually - """ - pp_dim, tp_dim = 0, 1 - pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size - pp_groups = mesh.get_dim_groups()[pp_dim] - - # TP - tp(model, args.n_layer, mesh, 0, tp_dim) - - X, Y = get_rand(args) - - # PP - output_chunk_spec = (TensorChunkSpec(0), sum_reducer) - stage = compile_stage( - model, - pp_rank, - args.world_size, - args.n_chunks, - args.device, - pp_groups, - example_inputs=[X, Y], - output_chunk_spec=output_chunk_spec, - ) - - return model, stage - - -def even_cut(model, args, pp_size): - """ - Evenly cut a model into pp_size stages - """ - cut = {} - cutpoint = args.n_layer // pp_size - for i in range(args.n_layer): - name = f"transformer.h.{i}" - if i > 0 and i % cutpoint == 0: - cut[name] = PipeSplitWrapper.SplitPoint.BEGINNING # or END - - annotate_split_points(model, cut) - - -def after_ar_cut(model, args, pp_size): - """ - Cut a model right after AllReduce happens - """ - cut = {} - cutpoint = args.n_layer // pp_size - for i in range(args.n_layer): - name = f"transformer.h.{i}" - if i != args.n_layer - 1 and i % cutpoint == cutpoint - 1: - cut[f"{name}.mlp.dropout"] = PipeSplitWrapper.SplitPoint.BEGINNING - - annotate_split_points(model, cut) - - -def pp_and_tp_selective( - model, mesh, args, tp_attn_layers=None, tp_mlp_layers=None, cut_fn=even_cut -): - """ - Apply pipeline parallelism and tensor parallelism to a model. - """ - - pp_dim, tp_dim = 0, 1 - pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size - pp_groups = mesh.get_dim_groups()[pp_dim] - - # TP - # Apply TP to layers if layer_id is in tp_attn / tp_mlp - tp_attn_layers = ( - list(range(args.n_layer)) if tp_attn_layers is None else tp_attn_layers - ) - tp_mlp_layers = ( - list(range(args.n_layer)) if tp_mlp_layers is None else tp_mlp_layers - ) - for i in range(args.n_layer): - name = f"transformer.h.{i}" - att = tp_attention(model, f"{name}.attn", mesh, tp_dim) - mlp = tp_mlp(model, f"{name}", mesh, tp_dim) - - X, Y = get_rand(args) - - # PP - cut_fn(model, args, args.pp_size) - output_chunk_spec = (TensorChunkSpec(0), sum_reducer) - stage = compile_stage( - model, - pp_rank, - args.pp_size, - args.n_chunks, - args.device, - pp_groups, - example_inputs=[X, Y], - output_chunk_spec=output_chunk_spec, - ) - - return model, stage - - -def pp_tp_train(stage, mesh, args): - pp_dim, tp_dim = 0, 1 - pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size - pp_groups = mesh.get_dim_groups()[pp_dim] - - train_iters = 10 if args.debug else args.train_iters - optimizer = torch.optim.AdamW( - stage.submod.parameters(), lr=args.learning_rate - ) - local_iter_num = 0 - iter_time = 0.0 - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=torch.profiler.schedule( - skip_first=5, wait=0, warmup=4, active=1, repeat=1 - ), - ) as prof: - while local_iter_num < train_iters: - optimizer.zero_grad() - t0 = time.perf_counter() - X, Y = get_rand(args) - if pp_rank == 0: - out = stage(X) - elif pp_rank == args.pp_size - 1: - out = stage(Y) - else: - out = stage() - optimizer.step() - t1 = time.perf_counter() - dt = t1 - t0 - local_iter_num += 1 - iter_time += dt - prof.step() - - prof.export_chrome_trace(f"trace_rank{args.rank}.json") - - return local_iter_num, iter_time - - -def pp_train(stage, args): - train_iters = 10 if args.debug else args.train_iters - optimizer = torch.optim.AdamW( - stage.submod.parameters(), lr=args.learning_rate - ) - local_iter_num = 0 - iter_time = 0.0 - while local_iter_num < train_iters: - optimizer.zero_grad() - t0 = time.perf_counter() - X, Y = get_rand(args) - if args.rank == 0: - out = stage(X) - elif args.rank == args.world_size - 1: - out = stage(Y) - else: - out = stage() - optimizer.step() - t1 = time.perf_counter() - dt = t1 - t0 - local_iter_num += 1 - iter_time += dt - - return local_iter_num, iter_time - - -def tp_train(): - local_iter_num = 0 - iter_time = 0.0 - optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) - while local_iter_num < train_iters: - optimizer.zero_grad(set_to_none=True) - t0 = time.perf_counter() - X, Y = get_rand(args) - logits, loss = model(X, Y) - loss.backward() - optimizer.step() - torch.distributed.barrier() - t1 = time.perf_counter() - dt = t1 - t0 - lossf = loss.item() - rank_print( - f"iter {local_iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms" - ) - local_iter_num += 1 - iter_time += dt - - return local_iter_num, iter_time - - -if __name__ == "__main__": - _multi_gpu = int(os.environ.get("RANK", -1)) != -1 # verify distributed run - assert ( - _multi_gpu - ), "this config assumes distributed setup - multi-gpu not ready here." - - args = get_args() - - device_type = ( - "cuda" if "cuda" in args.device else "cpu" - ) # for later use in torch.autocast - torch.cuda.set_device(args.device) - - dist.init_process_group( - backend=args.backend, rank=args.rank, world_size=args.world_size - ) - - if args.master_process: - os.makedirs(args.out_dir, exist_ok=True) - - torch.manual_seed(args.seed) - torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul - torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - torch.backends.cuda.enable_mem_efficient_sdp(enabled=False) - - # init these up here, can override if init_from='resume' (i.e. from a checkpoint) - iter_num = 0 - best_val_loss = 1e9 - - # model init - model_args = dict( - n_layer=args.n_layer, - n_head=args.n_head, - n_embd=args.n_embd, - block_size=args.block_size, - bias=args.bias, - vocab_size=None, - dropout=args.dropout, - ) # start with model_args from command line - - # init a new model from scratch - rank_print("Initializing a new model from scratch") - - oned_mesh = DeviceMesh(device_type, list(range(args.world_size))) - twod_mesh = DeviceMesh( - device_type=device_type, - mesh=torch.arange(0, args.world_size).view(-1, args.tp_size), - ) - - model_args["vocab_size"] = args.vocab_size - - gptconf = GPTConfig(**model_args) - model = GPT(twod_mesh, gptconf, args.device, args.pp_size) - model.to(args.device) - - _current_model_params = model.get_num_params() / 1e6 - - # model = tp(model, args.n_layer, oned_mesh) - # model, stage = pp(model, oned_mesh, args) - # model, stage = pp_and_tp(model, twod_mesh, args) - model, stage = pp_and_tp_selective( - model, twod_mesh, args, cut_fn=after_ar_cut - ) - - # iter_count, iter_time = pp_train(stage, args) - iter_count, iter_time = pp_tp_train(stage, twod_mesh, args) - - # display run stats - rank_print(f"\nTraining completed.\n") - - gpu_type = torch.cuda.get_device_name(0) - gpu_count = dist.get_world_size() - rank_print(f"\n----- Performance Stats --------\n") - rank_print(f"\nModel Size: {_current_model_params:.2f}M") - rank_print(f"Run completed with {gpu_count} gpus, of type {gpu_type}") - iter_avg = round(iter_time / iter_count, 4) - rank_print( - f"Avg iter speed (in seconds): {iter_avg}, with {iter_count} iterations averaged.\n" - ) - - dist.destroy_process_group() diff --git a/examples/selective2d/dim_solver b/examples/selective2d/dim_solver deleted file mode 160000 index 342e002e5..000000000 --- a/examples/selective2d/dim_solver +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 342e002e5e9c6d0e593ed0738e7dd7bb8da51a6e diff --git a/examples/selective2d/model.py b/examples/selective2d/model.py deleted file mode 100644 index e202c45b5..000000000 --- a/examples/selective2d/model.py +++ /dev/null @@ -1,540 +0,0 @@ -# Original code from https://github.com/karpathy/nanoGPT -""" -Full definition of a GPT Language Model, all of it in this single file. -References: -1) the official GPT-2 TensorFlow implementation released by OpenAI: -https://github.com/openai/gpt-2/blob/master/src/model.py -2) huggingface/transformers PyTorch implementation: -https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py -""" -import inspect - -import math -import os -from dataclasses import dataclass - -import torch -import torch.nn as nn -from torch.nn import functional as F - - -# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) -def new_gelu(x): - """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). - Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 - """ - return ( - 0.5 - * x - * ( - 1.0 - + torch.tanh( - math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) - ) - ) - ) - - -class LayerNorm(nn.Module): - """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" - - def __init__(self, mesh, ndim, bias): - super().__init__() - self.weight = nn.Parameter(torch.ones(ndim)) - self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None - self.mesh = mesh - - def forward(self, input): - return F.layer_norm( - input, self.weight.shape, self.weight, self.bias, 1e-5 - ) - - -class CausalSelfAttention(nn.Module): - def __init__(self, mesh, config): - super().__init__() - tp_size = mesh.mesh.size(1) - assert config.n_head % tp_size == 0 - assert config.n_embd % config.n_head == 0 - self.mesh = mesh - self.tp_size = tp_size - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear( - config.n_embd, 3 * config.n_embd, bias=config.bias - ) - self.q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - self.k = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - self.v = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - # regularization - self.attn_dropout = nn.Dropout(config.dropout) - self.resid_dropout = nn.Dropout(config.dropout) - self.n_head = config.n_head - self.n_embd = config.n_embd - self.dropout = config.dropout - # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary - self.flash = ( - hasattr(torch.nn.functional, "scaled_dot_product_attention") - and self.dropout == 0.0 - ) - - if not self.flash: - print( - "WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0" - ) - # causal mask to ensure that attention is only applied to the left in the input sequence - self.block_size = config.block_size - self.register_buffer( - "bias", - torch.tril( - torch.ones(config.block_size, config.block_size) - ).view(1, 1, config.block_size, config.block_size), - ) - - def forward(self, x): - ( - B, - T, - C, - ) = ( - x.size() - ) # batch size, sequence length, embedding dimensionality (n_embd) - - def print0(msg): - if os.getenv("RANK") == "0": - print(msg) - - channel_head_size = C // self.n_head - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q = ( - self.q(x) - .split(self.n_embd // self.tp_size, dim=2)[0] - .view(B, T, self.n_head // self.tp_size, C // self.n_head) - .transpose(1, 2) - ) # (B, nh, T, hs) - k = ( - self.k(x) - .split(self.n_embd // self.tp_size, dim=2)[0] - .view(B, T, self.n_head // self.tp_size, C // self.n_head) - .transpose(1, 2) - ) # (B, nh, T, hs) - v = ( - self.v(x) - .split(self.n_embd // self.tp_size, dim=2)[0] - .view(B, T, self.n_head // self.tp_size, C // self.n_head) - .transpose(1, 2) - ) # (B, nh, T, hs) - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - if self.flash: - # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True - ) - else: - # manual implementation of attention - from torch.distributed._tensor import ( - DeviceMesh, - distribute_tensor, - Replicate, - Shard, - ) - - mesh = DeviceMesh("cuda", list(range(2))) - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = ( - y.transpose(1, 2).contiguous().view(B, T, C // self.tp_size) - ) # re-assemble all head outputs side by side - - # output projection - y = self.resid_dropout(self.c_proj(y)) - return y - - -class MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.c_fc = nn.Linear( - config.n_embd, 4 * config.n_embd, bias=config.bias - ) - self.gelu = nn.GELU() - self.c_proj = nn.Linear( - 4 * config.n_embd, config.n_embd, bias=config.bias - ) - self.dropout = nn.Dropout(config.dropout) - - def forward(self, x): - x = self.c_fc(x) - x = self.gelu(x) - x = self.c_proj(x) - x = self.dropout(x) - return x - - -class Block(nn.Module): - def __init__(self, mesh, config): - super().__init__() - self.ln_1 = LayerNorm(mesh, config.n_embd, bias=config.bias) - self.attn = CausalSelfAttention(mesh, config) - self.ln_2 = LayerNorm(mesh, config.n_embd, bias=config.bias) - self.mlp = MLP(config) - self.mesh = mesh - - def forward(self, x): - x = x + self.attn(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -@dataclass -class GPTConfig: - block_size: int = 1024 - vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: int = 12 - n_head: int = 12 - n_embd: int = 768 - dropout: float = 0.0 - bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - - -class GPT(nn.Module): - def __init__(self, mesh, config, device, pp_size=2): - super().__init__() - assert config.vocab_size is not None - assert config.block_size is not None - self.config = config - self.mesh = mesh - self.pp_size = pp_size - self.device = device - - self.transformer = nn.ModuleDict( - dict( - wte=nn.Embedding(config.vocab_size, config.n_embd), - wpe=nn.Embedding(config.block_size, config.n_embd), - drop=nn.Dropout(config.dropout), - h=nn.ModuleList( - [Block(mesh, config) for _ in range(config.n_layer)] - ), - ln_f=LayerNorm(mesh, config.n_embd, bias=config.bias), - ) - ) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - # with weight tying when using torch.compile() some warnings get generated: - # "UserWarning: functional_call was passed multiple values for tied weights. - # This behavior is deprecated and will be an error in future versions" - # not 100% sure what this is, so far seems to be harmless. TODO investigate - self.transformer.wte.weight = ( - self.lm_head.weight - ) # https://paperswithcode.com/method/weight-tying - - # init all weights - self.apply(self._init_weights) - # apply special scaled init to the residual projections, per GPT-2 paper - for pn, p in self.named_parameters(): - if pn.endswith("c_proj.weight"): - torch.nn.init.normal_( - p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) - ) - - # report number of parameters - print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) - - def get_num_params(self, non_embedding=True): - """ - Return the number of parameters in the model. - For non-embedding count (default), the position embeddings get subtracted. - The token embeddings would too, except due to the parameter sharing these - params are actually used as weights in the final layer, so we include them. - """ - n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: - n_params -= self.transformer.wpe.weight.numel() - return n_params - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def forward(self, idx, targets=None): - # device = idx.device - # b, t = idx.size() - # assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - - # WARNING: t needs to actual sequence length, shape should be (1,t) - pos = torch.arange( - 0, self.config.block_size, dtype=torch.long, device=self.device - ).unsqueeze(0) - - # forward the GPT model itself - tok_emb = self.transformer.wte( - idx - ) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe( - pos - ) # position embeddings of shape (1, t, n_embd) - x = self.transformer.drop(tok_emb + pos_emb) - for block in self.transformer.h: - x = block(x) - x = self.transformer.ln_f(x) - - if targets is not None: - # if we are given some desired targets also calculate the loss - logits = self.lm_head(x) - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), - targets.view(-1), - ignore_index=-1, - ) - else: - # inference-time mini-optimization: only forward the lm_head on the very last position - logits = self.lm_head( - x[:, [-1], :] - ) # note: using list [-1] to preserve the time dim - loss = None - - return logits, loss - - def crop_block_size(self, block_size): - # model surgery to decrease the block size if necessary - # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) - # but want to use a smaller block size for some smaller, simpler model - assert block_size <= self.config.block_size - self.config.block_size = block_size - self.transformer.wpe.weight = nn.Parameter( - self.transformer.wpe.weight[:block_size] - ) - for block in self.transformer.h: - block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] - - @classmethod - def from_pretrained(cls, model_type, override_args=None): - assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} - override_args = override_args or {} # default to empty dict - # only dropout can be overridden see more notes below - assert all(k == "dropout" for k in override_args) - from transformers import GPT2LMHeadModel - - print("loading weights from pretrained gpt: %s" % model_type) - - # n_layer, n_head and n_embd are determined from model_type - config_args = { - "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params - "gpt2-medium": dict( - n_layer=24, n_head=16, n_embd=1024 - ), # 350M params - "gpt2-large": dict( - n_layer=36, n_head=20, n_embd=1280 - ), # 774M params - "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params - }[model_type] - print("forcing vocab_size=50257, block_size=1024, bias=True") - config_args[ - "vocab_size" - ] = 50257 # always 50257 for GPT model checkpoints - config_args[ - "block_size" - ] = 1024 # always 1024 for GPT model checkpoints - config_args["bias"] = True # always True for GPT model checkpoints - # we can override the dropout rate, if desired - if "dropout" in override_args: - print(f"overriding dropout rate to {override_args['dropout']}") - config_args["dropout"] = override_args["dropout"] - # create a from-scratch initialized minGPT model - config = GPTConfig(**config_args) - model = GPT(config) - sd = model.state_dict() - sd_keys = sd.keys() - sd_keys = [ - k for k in sd_keys if not k.endswith(".attn.bias") - ] # discard this mask / buffer, not a param - - # init a huggingface/transformers model - model_hf = GPT2LMHeadModel.from_pretrained(model_type) - sd_hf = model_hf.state_dict() - - # copy while ensuring all of the parameters are aligned and match in names and shapes - sd_keys_hf = sd_hf.keys() - sd_keys_hf = [ - k for k in sd_keys_hf if not k.endswith(".attn.masked_bias") - ] # ignore these, just a buffer - sd_keys_hf = [ - k for k in sd_keys_hf if not k.endswith(".attn.bias") - ] # same, just the mask (buffer) - transposed = [ - "attn.c_attn.weight", - "attn.c_proj.weight", - "mlp.c_fc.weight", - "mlp.c_proj.weight", - ] - # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear - # this means that we have to transpose these weights when we import them - assert len(sd_keys_hf) == len( - sd_keys - ), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" - for k in sd_keys_hf: - if any(k.endswith(w) for w in transposed): - # special treatment for the Conv1D weights we need to transpose - assert sd_hf[k].shape[::-1] == sd[k].shape - with torch.no_grad(): - sd[k].copy_(sd_hf[k].t()) - else: - # vanilla copy over the other parameters - assert sd_hf[k].shape == sd[k].shape - with torch.no_grad(): - sd[k].copy_(sd_hf[k]) - - return model - - def configure_optimizers( - self, weight_decay, learning_rate, betas, device_type - ): - """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. - """ - - # separate out all parameters to those that will and won't experience regularizing weight decay - decay = set() - no_decay = set() - whitelist_weight_modules = (torch.nn.Linear,) - blacklist_weight_modules = ( - torch.nn.LayerNorm, - LayerNorm, - torch.nn.Embedding, - ) - for mn, m in self.named_modules(): - for pn, p in m.named_parameters(): - fpn = "%s.%s" % (mn, pn) if mn else pn # full param name - # random note: because named_modules and named_parameters are recursive - # we will see the same tensors p many many times. but doing it this way - # allows us to know which parent module any tensor p belongs to... - if pn.endswith("bias"): - # all biases will not be decayed - no_decay.add(fpn) - elif pn.endswith("weight") and isinstance( - m, whitelist_weight_modules - ): - # weights of whitelist modules will be weight decayed - decay.add(fpn) - elif pn.endswith("weight") and isinstance( - m, blacklist_weight_modules - ): - # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) - - # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they - # will appear in the no_decay and decay sets respectively after the above. - # In addition, because named_parameters() doesn't return duplicates, it - # will only return the first occurence, key'd by 'transformer.wte.weight', below. - # so let's manually remove 'lm_head.weight' from decay set. This will include - # this tensor into optimization via transformer.wte.weight only, and not decayed. - decay.remove("lm_head.weight") - - # validate that we considered every parameter - param_dict = {pn: p for pn, p in self.named_parameters()} - inter_params = decay & no_decay - union_params = decay | no_decay - assert ( - len(inter_params) == 0 - ), "parameters %s made it into both decay/no_decay sets!" % ( - str(inter_params), - ) - assert len(param_dict.keys() - union_params) == 0, ( - "parameters %s were not separated into either decay/no_decay set!" - % (str(param_dict.keys() - union_params),) - ) - - # create the pytorch optimizer object - optim_groups = [ - { - "params": [param_dict[pn] for pn in sorted(list(decay))], - "weight_decay": weight_decay, - }, - { - "params": [param_dict[pn] for pn in sorted(list(no_decay))], - "weight_decay": 0.0, - }, - ] - # new PyTorch nightly has a new 'fused' option for AdamW that is much faster - use_fused = (device_type == "cuda") and ( - "fused" in inspect.signature(torch.optim.AdamW).parameters - ) - use_fused = False # YEONJU - print(f"using fused AdamW: {use_fused}") - extra_args = dict(fused=True) if use_fused else dict() - optimizer = torch.optim.AdamW( - optim_groups, lr=learning_rate, betas=betas, **extra_args - ) - - return optimizer - - def estimate_mfu(self, num_params, fwdbwd_per_iter, dt): - """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" - # first estimate the number of flops we do per iteration. - # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 - # N = self.get_num_params() - N = num_params - cfg = self.config - tp_size = 2 - actual_head = cfg.n_head // tp_size - # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size - L, H, Q, T = ( - cfg.n_layer, - actual_head, - cfg.n_embd // actual_head, - cfg.block_size, - ) - flops_per_token = 6 * N + 12 * L * H * Q * T - flops_per_fwdbwd = flops_per_token * T - flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter - # express our flops throughput as ratio of A100 bfloat16 peak flops - flops_achieved = flops_per_iter * (1.0 / dt) # per second - # flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS - # mfu = flops_achieved / flops_promised - flops_promised = 125e12 # A10 TFlops .... 312e12 A100 GPU bfloat16 peak flops is 312 TFLOPS - mfu = (flops_achieved / flops_promised) / tp_size - return mfu - - @torch.no_grad() - def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): - """ - Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete - the sequence max_new_tokens times, feeding the predictions back into the model each time. - Most likely you'll want to make sure to be in model.eval() mode of operation for this. - """ - for _ in range(max_new_tokens): - # if the sequence context is growing too long we must crop it at block_size - idx_cond = ( - idx - if idx.size(1) <= self.config.block_size - else idx[:, -self.config.block_size :] - ) - # forward the model to get the logits for the index in the sequence - logits, _ = self(idx_cond) - # pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / temperature - # optionally crop the logits to only the top k options - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float("Inf") - # apply softmax to convert logits to (normalized) probabilities - probs = F.softmax(logits, dim=-1) - # sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) - # append sampled index to the running sequence and continue - idx = torch.cat((idx, idx_next), dim=1) - - return idx diff --git a/examples/selective2d/run.sh b/examples/selective2d/run.sh deleted file mode 100644 index 0f713af2f..000000000 --- a/examples/selective2d/run.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -NUM_GPUS=8 - -# 1. run dimension solver -python3 ./dim_solver/main.py --num-gpu $NUM_GPUS --opt-ar1 - -# 2. process output file -# output file name: solver.out -# output data: pp_group_size tp_group_size size_microbatch i_stage n_chunks - -pp_group_size=$(cat ./solver.out | awk '{print $1}') -tp_group_size=$(cat ./solver.out | awk '{print $2}') -size_microbatch=$(cat ./solver.out | awk '{print $3}') -i_stage=$(cat ./solver.out | awk '{print $4}') -n_chunks=$(cat ./solver.out | awk '{print $5}') - -batch_size=$((size_microbatch * n_chunks)) - -# 3. run training with optimal configuration -torchrun --nproc-per-node=$NUM_GPUS 2d_train.py --batch_size $size_microbatch --n_chunks $n_chunks --i_stage $i_stage diff --git a/pippy-ddp.pptx b/pippy-ddp.pptx deleted file mode 100644 index 9904fef319ee0b2cba23a984df19d064b17c29a6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 39551 zcmagFV{~QTw(cEO?4)8}YJ` zWZ-0OYeVN|ZFP~lX|c|L?spC436;OT%lR3NwwF|tIh+$X(}0w-CTT<^C7Lk*%jmj` zL?mW2Nr13NEyc;l>CX#MW_=6q!on)r59+C(2YNDbLh8Na5uLwIP8nU54#a~s!YNHj zg~W(wAMVamMVM?M;)5tFdKN(0q^ImdTqIc>rxOQb7SwfbLm1+$TT5yKO5JWH{kMC( zWjghLN_60;!E1QlVtaONs(x$T7t!i}q@W0pM+@ChiN5Gy$-ZUf9(;=<-Aa^sqX{45 zP(nQ+f@+S~zSG1Vbyg|CfW%JEA-jodcn>Zl-ZOQlUmT23O2&qLH^V73Km}E1T&5I*%T=d*?PZn2@;~TP=@X@rM|WBJKr9=Ko{m(fpsw5tP?bUA%ca;dRKmPClfsWiMKIA zR;=<{Eh#;G3c^7uxPG_>$eD4YXZBkM-`Q&C`P;9L=JZx&a~W=mXF>iO3HcE?fghfD zNI0RcM+UPHJn4X}Ge0q2$cRIIU<(XSU=|P zaDY?LZvyFQYWffbrRx0ipgHE~F|L1pM|?|q^kdGvP2g3k^)mtIm*B>yKV|x^)apaS z2;_*eriCLZ{;m@8={0S8r-B@EMGE7{2CG^N?8`r%y1HYC{aB3VYket)v*%m-eiPe7 z1;c*o0m}C|gRr3Tqvb%$S>APbJWYBdVDrhT=hlagHpU^XuusSdAP zBf3XC!%dxA2`0g-ywkDnT^W;gK_bR=nstKAg6!j^ARp%UDAhV^2t218(t7);XqL&)Bd3^&t-&|l@r{^+erM-??8JP- z50dJta_XwB8+-3T0rD{B=m5IjE=9j|v(*LbTGhavSWq0Wh^M=D z3iAtDED0g$riEX(_h=oXhIbIIa@{%0?xor)`5XI(>>eC!zej%CC(j4bj&Xm}GpB*# z)L9a=Y~EX)c_)d))=AN{08z!~Irg;?bzf2;XEGFJElN+_{j7?5*RH ze`-b=LahxQxivo-V5(8Ba4*S`yu7EmaV%Zgo}d>oA&{FfraT0DzI;AJvuBuLI$HTd zGO+MB!y0LkG*T;Kf%qyyA1Vl0O|wp!=7+(TE88bM<4Ie$W8RgV7h&uWnqe?|Q{~7Y zYt&GaF;@xUiVZX!tGoMCs-I+V!Tcwmf*)9@TWCK>J4CCE-K)Wk)pT{ieDz{|#Php& z9EOs=JRb@S(srI7=X;7}N|K1hfz{+42`Q&us$2cY^DUT49TyD)otidHOiEa6T-qce zw3jY#FK4!}@C7dNj%w-pA+%76YCU4 z3^8|IfB5%22{REaF;&7RF;(GrEO#5^S2}e1OEhFbVy!xMU^V-j`y&4>Ls+y2D>y;=M@2nXT{t3Qlm-S#1 zzLqXwY-JA8JQw}-%&b>QTDuM0vO!~VVnhc~jrfv^zW9rWWqv9Oz(rDPTlCW6@>n-LM4C^Gz3U+e#HQ~{kD#Xh>Cz5IGLIAvox{T-zUX)9?e`AL z^e?zUDwX7aa}SSVnk-*?1@9>7FUJW0?~n~uE35_YunAQ&&Ij*k=|A@lPIYqosJ`*Q z*FBVla{EBE{i5y=z=Zo=2$kmQH($8Vc(W&4n4%ed7{0#66AHR|8X{ZG*NELTT1Hj) z(?cWZT@@cVp8u8C8C?2}6dI7KR!)rKpesRFRbD`8H>xssG}6^M8Ltx}Pq1-aXQOy( z;?nl2jxiZ8MWMwex%HFISAv7pu{p0@+d_gWd@2VCJN--hbA3!i3+^-!Msz*#?X_1} zvAHgL;;wzET6xnGVld%O%5TLyeYe8T2~o6>dDkfp)ww@{)43X4Ra?+_L(_DC3070B z0btqEh|d)?1m3UWOQcvXO&Xa-SFg>A%P)?Rv@C6Pw&=tH89|IRNuT{SFVA(gQ-a?o z)gn|ojUVBCC6DI<#p2Wy=LxDX)(ti?O8Mi{5M~FLn9jfk#kD!CKV&JHl~3qXdL(&P z7SAEv5su@EAMvK_asm6!R3JuINH>+eIWAF_vN^n3pIhkjk?5bcC;kXbUv;53*FP7M zP0=3sK;3n1oB4*kUkcFK+x$Rkifdnq7yWx*j358>f8vUR@^`h_-U1w;9v^+>uRH#> zkdOGb4^+Z&+bgq=a&eAD8g8a(7vrSgBdwksk|XpkA+(!%T=kUCiMDH+&eCZ-ACG=m z<$rt30oZ{Reys9)x>%J(Wg|5B`sRN@*|E-KZGx0|%Fcy)n8rRIlXk9~y8xvgVqrgo ztnayO?ze9uyZ`!x|LG(~$$Z||;6OlRJwZUe|92il=ibCGo*^kZOE))mVX!0=I>xddlSRfGMd>J3}GraMj`G z^svmQL-zaN;h|Uhi{Y#J??Yf)+)piO)|tD92d~Dhl%}=f29@8HxoK2So#Y=NJU}1y zi9P$y*UE`Hx0LO7?ny7fsj&Fzeq{UTbvYPNM%Z>nR=eZKi5O%%l#MRp-Ptr)8v$ik#kew8yoqVM2_ zaF*P71;bzJ}&hv2iUO9ox1{q%qT`R+&ExaYyd&l{r2g!EnVW;=E{)D7$b} z9IQR=K~bYgA@NTONYJ$o9@S2E%hwI3F2^pHz(o!vWyeW|@*_Xk7UHAN3QD7bg1jsQ z6NuhA$=RvV122^-Ha_D)2lT&>oh-_oWhD{U%sA5tNV=tq8Tmq!NsqB8f>%F5J2pG{ z%bxMmO(ebVx={1rmW_VT`Gb9nWBm~}21p-SA^T$K0z-n^uFA3}@GbEaeMaz1Y&lUw zbsGcgvx&vlk@>rTKjQ=0!e+`Q`izc?j&x!d&iJ?#Rw65t4~HW#zpQn{^>^xh2cR8R z{lcPZei>XDY2wVw(K2#1HSL-N50v82WH&A%a=qK@a~`psFE>eZK(qQ4)+j_rzB4!ruH{s-f!k zv6i$-Gf8u`%{WD_OK}3|+BRwf*8t;rIXC%`7MUVuEe*1(cJ4rcyy}1!J;OWDNDz1C zM?s2f4lR@!(##cYgELfvid=3qb`z%rm?X=0aZmonWND(fgBXb)8F8OsGmvC@_PLOM z`+kb?YdP@)pU~Gb3@kDX0Q#v!*>=$l`-mDPD``PRyeA}rKt_g45+^OQbI>*LhBsL2 zE&q!wFfMjMV>`&W2@VG=pC}_dM*0Rh8RVEYBQT(&D0HRYB($pJ&07!T5O|Rr!o|@$ zh<5ik<9fcPhG*F4jUPM90;X!~xcc?u?_a@_W~}sVt{LOikn*8s7v;Yo%#F}6;T0tr z2b8Yjf9S7zP|yyFQfSszetdgiL2=qV)>L5cvuG5#%wY8Tp)E@O$R#7QsJg6k*4FmB6ASv|sU8z|p5VP-iC z0pISRH@6PQH%f8I+~B=NQZ(uD(1Z%c>lXQjM%yg@{kGqIE>{*c+y>{1<=N93UcmJ! z-Y8F=>bjdN+x=&+Yvk)TVVrLKq|vnRx_v)->A`0I zpofLjE64LIKr@R-I!3@Vfihlx2L#TLS|ft)Vc5j_17abBr%D%)jQ2qXY{A0h&f^OZ z)Yd2y3T&*ASR|yu3nht;%)p4_mW^HHa9hEE_%v3+P1~&h$(Nxh)bT*S`0sU{d(C4NSzDIKu#*(iUMa3+h+N zei*}4i-5+<3S!6NraJyTe&C5)(Ao8*HCI7ZuaUFt&n5z(URmDYrPQ5(u}yKwBG=3L~zkM-H7;>t~*%CQmV_P2AtfFIm7pn1DjQ zxC}P6)5mVV;}mq3*hPn393e@D77kXz50d(^I*XK#ejI?mkzAeOt(d>8LlgEpS0fjk4*~VZd zh7NgYi5?=un5zj9#9;zpzf$PK!!l8YW9SnsgW)@B&(2Vs0`BzweI{!uSeIdaIo{EI zcJP6(CuE8d91xSpfN)mqCPBDxT8x7cY;7zG#EcjK{Ms@A9ufmqjgCmjFpPGW zQnKX9o(-jvz25)Ks7t^BlRXg5$j(6cxdSg);0-5OAPhZxfW6EHrjEGk-9rXtp9YJf z;Lb&lalxEb#P|`Wo-*RRd(uiVfie(S()V{bZFEto`3z=y`3zui>RQ9`%P?Xaezb)! z`tgR-XcT*qiJ=C!r+}yuP`q!EeR_S~NunS(yJ+CJ`sDb2T;f`0ENRzAXa&n_7vt}k zfpu)QZ)Y(zyT17|dY|sw6iSd<808Fy6RKW!eLyV;g8gm@E|f|}zJ^{B9x!z5rH2D@ z%05B2BNaT}ObxYg&VK4|^H$FY%NO#0~^zri0wyoyJ=!}d;dk|WR?OVI0Wj_#E z_+CL$eQC{TX8(vpX7*p`vk+O{vMFRf44PrnYPk#)ys`lpYTOJwGB_qJ`hso3wa#aFB@yO zwbu5%qu+^3Z?%g|mv zEUFSG&FhoOYxfYZNy1N%O2t8}QyPqsMHK4=KcdIA%{11zwlxet=5?OceCXb}u$I+_ z6!9BMoH2{nV=at#D(Hhs`78FayQ)&qT|s|yhxV0LvUB81*ZiJc@1QbKE+>q(5vO9z z$XFS5aHHlcTw<~QxH2bue|mhH_ikM0;i`myrrb0tskW&p5>T0Vl`zmH)b$1W?~;N1 zUnG{%e}DN24g!Mm-x#^PgRPz8KNzuMPQiAa2pMGON#YF&;~Ij2HKvk186VGlj-aUQ z19W5vCY=#le7*F{m(X#JYLmG%{n3fYY1q{c2&CX4Lz;6pKEjFn4jG=|$yD+pANbbJ zqa8Q^#fW71v%aDi>9lw9@xC~^vLaD;N+2lhb&{-h=7L%w&G*}Ot`q+&X zr&uIR%auP6E%p|SCl}o)q(Ppk<)A(G0BEn=c~Hc63eApb@1RNgSyJI!X~L&?*%s?s zU~Q$y8(ht+tvWhdio@;tZQa>;i9h5N(S%gWK+A>pSH=vOQ2?;P?Il^J&O*Km)8;vE zP{1V+RvY3rf@GDDyVAq71G~SB?RAc#?nr^ta$%0DH}%N_SqXuGe;q;7?m9Xn0kXcq zU!>JWUvwUomG-#C8t6Qw9^n|U9p!*tOi3eZv*>#;+n_(V9p;(k`c|}y?AnfK!f~nJ zS_ENIA!EKi$Y0OISpAxBw>uU7>6ro4;%Mp{-$h9Of@AIsNw7TNpxch|A-uno`?ral zqcb$N)8Dzn{BLvTVs7mEf95VqQO3590lrgi%{R<}&>DqGUZF*L{hN8Be!FQYrB^W%g*eY>5MV^co9_1_bZ&Vx-m^4OX81(=SIn(2b4KaUB=Q;ShdPlD-$jhPQR zKo9Cq_xlG-w7#{w4vX^}kvnufZkZ?qIUqFlWYy+`YhkOw$;~7y><+{Pol~6++v6+> znr|((d!9{N6q0h|ww4OK5$)xmW=-sCYH*jWdM;Fell~TZ2ri^VC|(~pyJs+rq+G^@ z8pKo0+LfCc?*BEA?`UCer4M=PHlsM(2(Mb^#|OYvK|vOxy_oCZR5{3qLEc$kW)qKf zN(-c1J+ukd5?APDjG;}F6%O%S(yh(56a6WH*USpo*nk!*<^*CFOBSwXjTFdj;(EZ( z)0LmE9VO$|@riAtP&^4-IKj(Ihh$-8`J;m^S@GU zjmbz{HuP@HS3Y=0*LRlvSY(P7BQD@CHjO%95?*5ycfLB)&x8(KG%zxb2Sq?ZVXB}l zPtqJeI6~*2SQyfVQ^Hw_=u7O?^dn0&Fdh1pgrlb{ACK$P)5@d1LQh8L7)Hz>GW=$b zsy$A2>{>>Nz8KGQJ9UfJm^c77$r$Q}J*Ij~DxJ^g>Nh7{QT8Fk5sE$TvzGXz`uoE= ztX>Ax;}v*{8>>FIC{ibPbpl5xFb^ZmDix||hbZwqPM8RR%jnUds2ax-p}IS2?Uzwt~z;cPbgAmljUvY8>`1P*CtL*~pvtEM|JI==hRoyFt({O6^AJDbnr&--rmcx&^ai#+Plq3*|@Nj^TW)Ax_{*lGe_%i&1w z&qI2>&iDQ4FGYF;&21j{mp`MC*WI7z$HN$IsrmWb?{Nd_?8~>|j){Xc{NoVMVcnSn z20`MwdB7krGVxLM8E|aulXE!X_;1VoMfkRk;+U*Ehc7yTHv$943F>~!c`G1SjG~@V zTFXu{lgZ6}+%T7k_wEF)Xt_q~8@(Z^Q1{paJhi3p#8xvK#XAn@dk|7$JVr`Uw1%$G zj2CTUq8Tu(@l*JFe~@d_5908!XOmzq> z5twvE9m5T&jnL=09WX~nMDO%_#X@?6A>=~Sews2239kIIxd|3{!YRORUw=9wW{uZQ zC`D+E#2P1gv%(PuWNC0Dl8zZwhfxLr93Fo+B*x^UErirEt2AWWNtOft=>k&IAr zt(XAIN1{{~FMre4H+bknix_qDAX=0CdP8rf-1@gMjcYT^)HNF#zv6M|3) zq1_A!+x?Ulwx*zEZ61@EkO8*lb}0NAAa%F?)}ZkAOj!K*qOUT2y38d)X*=N1P z3c7D}cbdxT%`ErGm_cruJgH%P?^|IN@p!MnAy^p8mvjZJDhLImOboGWn~o$X%)uzew0-1}-=# zYL3r%OR#e_hBJHkWM`sGbL2{-gU!;MHsZB$UUhzsN5QR%5lb4pDzts3iBI558iqY| z;F=}amoKGTL}vq9T@L5#Gd@STu;s|zNZ~~5=(wel)w8=O)SXxG#0iD>(T>BWUSuFb z?tm)BJ&Y(0){~68tvgc{E%8+;TC;BF3Dye2rYM?OfR|rcea@aT6wQc}PL3)jzO8Bb z@*Jm~a+}L=)})gG;KG^dkMFm;h%b2MG}pKla|@ln&)3Se)tn?dzWn3zQ)Eo!|Nem% zBBo-LXrLYDD~0M$qS!Cz^_eUdL2JvXoo&w_Sh?EZ;vUhpe)rHg*^ z$5z#i^PEm_oVGks`z*s*_{f(vybFw$dnPN6UoT1NY*(v6`83X5FB14E{6bgzaJu^O zUN46^-u&d@b&lZELDe`Nt*&j-hDc{=@B!JaN$nGOja)PFA!@x+)$AFq*uI6I(&1CB zVgVJcTe=OM{$ap(qjk9pzRa>7oT*{ibhj zss3{2YOFhXzW475^U6nbY{g&8()9O0`+NLTc;RGVXl1PA@Zy_~yaIl9{UOpzJT4l>`-Q83I}EVqRxP2=S9!9-!h zx)r`a>)CTYnbiLnKYT9{av?mcm#r*-D6gTXc>FNTTpV26gL%kad1R z@6i9}l1|9YEb$2e0wRR^|J#U{HgLCfcKY8C{l6OjTHhzswQM$6(Ld;Cc%aYNND_!2 z7=9xDp@C%&pS8yzlERBc?xeCOcL4`qRc+6xC`e5@z~|M0YP@9I2qJE3+Qx>p%{CIl5Wa({+UsQo%FE zycEUpn}vs3?bDW8#aCa7s}{*bs9uw0zOh<-bSB@m^fmeJ{RmTYK9+~YrJ9)>%~5$x zSl6XfIi@{e-Bl_=A{M8s{5$A5)Oi(wN%IN0p$Zec=VWnNC;=1y7qyfI%>2y!D=j?o zJ@U`#R22he4Fqrsi;&uiF~JitA=f#$T#w1er!Fm*;ZEIK8_x|~<6PvTkbR(;1G;TI4cb&X|0GDHQ8h2Lm5jC7e4_V?Z#0A^R_YbXu&Ug(r;jeAwDlp{7hm zU9%W_^dBL?1c1N>aA_g<}-Y3RASz0Nzm^!w$?{ ztC(SwI~MuXp)yFAyNxr(QqpzOv-H3We&fqZ-f%^&uro{tgLyytY!8wNP z_Gy>~f&(QzRqoq*gp-A&^U}#POLwsP2TWl9Hr%{Idkn6OuZ8e)KJIe zT%^TwHW5wJkNBqgI?A3!=chJ1IkxK^HEHk*bM-+FrT#W4*OI7q<-s@g6U$+tU749j zM30KW{M_(nq{|3fHSwqTGnVxNAM(l z2gX!Zr$4QgedZb029#G#vBDc->@W!YkPt*?!XU|h0gY${q7FX_F3Zv#M>>Myjn!e8 zyEmNU>mewGL>9pgLvgJ+^=7mYf{nN zO2zfeb+{Lb6$KrR0D-y-EasH^TBLznmh>~F$bCFS<>$HaW`$N*Zuwm zf0YM((ebTnyOLRh5!%j;$LzB=wo8un7AJMyX>kD70{#SuzmQe@xM*_GzcE-lU6aeZ zUQ_(S6ucqLwpIn@p}WJsxPQR>4hzA7r& zIOg+3{9C~EZ^YC6BXdFGuXxHL{b%v~`44`WP&>C-XGQ&D&$X#6}wZ?ngW zW9`l01tFQORh3u(tTn1ZoLE}N&Qq-ZARDw!XuR41c_C3;*lqRR<|xqZ++u00HO{Mc zYiG!0T4z?*x~a3U#XM@a3=o_@A+wg!yicK>;5naMKIZ%SQwDr;b{BZ^B0-p$eROPN zSXOj#96yU>u{84?;#%r!{*H-|d_!eZsW^j_(_Tuu+yq{pt6~}RHpuw)VxjV#W`D5A zpb{?$T2%?=^Xa_ooWgpOagDRrLq6Zt_cCF_6`8jCrJuPWZl$l>C`)K#u#PS|g|)^` zXEYr<4xi+(5<*5Z=ep{fk;RgY@Au8Rc2tk;^~+?2GR0HuC*abv0s4IPrK(&7)Vggk zW#+YNo&pJJ{|c>ik5zS6ilz~lIeNdLy*lkus@w?1`8$lc~u=(VVcBQLgY_V0qVcZns@*nuz?A+E5!;WRMJtE0toB63O`rHqRv#p)|2HaqU5= zp2i1|qMhQ$&h)7cdQ2a4rJ*iDAjm8#-^Vro*l_q!PtUheZQ&;1R#X(+hCC*~2(5=E zK+;n6{E$8jRyD(i*Frn^Lg-)v<*VCnDKqa~>Y(e5dlNMG+hmsVMkTU6f&kRphvc!x z#t{~@V$rRpw(k(UeuIVChHmB6B;ge_YJ3gpBP^%pUQ`h!dn+sHz|>R@M!NN$y;jY?Tkd_jzeBM7TDfSs3(uV~tQ6HvJHSCIWNfcJ@Y!K!}GvTdE#Gq2h z$%m-G^(w&I3!#G1OPlEWuxz{9CfpGN}*MCJ|=K0>z!$E?rGz-P>-|K7K6c`*rs+&qr}W=aJ^)UQ}_uI+95m4#tLnc9NP#lp%r)4dVq+ zL)vGC*;js$TG`(j-tZsXZ<)VEl6LBt>kf|LnKPAI{0_d)05K=L3q9WxAJ`20To^$( zq}h?Xd8nFLFsoLeOz#|#7I-2cJjnY2mHqW^xy+MI>=(;75D>n<1rYyFHRZoCaZLZx zA@%i`byj3=`Wav31yfOC5s3xjE&k=E9&(L^Q78avCq|wY<=Uz>GEw)Uk8D?gkYrsa zQ09zzHCT@(A{{&Hx1)IU`=3>3C0s4KkTOe5avwt=q-|tX#aemr%d1LAnmPPW+wqPo z%`OJpXMJ9-ub1NDk5&0p%omyFc*p^>(0VU}s=|wD8cV=Z$7m>p&a{?W)0uFhRLgfY zRmwY8ExSq9mpkgWexrg;-(^v+Z;C>LQ6qo~Fmu@!T?jd`uZKb(v{sXZ6Ibw)1qGkHfldhN10Pwz2R zkgIb=7dD?tC2n-Jp$IDRcb9duhyJF0cY5f{>VqWn8SElM6()I{OTG^+%W!p}LR8|Q zPm_9g$8eK+@sqmNJ=juj9#uKY#_?$!N2XeGMQg5PfxR}ogJ=WFVxL#GF{#6RR35x& z*cyt6f+TnnN@(H0wBeo`VG`Tp6l^aGeRpboSIrZ4v5ZQX&57!r#4Jl9yGb-?W}2)M zRnqk(O|H#e0U%-zeKyVnpdTJey#5qE3<$SA?pFrs;nw@VMo!OZaJ|=z>t9B$sUB-o zdD(}@T(S&D#${iKXGQHrH{y5TZ!U*WZ>}4B7ofFO`^v*bxM-jD9@nSN&*7Di-yJb+ zccQJ`{?u`4&4X=$A5SLmnn(_mwiZ>x>o{_8l3lOhQK8)!s)2Qs4nrv7!J{b^!(GC@~)_{{TYWTuF4;xP-+g=;St< zoUD76T*KuiMlQ_~J;mbLgFQh)Llvz)|EiER%X_db&9!s>j-a*u{oL~MpXFzo+kied zZ8#&~a*a&&u#vQCC7!95EB$Ncty^=f(IoZe7Osc*qj>2#no3~)PMI+wTvZ=?uk)iF zxU|Y-OhDu+NuLT9MT-*>CKvA#Di>^Zofl09>_N~nPtQ2Ck_s2kzZ&cNWI?BxJ*T@W zL=+vLpt{+dVz$j>Quk+vlGc0vKAE>YFhs&E)YaR+Ln- zw8&g8Z#BXexf@Io@ip*nte#=hB%yvDnO_w(5ESAHpohv31>`adfFcMcwfKbago<<} z`lRvUNbUTU-08CN=pA}CMI|>i&7ql*mLvCer5tWoT4@@=M+Y#F^gm4<6CzbEGx*M* zgte{pD|J;$g1kw%JPNI?O^Jo1`haL`tFAicubKj)`!9 zO`6uPlZ0*&2fh>w+aeDbF;X^Nt8_M3Xji!;yA=&kiF%M$4OAi%C8TdQd&H7Pd}E3f z;fpbzW__V;LFt?09PWNbA%^TS9YMqa8X=0J-p}v#0%xp+1F``$g{xO-*yx3{Ib4rU zFMn>Y^?1GCkD)T?UVnZKvZ(3Al3xld?4UF%kw`NMCxW1ka$_1-fkTMWEL`{_6=hpO zk;l%F-VW|@<57R4q-_?vfcK>oH7nwUr^{goN zAe)$VvdweI{(9#A!u)sMPY+x?t@^L`v-*n!{*&f1|Eq#sebXj`9lcA>d>gcOneG`1 zTp`h7V%nmxA>3sSifCuF=Gz3$G^x2tX?5cpyLUn(0_mh2P@=igVHVda%G)cZ`No6)H=DI!A%T5pP{dN4_$JVT@ ze^1dWDc{2VteK)x#u84^WH;HGV$p;O#WM;gGq>v<_su z;ay!2X5n;Y!($G0R4fnSI`XO@3=K@@SsU@#_D($x`=%kwU*+ld90igRk97boh7EfU zVeqg=d0c3u!%^h7EwBloJ2-ylFRLJ9KY9m-<4?o45=kU^Ml@J*@XOR1xmW~)K|J;Y%YO~nTN?2ZZ! zCC5I_-z;~#JM!=A3I#NM=p-Y=j~|iMlC7>=-A0pdG7+`OpQJUjbs~}&yyGB&*z=-5%=@`yi~INH=LiCnE9|YTXE*hSY(e;=K@s*u!I_>3iQ|5{2&vJp*Lx zFQ67)sma8;J~}B2c-wwEAT4;s=590xb7th~vN{iDssSJWVC?(}2fL#1&lZLE526K& z6N3N^v4bbpS=&#WWe@MRG-d$fWlrX;_hj_VQGG6B7Rs3-XC?#b1wd&O@EK!>*Jr#0 z5%EReM;)k{n=r(=thMXZj}Ie#y+7XecW)E5ub6yN{LF5(;FiiB6^hr#?o^?om-vf9Lm{WSkw-TBS85VES9l_o z1P_+_r+R}h!QEYkryFgzw`;QLn&g6vCPpaQyJ)sqUf4J(Egd2S;WWD;TBUhDad=%F za%iIEtn?QC$NT*HSWq1w6lt6?Xix#owJ=rZt5$|tjw`9VBO6D;&!;6W;Oy4itF=RO(B!jI5&%VQG+(f z>U+7~xrJ)xl0{GGqIqRuIK&B*Y|%$x7w}cHLfOc}^m%Y53bE?QnXbA9cWrSOLcBdC zTghEcyHRsvGiup&e{eLV>y~(9&XZG`VKYrxbdmPThy54yF^x4-?&zh$IlIl8fyg)~ zPlne~hMUO<<=%{-@)O`r@Qg?n9s?ElHc-dFY4z30~deHc!i&4TATFOoGnO&drKLh$h*UP z2OLGlS$ilOrxM@~7_}T;mx+y$cCW+;EIJ@~_-G3)tKB-RL}%)+#a7zDP6-cp2U=AC zR~r6giC7`eH$=5-yUZ*KcpfH?Hsbjpsueb#5LvEicxJ*p6ebW_V6(_)h|IwY`}0n< zX{?j*N$tyNYFQ9VXkB_QG~_KQ|JuRuBR;-ll^-ewDLxNh)FenvvL~U$-ZRG-G*mEw znGk$b*5T|WsMn6FxwQMm0Ndx_6;(0Qn_+YWxv|>WMw;&+pom6~g588Qlo_-{1s=W1 zTF({FXvm%xwU~zS{WcFGlJA-bs$BuX3e6JeA3l0-18G!eW6hmXkai3%b!IEY_Y9CFb!E|u4I!{F>sHKt-v(dUfnJUuwQVpsj zx3Shg=jT>EnnfSPW2CCfTO%-sJr1ZRc8=P7oAJg+R6;ho7b8w~zq)BOq5q^%;(A() zks~Rv&A3x#9yCP1wudLIScmRZDyv|9eAM_-)|*wwR;tjp((LBzfJY}jwc|% zjp3er)H(VNm$h?6uJsI?{tj=UL8V2bj*sS7ZhvRrTP3a>27eshG4kusUb-}7CdT(q zn<51D(!5GXQdM%2LrdjcF^Q3lmlDEs=T%n9GVPmIkb8oD2w)g^?=CC)h5FSRWijKD zJ*;E?k;L6pBeU@P(aw4ho`&h8vB{&dQrEKlOU>GtM^)?ho2HgKmLqsO;1@SPcL5eh z$|cco4Nv9@d6Tu}a^s@4mTK`?z4eR#t3`i`-g)`*tA*v~n$Rbt0<{fUbz|}9xkr@% zJfJdXrVp^Qz)y1{h<+k(Y#pg zNhki`EU4YJzeN=_d(MX1#+$4E(u&J-KE58E1>@&PhZX0=v&Mb*7!j`{P@Nr zX6XT$K*epND*6x9A$<=$E<48zm$Ey>1(n}rU-q?I^cm^{3daB5J-ACw8-Emlgx)}Oas5gQSHw(!3GTkBk&!(O!T7pOk`g&F_JI*;?9%V2A_tGkrj)XKUp*C1(?IK?KcmuDvPt(wppbF)i%x>T zGgwu%?g-P`yMcC}>TwMEGFecs;(X3RBOhLEZTzQ=hw(`@#IeWH)poegrXDQXje##^ zMcD#O5&!yc405*n)vL*YR!korxz)WIlxTuQ4fW8%N&wfHp4j@cfaahXjqH-4=WUp6 z?Ihu?%Mn5b7tli4$R24o2Ybp5@zqTl|2Qy%dM(_-+_xAVg|8S(o&|h`3@Z+A*pHZg zOhH`f;2Rq!4rHXDUH@42VFeu)w|+w@>@32jXx4WfCJ7_eG-9;9aA}>xC-#`WnS6)= z=wc#Cdx!*Xfu4MgFmNm~c5&MxaCw&Le1TnQH;Nr56d~WO@7Yl0m*@)V%kj-xvGLM8 zkL&Fb>F+}#OdOea1z_$(I_?WVFU5SuIuV|cM((8zuRpKDNk7n0SmrTPns@d+NY$$5 z+NgN7>gwGzB+hwoNikDa^)}6tMZob4&ZwjXj~i;4ls>lyJeh8>6F1r_G~P41%mJNwPyowN(0oXSc2OOA_ReS6U%z6uZ2n_A8Tcf52+0%`2BpLAG=2+? z=NE(wtB5mbkOg6W(VLpOE*ha=tLGwJk`E+f8|Do>$V#>FE%MM&BWvcpNdGIT!_`T4 z+vB;I#=QIzg8s)w5QZn^F5m_wWr__XPlqsSNKDQqu?{i1Ig|P&{}4<+ePIEDU1ZV` z84!C$Q5%eTy!mxjCNoJbl})Ag@oGZ8ujhyTYts`tleBD;`%>U)0lwfKNFj=XlL4$9 z1~n1$4d??;loBdg=~XHf)=HaSFL~ZHGnQCjVzmjlWo@?_8Z2Y(t3O8I&Cq1`ZLgdj~r{G<+o%a1W zoNP&7dkx48>1-epLe(9NR)TxPc5toSZ?!pb zsyMwQ)8(e{S+M~|91AsGyU`z$1~f|xe#}>4QzTk@$u=q6z|5VTe5H`ek>qsRXnc(} zH2~TDK-Z7`I@TH*7wZw>l|2Fb%s$8&5}mCHOZ!nGqigQIN+R&8Yjx%d?hxcFp?C}_qatV z$;l+`l8E?V=l@rGUmXxfvhF)r(BSSCB*EPY5ZocSySoI}KyY^m?i$?PJ-EADaDwwD z**zz_oW1And++}B&MWAdo}RAXbbVc2Ri9@Dh((QUI_^PCnk$BOsHD?!)1^pkbj!FW z6?*_iP`erZKHxK!nc{s2&NYR8%5MdMHmfb9TLl?~!!&BHC)Y*_;2O{1DG(^rnG&io zC9jXni#v#y<*p%=42{cW7u8TD41x`^<+qKAAQf~*jAgsf!@lV*L2`MCJCM6!U{{H8 zZ%fPU2b%xFy_0m6`?k)zfitl>c08J=hNj8ll(n;fSTMOasy~`G>(e~(l4{0ON1D)*%JTG8{?*owzdSd z5~vVm1Sw&RVl1a6tdGx0!YpL9kheQnYH9R%TlZ9noVvw6A`;S+FCw!Y1^P;KjhOdz zBW1MRK%wZWHkCfnf-doOd6upWO4((fkt{pN-@{` zZ}k&f3p1V`cUPAy7zs2k^sH5T4a?N?8BEa4RJnAtYr2Hapl=80-Wsf;)M2eN$)6+Z z(m5C7(eX%gZk!|r+ZbQe!a8Ey;3~s1(m;P95aPq6W*Ww>FH=NnmA%%n=uR_l;xdkG z(|K!~E18A7eduzW`6W5L`uzAWSwtA0jfG;M5`qgj=l^%JI@`}jSoz2{SqSxEC9wr8 z>v1{Y0e~)j2$%MGge@@75g*Vwkhxamg<6j_x>-Hn+aG=LtwXh#p6~?18Av(yr4@Rr_dGa-cr6 zfqg&dkg2{DqZtHO;X@3r^tSs7u4JAD2Zq6Tb0LmwyJuwsbu_-zLXpWdS(;ZPRc;s4 zkxAm_ls@hXK0{0D@Y(y14hn3E-nyFXkPIA> zZtiFV0X;_d$Fl%7o45JyB*J=k6c&Dk!&ZsYm^H^!)k70C-}Y6WUY88Y5~>=>$PM*G z4JdfPXL)gy;*qRes_R2iz4hpmP!7SFOWkwLwO9_p0hmn-$9CHfeUMPc$TX8WxX?3t z&x|TI?U27ftxJ*1#d0pI)V!F5#Cl)gb4hbdMW0vNSzm}(iHmr!q$?kf%>OwfV3uVV zsa$2iyPkbt$>X!e;8?u%M`SVuj1j_ZKdTMdawqg5k)BY2RnkU-1hyL#a4wlLRw_)nU{JMD=ZlVV(C~rpaeI z*Qg4>E9LpIcV}^Fe52D{ip>kfBWBLp?5`6z0U#jkqQJQdZ{{;yV(PG>eazx8&ILTeQ zDdx1dbyG@0{QZtjviPPvNndMH$kDev4`jm66~$f65-!%mYAZ1IOH#3JxiWb=7B9BC zcx;CwNIJmG*C=fB1yfhlDj!oT*E4QcsuV9RAJRbzQF0=&nl_8f-JS41DRn~Ku${^F zVP2qayHHf6U@pYm!G~I2ko&@!&hF#{0Ns_}EYrFubXb^RY0N?Ti4>|qT#-+sMpF`* zvA2W1A1gYfGC`p{rFP&EG_OHdQ=gIROSY}l%wNpkI*+)auNa=?m+dl7z>9mUXQ?x2 zpGgRps%2T(f_HJXLu{gRP_=p0^qO19mbYdyjONlg(x}6!mV&kx#DcH&)H~LkHw!^1}!5La02X{I6lY`^)T%2qy&6MSAIi4q7{fbG)WncN#@15=f ze>sF7oZB~uaatWX!|&B#7zh(m@}9^R@;m!)=vahrs8EdeOGUNGCdLKvNh>sqi|^_s z2*L#KSWK2_V}-<~7Qh?|A{%hi%>pDeJTG}}{*r+c{oH=$1{~_4{>D(x{@YML_rqxL zcYU*AU0`O&toL-LS2SkzIY>~~Db$!^AR`tlRK$uYFS(XnRFvUqlthZnT0BusJy%2g z&WX(ApwdBMm15>Wa4t{v7_6~c=15Xgewp;q2PNp$nyJC6k%@|e?gDobm%FLCveqVZ z>01U~X{)51DpaARwCubq{(&ofmV-*a{P9P@i6GSnR;z~BvHKN`saqPivN}pzZ~}dX z3aXP=Y%QGnZj5Odi2+IGDuJr-gr6zo3(k%)1Q(=@kBE#d$Xy`xgpTt_tyC2()by3i z78rDv#FpM(B^GiuQrQ(ry>Zyc?AxqBneOGMDW*mY89s zi1Arr6WjFT7o!aVy@dK=rCi2F=6jd@T=D~Fr9H*rfNGvV9x42`Vg}R7P{9GRCnb&vcsP6?VdEXtbsd6 z+9k`abp}TO3e)SU4Vvze7C9{2^0;u>*IP6?0$oI6?jos*v5q@Ez%U<5silpElV_px zbwl~b`7`#X=%(s=kZU?E!FhZ_Tov8th&hlOdwnlcUtpc?q9Y(Fm7RfstkI+kVS#+c zdmC0?x`E_{Z2+#X_BMolLMI$BNkNj{2(~NNxkv(F9)Rkr^};@~mb?0~zk38==yfNh z#NP5iKy&bjbS%Jk926wb3$m}2Ke((r2-r;I&|QZ_dWR3j^%foqdB7Xe297`OTG>d* z?fN1t|Dl(qiF(|YCyC@SQ3F??7Z2kWY=LbJiv9JiC$sRzYhfngDkDQ)t1#oyrf}`B z-P!Fr1SfF2MMy3m_HSVL?YQ4TfQRqA5R$X}l|Yiuhr(N}_hXHgS?|(3*6hiwLgnh8 zYDCitYB@Z1-^Mqzb=trYW*|Y`_kkhVC4Azh6>D1EYi5rBkb?H|kk`0mJK?ma923Y_SejVu}*XKvO?hLBek*G8KhbNAdFa2tB z3@&5g#3QG1AO;wX{7>LEN>bi~_OA5Jk=A6JR7HH)+UAn4hoBH43=sghql3N6N8rNe zM~sSanUwCB&5)+8tv`Ix{7@~y$?Ool*-)$1}fh^M?_X0c@wQ2v(a~;=#8~bcn(f*KvMY~h>$woVts`Y zfdBz`gd7|RZJSa~yi}#7OO~Tgn;zSy_nCMS z0Nc4R;v+ku4OEU!+k?zq4z};E#!-IqCzCu2o}f|-m~%dic3GaH_ZItoIg`HdXZ{G_ zWVHoL#+!08W%CuTpKGj?gh*rO76t5zraEx>Yl#}J_F%8({+iKV85s;W0;;Y3fF_H? zzZZ6yegK#$u)3EcX zRi-#BGvAA6v{IjL2ZJrAmAWC8YH>2W!*bKc>RZq=ld*)-%2XupDr@v!-m6tN<%0U( z&r`6uBQzWgwB_z$GgITZtB*>ntduGmj6|`lOco>gB9FX&-S+6A;rNhU!YQcMtAb!U zxxYXb=nlEeuia&vXC@YArPDJ$+J5OKi>4}ImY^J$sf!IRatHfyFhy}b>gXf#LE|eU ziL0ce<0dtGCG`@I=H_)(6%mUp^d(Hs5M9$*bA{fzO5ngfK4;l&HplTt8Z9Y&RV(#Ig-8xRCaD zF?&rBlN65+gv;(WdIWHzs!;M7zBNa6jtfWTwGKqlW~@$;)&sO{7bs|AX0>K*pdh+` z7~4K}DKP<|2t4pw2<-yonm1ndQ4{Gg4ZP%4*g@QigwVyP?$)$CLsw-MHBz8O*VUqD z-rW>VYzGMl8kjVA=~9W=hPGp)uzk!lW?D?)cNKKyVwJcBZ6580LECK6;8A)(BfdJ56H^6Z)VD6Hu1g?_}> zebqtnd?A6k$n7KeV1f=rl~F#}EaXscuwJ)4QiYl+NoXfFPg0*Co(Jyx`o`e1jk`uF zZbS#x2arf67h0zKFTRkecabgSz7^=!Uj zqO;KPe1(RSNoVaw+qxLlBYuRNO9?brSs1WAVP5L9FH}54qMI1*PK7#V^YR?xxnegIf`n{Ya`z~ ztQ#|1>I*8NiwaQ#;YF^Kke%S9bA92}DFChONV~m{=(@=uK$TBs$&Xlb50}GpDibh_ zC`d*$E8#=+YECk7?^X6o-f~Q5~jm>QMx}gY|fQT%ceVOPW?p|fiSrEuEcyD{nUMSs5 zTz2I*-dZA0870?m<~1zQL`@w8gQnh-2eJp?zso+JIKR2){Yzm&o~EQ#GH{qOK>QoS z)bD;N7FIx$#z&r?I)X}|j(|E}J|T_`PGmt_>bpklu5$je&{%0@=G&*~E4y29$L#Dn zoAEBJBqA+|pa#z6rSrwevq7cChKd^gD}#k0^ZSl!>ak4=S(bRlyCn$WNQ=wYW$iIMGwUe5s@{3xj&vUn30}t&?|P3yU^}2 zUa=R|a6qIc71D?)L_W{0o1Ke`j8}8=J(3sUYbk~)>Ni!s275QiRYXN%#n)2|TP`nD zd`y##=rV2Y8M{=$8837!V^~IN!pgz%G zBH&_4CE0VgBp0HKsSV6)1{ti-<#3v0E@mXsqPhZH41N)(`h_gx9Ql*fW>&=4Svi0^qbO@gx z>7_&wOBt5<-uW(5L#jA53fBxgfnfwb&_g$P#^>~@v7x;#N1H-I z@9#vXozCG}PXGReZ`m?imB1nrU#DDsN$1!>PUw1gMEfU{96%HnpCBUjr%VfX*0i52`I23Gb(Y-o#ucq08OAD{`FCi< zkvb^_pmU&uOt^#t32bqlkYJlH&h4vT0|XSNed6PBD3>iSB==Z?i(3?f#q@_A@QUPD z+PjGKrOB<^yp~aS5{EomCA*&ZSv-kxtkjm^TXJyRbL~k234~!0M7B|8qrsLEj%88T z2|JL&DB;)h@)pGUy{$Yg8K4W2DZaO1ona08EyH~r~F3FrF#V4q4|$;)@`l3iof!7fv*F*Ce=EduEKwEH9jbv_ zkYt6sJI*2{cyA=CUps7`#ixBuy1qejwlx^W3y3*ywJS4of&#Jx>U^|iZ_d7rBGy-hC-GDALA-lunWZJ!217wln&Fn) z5=#YD-pzb`YeO=X;==UVsqeO7^l62v0_%aX`Y;Q%1?(9$h}8hrHG)w~>w-TzH;(H2 zw85>+BqjyCn#+bBh4&lR->#NSmzvH_qTD`1p1Wb@S3*pbmLA((r zEyjiQ=@yoR8-(>?&G8DSfCa{j&@M0jMlR^;CwbUNhbeuq4jQK1S(^(O+Jy2-ivyc%2 z8K2rnKs>%Khw~425E7IfkYS;T9o>CDAW+*MhJAWmJ@K8Ynw9MfUVk8_7%1>8CSnTT z83MURBpmO>5H5pnG`4o@G9Ab>WCQf4+kZntMEcf-py9v9|@>~01`vgW< zF$q5GbW8|XCu-xjmU%m#u+C_GS>-{}wO7*4yp~zODk}Zy%W&IYFRe@Yi6g7?2;*b{ApJ7wMY8U%I2q@S?fLNPoyL z6hNzni&+oxVhMACj_az3LFo#?mSoN~F&i)I#956HYgIo*XXdu3--p0P5cHlvE#MQ^ zO_e!DL~Le9#FB*;k-(KPm_r!Pi|7FxN0yJ6R;F!LN6s1Jq)Vhj?nN&#fEHUTUBJ4S zi@J(X?`KBodwF;|7|i-&v|^8$i9ql?hE}8rt&brVxd0E57rS{@?x0qmtUSHino>F~ z+N)bXEhdH8lpfWvPoK2PM8JzaIqJB$}|bB4ntZwp+r z;a6YJOK}+Z9)ulHkX;_KaKAx#y5!${S!y_ceHq{Hb~ok9=7%U=yU$q5tcrU~Qt5m$ zr4GU#f{c)^cDJp_vBw2GE7a1?i55Fb89GReC5@k&8=tNgMj-R3C{W;;NECet)hE|H z$(mNcU#KLBeGY71;fjBFc;)30QEE$6HJq@*as$-GC%9~0W&{V(w+|VAxS$jtOJz#q zJjo+_s0H_^i;a4r7X}}Al7lkBZPQ6d)FBERN38!|i`-@{uqf`@!_US0%U9W8lO%)7 zwYXOUBxEDqtKT7w>qHn|ElqUME@8ZVdDtR57wrb(b#lkh?a*s9Kqk>j^Q51Cndr*R z78Q-^57cq+%-Q&uZIF`GoF?8O?)SOF${XirR-u26vXlPi9V`?;F&?!8*jeAO4)r3?U#}P$5A$yL7loAHUZa`9K87H&Sp7VRV zy|Zis&SQ?&9X1#g1{%y*vFTQdWm^s!nmT|K*%|R-Sf2cOUQ49sC~kvC{Y}o*fhWjc znV;_*s+nN|zbOW8|HuYlYv5>N2h0TUV;cFvxramyns(BmzWeU?=v(iq63$_2P7K*- zlK$1;7KA?z#rqYOsHCdcQ?n)h$FGv{odzRAi`{3P@N=4sPFwnY?ww6}6sRL*C|Zx~QB^Gc_bkVWlJD&U zO#NTFH6lI*T0#y%t<8#UtR5mfo!+&uij_6F)XwLiknZy3n(!?q(rX&``g;aU=dWfR z%a=BBcb@DbaD*)5I@|kGvR@XKF>2{GP%tM@vNL}6DvPQOIdSx(NMj8s4PrM9&J{E` zzpt`P_OV@d6R5#Bk$F4)3M7QizYk8fCZma1P~ppcITgqZtXr~UY%6c+A0;9%K(^yY zAp(%3DRTCp0{}lZpgbF(Z>9G$G&ZabOCEPbL5$ zZ}amUDd78$q7B@B)cNflEFI+cvw_FbLH}_!2p}EoANzm^;MstuBLJI>v6a1*ow1cQ zAtMbPfQ?^N68r}`u>9PA^>Y_-5-(&f98dwi=LMZ`c+Kze)CNF;0z?3R2ZR^^iUb0N z1oG4cAOKzk1jx_wvl{RP1QZM$;sqoWGz={8gla?pCjMpgj)94V{fdm7g7P&JGYcylJBPqK zK_OugQ8D@V3W`d~Dyn+=28Kq)CZ=}w4vrt3oLziB`uTqf2n>pfj){$nPe@G4%*xKm z&C4$+tf;K2uBol7Z|Lak>h9_7>mL}On4FrPnVp+oTi@8++TPjS+dn-!zqq{metmQM zqg_AR`H%8j%bv9h3D_=BaBwhi$RF(j0sZi!;Yi>RMD#C^`D7t=Y*2_9yrEF}BQwg| zp-C9!PSA91$6(M&nbuyN{%G3Imi@hk`TQTX?4J$$)vg7A008u70RsgEg8%~qgLnY} zEH9v5{3uW`P(KUI?*;B>f&Wnu{#l-YARs^taBy%);6EZPG%Vu(c6nL?c1y&kc>o+3 zu!AFkApv*+5BHft6oCJW4eaUb{{Y6={{Y7LzhHns1N2isB+fG)Zy$A=j@PA%a-RSl zI7M&CRS0wo-ao8C4uOHwxSkHXC0+&VJ(i9Lk3P^Z>^(fRv}RnCY0(ZmG~w2loUz#` zuNaY**d50@c)a($V9!$;C4eX`B?(tKW7?CFfW!)PULe2oX!A_v9z9Y`WK%$ zipM(7qt?uQDmxByi`s&sxVUwPalpgcL6cq@?F>O#9#1n>cv4at&;FffJDX9$ycAb! z;P=l&L_*4|FeB&8Ay0tbCd|?ME04q{K&eMs*^Z|BDfgx03HNo}@(TFE&i5hDtt!up z(HUNU9^bAE=~SCfAN0lwXq4+6xp_xE-5uNn**4S-LFB>j3rO3&4^^n6M9=@u%0BLD9-Lv`FlUy+D8QJ-%_fFJy;gtIX19H>~WYxE7;NI zhO6>!w=Q>zeD!EmS2=x~wdi-WS8lv($&D2El3A4Yih5Q z(vwBU&Gk5ykr@~mIKp0Hg+b1W3MYL>$3QufcPzIIqA26qhoQ*y`s2p?H1X&Dl22o? z-c%DeZqq;+8RHJLc|eWX|N^fI-e4zrd}ADCsSY(Jt>aaMP^E#Ii^K>T@m+-qV)I zhs5l;0|dN6l6UEyAVkFBn^^QawW?RvPO8-uk9+gOmuJEdE?@nm-R++Mj6gRb>gD8O zr<$JeHfNTZRMpp`jiyE7Ua(%SApOTjngg^}jQ$f_2a{ay9D1B>F3Zrac?;DN-T0jI zwS*?r#ZfCbD<}=E=o3~gz*^t3p&c2ijDrAUH;RtXk;>VnxZIwqCx)mNPu5PH30*6u z&@=&mvpr4v%PB_pXTu5AaZ4KFurX4kV};H`Fa5(9eJ_!H*D0$29{iIhz$fjoGA0Y` z^>L%gjL-yn_x=i@1z#dO8G?Z?Tzcz0V~@n9NpXP%!6w-vnN&ehFNA{7clSF>+0l9h zP9c>~C2PXm1d5pFd$|}75s48PV@JPjKxSAnKD?RD`B=R8n$opK!!5F%<6=Apd;0jC z=j-Ey`Np9Vp3QfqNrYq9o%1sG>#{=8CjiV7;Jraby&O0xUQzow;uFC2@(|qq{ zp}8`}v&f^q_q)52h%4e=QlQ{$U^B=My6R7GUB4q z-j{n9qhYk&;b^mqDkfhJ4!=O@W3rs=AeqDV2Tw2`JOLmsL)nz~AKmwvDl|SB+7_y8 z0;0VDTgmBY`?_?o1!KuYZ;ob$HgYsmcF@3GDVrxn224iaIhZ2D<_^_b(N-H+Kdqgz zb`hWp7VVX*Q#W?$A=nW^(2RrHV4sO796Ydp-fnn#T$*c_?>)SUr_SO?In6yhy3JD85b2GRn|}giOkbyU zpfBUi+>`S}T0sOan%a;YndA(Fu)u!@%ja(9ZvU=qUwM^KRz_|XxCTx+pI=l5utqgX zHNxx3(N_{=WMQyVVd(}8+XTnexk7|m~pPm+pdO;pCGiz_iWF*PYjJKTCF&UDQ z;s|^un>vf%giJS?N*u3a{M)z-iq;8-q#||@fynSDGiN+WC~x{)g7+F z>3J@adv%>^+ol+xOFdPCU-UInTup@JJI;jib*2T)o3%a>rI(~1@#wLkZgq5 zl;z98kfln8OVrx0e^=R!$JoN%PouGbE@>-$8wj~sBE`+Qq^lLt&mpCk?GfA)Kc?3 zi~L4W%aS}w=Yn07<*G_sZWZshJWesW*>3IDMMiy?We>2Q0J~a&ej&`vV-4Jy{QFl1)r0Doh4^HM0~bo2-QU| zb3_D)?6KL?9GJd39_%W)lrf8@x4jxocBlPL;znJMK{NkoG@ME}mUn8@HLNBvVWsV> zPd}n5QAa2Y;{_V|aJXGhazxw$OHcPZ8+RC=OQN6pM^Gdf#90AC_?IvETxWhqU#Ei*9 zc(xI*jxYO&Th2j*JNH{&Kj)cl4HPjRkjag(0SGE0*}I$#hRjo zB(%wdq^w0A%u>s*OV;k7%@zWp9--^?wBwUneWb5t-H=wu_tFcZLF#ARg6jymEs&qMxXZDxCaVHIYJpM=Jn!qcxM{u{xZxbSqrnS52-};ldpt_ z2g`%N1?LNpfpWEdapYO`GiXBmXhBpN5a0&+H1+&RH$tx&UyLyU7a6U7v2BsXCcBv0 zk(cSkRn|Kny7UC_xs*4b?(@bV8Dglz*S%C&@C}<|<5KGXj?Efc>cEX`#!VvAP+5n< zx!A7TK@rmh!)P@~7je?acQ69Ud_~Z8_eMM%!n8sfRAimG-2goNvo!cL@5gFSkcDeC zA-~IDaTAh?8b!foBs3jlP$*dEOS{J}VK+zNdE-s4Yq2FI-x@by zli$o~V$lYZMh|-3&b` zuR_-JOgyU%XELF2b1zU;?sav39iXL8%~g!16M*c9;+d(ZjxFGg5_Y^*~ zz&Kj!D~Z5|9U@UylrMF<$u&M6iabkfx1h`G3!>MpC8C99X}ZFdf-2<3BybDV*Mk&T z8*Bsv$)~Ie6V>99$SdNfaIE8@a_Dm?d~aTVjNk6C^N!$V!>>LZA0o$q5#vztpu|P# z=q`}DU3xUC_tc*^a@?y-c>07xV5?DiLso>?|$F~Rv zlSvl{<@@ZT3kz6#TY33c_q^EW;lb!G0oU$@J2Hkhx&-?Z(c6RkC>y!>lRP)!K%1I+;uDNE};^8z55(cO~0w{!3?+YVX&V;%ijZnz66D)K=q%FTfzrucCBetn?d*JXsB&S9olumUbz{6g8gyHh)uYGpsR z)T_g}Us|?#BBx0_m~CpI?XLzYuBb$K%OuY2Z)~ zq5J-wx`*ei%-j*4vgdAwN@yO4$*oK|@5qIl{q zv3Rca((wn`8k(pyDlN~ir{SQ=*QGsHP%X82quHZOafI!zrl(bwFz6u-H5~&$)tj(Q z#GAuANsX;Waja?9Nbf;D@`NHk$Q?`wyHt*y(=zzAukNU35tB0S(_I-%uXf1J-xtq<}OF40o zi_eO{$Khy(L8a?zc&EpbC@|M)Zey{qgPGc&gd~sDXeKAMZsPq7i{i>7NGDfqlBN2J z_tZ?SI# z*@=`Wr{v5>k;y{zoetbPBu#AZLefs+^kv|reIm99c;(vpc-ABynE}67IEf--W=~*p$ z1J(|Xi=-}m|GMncCynT?lJo_ZJkqg6=WbE`INF9?#*aKXZ

>%J?$9=r6Poc3bdj z<7W6Pf_kvFoR^3-GIHE3^^V){+W>5!H9&Tp9RcYYvi*@d%=n%!r_<3A`$c5m8xkDb z9@@_$R`GUvkXZ!sqPb3B?MYOQ)_d68%`vL(gUf2!Z?>|PkTXIz71RM@^qL9nJk_(Z zZz1F-nGRm6j%aYo*NPXZcp;z$VPS%$AUbgc)e8cw(N|!32vp4~oRVI`&zeSQxK|!Q z*zWNxjBa}ttiU6hBnv|1+ zZLVNlSL|0Qqw=(0hGHX$5C9p+M!qFtuSB=kvD7qj%Q|4?z0=YiQo=Rvsv1;5S%6eFD%7>et=Tv9|tkbF$ydTFuiTdZbl7X>T=yH^7(l z)fMI=A{Yn~SB+Y_5h>y=RB$U^9el7!LVQs;$xM6m?b~4d;7;S$qta`I&^icmlOI3JoP$StzTmq(bXeXp~Z}OB?k}>NhRGi`MVzg zgV=BUY|s!e8!SwN__wW%&a}{WBQ)VZZJ=s`8H22MIQocYvCa8Cp`nqL-JOTeEob6q zB|@tkRDe1AGJT{r_!hz{NGP5vju%4PKXCt|mH~6D zYE)3TO%IYjsrSWPH0r7yWV(q{f&=Y zb_mtU6Z*ZMH#taOmOv9LPDAzr(|Jj=;Dbvx>VI^%Bp7^iMU9NuByL{CfD3W$b+rpN zKgJW)pWf#Ah(;g4L(Jn?=iDveZG=C9`PgjbbeK`zXtINTsnFUCKU^0zD%GMy^4(U= zzcS=go!DmADmfJ_gWTzkUN`zMe~J+@T*2YVpL< zvJqS(T|jE;UZwvM{oNot^TyiPOU?F*LcWNZcIA~fPL>4dwS){=tZXl&Vc=tQ zCe$^H(~RHI2snSC59U8rX>Y?l&K`0h{8SGU>9%K2#J*`O(HxS$@`b#<@Bk;#gCRYU zu6vq7M9mH&ll$E*ff|P{Qy+zcNL4_A1bK}C#+dGYAX-8?fjBz>iUN^u3Wn-urGQVL zT~CdBv0}?KFR=@rB#m!AWPEi{oN)ijDV(~~lpAQ3hpE!@t{ydfPoCA@U5aR)(OF>o zB1$P+KGrD45ZaEP^AUbf&Gl}{TscWr&#}uelep`AL$T#j>%fU z_B6zlFhW&kgX!0_b?x1~EV0s&RL2V_tgJqap&^P&u#LvGsZ@M-ed+J;HqBymA;NX0d^oN?VCgvv|1V__L6$1E%vVph#w~}xJ9^5E~1!-$wZeP8*BWvmSa(cvPOz&?FAj; zL~X1M%m#ci%7d|;cWGr49W56(wLgE5zQlW#F#(A81V}K=$S#_kDtcF-aF5Rec(9Ke zxNTCh*#izdrHHaaLTk_=l-O6Tp}1`#+PAe8#ihW-6FU&T(%;J&5zVe6BH&UdAc>K~ zGZ>b4yeg+!i9a!4Kz8|_Lq6hNi~(U&-7LW2|Cx<5r05kEHuP~;o?TQ6(eA2J5HFC@ z|Jq+oGPQ`%fZy~2Glo0) zL+=mdOR<73kBkOWla7d&=vy(tBz$epaC4h&)unZ>+nfg`0uvdNdVBD2-IY+JikmGy z9Kew>*#kfRNqXUntm1J6U!0vLo$F_iEhs1i!xX(5 zSWgxdvAGXO7zC3_070U%~C2#bjKBA%rWF^?qbX5&ec5rlDy)V z+=tCTva$ldCHy0K?P&i%*k2?K8EuB?phJW_(Kx~DI3k0t z9rf*oqK;zUrfg_ZvT%p?F}yo7824w=r7_rl|L(jkY%;g7=CwP{*C7d;8tn18 zQohApk*C=Vi_aN)bfb}Ok4PC#)^Rj!envD@xnUJAC|mC;b3&<*(vwt&mo0dJ4g1%g zH8mD>+Vp$V|J1Yo;q>qa0{nL9za&G6Zvc}6QhU4L1xVQD{EG$B&CK0}ba0~!^xjp`&ScE1Bik``<1Y0uC+kQ-SuQh7 z+c3_YSvV%5LB#%eQH=4A^)UrXNW*xpatDR;`4<*%IT^eAO-0DneCt|B9`rty+WEt` zbCI3h2!C83Q{f;n67%bxl%FQt`rujfeh?X?&|>m%=?Z9R@IOI?pE0X{K3~r|!+#VN zen!mwkEg=V7?%I>RQMU&?mwOiKZ7#;$5Y{F9Gd@lD*O!P@gGlxpI+|&4^si?tMHR2 z`je~Zzpk2oI*k5*GW@@|daP27Oj-c44x}%C%YvmQU}b4-Q-r3r~PUD9_YA9Yq zdVvnr|Cs!gfOsXdhZ=`KK(bytnF(6x)heV@zv%uN?VM%CQ|bn(AV_t#+xv6L9F#lv z74n*il#o&UY!P&AyqVV|S@t5xGW9fdg$%Z?=VsRY$UBLz6TZSQRDcVQkhyC0u!@7x z@KxJlB|2M4Q99hfk=hloq+4fCDl&xUl;rE&;1zk{!+jEFni=pt10&$-)4vB3W5M9L zx$%m}iSm5$Ix2+~ML&LHc7h^L`KXbE!m3A+T%%?A%jhm4g%w}sBocu)+Z5kmtG}#f z*u3Pk8c_+n^Rc$sZvky|fwj_FQ-H;v1oIWr=JtHBc!&v@CRmn(T5<>JPBK^Z+>pTk zCo2g8N)0qLe*rXc{$VuTh9qy443sNL0(-*`^YcHK9~sh~%S~j&|BTSAp&kYc1fys9 z%?;`o!Y~l`2f`m1{om7&{U!+b`E)&p>wo2OKM$00Nx1?5z$_SlC>X$_|AhM;=r19j zClCIGK?<}B|3{ntnU41P6+F*h`>XmY=5N*il+E^e_2;=&e^vLz{;m4+Osmfkp699i zh2Tu|JHnr{);$M!o(t|5fFkj40Dl_(|M1~2>G?c6(l5AZ@_!Halka?fbg^X?`n!J*_j9%Q9PW8Ipv z@3|+?FF1JZ-{7A61^tHbUwgp(LQvNEgWBoO{xHukPipeliningDDPPipe 0Pipe 1Pipe 2Pipe 3 \ No newline at end of file diff --git a/pippy/auto_parallelization.py b/pippy/auto_parallelization.py deleted file mode 100644 index a22cc81f1..000000000 --- a/pippy/auto_parallelization.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Adaptation of https://github.com/alpa-projects/alpa/blob/a88992ce3b46024c0a4ee4aa8cb069a62830cec2/alpa/pipeline_parallel/stage_construction.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from enum import Enum -from typing import List, Tuple - -import numpy as np - -from torch import fx - -from pippy import pipe_split - -try: - from numba import njit # type: ignore - from numba.typed import List as NumbaList # type: ignore -except ImportError: - - def njit(*args, **kwargs): - def wrapper(func): - return func - - return wrapper - - NumbaList = list - - -class SubmeshSpace(Enum): - ALL = "ALL" - POWER_OF_TWO = "POWER_OF_TWO" - SMALL_POWER_OF_TWO = "SMALL_POWER_OF_TWO" - - -def get_possible_submesh_shapes( - n_compute_nodes: int, n_devices_per_node: int, submesh_space: SubmeshSpace -): - submeshes = [] - i = 1 - while i <= n_devices_per_node: - submeshes.append((1, i)) - i *= 2 - assert submeshes[-1][1] == n_devices_per_node - - # larger meshes: - if submesh_space == SubmeshSpace.ALL: - for i in range(2, n_compute_nodes + 1): - submeshes.append((i, n_devices_per_node)) - elif submesh_space == SubmeshSpace.POWER_OF_TWO: - i = 2 - while i <= n_compute_nodes: - submeshes.append((i, n_devices_per_node)) - i *= 2 - elif submesh_space == SubmeshSpace.SMALL_POWER_OF_TWO: - i = 2 - while i <= min(n_compute_nodes, 4): - submeshes.append((i, n_devices_per_node)) - i *= 2 - else: - raise ValueError(f"Invalid submesh space: {submesh_space}") - - return submeshes - - -NUMPY_RANDOM_SEED = 42 - - -def estimate_intra_costs( - n_submesh_choices, - n_layers, - max_n_succ_stages=4096, - n_autosharding_configs=1, -): - np.random.seed(NUMPY_RANDOM_SEED) - intra_costs = np.random.rand( - n_layers, n_layers, n_submesh_choices, n_autosharding_configs - ) - max_n_succ_stages = np.full( - (n_layers, n_layers, n_submesh_choices, n_autosharding_configs), - max_n_succ_stages, - ) - return intra_costs, max_n_succ_stages - - -@njit(fastmath=True) -def get_optimal_submesh_assignments( - best_n_stages, F_argmin, n_devices, n_ops, submesh_sizes -): - """ - Standard backtracking approach to find the optimal op-mesh assignment, starting with - the optimal number of stages (best_n_stages). - - The return is a list [((layer_start, next_layer_start), submesh_shape_idx, sharding_config_idx)] - where (layer_start, next_layer_start) is [) slice of the ops and submesh_shape_idx is the submesh - those ops should be mapped to (sharding_config_idx is currently always 1 but will be eventually used - pick optimal tensor sharding configuration). - """ - current_s = best_n_stages - current_layer = 0 - current_devices = n_devices - - optimal_layer_submesh_assignments = [] - while current_s > 0 and current_layer < n_ops and current_devices > 0: - next_start_layer, submesh_shape_idx, sharding_config_idx = F_argmin[ - current_s, current_layer, current_devices - ] - assert next_start_layer != -1 and current_devices != -1 - optimal_layer_submesh_assignments.append( - ( - (current_layer, next_start_layer), - submesh_shape_idx, - sharding_config_idx, - ) - ) - current_s -= 1 - current_layer = next_start_layer - current_devices -= submesh_sizes[submesh_shape_idx] - - assert current_s == 0 and current_layer == n_ops and current_devices == 0 - - return optimal_layer_submesh_assignments - - -@njit(fastmath=True) -def inter_op_dp_inner_loop( - n_layers, n_devices, submesh_sizes, valid_idxs_costs, max_n_succ_stages -): - """ - Equation 3 from the Alpa paper. Primary difference from the paper is the - s - 1 <= max_n_succ_stages check, which is used to characterize memory capacity - of each stage placement (if s - 1 > max_n_succ_stages check then placing that stage - would lead to OOM and thus continue). - """ - F = np.full( - (n_layers + 1, n_layers + 1, n_devices + 1), np.inf, dtype=np.float32 - ) - F_stage_max = np.full( - (n_layers + 1, n_layers + 1, n_devices + 1), 0.0, dtype=np.float32 - ) - F_argmin = np.full( - (n_layers + 1, n_layers + 1, n_devices + 1, 3), -1, dtype=np.int32 - ) - F[0, n_layers, 0] = 0 - - for d in range(1, n_devices + 1): - for ( - l, - i, - submesh_shape_idx, - sharding_config_idx, - stage_cost, - ) in valid_idxs_costs: - l, i, submesh_shape_idx, sharding_config_idx = map( - int, (l, i, submesh_shape_idx, sharding_config_idx) - ) - - n_submesh_devices = submesh_sizes[submesh_shape_idx] - if n_submesh_devices <= d: - for s in range(1, n_layers + 1): - if ( - s - 1 - > max_n_succ_stages[ - l, i, submesh_shape_idx, sharding_config_idx - ] - ): - continue - - new_cost = ( - F[s - 1, i + 1, d - n_submesh_devices] + stage_cost - ) - if new_cost < F[s, l, d]: - F[s, l, d] = new_cost - F_argmin[s, l, d] = ( - i + 1, - submesh_shape_idx, - sharding_config_idx, - ) - F_stage_max[s, l, d] = max( - F_stage_max[s - 1, i + 1, d - n_submesh_devices], - stage_cost, - ) - - return F, F_stage_max, F_argmin - - -def inter_op_dp( - n_layers: int, - n_devices: int, - n_microbatches: int, - submesh_shapes: List[Tuple[int, int]], - intra_compute_costs, - max_n_succ_stages, -): - """ - DP to compute optimal latency and number of pipeline stages and mapping of - stages to compute cluster submeshes. - """ - min_cost = np.inf - best_solution = None - prev_intra_cost = 0.0 - gap = 1e-6 - - submesh_sizes: list = NumbaList() - for n, m in submesh_shapes: - submesh_sizes.append(n * m) - - for intra_cost in np.sort(np.unique(intra_compute_costs)): - if intra_cost - prev_intra_cost < gap: - continue - if intra_cost * n_microbatches >= min_cost: - break - - # Optimization that lifts a check for stage_cost <= t_max_stage_cost - # out of the inner dp loop (see alpa/~/stage_construction.py#L121). - # This yields a ~100-200x improvement over the baseline implementation. - valid_cost_idxs = np.transpose( - (intra_compute_costs <= intra_cost).nonzero() - ) - # This corresponds to the i of k <= i <= K from eqn. 3 in the alpa paper. - valid_cost_idxs = valid_cost_idxs[ - valid_cost_idxs[:, 0] <= valid_cost_idxs[:, 1] - ] - valid_costs = intra_compute_costs[tuple(valid_cost_idxs.T)] - valid_idxs_costs = np.hstack( - [valid_cost_idxs, valid_costs[:, np.newaxis]] - ) - - F, F_stage_max, F_argmin = inter_op_dp_inner_loop( - n_layers, - n_devices, - submesh_sizes, - valid_idxs_costs, - max_n_succ_stages, - ) - - best_n_stages = F[:, 0, n_devices].argmin() - all_stages_cost = F[best_n_stages, 0, n_devices] - slowest_stage_cost = F_stage_max[best_n_stages, 0, n_devices] - if np.isinf(all_stages_cost): - continue - slowest_stage_total_cost = (n_microbatches - 1) * slowest_stage_cost - - if all_stages_cost + slowest_stage_total_cost < min_cost: - min_cost = all_stages_cost + slowest_stage_total_cost - best_solution = best_n_stages, F_argmin - prev_intra_cost = intra_cost - - assert best_solution is not None - best_n_stages, F_argmin = best_solution - optimal_layer_submesh_assignments = get_optimal_submesh_assignments( - best_n_stages, F_argmin, n_devices, n_layers, submesh_sizes - ) - return optimal_layer_submesh_assignments - - -@dataclass -class AutoParallelConfig: - n_compute_nodes: int - n_devices_per_node: int - n_microbatches: int - submesh_space: SubmeshSpace = SubmeshSpace.ALL - - -def dp_auto_parallel(config: AutoParallelConfig): - def _dp_auto_parallel(fx_mod: fx.GraphModule): - n_graph_nodes = len(fx_mod.graph.nodes) - submesh_shapes = get_possible_submesh_shapes( - n_compute_nodes=config.n_compute_nodes, - n_devices_per_node=config.n_devices_per_node, - submesh_space=config.submesh_space, - ) - intra_costs, max_n_succ_stages = estimate_intra_costs( - len(submesh_shapes), n_layers=n_graph_nodes - ) - optimal_layer_submesh_assignments = inter_op_dp( - n_layers=n_graph_nodes, - n_devices=config.n_compute_nodes * config.n_devices_per_node, - n_microbatches=config.n_microbatches, - submesh_shapes=submesh_shapes, - intra_compute_costs=intra_costs, - max_n_succ_stages=max_n_succ_stages, - ) - split_points = { - current_layer - for ( - (current_layer, _next_start_layer), - _submesh_choice, - _autosharding_choice, - ) in optimal_layer_submesh_assignments - } - for i, node in reversed(list(enumerate(fx_mod.graph.nodes))): - if i in split_points: - with fx_mod.graph.inserting_before(node): - fx_mod.graph.call_function(pipe_split, (), {}) - fx_mod.recompile() - return fx_mod - - return _dp_auto_parallel diff --git a/run_all_tests.sh b/run_all_tests.sh deleted file mode 100755 index 897556562..000000000 --- a/run_all_tests.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates - -set -ex -python test/min_gpt_tracing.py -pytest --cov=pippy test - -for CUDA in 0 1; do - for REP in 0 1; do - for SCHD in FillDrain 1F1B; do - python test/local_test_forward.py -s ${SCHD} --replicate ${REP} --cuda ${CUDA} - python test/local_test_forward_backward.py -s ${SCHD} --replicate ${REP} --cuda ${CUDA} - done - done -done From 3e905135688eabd4ae32c54e97034f84bd89e3a6 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 26 Jan 2024 13:23:01 -0800 Subject: [PATCH 96/96] Version 0.2.0 --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 17e51c385..0ea3a944b 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.1 +0.2.0