import os
import torch
from tkinter import filedialog, Tk
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from transformers import MarianTokenizer, MarianMTModel

def setup_translation(from_code, to_code):
    """Configurează modelul de traducere MarianMT"""
    model_name = f"Helsinki-NLP/opus-mt-{from_code}-{to_code}"
    try:
        print(f"Încarc modelul {model_name}...")
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        print(f"Model {model_name} încărcat cu succes.")
        return tokenizer, model
    except Exception as e:
        print(f"Eroare la încărcarea modelului {model_name}: {str(e)}")
        print("Verifică dacă modelul există sau dacă ai conexiune pentru prima descărcare.")
        return None, None

def select_device():
    """Detectează și selectează dispozitivul (NPU -> GPU -> CPU)"""
    device = None
    device_name = None
    
    try:
        if hasattr(torch, 'npu') and torch.npu.is_available():
            device = torch.device("npu")
            device_name = "NPU"
            print("NPU detectat, voi folosi NPU pentru traducere.")
    except AttributeError:
        pass
    
    if device is None and torch.cuda.is_available():
        device = torch.device("cuda")
        device_name = "GPU"
        print("GPU detectat, voi folosi GPU pentru traducere.")
    
    if device is None:
        device = torch.device("cpu")
        device_name = "CPU"
        print("Niciun NPU sau GPU detectat, voi folosi CPU pentru traducere.")
    
    return device, device_name

def choose_language():
    """Permite utilizatorului să aleagă limbile de traducere"""
    print("Vrei să folosești traducerea implicită EN -> RO? (da/nu)")
    choice = input().strip().lower()
    
    if choice in ["da", "d", "yes", "y"]:
        return "en", "ro"
    
    if choice not in ["nu", "n", "no"]:
        print("Răspuns invalid. Folosesc implicit EN -> RO.")
        return "en", "ro"
    
    languages = [
        ("en", "Engleză"),
        ("es", "Spaniolă"),
        ("fr", "Franceză"),
        ("de", "Germană"),
        ("zh", "Chineză"),
        ("ru", "Rusă"),
        ("ar", "Arabă"),
        ("ro", "Română")
    ]
    
    print("\nSelectează limba sursă (introduce numărul):")
    for i, (code, name) in enumerate(languages, 1):
        print(f"{i}. {name} ({code})")
    
    try:
        from_idx = int(input()) - 1
        if not 0 <= from_idx < len(languages):
            raise ValueError
        from_code = languages[from_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit Engleză (en).")
        from_code = "en"
    
    print("\nSelectează limba destinație (introduce numărul):")
    for i, (code, name) in enumerate(languages, 1):
        print(f"{i}. {name} ({code})")
    
    try:
        to_idx = int(input()) - 1
        if not 0 <= to_idx < len(languages) or to_idx == from_idx:
            raise ValueError
        to_code = languages[to_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit Română (ro).")
        to_code = "ro"
    
    return from_code, to_code

def translate_file(args):
    """Funcție pentru traducerea unui singur fișier cu bară de progres"""
    file_path, from_code, to_code, output_dir, device_name, device = args
    file_name = os.path.basename(file_path)
    
    try:
        # Încarcă modelul în procesul curent
        tokenizer, model = setup_translation(from_code, to_code)
        if not tokenizer or not model:
            return
        
        model.to(device)
        model.eval()
        
        print(f"Procesez: {file_name} cu {device_name}")
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        # Împărțim textul în propoziții pentru progres (aproximativ)
        sentences = text.split('.')
        sentences = [s.strip() + '.' for s in sentences if s.strip()]  # Adaugăm punctul înapoi
        translated_sentences = []
        
        with tqdm(total=len(sentences), desc=f"Traduc {file_name}", leave=False) as pbar:
            for sentence in sentences:
                inputs = tokenizer(sentence, return_tensors="pt", padding=True).to(device)
                translated = model.generate(**inputs)
                translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
                translated_sentences.append(translated_text)
                pbar.update(1)
        
        translated_text = " ".join(translated_sentences)
        
        output_path = os.path.join(output_dir, file_name)
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(translated_text)
        
        print(f"Complet: {file_name}")
    except Exception as e:
        print(f"Eroare la {file_name}: {str(e)}")

def translate_files():
    """Traduce fișierele selectate în paralel"""
    # Alegere limbi
    from_code, to_code = choose_language()
    print(f"\nTraducere selectată: {from_code} -> {to_code}")
    
    # Detectare dispozitiv
    device, device_name = select_device()
    
    # Verificare model (doar pentru feedback inițial)
    tokenizer, model = setup_translation(from_code, to_code)
    if not tokenizer or not model:
        return
    
    # Selectare fișiere
    root = Tk()
    root.withdraw()
    files = filedialog.askopenfilenames(
        title="Selectează fișierele de tradus",
        filetypes=[("Text files", "*.txt")]
    )
    root.destroy()
    
    if not files:
        print("Niciun fișier selectat. Program terminat.")
        return
    
    # Creare folder output
    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(script_dir, to_code)
    os.makedirs(output_dir, exist_ok=True)
    
    # Număr de procesoare disponibile
    num_processes = cpu_count()
    print(f"\nDetectate {num_processes} nuclee CPU. Procesez {len(files)} fișiere în paralel folosind {device_name}...")
    
    # Pregătire argumente pentru procesare paralelă
    tasks = [(file_path, from_code, to_code, output_dir, device_name, device) for file_path in files if file_path.endswith('.txt')]
    
    # Procesare paralelă
    with Pool(processes=num_processes) as pool:
        pool.map(translate_file, tasks)
    
    print(f"\nTraducere completă! Fișierele sunt în folderul '{to_code}'.")

if __name__ == "__main__":
    translate_files()
