import cv2
import os
import numpy as np

# Create a folder for storing registered faces
if not os.path.exists("registered_faces"):
    os.makedirs("registered_faces")

# Initialize the face detector
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

# Initialize the webcam
cap = cv2.VideoCapture(0)

# Font for the ID text
font = cv2.FONT_HERSHEY_SIMPLEX

print("Please enter your user ID for registration:")

# User ID to register
user_id = input("Enter user ID: ")

# Create a subfolder to store the registered user's images
user_folder = f"registered_faces/{user_id}"
if not os.path.exists(user_folder):
    os.makedirs(user_folder)

# Initialize a counter to save multiple face images
count = 0

print(f"Starting registration for user {user_id}. Please look at the camera...")

while True:
    ret, frame = cap.read()
    if not ret:
        print("Failed to grab frame!")
        break

    # Convert the image to grayscale (necessary for face detection)
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # Detect faces in the image
    faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

    # Draw rectangle around the faces and save the images
    for (x, y, w, h) in faces:
        # Draw rectangle around face
        cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)

        # Save the face image as a .jpg file in the user folder
        count += 1
        face_img = frame[y:y+h, x:x+w]
        cv2.imwrite(f"{user_folder}/face_{count}.jpg", face_img)

        # Display the user ID on the image
        cv2.putText(frame, f"ID: {user_id}", (x, y - 10), font, 1, (0, 255, 0), 2, cv2.LINE_AA)

    # Display the captured frame
    cv2.imshow("Face Registration", frame)

    # Exit if 'q' is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    # Stop after 20 images for registration
    if count >= 20:
        print("Registration complete.")
        break

# Release the webcam and close the window
cap.release()
cv2.destroyAllWindows()
