#!/usr/bin/env python3
"""
Facial Recognition Accuracy Improvement Script
This script helps improve the accuracy of the facial recognition system
by retraining the model with better quality images.
"""

import cv2
import numpy as np
import os
import json
from livestream import train_face_recognizer, preprocess_face

def show_user_info():
    """Display information about registered users"""
    print("\n" + "="*60)
    print("REGISTERED USERS")
    print("="*60)
    
    if os.path.exists("user_data.json"):
        with open("user_data.json", "r") as f:
            user_data = json.load(f)
        
        for user_id, info in user_data.items():
            face_dir = os.path.join("registered_faces", user_id)
            image_count = 0
            if os.path.exists(face_dir):
                image_count = len([f for f in os.listdir(face_dir) if f.endswith('.jpg')])
            
            print(f"ID: {user_id:>6} | Name: {info['name']:<25} | Images: {image_count:>3}")
    else:
        print("No user data found.")
    
    print("="*60)

def collect_better_training_data(user_id):
    """Collect high-quality training data for a specific user"""
    print(f"\nCollecting training data for User {user_id}")
    
    # Check if user exists
    user_folder = os.path.join("registered_faces", user_id)
    if not os.path.exists(user_folder):
        print(f"User folder {user_folder} not found!")
        return False
    
    # Get user name
    user_name = "Unknown"
    if os.path.exists("user_data.json"):
        with open("user_data.json", "r") as f:
            user_data = json.load(f)
            if str(user_id) in user_data:
                user_name = user_data[str(user_id)]["name"]
    
    print(f"User: {user_name}")
    print("Instructions:")
    print("1. Position yourself 2-3 feet from the camera")
    print("2. Ensure good lighting (face should be well-lit)")
    print("3. Move your head slowly in different directions")
    print("4. Include variations: straight, left, right, up, down")
    print("5. If you wear glasses, include some images without them")
    print("6. Press 'q' to stop early")
    
    input("\nPress Enter when ready to start...")
    
    cap = cv2.VideoCapture(0)
    count = 0
    target_count = 150  # Increased target for better accuracy
    
    while count < target_count:
        ret, frame = cap.read()
        if not ret:
            break
        
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(100, 100))
        
        for (x, y, w, h) in faces:
            face_img = gray[y:y + h, x:x + w]
            
            # Only save high-quality faces
            if w > 100 and h > 100:
                # Draw rectangle and counter
                cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
                cv2.putText(frame, f"Capturing: {count}/{target_count}", (x, y - 10), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
                
                # Save high-quality image
                img_path = os.path.join(user_folder, f"high_quality_{count:03d}.jpg")
                cv2.imwrite(img_path, face_img)
                count += 1
                
                # Delay to avoid duplicates
                cv2.waitKey(200)
                
                if count >= target_count:
                    break
        
        # Show instructions on screen
        cv2.putText(frame, "Move head slowly in different directions", (10, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        cv2.putText(frame, f"Captured: {count}/{target_count}", (10, 60), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        cv2.putText(frame, "Press 'q' to stop", (10, 90), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        
        cv2.imshow("High-Quality Training Data Collection", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    
    cap.release()
    cv2.destroyAllWindows()
    print(f"\nTraining data collection complete! Captured {count} high-quality images.")
    return True

def test_recognition_accuracy():
    """Test the current recognition accuracy"""
    print("\nTesting recognition accuracy...")
    print("Position yourself in front of the camera and the system will attempt recognition.")
    print("Press 'q' to stop testing.")
    
    if not os.path.exists("face_recognizer.yml"):
        print("No trained model found! Please train the model first.")
        return
    
    recognizer = cv2.face.LBPHFaceRecognizer_create()
    recognizer.read("face_recognizer.yml")
    
    cap = cv2.VideoCapture(0)
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(80, 80))
        
        for (x, y, w, h) in faces:
            face_img = gray[y:y + h, x:x + w]
            face_img = preprocess_face(face_img)
            
            try:
                user_id, confidence = recognizer.predict(face_img)
                
                # Get user name
                user_name = "Unknown"
                if os.path.exists("user_data.json"):
                    with open("user_data.json", "r") as f:
                        user_data = json.load(f)
                        if str(user_id) in user_data:
                            user_name = user_data[str(user_id)]["name"]
                
                # Color coding based on confidence
                if confidence < 120:  # Good recognition
                    color = (0, 255, 0)
                    status = f"Recognized: {user_name} (ID: {user_id})"
                else:
                    color = (0, 0, 255)
                    status = f"Unknown (Confidence: {confidence:.1f})"
                
                cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
                cv2.putText(frame, status, (x, y - 10), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
                
            except Exception as e:
                cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
                cv2.putText(frame, f"Error: {str(e)}", (x, y - 10), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
        
        cv2.imshow("Recognition Accuracy Test", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    
    cap.release()
    cv2.destroyAllWindows()

def main():
    """Main menu for accuracy improvement"""
    print("Facial Recognition Accuracy Improvement Tool")
    print("=" * 50)
    
    while True:
        print("\nOptions:")
        print("1. Show registered users")
        print("2. Collect better training data for a user")
        print("3. Retrain the model")
        print("4. Test recognition accuracy")
        print("5. Exit")
        
        choice = input("\nSelect an option (1-5): ").strip()
        
        if choice == "1":
            show_user_info()
        
        elif choice == "2":
            user_id = input("Enter user ID to improve: ").strip()
            if user_id:
                collect_better_training_data(user_id)
        
        elif choice == "3":
            print("\nRetraining the model...")
            train_face_recognizer()
            print("Model retraining complete!")
        
        elif choice == "4":
            test_recognition_accuracy()
        
        elif choice == "5":
            print("Exiting...")
            break
        
        else:
            print("Invalid choice. Please try again.")

if __name__ == "__main__":
    # Load face cascade
    face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
    main()



