import argparse
import datetime
import json
import logging
import os
import subprocess

import librosa
import torch
from benchmark_helper import setup_logger
from transformers import WhisperConfig, WhisperProcessor

logger = logging.getLogger(__name__)


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-a",
        "--audio-path",
        type=str,
        required=True,
        help="Path to folder of audio files for E2E evaluation",
    )

    parser.add_argument(
        "-l",
        "--language",
        default=None,
        help="Language of audio file",
    )

    parser.add_argument(
        "-t",
        "--task",
        default=None,
        choices=["transcribe", "translate"],
        help="Task to complete",
    )

    parser.add_argument(
        "-w",
        "--warmup-runs",
        type=int,
        default=5,
    )

    parser.add_argument(
        "-n",
        "--num-runs",
        type=int,
        default=10,
    )

    parser.add_argument(
        "--hf-pt-eager",
        default=False,
        action="store_true",
        help="Benchmark in PyTorch without `torch.compile`",
    )

    parser.add_argument(
        "--hf-pt-compile",
        default=False,
        action="store_true",
        help="Benchmark in PyTorch with `torch.compile`",
    )

    parser.add_argument(
        "--hf-ort-dir-path",
        type=str,
        help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
    )

    parser.add_argument(
        "--ort-model-path",
        type=str,
        help="Path to ONNX model for ORT benchmarking",
    )

    parser.add_argument(
        "--model-name",
        type=str,
        required=True,
        help="Model name in Hugging Face (e.g. openai/whisper-large-v2)",
    )

    parser.add_argument(
        "--precision",
        type=str,
        required=True,
        choices=["int8", "fp16", "fp32"],
        help="Precision to run model",
    )

    parser.add_argument(
        "--device",
        type=str,
        required=True,
        choices=["cpu", "cuda", "rocm"],
        help="Device to benchmark models",
    )

    parser.add_argument(
        "--device-id",
        type=int,
        default=0,
        help="GPU device ID",
    )

    parser.add_argument(
        "--verbose",
        default=False,
        action="store_true",
        help="Print detailed logs",
    )

    parser.add_argument(
        "--timeout",
        type=int,
        default=5,
        help="Number of mins to attempt the benchmark before moving on",
    )

    parser.add_argument("--tune", default=False, action="store_true")

    args = parser.parse_args()

    setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-"))  # noqa: B010
    log_folder_name = f"./{args.model_size}-{args.precision}"
    setattr(args, "log_folder", log_folder_name)  # noqa: B010
    os.makedirs(args.log_folder, exist_ok=True)

    # Convert timeout value to secs
    args.timeout *= 60

    return args


def process_log_file(device_id, log_file, base_results):
    entries = []

    # Detect steps in speech pipeline
    step = None
    load_audio_pattern = "Load audio: "
    feat_ext_pattern = "Feature extraction: "
    pytorch_pattern = "Evaluating PyTorch..."
    onnxruntime_pattern = "Evaluating ONNX Runtime..."

    load_audio_latency_s, load_audio_throughput_s = None, None
    feat_ext_latency_s, feat_ext_throughput_s = None, None
    token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None
    throughput, memory = None, None

    # Detect metrics
    latency_pattern = "Latency: "
    throughput_pattern = "Throughput: "
    token_length_pattern = "Generated token length: "
    memory_pattern = "peak="

    with open(log_file) as f:
        for input_line in f:
            line = input_line.replace("\n", "")

            # Get step in speech recognition pipeline
            if load_audio_pattern in line:
                step = "load-audio"
            elif feat_ext_pattern in line:
                step = "feature-extraction"
            elif pytorch_pattern in line or onnxruntime_pattern in line:
                step = "process"

            # Check metrics
            if latency_pattern in line:
                latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
            elif throughput_pattern in line:
                throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
                if step == "load-audio":
                    load_audio_latency_s, load_audio_throughput_s = latency_s, throughput
                    step = None
                if step == "feature-extraction":
                    feat_ext_latency_s, feat_ext_throughput_s = latency_s, throughput
                    step = None
            elif token_length_pattern in line:
                token_length = int(line[len(token_length_pattern) : line.rfind(" ")])
                per_token_latency_s = latency_s / token_length
                per_token_latency_ms = per_token_latency_s * 1000
            elif memory_pattern in line:
                if "CPU" in line:
                    # Example format for log entry:
                    # CPU memory usage: before=1000.0 MB, peak=2000.0 MB
                    memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
                else:
                    # Example format for log entry:
                    # GPU memory usage: before=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1638.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875},  peak=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1780.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}]
                    peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
                    usage = json.loads(peak)[device_id]["max_used_MB"]
                    memory = float(usage) / 1000

                # Calculate real-time factor (RTF):
                # RTF = total latency / audio duration
                total_latency = (
                    (load_audio_latency_s if load_audio_latency_s else 0)
                    + (feat_ext_latency_s if feat_ext_latency_s else 0)
                    + (latency_s if latency_s else 0)
                )
                audio_duration = base_results[-1]
                rtf = (total_latency / audio_duration) if audio_duration else -1
                logger.info(f"Total latency: {total_latency} s")
                logger.info(f"Audio duration: {audio_duration} s")
                logger.info(f"Real-time factor: {rtf}")

                # Append log entry to list of entries
                entry = base_results + [  # noqa: RUF005
                    token_length,
                    load_audio_latency_s,
                    load_audio_throughput_s,
                    feat_ext_latency_s if feat_ext_latency_s else -1,
                    feat_ext_throughput_s if feat_ext_throughput_s else -1,
                    latency_s,
                    per_token_latency_ms,
                    throughput,
                    memory,
                    rtf,
                ]
                entries.append(entry)

    return entries


def save_results(results, filename):
    import pandas as pd

    df = pd.DataFrame(
        results,
        columns=[
            "Engine",
            "Precision",
            "Device",
            "Audio File",
            "Duration (s)",
            "Token Length",
            "Load Audio Latency (s)",
            "Load Audio Throughput (qps)",
            "Feature Extractor Latency (s)",
            "Feature Extractor Throughput (qps)",
            "Latency (s)",
            "Per Token Latency (ms/token)",
            "Throughput (qps)",
            "Memory (GB)",
            "Real Time Factor (RTF)",
        ],
    )

    # Set column types
    df["Duration (s)"] = df["Duration (s)"].astype("float")
    df["Token Length"] = df["Token Length"].astype("int")
    df["Load Audio Latency (s)"] = df["Load Audio Latency (s)"].astype("float")
    df["Load Audio Throughput (qps)"] = df["Load Audio Throughput (qps)"].astype("float")
    df["Feature Extractor Latency (s)"] = df["Feature Extractor Latency (s)"].astype("float")
    df["Feature Extractor Throughput (qps)"] = df["Feature Extractor Throughput (qps)"].astype("float")
    df["Latency (s)"] = df["Latency (s)"].astype("float")
    df["Per Token Latency (ms/token)"] = df["Per Token Latency (ms/token)"].astype("float")
    df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
    df["Memory (GB)"] = df["Memory (GB)"].astype("float")
    df["Real Time Factor (RTF)"] = df["Real Time Factor (RTF)"].astype("float")

    df.to_csv(filename, index=False)
    logger.info(f"Results saved in {filename}!")


def benchmark(args, benchmark_cmd, engine, audio_file, duration):
    log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
    log_path = os.path.join(args.log_folder, log_filename)
    with open(log_path, "w") as log_file:
        process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
        try:
            process.wait(args.timeout)
        except subprocess.TimeoutExpired:
            process.kill()

    # Create entries for csv
    logger.info("Gathering data from log files...")
    base_results = [engine, args.precision, args.device, audio_file, duration]
    results = process_log_file(args.device_id, log_path, base_results)

    return results


def main():
    args = get_args()
    setup_logger(args.verbose)
    logger.info(args.__dict__)
    torch.backends.cudnn.benchmark = True

    config = WhisperConfig.from_pretrained(args.model_name)
    processor = WhisperProcessor.from_pretrained(args.model_name)

    # Calculate forced decoder input ids
    hf_forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
    ort_forced_decoder_ids = [config.decoder_start_token_id] + list(  # noqa: RUF005
        map(lambda token_id: token_id[1], hf_forced_decoder_ids)
    )
    hf_decoder_input_ids_cmd = (
        ["--decoder-input-ids", str(hf_forced_decoder_ids)] if args.language and args.task else []
    )
    ort_decoder_input_ids_cmd = (
        ["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else []
    )
    ort_tune_cmd = ["--tune"] if args.tune else []

    all_results = []
    for audio_file in os.listdir(args.audio_path):
        audio_path = os.path.join(args.audio_path, audio_file)
        try:
            duration = librosa.get_duration(path=audio_path)
        except Exception as e:
            duration = -1
            logger.warning(f"An error occurred while trying to calculate the audio duration: {e}", exc_info=True)
            logger.warning(
                f"If you get an error that says:\n\tsoundfile.LibsndfileError: Error opening '{audio_file}': File contains data in an unknown format.\nyou may not have installed `ffmpeg` in addition to installing `librosa`."
            )
        logger.info(f"Testing {audio_path}...")

        # Benchmark PyTorch without torch.compile
        if args.hf_pt_eager:
            benchmark_cmd = [  # noqa: RUF005
                "python",
                "-m",
                "models.whisper.benchmark",
                "--audio-path",
                audio_path,
                "--benchmark-type",
                "hf-pt-eager",
                "--model-name",
                args.model_name,
                "--precision",
                args.precision,
                "--device",
                args.device,
                "--device-id",
                str(args.device_id),
                "--warmup-runs",
                str(args.warmup_runs),
                "--num-runs",
                str(args.num_runs),
                "--log-folder",
                args.log_folder,
            ] + hf_decoder_input_ids_cmd
            logger.info("Benchmark PyTorch without torch.compile")
            results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration)
            all_results.extend(results)

        # Benchmark PyTorch with torch.compile
        if args.hf_pt_compile:
            benchmark_cmd = [  # noqa: RUF005
                "python",
                "-m",
                "models.whisper.benchmark",
                "--audio-path",
                audio_path,
                "--benchmark-type",
                "hf-pt-compile",
                "--model-name",
                args.model_name,
                "--precision",
                args.precision,
                "--device",
                args.device,
                "--device-id",
                str(args.device_id),
                "--warmup-runs",
                str(args.warmup_runs),
                "--num-runs",
                str(args.num_runs),
                "--log-folder",
                args.log_folder,
            ] + hf_decoder_input_ids_cmd
            logger.info("Benchmark PyTorch with torch.compile")
            results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration)
            all_results.extend(results)

        # Benchmark Optimum + ONNX Runtime
        if args.hf_ort_dir_path:
            benchmark_cmd = [  # noqa: RUF005
                "python",
                "-m",
                "models.whisper.benchmark",
                "--audio-path",
                audio_path,
                "--benchmark-type",
                "hf-ort",
                "--hf-ort-dir-path",
                args.hf_ort_dir_path,
                "--model-name",
                args.model_name,
                "--precision",
                args.precision,
                "--device",
                args.device,
                "--device-id",
                str(args.device_id),
                "--warmup-runs",
                str(args.warmup_runs),
                "--num-runs",
                str(args.num_runs),
                "--log-folder",
                args.log_folder,
            ] + hf_decoder_input_ids_cmd
            logger.info("Benchmark Optimum + ONNX Runtime")
            results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration)
            all_results.extend(results)

        # Benchmark ONNX Runtime
        if args.ort_model_path:
            benchmark_cmd = (
                [  # noqa: RUF005
                    "python",
                    "-m",
                    "models.whisper.benchmark",
                    "--audio-path",
                    audio_path,
                    "--benchmark-type",
                    "ort",
                    "--ort-model-path",
                    args.ort_model_path,
                    "--model-name",
                    args.model_name,
                    "--precision",
                    args.precision,
                    "--device",
                    args.device,
                    "--device-id",
                    str(args.device_id),
                    "--warmup-runs",
                    str(args.warmup_runs),
                    "--num-runs",
                    str(args.num_runs),
                    "--log-folder",
                    args.log_folder,
                ]
                + ort_decoder_input_ids_cmd
                + ort_tune_cmd
            )
            logger.info("Benchmark ONNX Runtime")
            results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration)
            all_results.extend(results)

    csv_file = f"{args.model_size}-{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
    save_results(all_results, os.path.join(args.log_folder, csv_file))


if __name__ == "__main__":
    main()
