import os
import torch
from tkinter import filedialog, Tk
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import whisper
from pathlib import Path

def select_device():
    """Detectează și selectează dispozitivul (GPU -> CPU)"""
    device = None
    device_name = None
    
    if torch.cuda.is_available():
        device = torch.device("cuda")
        device_name = "GPU"
        print("GPU detectat, voi folosi GPU pentru transcriere.")
    else:
        device = torch.device("cpu")
        device_name = "CPU"
        print("Niciun GPU detectat, voi folosi CPU pentru transcriere.")
    
    return device, device_name

def choose_language():
    """Permite utilizatorului să aleagă limba audio"""
    print("Vrei să folosești limba implicită Engleză (en)? (da/nu)")
    choice = input().strip().lower()
    
    if choice in ["da", "d", "yes", "y"]:
        return "en"
    
    if choice not in ["nu", "n", "no"]:
        print("Răspuns invalid. Folosesc implicit Engleză (en).")
        return "en"
    
    languages = [
        ("en", "Engleză"),
        ("es", "Spaniolă"),
        ("fr", "Franceză"),
        ("de", "Germană"),
        ("zh", "Chineză"),
        ("ru", "Rusă"),
        ("ar", "Arabă"),
        ("ro", "Română")
    ]
    
    print("\nSelectează limba audio (introduce numărul):")
    for i, (code, name) in enumerate(languages, 1):
        print(f"{i}. {name} ({code})")
    
    try:
        lang_idx = int(input()) - 1
        if not 0 <= lang_idx < len(languages):
            raise ValueError
        lang_code = languages[lang_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit Engleză (en).")
        lang_code = "en"
    
    return lang_code

def choose_model():
    """Permite utilizatorului să aleagă modelul Whisper și oferă explicații"""
    models = [
        ("tiny", "Tiny", 
         "Cel mai mic și rapid model. Potrivit pentru dispozitive cu resurse limitate. "
         "Precizie scăzută, ideal pentru audio clar și limbi comune (ex. engleză). "
         "Dimensiune: ~39 MB."),
        ("base", "Base", 
         "Echilibru între viteză și precizie. Bun pentru majoritatea cazurilor generale. "
         "Funcționează bine pe audio de calitate medie. Dimensiune: ~74 MB."),
        ("small", "Small", 
         "Precizie mai bună decât Base, dar mai lent. Recomandat pentru audio cu zgomot moderat. "
         "Dimensiune: ~244 MB."),
        ("medium", "Medium", 
         "Model avansat cu precizie ridicată. Bun pentru limbi mai puțin comune sau audio complex. "
         "Necesită mai multe resurse. Dimensiune: ~769 MB."),
        ("large", "Large", 
         "Cel mai precis model, ideal pentru audio dificil (zgomot, accente, limbi rare). "
         "Foarte lent pe CPU, recomandat cu GPU. Dimensiune: ~1.5 GB.")
    ]
    
    print("\nAlege modelul Whisper (introduce numărul):")
    for i, (code, name, desc) in enumerate(models, 1):
        print(f"{i}. {name}\n   - {desc}")
    
    try:
        model_idx = int(input()) - 1
        if not 0 <= model_idx < len(models):
            raise ValueError
        model_code = models[model_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit modelul 'base'.")
        model_code = "base"
    
    return model_code

def transcribe_file(args):
    """Funcție pentru transcrierea unui fișier audio"""
    file_path, language, output_dir, device_name, model = args
    file_name = os.path.basename(file_path)
    output_file_name = f"{Path(file_path).stem}.txt"
    
    try:
        print(f"Procesez: {file_name} cu {device_name}")
        
        # Încarcă fișierul audio și transcrie cu Whisper
        result = model.transcribe(file_path, language=language, verbose=False)
        transcribed_text = result["text"]
        
        # Salvează transcrierea
        output_path = os.path.join(output_dir, output_file_name)
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(transcribed_text)
        
        print(f"Complet: {file_name} -> {output_file_name}")
    except Exception as e:
        print(f"Eroare la {file_name}: {str(e)}")

def transcribe_files():
    """Transcrie fișierele audio selectate în paralel"""
    # Alegere limbă
    language = choose_language()
    print(f"\nLimba selectată pentru transcriere: {language}")
    
    # Alegere model
    model_name = choose_model()
    print(f"Model selectat: {model_name}")
    
    # Detectare dispozitiv
    device, device_name = select_device()
    
    # Încarcă modelul Whisper
    print(f"Încarc modelul Whisper '{model_name}'...")
    model = whisper.load_model(model_name).to(device)
    
    # Selectare fișiere
    root = Tk()
    root.withdraw()
    files = filedialog.askopenfilenames(
        title="Selectează fișierele audio de transcris",
        filetypes=[
            ("Audio files", "*.mp3 *.wav *.m4a *.flac"),
            ("All files", "*.*")
        ]
    )
    root.destroy()
    
    if not files:
        print("Niciun fișier selectat. Program terminat.")
        return
    
    # Creare folder output
    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(script_dir, "transcriptions")
    os.makedirs(output_dir, exist_ok=True)
    
    # Număr de procesoare disponibile
    num_processes = cpu_count()
    print(f"\nDetectate {num_processes} nuclee CPU. Procesez {len(files)} fișiere în paralel folosind {device_name}...")
    
    # Pregătire argumente pentru procesare paralelă
    tasks = [(file_path, language, output_dir, device_name, model) for file_path in files
             if file_path.lower().endswith(('.mp3', '.wav', '.m4a', '.flac'))]
    
    # Procesare paralelă
    with Pool(processes=num_processes) as pool:
        list(tqdm(pool.imap(transcribe_file, tasks), total=len(tasks), desc="Transcriere fișiere"))
    
    print(f"\nTranscriere completă! Fișierele sunt în folderul 'transcriptions'.")

if __name__ == "__main__":
    transcribe_files()
