# Transcriptor YouTube pentru Google Colab cu GPU A100
# Script complet cu instalarea tuturor dependențelor necesare

# Instalăm dependențele necesare
!pip install -q git+https://github.com/openai/whisper.git
!pip install -q yt-dlp
!pip install -q ffmpeg-python
!apt-get update && apt-get install -y ffmpeg
!pip install -q torchaudio

# Importăm bibliotecile necesare
import os
import torch
import subprocess
import numpy as np
import time
from tqdm.notebook import tqdm
from IPython.display import display, HTML, Audio, clear_output
from google.colab import files
import json
import tempfile

# Verifică instalarea whisper
try:
    import whisper
except ImportError:
    print("Reinstalez whisper...")
    !pip install -q git+https://github.com/openai/whisper.git
    import whisper

# Verifică instalarea yt-dlp
try:
    import yt_dlp
except ImportError:
    print("Reinstalez yt-dlp...")
    !pip install -q yt-dlp
    import yt_dlp

# Verifică disponibilitatea GPU
def verify_gpu():
    """Verifică dacă GPU-ul este disponibil și afișează informații despre acesta"""
    print("🔍 Verificare GPU...")
    
    # Verifică dacă CUDA este disponibil
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # GB
        print(f"✅ GPU detectat: {gpu_name} cu {gpu_memory:.2f} GB VRAM")
        
        # Afișează informații detaliate despre GPU
        !nvidia-smi
        
        # Verifică dacă este A100
        if "A100" in gpu_name:
            print(f"\n✅ Excelent! Ai un GPU A100 care este perfect pentru transcrieri de mare viteză.")
            # Setăm pentru performanță maximă
            torch.backends.cudnn.benchmark = True
            return "cuda", "A100 GPU"
        else:
            print(f"\n⚠️ Ai un GPU {gpu_name}, dar nu este A100. Performanța va fi redusă.")
            return "cuda", "GPU"
    else:
        print("❌ Niciun GPU disponibil. Utilizez CPU (va fi mult mai lent)")
        return "cpu", "CPU"

def choose_language():
    """Permite utilizatorului să aleagă limba pentru transcriere"""
    languages = [
        ("auto", "Autodetectare"),
        ("en", "Engleză"), 
        ("es", "Spaniolă"), 
        ("fr", "Franceză"), 
        ("de", "Germană"),
        ("ro", "Română"),
        ("zh", "Chineză"), 
        ("ru", "Rusă"), 
        ("ar", "Arabă"),
        ("ja", "Japoneză"),
        ("ko", "Coreeană"),
        ("it", "Italiană"),
        ("pt", "Portugheză"),
        ("nl", "Olandeză"),
        ("tr", "Turcă"),
        ("pl", "Poloneză"),
        ("hu", "Maghiară")
    ]
    
    print("\n🌍 Selectează limba audio:")
    # Afișează limbile în coloane pentru o vizualizare mai bună
    col_size = 4
    for i in range(0, len(languages), col_size):
        row = languages[i:i+col_size]
        print(" | ".join([f"{j}. {name} ({code})" for j, (code, name) in enumerate(row, i)]))
    
    try:
        lang_idx = int(input("\nIntrodu numărul (0 pentru autodetectare): "))
        if not 0 <= lang_idx < len(languages):
            raise ValueError
        selected = languages[lang_idx][0]
        print(f"Limba selectată: {languages[lang_idx][1]}")
        
        # Verificare specială pentru limba română
        if selected == "ro":
            # Verifică dacă Whisper suportă limba română
            try:
                import whisper
                available_langs = whisper.tokenizer.LANGUAGES
                if selected not in available_langs:
                    print("\n⚠️ Atenție: Limba română nu este explicit suportată în versiunea actuală de Whisper.")
                    print("Se recomandă să utilizezi autodetectarea limbii ('auto') pentru rezultate mai bune.")
                    if input("Dorești să folosești autodetectare în loc? (da/nu, implicit: da): ").lower() not in ['nu', 'n', 'no']:
                        print("Folosesc autodetectare...")
                        return None
                    else:
                        print("Se continuă cu limba română specificată explicit.")
            except:
                print("Nu am putut verifica lista de limbi suportate.")
        
        return None if selected == "auto" else selected
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc autodetectare.")
        return None

def choose_model():
    """Permite utilizatorului să aleagă modelul Whisper"""
    # Pentru A100 putem folosi orice model, chiar și large-v3
    models = [
        ("tiny", "Tiny", "Cel mai mic (39 MB). Rapid dar imprecis."),
        ("base", "Base", "Mic (74 MB). Bun pentru text clar."),
        ("small", "Small", "Mediu (244 MB). Echilibrat."),
        ("medium", "Medium", "Mare (769 MB). Precis pentru multe limbi."),
        ("large-v3", "Large-v3", "Cel mai mare (1.5 GB). Cea mai bună calitate. Recomandat pentru A100.")
    ]
    
    print("\n🤖 Alege modelul Whisper:")
    for i, (code, name, desc) in enumerate(models, 1):
        print(f"{i}. {name} - {desc}")
    
    print("\nCu un A100 de 40GB, poți folosi orice model, inclusiv large-v3 cu batch processing.")
    print("Recomandare: large-v3 pentru cea mai bună calitate, medium pentru echilibru viteză/precizie")
    
    try:
        model_idx = int(input("Introdu numărul (implicit: large-v3): ") or "5") - 1
        if not 0 <= model_idx < len(models):
            raise ValueError
        selected = models[model_idx][0]
        print(f"Model selectat: {models[model_idx][1]}")
        return selected
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc 'large-v3'.")
        return "large-v3"

def choose_output_format():
    """Permite utilizatorului să aleagă formatul de output"""
    formats = [
        ("txt", "Text simplu"),
        ("srt", "Subtitrări SRT cu timestamp"),
        ("vtt", "Subtitrări VTT (pentru web)"),
        ("all", "Toate formatele de mai sus"),
        ("json", "JSON (toate detaliile)")
    ]
    
    print("\n📄 Alege formatul de ieșire:")
    for i, (code, desc) in enumerate(formats, 1):
        print(f"{i}. {desc}" + (f" (.{code})" if code != "all" else ""))
    
    try:
        format_idx = int(input("Introdu numărul (implicit: 1): ") or "1") - 1
        if not 0 <= format_idx < len(formats):
            raise ValueError
        selected = formats[format_idx][0]
        print(f"Format selectat: {formats[format_idx][1]}")
        return selected
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc text simplu (TXT).")
        return "txt"

def download_youtube_audio(url, output_dir):
    """Descarcă audio de pe YouTube folosind yt-dlp"""
    try:
        print(f"⬇️ Descarc audio de la: {url}")
        output_template = os.path.join(output_dir, "youtube_audio.%(ext)s")
        
        # Configurăm yt-dlp pentru a extrage cel mai bun audio
        ydl_opts = {
            'format': 'bestaudio/best',
            'postprocessors': [{
                'key': 'FFmpegExtractAudio',
                'preferredcodec': 'mp3',
                'preferredquality': '192',
            }],
            'outtmpl': output_template,
            'quiet': False,
            'no_warnings': False
        }
        
        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            # Obținem informații despre video înainte de descărcare
            info = ydl.extract_info(url, download=False)
            video_title = info.get('title', 'Video necunoscut')
            video_duration = info.get('duration', 0)
            
            print(f"🎬 Titlu: {video_title}")
            print(f"⏱️ Durată: {format_time(video_duration)}")
            
            # Descărcăm video-ul
            ydl.download([url])
        
        # Găsește fișierul descărcat
        for file in os.listdir(output_dir):
            if file.startswith("youtube_audio") and file.endswith(".mp3"):
                audio_path = os.path.join(output_dir, file)
                # Redenumește fișierul cu un nume mai descriptiv
                new_name = video_title.replace("/", "-").replace("\\", "-")[:50] + ".mp3"
                new_path = os.path.join(output_dir, new_name)
                os.rename(audio_path, new_path)
                return new_path
                
        raise FileNotFoundError("Fișierul audio nu a fost găsit după descărcare.")
    except Exception as e:
        print(f"❌ Eroare la descărcare: {str(e)}")
        return None

def format_time(seconds):
    """Convertește secunde în format ore:minute:secunde"""
    if seconds is None:
        return "00:00:00"
    
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    seconds = seconds % 60
    return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"

def format_srt(segments):
    """Convertește segmentele Whisper în format SRT"""
    srt_content = []
    for i, segment in enumerate(segments, 1):
        start = segment["start"]
        end = segment["end"]
        text = segment["text"].strip()
        
        start_h, start_m = divmod(int(start // 3600), 60)
        start_m, start_s = divmod(int(start % 3600 // 60), 60)
        start_ms = int((start % 1) * 1000)
        end_h, end_m = divmod(int(end // 3600), 60)
        end_m, end_s = divmod(int(end % 3600 // 60), 60)
        end_ms = int((end % 1) * 1000)
        
        timestamp = (f"{start_h:02d}:{start_m:02d}:{start_s:02d},{start_ms:03d} --> "
                    f"{end_h:02d}:{end_m:02d}:{end_s:02d},{end_ms:03d}")
        
        srt_content.append(f"{i}\n{timestamp}\n{text}\n")
    
    return "\n".join(srt_content)

def format_vtt(segments):
    """Convertește segmentele Whisper în format VTT"""
    vtt_content = ["WEBVTT\n"]
    
    for i, segment in enumerate(segments, 1):
        start = segment["start"]
        end = segment["end"]
        text = segment["text"].strip()
        
        start_h, start_m = divmod(int(start // 3600), 60)
        start_m, start_s = divmod(int(start % 3600 // 60), 60)
        start_ms = int((start % 1) * 1000)
        end_h, end_m = divmod(int(end // 3600), 60)
        end_m, end_s = divmod(int(end % 3600 // 60), 60)
        end_ms = int((end % 1) * 1000)
        
        timestamp = (f"{start_h:02d}:{start_m:02d}:{start_s:02d}.{start_ms:03d} --> "
                    f"{end_h:02d}:{end_m:02d}:{end_s:02d}.{end_ms:03d}")
        
        vtt_content.append(f"{timestamp}\n{text}\n")
    
    return "\n".join(vtt_content)

def post_process_segments(segments):
    """Post-procesează segmentele pentru a elimina repetițiile și a îmbunătăți transcrierea"""
    if not segments:
        return segments
    
    # Detectarea repetițiilor
    processed_segments = []
    repetition_threshold = 0.9  # Pragul de similaritate pentru a considera o repetitie
    
    def similarity(text1, text2):
        """Calculează similaritatea între două texte"""
        if not text1 or not text2:
            return 0
        
        # Normalizăm textele
        text1 = text1.lower().strip()
        text2 = text2.lower().strip()
        
        # Dacă un text este inclus în celălalt
        if text1 in text2 or text2 in text1:
            return min(len(text1), len(text2)) / max(len(text1), len(text2))
        
        # Numărul de caractere comune
        common_chars = sum(1 for c in text1 if c in text2)
        return common_chars / max(len(text1), len(text2))
    
    # Eliminarea repetițiilor consecutive
    prev_text = ""
    for segment in segments:
        current_text = segment["text"].strip()
        
        # Verifică similitudinea cu segmentul anterior
        if similarity(current_text, prev_text) < repetition_threshold:
            processed_segments.append(segment)
            prev_text = current_text
    
    # Dacă au fost eliminate prea multe segmente, revenire la originalul
    if len(processed_segments) < len(segments) * 0.5:
        print("⚠️ Atenție: Au fost detectate multe repetitii. S-ar putea să fie o problemă cu audio-ul.")
        
        # Alternativă: să consolidăm segmentele cu texte similare
        consolidated_segments = []
        i = 0
        while i < len(segments):
            current = segments[i]
            j = i + 1
            # Caută segmente consecutive cu text similar
            while j < len(segments) and similarity(segments[i]["text"], segments[j]["text"]) >= repetition_threshold:
                j += 1
            
            # Ajustează timpul de sfârșit dacă am găsit repetitii
            if j > i + 1:
                current["end"] = segments[j-1]["end"]
            
            consolidated_segments.append(current)
            i = j
        
        return consolidated_segments
    
    return processed_segments or segments  # Întoarce segmentele originale dacă nu avem nimic

def save_and_download_output(result, audio_file, output_format, output_dir):
    """Salvează și descarcă rezultatele în formatul selectat"""
    base_filename = os.path.splitext(os.path.basename(audio_file))[0]
    output_files = []

    # Post-procesare segmente pentru a elimina repetițiile
    print("\n🔄 Post-procesare pentru îmbunătățirea transcrierii...")
    original_segment_count = len(result["segments"])
    result["segments"] = post_process_segments(result["segments"])
    processed_segment_count = len(result["segments"])
    
    if original_segment_count != processed_segment_count:
        print(f"✅ Post-procesare completă: {original_segment_count - processed_segment_count} segmente redundante eliminate.")
        # Actualizăm textul complet după post-procesare
        result["text"] = " ".join([segment["text"] for segment in result["segments"]])
    else:
        print("✅ Post-procesare completă: nu au fost găsite repetitii semnificative.")

    # Determină ce formate să generezi
    formats_to_generate = []
    if output_format == "all":
        formats_to_generate = ["txt", "srt", "vtt"]
    else:
        formats_to_generate = [output_format]

    # Generează fiecare format
    for fmt in formats_to_generate:
        output_filename = f"{base_filename}.{fmt}"
        output_path = os.path.join(output_dir, output_filename)
        
        if fmt == "txt":
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(result["text"])
            print(f"✅ Salvat text în: {output_filename}")
        
        elif fmt == "srt":
            srt_text = format_srt(result["segments"])
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(srt_text)
            print(f"✅ Salvat subtitrări SRT în: {output_filename}")
        
        elif fmt == "vtt":
            vtt_text = format_vtt(result["segments"])
            with open(output_path, 'w', encoding='utf-8') as f:
                f.write(vtt_text)
            print(f"✅ Salvat subtitrări VTT în: {output_filename}")
        
        elif fmt == "json":
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(result, f, indent=2, ensure_ascii=False)
            print(f"✅ Salvat date complete JSON în: {output_filename}")
        
        output_files.append(output_path)
    
    # Afișează un preview pentru text
    if "txt" in formats_to_generate or output_format == "txt":
        txt_path = os.path.join(output_dir, f"{base_filename}.txt")
        with open(txt_path, 'r', encoding='utf-8') as f:
            content = f.read()
        print("\n📝 PREVIEW TRANSCRIERE (după post-procesare):\n" + "-" * 80)
        display_text = content[:2000] + ("..." if len(content) > 2000 else "")
        print(display_text)
        print("-" * 80)
    
    # Descarcă fișierele
    for file_path in output_files:
        files.download(file_path)
    
    return output_files

def transcribe_audio(audio_file, model, language, output_format, device, compute_type="float16"):
    """Transcrie un fișier audio folosind Whisper"""
    filename = os.path.basename(audio_file)
    output_dir = os.path.dirname(audio_file)
    
    start_time = time.time()
    print(f"🔊 Procesez: {filename} cu {device}")
    
    # Recomandări speciale pentru limba română
    if language == "ro":
        print("\n⚠️ Recomandări pentru transcriere în limba română:")
        print("  - Dacă apar repetitii excesive, încearcă autodetectarea limbii sau limba 'en'")
        print("  - Modelele 'medium' și 'small' pot funcționa mai bine decât 'large' pentru română")
        print("  - Dacă audio-ul conține termeni tehnici/religioși, rezultatele pot fi imprecise")
        print("  - Post-procesorul va elimina repetițiile pentru un rezultat mai curat")
    
    try:
        # Ajustare parametri pentru limba română
        if language == "ro":
            verbose = True  # Activăm output-ul verbos pentru a vedea cum detectează limba
            temperature = 0.2  # Creștem puțin temperatura pentru a reduce repetițiile
        else:
            verbose = True
            temperature = 0  # Reducem aleatorismul
        
        # Afișăm bara de progres
        print("⏳ Transcriere în curs... (poate dura ceva timp pentru fișiere lungi)")
        
        # Pentru A100, putem folosi precizie fp16 pentru viteză mai mare
        result = model.transcribe(
            audio_file, 
            language=language,
            verbose=verbose,
            fp16=(device == "cuda" and compute_type == "float16"),
            temperature=temperature,  
            beam_size=5  # Oferă rezultate mai bune cu un impact minim asupra performanței
        )
        
        end_time = time.time()
        duration = end_time - start_time
        audio_duration = result["segments"][-1]["end"] if result["segments"] else 0
        speed_factor = audio_duration / duration if duration > 0 else 0
        
        print(f"✅ Transcriere completă!")
        print(f"⏱️ Timp procesare: {format_time(duration)} pentru audio de {format_time(audio_duration)}")
        print(f"🚀 Viteză: {speed_factor:.2f}x timp real")
        
        # Salvează și descarcă rezultatele
        output_files = save_and_download_output(result, audio_file, output_format, output_dir)
        
        return output_files
    
    except Exception as e:
        print(f"❌ Eroare la transcriere: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

def create_progress_display():
    """Creează un display interactiv pentru progres"""
    display(HTML("""
    <style>
        .progress-bar {
            width: 100%;
            background-color: #f0f0f0;
            border-radius: 5px;
            margin: 10px 0;
        }
        .progress {
            width: 0%;
            height: 30px;
            background-color: #4CAF50;
            text-align: center;
            line-height: 30px;
            color: white;
            border-radius: 5px;
        }
    </style>
    <div class="progress-bar">
        <div class="progress" id="progress">0%</div>
    </div>
    <div id="status">Inițializare...</div>
    """))

def main():
    # Banner de start
    print("""
    ╔═══════════════════════════════════════════════════════════════╗
    ║                                                               ║
    ║  🎧 TRANSCRIPTOR AUDIO PENTRU GOOGLE COLAB CU A100 🎧         ║
    ║  Optimizat pentru procesare rapidă pe GPU A100                ║
    ║                                                               ║
    ╚═══════════════════════════════════════════════════════════════╝
    """)
    
    # Verifică disponibilitatea GPU
    device, device_name = verify_gpu()
    
    # Setează directorul de ieșire
    output_dir = "/content/transcriptions"
    os.makedirs(output_dir, exist_ok=True)
    
    # Alege limba, modelul și formatul
    language = choose_language()
    model_name = choose_model()
    output_format = choose_output_format()
    
    # Încarcă modelul Whisper
    print(f"\n⏳ Încărcând modelul Whisper '{model_name}'...")
    model = whisper.load_model(model_name).to(device)
    print(f"✅ Model încărcat cu succes pe {device_name}!")
    
    # Întreabă despre sursa audio
    print("\n🔍 Alege sursa audio:")
    print("1. Link YouTube")
    print("2. Încarcă fișier audio")
    print("3. Specifică URL pentru fișier audio")
    
    source_choice = input("Introdu numărul (implicit: 1): ") or "1"
    
    audio_file = None
    if source_choice == "1":
        youtube_url = input("\n🔗 Introdu link-ul YouTube: ").strip()
        audio_file = download_youtube_audio(youtube_url, output_dir)
    elif source_choice == "2":
        print("\n📂 Încarcă un fișier audio (MP3, WAV, M4A, etc.):")
        uploaded = files.upload()
        
        if uploaded:
            # Ia primul fișier încărcat
            filename = list(uploaded.keys())[0]
            audio_file = os.path.join(output_dir, filename)
            
            # Salvează fișierul în directorul de ieșire
            with open(audio_file, 'wb') as f:
                f.write(uploaded[filename])
            
            print(f"✅ Fișier încărcat: {filename}")
        else:
            print("❌ Niciun fișier încărcat.")
    elif source_choice == "3":
        audio_url = input("\n🔗 Introdu URL-ul fișierului audio: ").strip()
        
        if audio_url:
            try:
                import requests
                from urllib.parse import urlparse
                
                # Obține numele fișierului din URL
                parsed_url = urlparse(audio_url)
                filename = os.path.basename(parsed_url.path)
                
                if not filename:
                    filename = "audio_file_from_url.mp3"
                
                audio_file = os.path.join(output_dir, filename)
                
                print(f"⬇️ Descărcând {audio_url}...")
                response = requests.get(audio_url, stream=True)
                response.raise_for_status()
                
                total_size = int(response.headers.get('content-length', 0))
                block_size = 1024
                
                with open(audio_file, 'wb') as f:
                    for data in tqdm(response.iter_content(block_size), 
                                    total=total_size//block_size, 
                                    unit='KB', unit_scale=True):
                        f.write(data)
                
                print(f"✅ Fișier descărcat: {filename}")
            except Exception as e:
                print(f"❌ Eroare la descărcarea fișierului: {str(e)}")
                audio_file = None
        else:
            print("❌ URL invalid.")
    
    if audio_file:
        # Redă audio pentru verificare
        display(HTML(f"<p>🎵 <b>Audio original:</b></p>"))
        display(Audio(audio_file))
        
        # Pentru A100, putem folosi compute_type="float16" pentru viteză
        compute_type = "float16" if device == "cuda" else "float32"
        print(f"🔄 Începe transcrierea cu modelul {model_name}...")
        output_paths = transcribe_audio(audio_file, model, language, output_format, device, compute_type)
        
        if output_paths:
            print("\n📋 Rezultatele transcrierii au fost salvate și descărcate.")
            print(f"📂 Director fișiere: {output_dir}")
            
            # Opțional, adaugă un buton pentru a arhiva toate rezultatele
            if isinstance(output_paths, list) and len(output_paths) > 1:
                # Creează un fișier zip cu toate rezultatele
                import zipfile
                zip_path = os.path.join(output_dir, "transcriptions.zip")
                with zipfile.ZipFile(zip_path, 'w') as zipf:
                    for file in output_paths:
                        zipf.write(file, os.path.basename(file))
                
                print("📦 Toate fișierele au fost arhivate în transcriptions.zip")
                files.download(zip_path)
    else:
        print("❌ Nu am putut obține fișierul audio. Verifică link-ul sau încearcă să încarci manual.")
    
    print("\n✨ Procesul de transcriere s-a încheiat. Mulțumesc pentru utilizare! ✨")

if __name__ == "__main__":
    # Verifică și instalează dependențele lipsă
    !apt-get update -qq
    !apt-get install -y -qq ffmpeg
    !pip install -q IPython tqdm requests
    
    # Creează directorul pentru transcrieri
    os.makedirs("/content/transcriptions", exist_ok=True)
    
    # Rulează aplicația principală
    main()
