import os
import torch
import subprocess
from tkinter import filedialog, Tk
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import whisper
from pathlib import Path

def select_device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_name = "GPU" if torch.cuda.is_available() else "CPU"
    print(f"{device_name} detectat, voi folosi {device_name} pentru transcriere.")
    return device, device_name

def choose_language():
    print("Vrei să folosești limba implicită Engleză (en)? (da/nu)")
    choice = input().strip().lower()
    if choice in ["da", "d", "yes", "y"]:
        return "en"
    if choice not in ["nu", "n", "no"]:
        print("Răspuns invalid. Folosesc implicit Engleză (en).")
        return "en"
    
    languages = [("en", "Engleză"), ("es", "Spaniolă"), ("fr", "Franceză"), ("de", "Germană"),
                 ("zh", "Chineză"), ("ru", "Rusă"), ("ar", "Arabă"), ("ro", "Română")]
    print("\nSelectează limba audio (introduce numărul):")
    for i, (code, name) in enumerate(languages, 1):
        print(f"{i}. {name} ({code})")
    
    try:
        lang_idx = int(input()) - 1
        if not 0 <= lang_idx < len(languages):
            raise ValueError
        return languages[lang_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit Engleză (en).")
        return "en"

def choose_model():
    models = [
        ("tiny", "Tiny", "Cel mai mic și rapid. Precizie scăzută, bun pentru audio clar. ~39 MB."),
        ("base", "Base", "Echilibru viteză-precizie, uz general. ~74 MB."),
        ("small", "Small", "Mai precis, lent, bun pentru zgomot moderat. ~244 MB."),
        ("medium", "Medium", "Precizie ridicată, audio complex. ~769 MB."),
        ("large", "Large", "Cel mai precis, lent, ideal cu GPU. ~1.5 GB.")
    ]
    print("\nAlege modelul Whisper (introduce numărul):")
    for i, (code, name, desc) in enumerate(models, 1):
        print(f"{i}. {name}\n   - {desc}")
    
    try:
        model_idx = int(input()) - 1
        if not 0 <= model_idx < len(models):
            raise ValueError
        return models[model_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc 'base'.")
        return "base"

def choose_output_format():
    print("\nCe format de ieșire dorești? (introduce numărul):")
    print("1. TXT - Text simplu")
    print("2. SRT - Subtitrări cu timestamp")
    try:
        choice = int(input()) - 1
        if choice == 0:
            return "txt"
        elif choice == 1:
            return "srt"
        else:
            raise ValueError
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit TXT.")
        return "txt"

def download_youtube_audio(url, output_dir):
    """Descarcă audio de pe YouTube folosind yt-dlp"""
    try:
        print(f"Descarc audio de la: {url}")
        output_template = os.path.join(output_dir, "youtube_audio.%(ext)s")
        cmd = [
            "yt-dlp",
            "-x",  # Extrage doar audio
            "--audio-format", "mp3",  # Convertește în MP3
            "-o", output_template,  # Numele fișierului de ieșire
            url
        ]
        subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        # Găsește fișierul descărcat (yt-dlp adaugă extensia automat)
        for file in os.listdir(output_dir):
            if file.startswith("youtube_audio") and file.endswith(".mp3"):
                return os.path.join(output_dir, file)
        raise FileNotFoundError("Fișierul audio nu a fost găsit după descărcare.")
    except subprocess.CalledProcessError as e:
        print(f"Eroare la descărcare: {e.stderr.decode()}")
        return None
    except Exception as e:
        print(f"Eroare neașteptată la descărcare: {str(e)}")
        return None

def format_srt(segments):
    """Convertește segmentele Whisper în format SRT"""
    srt_content = []
    for i, segment in enumerate(segments, 1):
        start = segment["start"]
        end = segment["end"]
        text = segment["text"].strip()
        
        start_h, start_m = divmod(int(start // 3600), 60)
        start_m, start_s = divmod(int(start % 3600 // 60), 60)
        start_ms = int((start % 1) * 1000)
        end_h, end_m = divmod(int(end // 3600), 60)
        end_m, end_s = divmod(int(end % 3600 // 60), 60)
        end_ms = int((end % 1) * 1000)
        
        timestamp = (f"{start_h:02d}:{start_m:02d}:{start_s:02d},{start_ms:03d} --> "
                    f"{end_h:02d}:{end_m:02d}:{end_s:02d},{end_ms:03d}")
        
        srt_content.append(f"{i}\n{timestamp}\n{text}\n")
    
    return "\n".join(srt_content)

def transcribe_file(args):
    file_path, language, output_dir, device_name, model, output_format = args
    file_name = os.path.basename(file_path)
    output_file_name = f"{Path(file_path).stem}.{output_format}"
    
    try:
        print(f"Procesez: {file_name} cu {device_name}")
        result = model.transcribe(file_path, language=language, verbose=False)
        
        if output_format == "txt":
            transcribed_text = result["text"]
            output_path = os.path.join(output_dir, output_file_name)
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(transcribed_text)
        elif output_format == "srt":
            srt_text = format_srt(result["segments"])
            output_path = os.path.join(output_dir, output_file_name)
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(srt_text)
        
        print(f"Complet: {file_name} -> {output_file_name}")
    except Exception as e:
        print(f"Eroare la {file_name}: {str(e)}")

def transcribe_files():
    language = choose_language()
    print(f"\nLimba selectată: {language}")
    
    model_name = choose_model()
    print(f"Model selectat: {model_name}")
    
    output_format = choose_output_format()
    print(f"Format de ieșire selectat: {output_format.upper()}")
    
    device, device_name = select_device()
    
    try:
        print(f"Încep descărcarea modelului '{model_name}'...")
        model = whisper.load_model(model_name).to(device)
        print("Model descărcat și încărcat cu succes!")
    except Exception as e:
        print(f"Eroare la descărcarea/încărcarea modelului: {str(e)}")
        return
    
    # Întreabă despre sursa audio
    print("\nVrei să procesezi un link YouTube sau fișiere locale? (introduce numărul):")
    print("1. Link YouTube")
    print("2. Fișiere locale")
    try:
        source_choice = int(input()) - 1
        if source_choice not in [0, 1]:
            raise ValueError
    except (ValueError, IndexError):
        print("Selecție invalidă. Program terminat.")
        return
    
    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(script_dir, "transcriptions")
    os.makedirs(output_dir, exist_ok=True)
    
    files = []
    if source_choice == 0:  # YouTube
        youtube_url = input("Introdu link-ul YouTube: ").strip()
        audio_file = download_youtube_audio(youtube_url, output_dir)
        if audio_file:
            files.append(audio_file)
        else:
            print("Descărcarea a eșuat. Program terminat.")
            return
    else:  # Fișiere locale
        root = Tk()
        root.withdraw()
        files = filedialog.askopenfilenames(
            title="Selectează fișierele audio",
            filetypes=[("Audio files", "*.mp3 *.wav *.m4a *.flac"), ("All files", "*.*")]
        )
        root.destroy()
        files = [f for f in files if f.lower().endswith(('.mp3', '.wav', '.m4a', '.flac'))]
    
    if not files:
        print("Niciun fișier selectat sau descărcat. Program terminat.")
        return
    
    num_processes = cpu_count()
    print(f"\nDetectate {num_processes} nuclee CPU. Procesez {len(files)} fișiere cu {device_name}...")
    
    tasks = [(file_path, language, output_dir, device_name, model, output_format) for file_path in files]
    
    with Pool(processes=num_processes) as pool:
        list(tqdm(pool.imap(transcribe_file, tasks), total=len(tasks), desc="Transcriere fișiere"))
    
    print(f"\nTranscriere completă! Fișierele sunt în 'transcriptions'.")

if __name__ == "__main__":
    transcribe_files()
