# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

import numpy as np
from fusion_attention import AttentionMask, FusionAttention
from onnx import helper
from onnx_model import OnnxModel

logger = logging.getLogger(__name__)


class FusionBartAttention(FusionAttention):
    """
    Fuse Bart Attention subgraph into one Attention node.
    """

    def __init__(
        self,
        model: OnnxModel,
        hidden_size: int,
        num_heads: int,
        attention_mask: AttentionMask,
    ):
        super().__init__(model, hidden_size, num_heads, attention_mask)

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
        qkv_nodes = self.model.match_parent_path(
            normalize_node,
            ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
            [1, 1, 0, 0, 0],
        )
        if qkv_nodes is not None:
            (
                add_out,
                matmul_out,
                reshape_qkv,
                transpose_qkv,
                matmul_qkv,
            ) = qkv_nodes
        else:
            logger.debug("fuse_attention: failed to match qkv path")
            return

        other_inputs = []
        for input_ in normalize_node.input:
            if input_ not in output_name_to_node:
                continue
            if input_ == qkv_nodes[0].output[0]:
                continue
            other_inputs.append(input_)
        if len(other_inputs) != 1:
            return
        root_input = other_inputs[0]

        # Sometimes the input name to the attention MatMul nodes does not match the input name to the end
        # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
        # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
        # children nodes for each of its output names.
        """
                                        root_input
                    +---------------------------------------------------+
                    |                                                   |
                    |                                                   |
        SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
        """
        skip_layernorm = output_name_to_node[root_input]
        # For some attention blocks, the end SkipLayerNormalization node may point to another node whose
        # child is the LayerNormalization node.
        if skip_layernorm.op_type in {"Add", "Clip"}:
            skip_layernorm = self.model.get_children(skip_layernorm)[0]
        for output in skip_layernorm.output:
            if not output:
                continue
            children = input_name_to_nodes[output]
            children_types = [child.op_type for child in children]
            if children_types.count("MatMul") >= 1:
                root_input = output
                break

        graph_input_names = {node.name for node in self.model.graph().input}
        graph_output_names = {node.name for node in self.model.graph().output}

        v_nodes_past_or_present = self.model.match_parent_path(
            matmul_qkv,
            ["Transpose", "Reshape", "Add", "MatMul"],
            [1, 0, 0, None],
        )
        v_nodes_with_past = self.model.match_parent_path(
            matmul_qkv,
            ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
            [1, 1, 0, 0, None],
        )
        v_nodes_past_only_oai = self.model.match_parent_path(
            matmul_qkv,
            ["Transpose", "Reshape", "Reshape", "Transpose"],
            [1, 0, 0, 0],
        )
        past_v, present_v = "", ""
        v_nodes, add_v, matmul_v = [], None, None
        if v_nodes_past_or_present is not None:
            v_nodes = v_nodes_past_or_present
            (transpose_v, reshape_v, add_v, matmul_v) = v_nodes

            # Find past_v input name
            start_child_nodes = input_name_to_nodes[add_v.output[0]]
            for start_child_node in start_child_nodes:
                if start_child_node.op_type == "Concat":
                    concat_v_nodes = self.model.match_parent_path(
                        start_child_node,
                        ["Reshape", "Transpose"],
                        [0, 0],
                    )
                    if concat_v_nodes is not None:
                        past_v = concat_v_nodes[-1].input[0]
                    start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
                    break

            # Find present_v output name
            for start_child_node in start_child_nodes:
                start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
                for start_grandchild_node in start_grandchild_nodes:
                    if start_grandchild_node.output[0] in graph_output_names:
                        present_v = start_grandchild_node.output[0]
                        break
                if present_v != "":
                    break
        elif v_nodes_with_past is not None:
            v_nodes = v_nodes_with_past
            (concat_v, transpose_v, reshape_v, add_v, matmul_v) = v_nodes
            past_v = concat_v.input[0]
            present_v = concat_v.output[0]
        elif matmul_qkv.input[1] in graph_input_names:
            # Hugging Face's cross-attention where past_v is used directly as value
            past_v = matmul_qkv.input[1]
        elif v_nodes_past_only_oai is not None:
            # OpenAI's cross-attention where past_v is used directly as value
            v_nodes = v_nodes_past_only_oai
            past_v = v_nodes[-1].input[0]
        else:
            logger.debug("fuse_attention: failed to match v path")
            return
        past_v = past_v if past_v in graph_input_names else ""
        present_v = present_v if present_v in graph_output_names else ""

        qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
        qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
        qk_nodes, add_qk = [], None
        if qk_nodes_no_mask is not None:
            _, matmul_qk = qk_nodes_no_mask
            qk_nodes = qk_nodes_no_mask
        elif qk_nodes_with_mask is not None:
            _, add_qk, matmul_qk = qk_nodes_with_mask
            qk_nodes = qk_nodes_with_mask
        else:
            logger.debug("fuse_attention: failed to match qk path")
            return

        q_nodes_hf = self.model.match_parent_path(
            matmul_qk,
            ["Transpose", "Reshape", "Mul", "Add", "MatMul"],
            [0, 0, 0, 0, 1],
        )
        q_nodes_oai = self.model.match_parent_path(
            matmul_qk,
            ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
            [0, 0, 0, 0, 1],
        )
        q_nodes = []
        if q_nodes_hf is not None:
            q_nodes = q_nodes_hf
            (transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
        elif q_nodes_oai is not None:
            q_nodes = q_nodes_oai
            (mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes
        else:
            logger.debug("fuse_attention: failed to match q path")
            return

        k_nodes_no_past_hf = self.model.match_parent_path(
            matmul_qk,
            ["Transpose", "Reshape", "MatMul"],
            [1, 0, 0],
        )
        k_nodes_with_past_hf = self.model.match_parent_path(
            matmul_qk,
            ["Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
            [1, 0, 1, 0, 0],
        )
        k_nodes_past_or_present_oai = self.model.match_parent_path(
            matmul_qk,
            ["Mul", "Transpose", "Reshape", "MatMul"],
            [1, 0, 0, 0],
        )
        k_nodes_past_only_oai = self.model.match_parent_path(
            matmul_qk,
            ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
            [1, 0, 0, 0, 0],
        )
        past_k, present_k = "", ""
        k_nodes, add_k, matmul_k = [], None, None
        if k_nodes_no_past_hf is not None:
            k_nodes = k_nodes_no_past_hf
            (transpose_k, reshape_k, matmul_k) = k_nodes

            # Find present_k output name
            transpose_k_nodes = input_name_to_nodes[reshape_k.output[0]]
            for transpose_k_node in transpose_k_nodes:
                if transpose_k_node.output[0] in graph_output_names:
                    present_k = transpose_k_node.output[0]
                    break
        elif k_nodes_with_past_hf is not None:
            k_nodes = k_nodes_with_past_hf
            (_, concat_k, transpose_k, reshape_k, matmul_k) = k_nodes
            past_k = concat_k.input[0]
            present_k = concat_k.output[0]
        elif output_name_to_node[matmul_qk.input[1]].input[0] in graph_input_names:
            # Hugging Face's cross-attention where past_k is used directly as key
            k_nodes = [output_name_to_node[matmul_qk.input[1]]]
            past_k = k_nodes[0].input[0]
        elif k_nodes_past_or_present_oai is not None:
            k_nodes = k_nodes_past_or_present_oai
            (_, transpose_k, reshape_k, matmul_k) = k_nodes

            # Find past_k input name
            start_child_nodes = input_name_to_nodes[matmul_k.output[0]]
            for start_child_node in start_child_nodes:
                if start_child_node.op_type == "Concat":
                    concat_k_nodes = self.model.match_parent_path(
                        start_child_node,
                        ["Reshape", "Transpose"],
                        [0, 0],
                    )
                    if concat_k_nodes is not None:
                        past_k = concat_k_nodes[-1].input[0]
                    start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
                    break

            # Find present_k output name
            for start_child_node in start_child_nodes:
                start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
                for start_grandchild_node in start_grandchild_nodes:
                    if start_grandchild_node.output[0] in graph_output_names:
                        present_k = start_grandchild_node.output[0]
                        break
                if present_k != "":
                    break
        elif k_nodes_past_only_oai is not None:
            # OpenAI's cross-attention where past_k is used directly as key
            k_nodes = k_nodes_past_only_oai
            past_k = k_nodes[-1].input[0]
        else:
            logger.debug("fuse_attention: failed to match k path")
            return
        past_k = past_k if past_k in graph_input_names else ""
        present_k = present_k if present_k in graph_output_names else ""

        if matmul_k is not None and add_k is None:
            # Create empty Add node for attention graph
            add_v_tensor = self.model.get_initializer(add_v.input[0])
            bias_dim = add_v_tensor.dims[0]
            dtype = add_v_tensor.data_type
            empty_bias_name = "empty_bias"
            empty_tensor = self.model.get_initializer(empty_bias_name)
            if empty_tensor is None:
                self.add_initializer(
                    empty_bias_name,
                    dtype,
                    dims=[bias_dim],
                    vals=np.array([0.0] * bias_dim, dtype=helper.tensor_dtype_to_np_dtype(dtype)),
                )

            add_name = self.model.create_node_name("Add")
            add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k.name], add_name)

        three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and matmul_v is None
        one_root_input = (
            not three_root_inputs
            and matmul_q.input[0] == root_input
            and matmul_k.input[0] == root_input
            and matmul_v.input[0] == root_input
        )
        two_root_inputs = (
            not three_root_inputs
            and matmul_q.input[0] == root_input
            and matmul_k.input[0] == matmul_v.input[0]
            and matmul_k.input[0] != matmul_q.input[0]
        )

        # There are 5 types of attention:
        # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask
        # 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask
        # 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask
        # 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value
        # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask
        encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask
        decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask
        decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask
        decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v)
        decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask

        # For decoder self-attentions, the attention mask needs to be included in the attention node
        causal_mask = qk_nodes == qk_nodes_with_mask
        mask_nodes = []
        if causal_mask:
            mask_nodes_bart = self.model.match_parent_path(
                add_qk,
                ["Where"],
                [1],
            )
            mask_nodes_whisper_hf = self.model.match_parent_path(
                add_qk,
                ["Slice", "Expand", "Where"],
                [1, 0, 1],
            )
            mask_nodes_whisper_oai = self.model.match_parent_path(
                add_qk,
                ["Slice", "Unsqueeze", "Gather", "Shape", "Add"],
                [1, 2, 0, 0, 0],
            )
            mask_nodes_whisper_oai_unit_test = self.model.match_parent_path(
                add_qk,
                ["Slice", "Slice"],
                [1, 0],
            )
            if mask_nodes_whisper_hf is not None:
                mask_nodes = mask_nodes_whisper_hf
            elif mask_nodes_whisper_oai is not None:
                mask_nodes = mask_nodes_whisper_oai
            elif mask_nodes_whisper_oai_unit_test is not None:
                mask_nodes = mask_nodes_whisper_oai_unit_test
            elif mask_nodes_bart is not None:
                mask_nodes = mask_nodes_bart
            else:
                logger.debug("fuse_attention: failed to match mask nodes")
                return
            assert len(mask_nodes) > 0

        if (
            encoder_attention
            or decoder_self_attention
            or decoder_cross_attention
            or decoder_self_attention_with_past
            or decoder_cross_attention_with_past
        ):
            attention_last_node = reshape_qkv
            num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)

            if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
                logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
                return

            new_node = None
            if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
                # Note: Decoder attention with past key and past value is fused as multi-head attention
                # rather than attention because multi-head attention supports separate past key and past
                # value whereas attention supports concatenated past key and past value.
                new_node = (
                    self.create_multihead_attention_node(
                        q_matmul=matmul_q,
                        k_matmul=matmul_k if decoder_cross_attention or decoder_self_attention_with_past else past_k,
                        v_matmul=matmul_v if decoder_cross_attention or decoder_self_attention_with_past else past_v,
                        q_add=add_q,
                        k_add=add_k if decoder_cross_attention or decoder_self_attention_with_past else None,
                        v_add=add_v if decoder_cross_attention or decoder_self_attention_with_past else None,
                        num_heads=num_heads,
                        hidden_size=hidden_size,
                        output=attention_last_node.output[0],
                        unidirectional=causal_mask,
                        past_k=past_k if decoder_self_attention_with_past else "",
                        past_v=past_v if decoder_self_attention_with_past else "",
                        present_k=present_k,
                        present_v=present_v,
                    )
                    if self.use_multi_head_attention
                    else None
                )
            else:
                # Temporarily set multi-head attention flag to false
                use_multi_head_attention_ground_truth = self.use_multi_head_attention
                self.use_multi_head_attention = False
                new_node = self.create_attention_node(
                    mask_index=None,
                    q_matmul=matmul_q,
                    k_matmul=matmul_k,
                    v_matmul=matmul_v,
                    q_add=add_q,
                    k_add=add_k,
                    v_add=add_v,
                    num_heads=num_heads,
                    hidden_size=hidden_size,
                    first_input=root_input,
                    output=attention_last_node.output[0],
                    causal=causal_mask,
                    past_k=past_k,
                    past_v=past_v,
                    present_k=present_k,
                    present_v=present_v,
                )
                self.use_multi_head_attention = use_multi_head_attention_ground_truth
            if new_node is None:
                logger.debug("fuse_attention: failed to create fused node")
                return

            self.nodes_to_add.append(new_node)
            self.node_name_to_graph_name[new_node.name] = self.this_graph_name

            self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
            self.nodes_to_remove.extend(qk_nodes)

            # When using multi-head attention, keep MatMul nodes in original graph
            if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
                if len(q_nodes) > 0 and q_nodes[-1].op_type == "MatMul":
                    q_nodes.pop()
                if len(k_nodes) > 0 and k_nodes[-1].op_type == "MatMul":
                    k_nodes.pop()
                if len(v_nodes) > 0 and v_nodes[-1].op_type == "MatMul":
                    v_nodes.pop()
                if self.disable_multi_head_attention_bias:
                    if len(q_nodes) > 0 and q_nodes[-1].op_type == "Add":
                        q_nodes.pop()
                    if len(k_nodes) > 0 and k_nodes[-1].op_type == "Add":
                        k_nodes.pop()
                    if len(v_nodes) > 0 and v_nodes[-1].op_type == "Add":
                        v_nodes.pop()

            self.nodes_to_remove.extend(q_nodes)
            self.nodes_to_remove.extend(k_nodes)
            self.nodes_to_remove.extend(v_nodes)

            # Use prune graph to remove mask nodes since they are shared by all attention nodes.
            self.prune_graph = True
