# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import numpy as np

from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_col2im import col2im_naive_implementation


class ConvTranspose(OpRun):
    def _run(  # type: ignore
        self,
        X,
        W,
        B=None,
        auto_pad=None,
        dilations=None,
        group=None,
        kernel_shape=None,
        output_padding=None,
        output_shape=None,
        pads=None,
        strides=None,
    ):
        if dilations is None:
            dilations = [1 for s in X.shape[2:]]
        if kernel_shape is None:
            kernel_shape = W.shape[2:]
        if output_padding is None:
            output_padding = [0 for s in X.shape[2:]] * 2
        if strides is None:
            strides = [1 for s in X.shape[2:]]
        if pads is None and auto_pad not in {"SAME_UPPER", "SAME_LOWER"}:
            pads = [0 for i in range(2 * len(strides))]
        if pads is None:
            if output_shape is None:
                output_shape = [
                    X.shape[i + 2] * strides[i] for i in range(len(strides))
                ]
            total_padding = [
                strides[i] * (X.shape[i + 2] - 1)
                + output_padding[i]
                + ((kernel_shape[i] - 1) * dilations[i] + 1)
                - output_shape[i]
                for i in range(len(output_shape))
            ]
            pads_1 = []
            pads_2 = []
            for i in range(len(output_shape)):
                if auto_pad == "SAME_UPPER":
                    pads_1.append(total_padding[i] // 2)
                    pads_2.append(total_padding[i] - (total_padding[i] // 2))
                else:
                    pads_1.append(total_padding[i] - (total_padding[i] // 2))
                    pads_2.append(total_padding[i] // 2)
            pads = pads_1 + pads_2
            n_dims = len(pads) // 2
        else:
            n_dims = len(X.shape) - 2
            new_pads = np.array([(pads[i], pads[i + n_dims]) for i in range(n_dims)])
            if output_shape is None:
                output_shape = [
                    strides[i] * (X.shape[i + 2] - 1)
                    + output_padding[i]
                    + ((kernel_shape[i] - 1) * dilations[i] + 1)
                    - new_pads[i, :].sum()
                    for i in range(n_dims)
                ]

        kernel_shape = W.shape[2:]
        kernel_size = np.prod(kernel_shape)
        num_output_channels = W.shape[1] * group
        kernel_dim = num_output_channels // group * kernel_size

        C = X.shape[1]  # num_inputs_channels
        m = kernel_dim  # kernel_dim
        n = np.prod(X.shape[2:])  # input_image_size
        k = C // group
        w_reshaped = W.reshape((group, k, m))
        final = None

        # N x C x H x W = X.shape
        # C x M/group x k1 x k2 = W.shape
        if group == 1:
            for image_id in range(X.shape[0]):
                w_t = w_reshaped[0].T
                gemm = np.matmul(w_t, X[image_id].reshape((k, n)))
                gemmc = gemm.reshape((num_output_channels, -1, gemm.shape[-1]))
                for c in range(num_output_channels):
                    res = col2im_naive_implementation(
                        gemmc[c], output_shape, kernel_shape, dilations, pads, strides
                    )
                    if final is None:
                        final = np.empty(
                            X.shape[:1] + (num_output_channels,) + res.shape,
                            dtype=X.dtype,
                        )
                    if B is not None:
                        res += B[c]
                    final[image_id, c, ...] = res[...]
        else:
            final = np.zeros((X.shape[0], num_output_channels, *output_shape))
            output_array = []

            for group_id in range(group):
                group_X = X[:, group_id * C // group : (group_id + 1) * C // group, ...]
                group_W = W[
                    group_id
                    * num_output_channels
                    // group : (group_id + 1)
                    * num_output_channels
                    // group,
                    ...,
                ]

                group_output = self._run(
                    group_X,
                    group_W,
                    B=B,
                    auto_pad=auto_pad,
                    dilations=dilations,
                    group=1,
                    kernel_shape=kernel_shape,
                    output_padding=output_padding,
                    output_shape=output_shape,
                    pads=pads,
                    strides=strides,
                )
                group_output = np.array(group_output[0])
                output_array.append(group_output)

            for image_id in range(X.shape[0]):
                for group_id in range(group):
                    group_output = output_array[group_id]
                    final[image_id, group_id : (group_id + 1), ...] = group_output[
                        image_id, ...
                    ]

        return (final.astype(X.dtype),)  # type: ignore[union-attr]
