import os
import argostranslate.package
import argostranslate.translate
import torch
from tkinter import filedialog, Tk
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import re

def setup_translation(from_code, to_code):
    """Configurează pachetul de traducere pentru limbile selectate"""
    try:
        print("Actualizare index pachete...")
        argostranslate.package.update_package_index()
        available_packages = argostranslate.package.get_available_packages()
        package_to_install = next(
            filter(
                lambda x: x.from_code == from_code and x.to_code == to_code, 
                available_packages
            ), None
        )
        if not package_to_install:
            print(f"Nu există pachet de traducere disponibil pentru {from_code} -> {to_code}")
            return False
        
        print(f"Descărcare pachet de traducere {from_code} -> {to_code}...")
        argostranslate.package.install_from_path(package_to_install.download())
        return True
    except Exception as e:
        print(f"Eroare la configurare: {str(e)}")
        return False

def select_device():
    """Detectează și selectează dispozitivul (NPU -> GPU -> CPU)"""
    device = None
    device_name = None
    
    try:
        if hasattr(torch, 'npu') and torch.npu.is_available():
            device = torch.device("npu")
            device_name = "NPU"
            print("NPU detectat, voi folosi NPU pentru traducere.")
    except AttributeError:
        pass
    
    if device is None and torch.cuda.is_available():
        device = torch.device("cuda")
        device_name = "GPU"
        print("GPU detectat, voi folosi GPU pentru traducere.")
    
    if device is None:
        device = torch.device("cpu")
        device_name = "CPU"
        print("Niciun NPU sau GPU detectat, voi folosi CPU pentru traducere.")
    
    return device, device_name

def choose_language():
    """Permite utilizatorului să aleagă limbile de traducere"""
    print("Vrei să folosești traducerea implicită EN -> RO? (da/nu)")
    choice = input().strip().lower()
    
    if choice in ["da", "d", "yes", "y"]:
        return "en", "ro"
    
    if choice not in ["nu", "n", "no"]:
        print("Răspuns invalid. Folosesc implicit EN -> RO.")
        return "en", "ro"
    
    languages = [
        ("en", "Engleză"),
        ("es", "Spaniolă"),
        ("fr", "Franceză"),
        ("de", "Germană"),
        ("zh", "Chineză"),
        ("ru", "Rusă"),
        ("ar", "Arabă"),
        ("ro", "Română")
    ]
    
    print("\nSelectează limba sursă (introduce numărul):")
    for i, (code, name) in enumerate(languages, 1):
        print(f"{i}. {name} ({code})")
    
    try:
        from_idx = int(input()) - 1
        if not 0 <= from_idx < len(languages):
            raise ValueError
        from_code = languages[from_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit Engleză (en).")
        from_code = "en"
    
    print("\nSelectează limba destinație (introduce numărul):")
    for i, (code, name) in enumerate(languages, 1):
        print(f"{i}. {name} ({code})")
    
    try:
        to_idx = int(input()) - 1
        if not 0 <= to_idx < len(languages) or to_idx == from_idx:
            raise ValueError
        to_code = languages[to_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit Română (ro).")
        to_code = "ro"
    
    return from_code, to_code

def parse_srt(content):
    """Parsează conținutul SRT și returnează o listă de blocuri (index, timestamp, text)"""
    blocks = []
    lines = content.strip().split('\n\n')
    
    for block in lines:
        if not block.strip():
            continue
        parts = block.strip().split('\n', 2)
        if len(parts) < 3:
            continue
        index = parts[0].strip()
        timestamp = parts[1].strip()
        text = parts[2].strip()
        blocks.append((index, timestamp, text))
    return blocks

def translate_srt_block(block, from_code, to_code):
    """Traduce textul unui bloc SRT, păstrând indexul și timestamp-ul"""
    index, timestamp, text = block
    translated_text = argostranslate.translate.translate(text, from_code, to_code)
    return (index, timestamp, translated_text)

def translate_file(args):
    """Funcție pentru traducerea unui fișier SRT"""
    file_path, from_code, to_code, output_dir, device_name = args
    file_name = os.path.basename(file_path)
    
    try:
        print(f"Procesez: {file_name} cu {device_name}")
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Parsează fișierul SRT
        srt_blocks = parse_srt(content)
        
        # Traduce blocurile cu bară de progres
        translated_blocks = []
        with tqdm(total=len(srt_blocks), desc=f"Traduc {file_name}", leave=False) as pbar:
            for block in srt_blocks:
                translated_block = translate_srt_block(block, from_code, to_code)
                translated_blocks.append(translated_block)
                pbar.update(1)
        
        # Reconstruiește fișierul SRT
        translated_content = '\n\n'.join(
            f"{index}\n{timestamp}\n{text}" for index, timestamp, text in translated_blocks
        ) + '\n'
        
        output_path = os.path.join(output_dir, file_name)
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(translated_content)
        
        print(f"Complet: {file_name}")
    except Exception as e:
        print(f"Eroare la {file_name}: {str(e)}")

def translate_files():
    """Traduce fișierele SRT selectate în paralel"""
    # Alegere limbi
    from_code, to_code = choose_language()
    print(f"\nTraducere selectată: {from_code} -> {to_code}")
    
    # Configurare traducere
    if not setup_translation(from_code, to_code):
        return
    
    # Detectare dispozitiv
    device, device_name = select_device()
    
    # Selectare fișiere
    root = Tk()
    root.withdraw()
    files = filedialog.askopenfilenames(
        title="Selectează fișierele SRT de tradus",
        filetypes=[("SRT files", "*.srt")]
    )
    root.destroy()
    
    if not files:
        print("Niciun fișier selectat. Program terminat.")
        return
    
    # Creare folder output
    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(script_dir, to_code)
    os.makedirs(output_dir, exist_ok=True)
    
    # Număr de procesoare disponibile
    num_processes = cpu_count()
    print(f"\nDetectate {num_processes} nuclee CPU. Procesez {len(files)} fișiere în paralel folosind {device_name}...")
    
    # Pregătire argumente pentru procesare paralelă
    tasks = [(file_path, from_code, to_code, output_dir, device_name) for file_path in files if file_path.endswith('.srt')]
    
    # Procesare paralelă
    with Pool(processes=num_processes) as pool:
        pool.map(translate_file, tasks)
    
    print(f"\nTraducere completă! Fișierele sunt în folderul '{to_code}'.")

if __name__ == "__main__":
    translate_files()
