"""Module containing functional implementations of 3D transformations.

This module provides a collection of utility functions for manipulating and transforming
3D volumetric data (such as medical imaging volumes). The functions here implement the core
algorithms for operations like padding, cropping, rotation, and other spatial manipulations
specifically designed for 3D data.
"""

from __future__ import annotations

import random
from typing import Literal

import numpy as np

from albumentations.augmentations.utils import handle_empty_array
from albumentations.core.type_definitions import NUM_VOLUME_DIMENSIONS


def adjust_padding_by_position3d(
    paddings: list[tuple[int, int]],  # [(front, back), (top, bottom), (left, right)]
    position: Literal["center", "random"],
    py_random: random.Random,
) -> tuple[int, int, int, int, int, int]:
    """Adjust padding values based on desired position for 3D data.

    Args:
        paddings (list[tuple[int, int]]): List of tuples containing padding pairs
            for each dimension [(d_pad), (h_pad), (w_pad)]
        position (Literal["center", "random"]): Position of the image after padding.
        py_random (random.Random): Random number generator

    Returns:
        tuple[int, int, int, int, int, int]: Final padding values (d_front, d_back, h_top, h_bottom, w_left, w_right)

    """
    if position == "center":
        return (
            paddings[0][0],  # d_front
            paddings[0][1],  # d_back
            paddings[1][0],  # h_top
            paddings[1][1],  # h_bottom
            paddings[2][0],  # w_left
            paddings[2][1],  # w_right
        )

    # For random position, redistribute padding for each dimension
    d_pad = sum(paddings[0])
    h_pad = sum(paddings[1])
    w_pad = sum(paddings[2])

    return (
        py_random.randint(0, d_pad),  # d_front
        d_pad - py_random.randint(0, d_pad),  # d_back
        py_random.randint(0, h_pad),  # h_top
        h_pad - py_random.randint(0, h_pad),  # h_bottom
        py_random.randint(0, w_pad),  # w_left
        w_pad - py_random.randint(0, w_pad),  # w_right
    )


def pad_3d_with_params(
    volume: np.ndarray,
    padding: tuple[int, int, int, int, int, int],
    value: tuple[float, ...] | float,
) -> np.ndarray:
    """Pad 3D volume with given parameters.

    Args:
        volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
        padding (tuple[int, int, int, int, int, int]): Padding values in format:
            (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
            where:
            - depth_front/back: padding at start/end of depth axis (z)
            - height_top/bottom: padding at start/end of height axis (y)
            - width_left/right: padding at start/end of width axis (x)
        value (tuple[float, ...] | float): Value to fill the padding

    Returns:
        np.ndarray: Padded volume with same number of dimensions as input

    Note:
        The padding order matches the volume dimensions (depth, height, width).
        For each dimension, the first value is padding at the start (smaller indices),
        and the second value is padding at the end (larger indices).

    """
    depth_front, depth_back, height_top, height_bottom, width_left, width_right = padding

    # Skip if no padding is needed
    if all(p == 0 for p in padding):
        return volume

    # Handle both 3D and 4D arrays
    pad_width = [
        (depth_front, depth_back),  # depth (z) padding
        (height_top, height_bottom),  # height (y) padding
        (width_left, width_right),  # width (x) padding
    ]

    # Add channel padding if 4D array
    if volume.ndim == NUM_VOLUME_DIMENSIONS:
        pad_width.append((0, 0))  # no padding for channels

    return np.pad(
        volume,
        pad_width=pad_width,
        mode="constant",
        constant_values=value,
    )


def crop3d(
    volume: np.ndarray,
    crop_coords: tuple[int, int, int, int, int, int],
) -> np.ndarray:
    """Crop 3D volume using coordinates.

    Args:
        volume (np.ndarray): Input volume with shape (z, y, x) or (z, y, x, channels)
        crop_coords (tuple[int, int, int, int, int, int]):
            (z_min, z_max, y_min, y_max, x_min, x_max) coordinates for cropping

    Returns:
        np.ndarray: Cropped volume with same number of dimensions as input

    """
    z_min, z_max, y_min, y_max, x_min, x_max = crop_coords

    return volume[z_min:z_max, y_min:y_max, x_min:x_max]


def cutout3d(volume: np.ndarray, holes: np.ndarray, fill: tuple[float, ...] | float) -> np.ndarray:
    """Cut out holes in 3D volume and fill them with a given value.

    Args:
        volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
        holes (np.ndarray): Array of holes with shape (num_holes, 6).
            Each hole is represented as [z1, y1, x1, z2, y2, x2]
        fill (tuple[float, ...] | float): Value to fill the holes

    Returns:
        np.ndarray: Volume with holes filled with the given value

    """
    volume = volume.copy()
    for z1, y1, x1, z2, y2, x2 in holes:
        volume[z1:z2, y1:y2, x1:x2] = fill
    return volume


def transform_cube(cube: np.ndarray, index: int) -> np.ndarray:
    """Transform cube by index (0-47)

    Args:
        cube (np.ndarray): Input array with shape (D, H, W) or (D, H, W, C)
        index (int): Integer from 0 to 47 specifying which transformation to apply

    Returns:
        np.ndarray: Transformed cube with same shape as input

    """
    if not (0 <= index < 48):
        raise ValueError("Index must be between 0 and 47")

    transformations = {
        # First 4: rotate around axis 0 (indices 0-3)
        0: lambda x: x,
        1: lambda x: np.rot90(x, k=1, axes=(1, 2)),
        2: lambda x: np.rot90(x, k=2, axes=(1, 2)),
        3: lambda x: np.rot90(x, k=3, axes=(1, 2)),
        # Next 4: flip 180° about axis 1, then rotate around axis 0 (indices 4-7)
        4: lambda x: x[::-1, :, ::-1],  # was: np.flip(x, axis=(0, 2))
        5: lambda x: np.rot90(np.rot90(x, k=2, axes=(0, 2)), k=1, axes=(1, 2)),
        6: lambda x: x[::-1, ::-1, :],  # was: np.flip(x, axis=(0, 1))
        7: lambda x: np.rot90(np.rot90(x, k=2, axes=(0, 2)), k=3, axes=(1, 2)),
        # Next 8: split between 90° and 270° about axis 1, then rotate around axis 2 (indices 8-15)
        8: lambda x: np.rot90(x, k=1, axes=(0, 2)),
        9: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 2)), k=1, axes=(0, 1)),
        10: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 2)), k=2, axes=(0, 1)),
        11: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim)),
        12: lambda x: np.rot90(x, k=-1, axes=(0, 2)),
        13: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=1, axes=(0, 1)),
        14: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=2, axes=(0, 1)),
        15: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=3, axes=(0, 1)),
        # Final 8: split between rotations about axis 2, then rotate around axis 1 (indices 16-23)
        16: lambda x: np.rot90(x, k=1, axes=(0, 1)),
        17: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 1)), k=1, axes=(0, 2)),
        18: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 1)), k=2, axes=(0, 2)),
        19: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim)),
        20: lambda x: np.rot90(x, k=-1, axes=(0, 1)),
        21: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=1, axes=(0, 2)),
        22: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=2, axes=(0, 2)),
        23: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=3, axes=(0, 2)),
        # Reflected versions (24-47) - same as above but with initial reflection
        24: lambda x: x[:, :, ::-1],  # was: np.flip(x, axis=2)
        25: lambda x: x.transpose(0, 2, 1, *range(3, x.ndim)),
        26: lambda x: x[:, ::-1, :],  # was: np.flip(x, axis=1)
        27: lambda x: np.rot90(x[:, :, ::-1], k=3, axes=(1, 2)),
        28: lambda x: x[::-1, :, :],  # was: np.flip(x, axis=0)
        29: lambda x: np.rot90(x[::-1, :, :], k=1, axes=(1, 2)),
        30: lambda x: x[::-1, ::-1, ::-1],  # was: np.flip(x, axis=(0, 1, 2))
        31: lambda x: np.rot90(x[::-1, :, :], k=-1, axes=(1, 2)),
        32: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim)),
        33: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[::-1, :, :],
        34: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim))[::-1, ::-1, :],
        35: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[:, ::-1, :],
        36: lambda x: np.rot90(x[:, :, ::-1], k=-1, axes=(0, 2)),
        37: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[::-1, ::-1, ::-1],
        38: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim))[:, ::-1, ::-1],
        39: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[:, :, ::-1],
        40: lambda x: np.rot90(x[:, :, ::-1], k=1, axes=(0, 1)),
        41: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[:, :, ::-1],
        42: lambda x: x.transpose(1, 0, 2, *range(3, x.ndim)),
        43: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[::-1, :, :],
        44: lambda x: np.rot90(x[:, :, ::-1], k=-1, axes=(0, 1)),
        45: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[:, ::-1, :],
        46: lambda x: x.transpose(1, 0, 2, *range(3, x.ndim))[::-1, ::-1, :],
        47: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[::-1, ::-1, ::-1],
    }

    return transformations[index](cube.copy())


@handle_empty_array("keypoints")
def filter_keypoints_in_holes3d(keypoints: np.ndarray, holes: np.ndarray) -> np.ndarray:
    """Filter out keypoints that are inside any of the 3D holes.

    Args:
        keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
                               The first three columns are x, y, z coordinates.
        holes (np.ndarray): Array of holes with shape (num_holes, 6).
                           Each hole is represented as [z1, y1, x1, z2, y2, x2].

    Returns:
        np.ndarray: Array of keypoints that are not inside any hole.

    """
    if holes.size == 0:
        return keypoints

    # Broadcast keypoints and holes for vectorized comparison
    # Convert keypoints from XYZ to ZYX for comparison with holes
    kp_z = keypoints[:, 2][:, np.newaxis]  # Shape: (num_keypoints, 1)
    kp_y = keypoints[:, 1][:, np.newaxis]  # Shape: (num_keypoints, 1)
    kp_x = keypoints[:, 0][:, np.newaxis]  # Shape: (num_keypoints, 1)

    # Extract hole coordinates (in ZYX order)
    hole_z1 = holes[:, 0]  # Shape: (num_holes,)
    hole_y1 = holes[:, 1]
    hole_x1 = holes[:, 2]
    hole_z2 = holes[:, 3]
    hole_y2 = holes[:, 4]
    hole_x2 = holes[:, 5]

    # Check if each keypoint is inside each hole
    inside_hole = (
        (kp_z >= hole_z1)
        & (kp_z < hole_z2)
        & (kp_y >= hole_y1)
        & (kp_y < hole_y2)
        & (kp_x >= hole_x1)
        & (kp_x < hole_x2)
    )

    # A keypoint is valid if it's not inside any hole
    valid_keypoints = ~np.any(inside_hole, axis=1)

    # Return filtered keypoints with same dtype as input
    result = keypoints[valid_keypoints]
    if len(result) == 0:
        # Ensure empty result has correct shape and dtype
        return np.array([], dtype=keypoints.dtype).reshape(0, keypoints.shape[1])
    return result


def keypoints_rot90(
    keypoints: np.ndarray,
    k: int,
    axes: tuple[int, int],
    volume_shape: tuple[int, int, int],
) -> np.ndarray:
    """Rotate keypoints 90 degrees k times around the specified axes.

    Args:
        keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
                               The first three columns are x, y, z coordinates.
        k (int): Number of times to rotate by 90 degrees.
        axes (tuple[int, int]): Axes to rotate around.
        volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width).

    Returns:
        np.ndarray: Rotated keypoints with same shape as input.

    """
    if k == 0 or len(keypoints) == 0:
        return keypoints

    # Normalize factor to range [0, 3]
    k = ((k % 4) + 4) % 4

    result = keypoints.copy()

    # Get dimensions for the rotation axes
    dims = [volume_shape[ax] for ax in axes]

    # Get coordinates to rotate
    coords1 = result[:, axes[0]].copy()
    coords2 = result[:, axes[1]].copy()

    # Apply rotation based on factor (counterclockwise)
    if k == 1:  # 90 degrees CCW
        result[:, axes[0]] = (dims[1] - 1) - coords2
        result[:, axes[1]] = coords1
    elif k == 2:  # 180 degrees
        result[:, axes[0]] = (dims[0] - 1) - coords1
        result[:, axes[1]] = (dims[1] - 1) - coords2
    elif k == 3:  # 270 degrees CCW
        result[:, axes[0]] = coords2
        result[:, axes[1]] = (dims[0] - 1) - coords1

    return result


@handle_empty_array("keypoints")
def transform_cube_keypoints(
    keypoints: np.ndarray,
    index: int,
    volume_shape: tuple[int, int, int],
) -> np.ndarray:
    """Transform keypoints according to the cube transformation specified by index.

    Args:
        keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
                               The first three columns are x, y, z coordinates.
        index (int): Integer from 0 to 47 specifying which transformation to apply.
        volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width).

    Returns:
        np.ndarray: Transformed keypoints with same shape as input.

    """
    if not (0 <= index < 48):
        raise ValueError("Index must be between 0 and 47")

    # Create working copy preserving all columns
    working_points = keypoints.copy()

    # Convert only XYZ coordinates to HWD, keeping other columns unchanged
    xyz = working_points[:, :3]  # Get first 3 columns (XYZ)
    xyz = xyz[:, [2, 1, 0]]  # XYZ -> HWD
    working_points[:, :3] = xyz  # Put back transformed coordinates

    current_shape = volume_shape

    # Handle reflection first (indices 24-47)
    if index >= 24:
        working_points[:, 2] = current_shape[2] - 1 - working_points[:, 2]  # Reflect W axis

    rotation_index = index % 24

    # Apply the same rotation logic as transform_cube
    if rotation_index < 4:
        # First 4: rotate around axis 0
        result = keypoints_rot90(working_points, k=rotation_index, axes=(1, 2), volume_shape=current_shape)
    elif rotation_index < 8:
        # Next 4: flip 180° about axis 1, then rotate around axis 0
        temp = keypoints_rot90(working_points, k=2, axes=(0, 2), volume_shape=current_shape)
        result = keypoints_rot90(temp, k=rotation_index - 4, axes=(1, 2), volume_shape=volume_shape)
    elif rotation_index < 16:
        if rotation_index < 12:
            temp = keypoints_rot90(working_points, k=1, axes=(0, 2), volume_shape=current_shape)
            temp_shape = (current_shape[2], current_shape[1], current_shape[0])
            result = keypoints_rot90(temp, k=rotation_index - 8, axes=(0, 1), volume_shape=temp_shape)
        else:
            temp = keypoints_rot90(working_points, k=3, axes=(0, 2), volume_shape=current_shape)
            temp_shape = (current_shape[2], current_shape[1], current_shape[0])
            result = keypoints_rot90(temp, k=rotation_index - 12, axes=(0, 1), volume_shape=temp_shape)
    elif rotation_index < 20:
        temp = keypoints_rot90(working_points, k=1, axes=(0, 1), volume_shape=current_shape)
        temp_shape = (current_shape[1], current_shape[0], current_shape[2])
        result = keypoints_rot90(temp, k=rotation_index - 16, axes=(0, 2), volume_shape=temp_shape)
    else:
        temp = keypoints_rot90(working_points, k=3, axes=(0, 1), volume_shape=current_shape)
        temp_shape = (current_shape[1], current_shape[0], current_shape[2])
        result = keypoints_rot90(temp, k=rotation_index - 20, axes=(0, 2), volume_shape=temp_shape)

    # Convert back from HWD to XYZ coordinates for first 3 columns only
    xyz = result[:, :3]
    xyz = xyz[:, [2, 1, 0]]  # HWD -> XYZ
    result[:, :3] = xyz

    return result
