Skip to content

Commit ba3bcab

Browse files
smessmerfacebook-github-bot
authored andcommitted
Making ops c10-full: Generator arguments (#49013)
Summary: Pull Request resolved: pytorch/pytorch#49013 I don't know why this works. I know, this is never a good way to start a PR description :P I know that Generator is a dispatch relevant argument when called from an unboxed API and is ignored for dispatch purposes when called from a boxed API. This should break something, but maybe we don't have test cases for that. We likely need to align the unboxed and boxed dispatch behavior before landing this. The best solution would be to make Generator not dispatch relevant in unboxing. But that might be a bigger change. An acceptable solution could be to make Generator dispatch relevant in boxing, but that needs perf measurements. This PR needs further discussion. ghstack-source-id: 118619230 (Note: this ignores all push blocking failures!) Reviewed By: bhosmer Differential Revision: D25394998 fbshipit-source-id: f695c659ee6e3738f74cdf0af1a514ac0c30ebff
1 parent 52e17d8 commit ba3bcab

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

torchcsprng/csrc/csprng.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Tensor& normal_(Tensor& self, double mean, double std, c10::optional<Generator>
9393
}
9494
}
9595

96-
Tensor& normal_Tensor_float_out(Tensor& output, const Tensor& mean, double std, c10::optional<Generator> gen) {
96+
Tensor& normal_Tensor_float_out(const Tensor& mean, double std, c10::optional<Generator> gen, Tensor& output) {
9797
if (output.device().type() == DeviceType::CPU) {
9898
return cpu::normal_Tensor_float_out(output, mean, std, gen);
9999
#ifdef WITH_CUDA
@@ -105,7 +105,7 @@ Tensor& normal_Tensor_float_out(Tensor& output, const Tensor& mean, double std,
105105
}
106106
}
107107

108-
Tensor& normal_float_Tensor_out(Tensor& output, double mean, const Tensor& std, c10::optional<Generator> gen) {
108+
Tensor& normal_float_Tensor_out(double mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
109109
if (output.device().type() == DeviceType::CPU) {
110110
return cpu::normal_float_Tensor_out(output, mean, std, gen);
111111
#ifdef WITH_CUDA
@@ -117,7 +117,7 @@ Tensor& normal_float_Tensor_out(Tensor& output, double mean, const Tensor& std,
117117
}
118118
}
119119

120-
Tensor& normal_Tensor_Tensor_out(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
120+
Tensor& normal_Tensor_Tensor_out(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
121121
if (output.device().type() == DeviceType::CPU) {
122122
return cpu::normal_Tensor_Tensor_out(output, mean, std, gen);
123123
#ifdef WITH_CUDA
@@ -272,12 +272,12 @@ namespace {
272272
}
273273
} // namespace
274274

275-
Tensor& randperm_generator_out(Tensor& result, int64_t n, c10::optional<Generator> generator) {
275+
Tensor& randperm_generator_out(int64_t n, c10::optional<Generator> generator, Tensor& result) {
276276
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
277277
check_supported_max_int_with_precision(n, result);
278278
if (result.device().type() == at::kCUDA) {
279279
auto result_cpu = at::empty({n}, result.options().device(kCPU));
280-
randperm_generator_out(result_cpu, n, generator);
280+
randperm_generator_out(n, generator, result_cpu);
281281
result.resize_({n});
282282
return result.copy_(result_cpu);
283283
}
@@ -344,29 +344,29 @@ bool supports_cuda() {
344344

345345
TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
346346
// Random
347-
m.impl_UNBOXED("random_.from", random_from_to);
348-
m.impl_UNBOXED("random_.to", random_to);
349-
m.impl_UNBOXED("random_", random_);
347+
m.impl("random_.from", random_from_to);
348+
m.impl("random_.to", random_to);
349+
m.impl("random_", random_);
350350
// Uniform
351-
m.impl_UNBOXED("uniform_", uniform_);
351+
m.impl("uniform_", uniform_);
352352
// Normal
353-
m.impl_UNBOXED("normal_", normal_);
354-
m.impl_UNBOXED("normal.Tensor_float_out", normal_Tensor_float_out);
355-
m.impl_UNBOXED("normal.float_Tensor_out", normal_float_Tensor_out);
356-
m.impl_UNBOXED("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out);
357-
m.impl_UNBOXED("normal.Tensor_float", normal_Tensor_float);
358-
m.impl_UNBOXED("normal.float_Tensor", normal_float_Tensor);
359-
m.impl_UNBOXED("normal.Tensor_Tensor", normal_Tensor_Tensor);
353+
m.impl("normal_", normal_);
354+
m.impl("normal.Tensor_float_out", normal_Tensor_float_out);
355+
m.impl("normal.float_Tensor_out", normal_float_Tensor_out);
356+
m.impl("normal.Tensor_Tensor_out", normal_Tensor_Tensor_out);
357+
m.impl("normal.Tensor_float", normal_Tensor_float);
358+
m.impl("normal.float_Tensor", normal_float_Tensor);
359+
m.impl("normal.Tensor_Tensor", normal_Tensor_Tensor);
360360
// Cauchy
361-
m.impl_UNBOXED("cauchy_", cauchy_);
361+
m.impl("cauchy_", cauchy_);
362362
// LogNormal
363-
m.impl_UNBOXED("log_normal_", log_normal_);
363+
m.impl("log_normal_", log_normal_);
364364
// Geometric
365-
m.impl_UNBOXED("geometric_", geometric_);
365+
m.impl("geometric_", geometric_);
366366
// Exponential
367-
m.impl_UNBOXED("exponential_", exponential_);
367+
m.impl("exponential_", exponential_);
368368
// Random permutation
369-
m.impl_UNBOXED("randperm.generator_out", randperm_generator_out);
369+
m.impl("randperm.generator_out", randperm_generator_out);
370370
}
371371

372372
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

0 commit comments

Comments
 (0)