# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import numpy as np

from onnx.reference.ops._op import OpRunUnaryNum


class Hardmax(OpRunUnaryNum):
    def _run(self, x, axis=None):  # type: ignore
        axis = axis or self.axis  # type: ignore
        x_argmax = np.argmax(x, axis=axis)  # type: ignore
        y = np.zeros_like(x)
        np.put_along_axis(
            y,
            np.expand_dims(x_argmax, axis=axis),
            1,
            axis=axis,  # type: ignore
        )
        return (y,)
