import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import cv2
import numpy as np
from openvino import Core
import os
import pickle

# Inițializare OpenVINO și verificare dispozitive
ie = Core()
devices = ie.available_devices
device_to_use = "CPU"
if "NPU" in devices:
    device_to_use = "NPU"
elif "GPU" in devices:
    device_to_use = "GPU"
print(f"Folosesc: {device_to_use} (NPU = NPU, GPU = GPU, CPU = CPU)")

# Verifică dacă modelele există (descărcate manual)
model_folder = "models"
if not (os.path.exists(os.path.join(model_folder, "face-detection-adas-0001.xml")) and
        os.path.exists(os.path.join(model_folder, "face-detection-adas-0001.bin")) and
        os.path.exists(os.path.join(model_folder, "face-reidentification-retail-0095.xml")) and
        os.path.exists(os.path.join(model_folder, "face-reidentification-retail-0095.bin"))):
    print("Modelele lipsesc. Descarcă manual din OpenVINO Model Zoo și pune-le în folderul 'models'.")
    exit()

# Încărcare modele
try:
    model_det = ie.read_model(model=os.path.join(model_folder, "face-detection-adas-0001.xml"))
    model_rec = ie.read_model(model=os.path.join(model_folder, "face-reidentification-retail-0095.xml"))
    compiled_det = ie.compile_model(model_det, device_to_use)
    compiled_rec = ie.compile_model(model_rec, device_to_use)
except Exception as e:
    print(f"Eroare la încărcarea modelelor: {e}")
    exit()

# Bază de date (fișier pickle pentru stocare)
db_file = "face_database.pkl"
database = {}  # {nume: embedding}

# Încarcă baza de date dacă există
if os.path.exists(db_file):
    with open(db_file, "rb") as f:
        database = pickle.load(f)

def detect_face(image):
    blob = cv2.dnn.blobFromImage(image, size=(672, 384), ddepth=cv2.CV_8U)
    outputs = compiled_det.infer_new_request({0: blob})
    boxes = outputs[compiled_det.outputs[0].get_any_name()]
    for box in boxes[0][0]:
        if box[2] > 0.5:
            x_min, y_min, x_max, y_max = (box[3:7] * np.array([image.shape[1], image.shape[0], image.shape[1], image.shape[0]])).astype(int)
            return image[y_min:y_max, x_min:x_max]
    return None

def get_embedding(face):
    # Ajustăm dimensiunea la 128x128 conform cerințelor modelului
    blob = cv2.dnn.blobFromImage(face, size=(128, 128))
    embedding = compiled_rec.infer_new_request({0: blob})
    return embedding[compiled_rec.outputs[0].get_any_name()].flatten()

def compare_faces(emb1, emb2):
    return np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))

# Interfață grafică
root = tk.Tk()
root.title("Bază de Date Fețe")
root.geometry("600x400")

frame_db = tk.LabelFrame(root, text="Bază de Date")
frame_db.pack(side=tk.LEFT, padx=10, pady=10)
frame_test = tk.LabelFrame(root, text="Imagine Test")
frame_test.pack(side=tk.RIGHT, padx=10, pady=10)

label_db = tk.Label(frame_db)
label_db.pack()
label_test = tk.Label(frame_test)
label_test.pack()

# Adaugă mai multe imagini în baza de date
def add_multiple_to_database():
    files = filedialog.askopenfilenames(filetypes=[("Image files", "*.jpg *.png")])
    if files:
        added_count = 0
        for file in files:
            img = cv2.imread(file)
            if img is None:
                continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            face = detect_face(img)
            if face is not None:
                embedding = get_embedding(face)
                name = os.path.basename(file).split('.')[0]
                database[name] = embedding
                added_count += 1
                # Afișează ultima față adăugată
                face_img = Image.fromarray(face).resize((150, 150))
                face_img_tk = ImageTk.PhotoImage(face_img)
                label_db.config(image=face_img_tk)
                label_db.image = face_img_tk
        # Salvează baza de date
        with open(db_file, "wb") as f:
            pickle.dump(database, f)
        messagebox.showinfo("Succes", f"{added_count} fețe au fost adăugate în baza de date!")

# Verifică imaginea de test
def check_face():
    if not database:
        messagebox.showwarning("Atenție", "Baza de date este goală! Adaugă imagini mai întâi.")
        return
    file = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.png")])
    if file:
        img = cv2.imread(file)
        if img is None:
            messagebox.showerror("Eroare", "Imaginea nu poate fi încărcată!")
            return
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        face = detect_face(img)
        if face is not None:
            test_embedding = get_embedding(face)
            # Afișează imaginea de test
            test_img = Image.fromarray(face).resize((150, 150))
            test_img_tk = ImageTk.PhotoImage(test_img)
            label_test.config(image=test_img_tk)
            label_test.image = test_img_tk
            # Compară cu toate fețele din baza de date
            best_match = None
            best_similarity = -1
            for name, embedding in database.items():
                similarity = compare_faces(test_embedding, embedding)
                if similarity > best_similarity:
                    best_similarity = similarity
                    best_match = name
            if best_similarity > 0.8:
                messagebox.showinfo("Rezultat", f"Persoana este: {best_match}\nSimilaritate: {best_similarity:.2f}")
            else:
                messagebox.showinfo("Rezultat", f"Nu s-a găsit o potrivire.\nCea mai apropiată: {best_match} (similaritate: {best_similarity:.2f})")
        else:
            messagebox.showerror("Eroare", "Nu s-a detectat față!")

btn_add = tk.Button(root, text="Adaugă Imagini", command=add_multiple_to_database)
btn_add.pack(pady=5)
btn_test = tk.Button(root, text="Verifică Față", command=check_face)
btn_test.pack(pady=5)

root.mainloop()
