import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import cv2
import numpy as np
from openvino import Core
import os
import pickle
import time

ie = Core()
devices = ie.available_devices
device_to_use = "NPU" if "NPU" in devices else "CPU"
print(f"Folosesc inițial: {device_to_use}")

model_folder = "models"
if not (os.path.exists(os.path.join(model_folder, "face-detection-adas-0001.xml")) and
        os.path.exists(os.path.join(model_folder, "face-detection-adas-0001.bin")) and
        os.path.exists(os.path.join(model_folder, "face-reidentification-retail-0095.xml")) and
        os.path.exists(os.path.join(model_folder, "face-reidentification-retail-0095.bin"))):
    print("Modelele lipsesc. Descarcă manual din OpenVINO Model Zoo și pune-le în folderul 'models'.")
    exit()

try:
    model_det = ie.read_model(model=os.path.join(model_folder, "face-detection-adas-0001.xml"))
    model_rec = ie.read_model(model=os.path.join(model_folder, "face-reidentification-retail-0095.xml"))
    compiled_det = ie.compile_model(model_det, device_to_use)
    compiled_rec = ie.compile_model(model_rec, device_to_use)
except Exception as e:
    print(f"Eroare la încărcarea modelelor: {e}")
    exit()

db_file = "face_database.pkl"
database = {}
if os.path.exists(db_file):
    with open(db_file, "rb") as f:
        database = pickle.load(f)

npu_failed = False

def switch_to_cpu():
    global compiled_det, compiled_rec, device_to_use, npu_failed
    device_to_use = "CPU"
    compiled_det = ie.compile_model(model_det, device_to_use)
    compiled_rec = ie.compile_model(model_rec, device_to_use)
    npu_failed = True
    print(f"Comutat pe: {device_to_use} (NPU dezactivat permanent)")

def detect_face(image):
    blob = cv2.dnn.blobFromImage(image, size=(672, 384), ddepth=cv2.CV_8U)
    try:
        outputs = compiled_det.infer_new_request({0: blob})
    except RuntimeError as e:
        if not npu_failed:
            print(f"Eroare NPU: {e}. Comut pe CPU.")
            switch_to_cpu()
            outputs = compiled_det.infer_new_request({0: blob})
        else:
            raise
    boxes = outputs[compiled_det.outputs[0].get_any_name()]
    for box in boxes[0][0]:
        if box[2] > 0.5:
            x_min, y_min, x_max, y_max = (box[3:7] * np.array([image.shape[1], image.shape[0], image.shape[1], image.shape[0]])).astype(int)
            return image[y_min:y_max, x_min:x_max]
    return None

def get_embedding(face):
    blob = cv2.dnn.blobFromImage(face, size=(128, 128))
    try:
        embedding = compiled_rec.infer_new_request({0: blob})
    except RuntimeError as e:
        if not npu_failed:
            print(f"Eroare NPU: {e}. Comut pe CPU.")
            switch_to_cpu()
            embedding = compiled_rec.infer_new_request({0: blob})
        else:
            raise
    return embedding[compiled_rec.outputs[0].get_any_name()].flatten()

def compare_faces(emb1, emb2):
    return np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))

root = tk.Tk()
root.title("Bază de Date Fețe")
root.geometry("600x400")

frame_db = tk.LabelFrame(root, text="Bază de Date")
frame_db.pack(side=tk.LEFT, padx=10, pady=10)
frame_test = tk.LabelFrame(root, text="Imagine Test")
frame_test.pack(side=tk.RIGHT, padx=10, pady=10)

label_db = tk.Label(frame_db)
label_db.pack()
label_test = tk.Label(frame_test)
label_test.pack()

def add_multiple_to_database():
    files = filedialog.askopenfilenames(filetypes=[("Image files", "*.jpg *.png")])
    if files:
        added_count = 0
        skipped_count = 0
        for file in files:
            img = cv2.imread(file)
            if img is None or img.size == 0:
                skipped_count += 1
                continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            face = detect_face(img)
            if face is None or face.size == 0:
                skipped_count += 1
                continue
            embedding = get_embedding(face)
            name = os.path.basename(file).split('.')[0]
            is_duplicate = False
            for existing_name, existing_emb in database.items():
                similarity = compare_faces(embedding, existing_emb)
                if similarity > 0.9:
                    is_duplicate = True
                    skipped_count += 1
                    break
            if not is_duplicate:
                database[name] = embedding
                added_count += 1
                face_img = Image.fromarray(face).resize((150, 150))
                face_img_tk = ImageTk.PhotoImage(face_img)
                label_db.config(image=face_img_tk)
                label_db.image = face_img_tk
            if not npu_failed:  # Adaugă întârziere doar pe NPU
                time.sleep(0.5)  # 0.5 secunde = 2 imagini pe secundă
        with open(db_file, "wb") as f:
            pickle.dump(database, f)
        messagebox.showinfo("Succes", f"{added_count} fețe adăugate, {skipped_count} sărite (fără fețe sau duplicate)!")

def check_face():
    if not database:
        messagebox.showwarning("Atenție", "Baza de date este goală! Adaugă imagini mai întâi.")
        return
    file = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.png")])
    if file:
        img = cv2.imread(file)
        if img is None or img.size == 0:
            messagebox.showerror("Eroare", "Imaginea nu poate fi încărcată!")
            return
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        face = detect_face(img)
        if face is not None and face.size > 0:
            test_embedding = get_embedding(face)
            test_img = Image.fromarray(face).resize((150, 150))
            test_img_tk = ImageTk.PhotoImage(test_img)
            label_test.config(image=test_img_tk)
            label_test.image = test_img_tk
            best_match = None
            best_similarity = -1
            for name, embedding in database.items():
                similarity = compare_faces(test_embedding, embedding)
                if similarity > best_similarity:
                    best_similarity = similarity
                    best_match = name
            if best_similarity > 0.8:
                messagebox.showinfo("Rezultat", f"Persoana este: {best_match}\nSimilaritate: {best_similarity:.2f}")
            else:
                messagebox.showinfo("Rezultat", f"Nu s-a găsit o potrivire.\nCea mai apropiată: {best_match} (similaritate: {best_similarity:.2f})")
        else:
            messagebox.showerror("Eroare", "Nu s-a detectat față!")

btn_add = tk.Button(root, text="Adaugă Imagini", command=add_multiple_to_database)
btn_add.pack(pady=5)
btn_test = tk.Button(root, text="Verifică Față", command=check_face)
btn_test.pack(pady=5)

root.mainloop()
