import os
import sys
import numpy as np
import cv2
import time
import base64
from dotenv import load_dotenv

load_dotenv()

# --- Path Setup ---
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
plugin_path = os.path.join(project_root, 'plugins', 'Silent-Face-Anti-Spoofing-master')
if plugin_path not in sys.path:
    sys.path.append(plugin_path)

# --- Plugin Imports ---
from src.anti_spoof_predict import AntiSpoofPredict
from src.generate_patches import CropImage
from src.utility import parse_model_name

# --- Service Configuration ---
MODEL_DIR = os.path.join(project_root, 'plugins', 'Silent-Face-Anti-Spoofing-master', 'resources', 'anti_spoof_models')
MODEL_FILE = os.path.join(MODEL_DIR, '4_0_0_300x300_MultiFTNet.pth')
# Confidence threshold for the anti-spoofing model. Loaded from .env for configurability.
SPOOF_THRESHOLD = float(os.environ.get("SPOOF_THRESHOLD", 0.95))

# --- Service Initialization ---
try:
    _model_test = AntiSpoofPredict(0)
    _image_cropper = CropImage()
    print("[OK] Spoof detection service initialized successfully.")
except Exception as e:
    _model_test = None
    _image_cropper = None
    print(f"[ERROR] Failed to initialize spoof detection service: {e}")

def _decode_image(base64_string):
    """Internal function to decode a base64 image string."""
    try:
        header, encoded = base64_string.split(',', 1)
        image_bytes = base64.b64decode(encoded)
        nparr = np.frombuffer(image_bytes, np.uint8)
        return cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    except Exception as e:
        print(f"[ERROR] Error during image decoding: {e}")
        return None

def _is_real(image_array):
    """Internal function to run the spoof detection model."""
    if _model_test is None or _image_cropper is None:
        print("⚠️ Spoof detection service is not available.")
        return True, 0.0

    try:
        image_bbox = _model_test.get_bbox(image_array)
        if image_bbox is None:
            print("⚠️ Spoof check: No face detected.")
            return False, 0.0

        model_name = os.path.basename(MODEL_FILE)
        h_input, w_input, model_type, scale = parse_model_name(model_name)
        
        param = {
            "org_img": image_array, "bbox": image_bbox, "scale": scale,
            "out_w": w_input, "out_h": h_input, "crop": True,
        }
        if scale is None: param["crop"] = False
        
        img = _image_cropper.crop(**param)
        prediction = _model_test.predict(img, MODEL_FILE)
        
        label = np.argmax(prediction)
        score = prediction[0][label]
        
        # Label 0 is "Real Face". We also check against the confidence threshold.
        is_live = label == 0 and score > SPOOF_THRESHOLD
        print(f"Spoof check details: Label={label}, Score={score:.2f}, Threshold={SPOOF_THRESHOLD}, Result={'Pass' if is_live else 'Fail'}")
        
        return is_live, score

    except Exception as e:
        print(f"[ERROR] Error during spoof detection: {e}")
        return True, 0.0

def check_liveness(base64_image_list):
    """
    Public function to perform a full liveness check on a list of base64 image strings.
    It will return True only if ALL images in the list pass the liveness check.
    
    Returns:
        tuple: (is_live: bool, score: float, error_message: str or None)
    """
    if not isinstance(base64_image_list, list):
        # Handle the case where a single image string might still be passed
        base64_image_list = [base64_image_list]

    final_score = 0.0
    for i, base64_image_string in enumerate(base64_image_list):
        decoded_image = _decode_image(base64_image_string)
        if decoded_image is None:
            return False, 0.0, f"Invalid image data for spoof check in image {i+1}"

        is_live, score = _is_real(decoded_image)
        final_score = score  # Keep the score of the last checked image

        if not is_live:
            message = f"Liveness check failed on image {i+1}: Spoof attempt detected. Score: {score:.2f}, Threshold: {SPOOF_THRESHOLD}"
            print(f" LIVENESS FAIL: {message}")
            return False, score, message
    
    print(f" LIVENESS PASS (all images): Final Score: {final_score:.2f}")
    return True, final_score, None