import tkinter as tk
from tkinter import filedialog, ttk
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from threading import Thread
import os
import re
from multiprocessing import Pool

class NewsSearchApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Interogare Știri")
        self.root.geometry("600x500")

        self.chunks = []
        self.embeddings = None
        self.index = None
        self.model = SentenceTransformer('all-MiniLM-L6-v2')  # Încărcat o dată
        self.chunk_size = tk.IntVar(value=1000)
        self.top_k = tk.IntVar(value=5)

        # Interfață
        self.label = tk.Label(root, text="Încarcă fișiere cu știri:")
        self.label.pack(pady=10)

        self.load_button = tk.Button(root, text="Încarcă Fișiere", command=self.load_files)
        self.load_button.pack(pady=5)

        self.progress = ttk.Progressbar(root, length=400, mode='determinate')
        self.progress.pack(pady=10)

        self.status_label = tk.Label(root, text="")
        self.status_label.pack(pady=5)

        tk.Label(root, text="Lungime fragment (caractere):").pack(pady=5)
        tk.Entry(root, textvariable=self.chunk_size, width=10).pack(pady=5)
        tk.Label(root, text="Număr rezultate (k):").pack(pady=5)
        tk.Entry(root, textvariable=self.top_k, width=10).pack(pady=5)

        self.query_label = tk.Label(root, text="Introdu întrebarea:")
        self.query_label.pack(pady=5)
        self.query_entry = tk.Entry(root, width=50)
        self.query_entry.pack(pady=5)

        self.search_button = tk.Button(root, text="Caută", command=self.search, state=tk.DISABLED)
        self.search_button.pack(pady=5)

        self.save_button = tk.Button(root, text="Salvează Rezultate", command=self.save_results, state=tk.DISABLED)
        self.save_button.pack(pady=5)

        self.result_text = tk.Text(root, height=10, width=70)
        self.result_text.pack(pady=10)

    def load_files(self):
        file_paths = filedialog.askopenfilenames(filetypes=[("Text files", "*.txt")])
        if not file_paths:
            return

        self.status_label.config(text="Încarcă și prelucrează fișierele...")
        self.load_button.config(state=tk.DISABLED)
        Thread(target=self.process_files, args=(file_paths,)).start()

    def process_chunk(self, chunk):
        return self.model.encode(chunk)

    def process_files(self, file_paths):
        self.chunks = []
        for file_path in file_paths:
            with open(file_path, "r", encoding="utf-8") as f:
                content = f.read()
                self.chunks.extend([content[i:i+self.chunk_size.get()] for i in range(0, len(content), self.chunk_size.get())])

        total_chunks = len(self.chunks)
        if total_chunks == 0:
            self.status_label.config(text="Niciun conținut găsit.")
            return

        # Procesare paralelă
        with Pool() as pool:
            self.embeddings = np.array(pool.map(self.process_chunk, self.chunks))
            for i in range(total_chunks):
                progress = (i + 1) / total_chunks * 100
                self.progress['value'] = progress
                self.status_label.config(text=f"Prelucrare: {int(progress)}%")
                self.root.update()

        d = self.embeddings.shape[1]
        quantizer = faiss.IndexFlatL2(d)
        self.index = faiss.IndexIVFFlat(quantizer, d, 50)  # Optimizare cu IVF
        self.index.train(self.embeddings)
        self.index.add(self.embeddings)

        self.status_label.config(text="Fișiere prelucrate. Poți interoga!")
        self.load_button.config(state=tk.NORMAL)
        self.search_button.config(state=tk.NORMAL)
        self.save_button.config(state=tk.NORMAL)

    def search(self):
        query = self.query_entry.get()
        if not query:
            return

        query_embedding = self.model.encode([query])
        k = self.top_k.get()
        distances, indices = self.index.search(query_embedding, k)

        self.result_text.delete(1.0, tk.END)
        results = []
        for i, idx in enumerate(indices[0]):
            fragment = self.chunks[idx][:100] + "..."
            result = f"Fragment: {fragment} (Distanță: {distances[0][i]:.2f})\n"
            self.result_text.insert(tk.END, result)
            results.append(self.chunks[idx])

        self.results_to_save = results

    def save_results(self):
        if hasattr(self, 'results_to_save'):
            script_dir = os.path.dirname(os.path.abspath(__file__))
            file_path = os.path.join(script_dir, "rezultate_stiri.txt")
            with open(file_path, "w", encoding="utf-8") as f:
                for result in self.results_to_save:
                    clean_text = re.sub(r'\(Distanță:.*?\)|\(Fragment:.*?\)', '', result).strip()
                    f.write(clean_text + "\n")
            self.status_label.config(text="Rezultate salvate în rezultate_stiri.txt")

if __name__ == "__main__":
    root = tk.Tk()
    app = NewsSearchApp(root)
    root.mainloop()
