import os
from tkinter import filedialog, Tk
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from transformers import MarianTokenizer, MarianMTModel
import torch
import numpy as np
from openvino import Core

def setup_translation(from_code, to_code, device):
    """Configurează modelul de traducere MarianMT cu OpenVINO"""
    model_name = f"Helsinki-NLP/opus-mt-{from_code}-{to_code}"
    max_length_source = 512
    max_length_target = 512
    try:
        print(f"Încarc modelul {model_name}...")
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        print("Model PyTorch încărcat.")

        print("Pregătesc exportul ONNX...")
        dummy_input = tokenizer("Test", return_tensors="pt", padding="max_length", max_length=max_length_source)
        input_ids = dummy_input["input_ids"].to(torch.int32)
        attention_mask = dummy_input["attention_mask"].to(torch.int32)
        decoder_input_ids = torch.ones((1, max_length_target), dtype=torch.int32) * tokenizer.pad_token_id
        decoder_input_ids[0, 0] = tokenizer.bos_token_id
        decoder_attention_mask = torch.zeros((1, max_length_target), dtype=torch.int32)
        decoder_attention_mask[0, 0] = 1

        torch.onnx.export(
            model,
            (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask),
            "model.onnx",
            opset_version=11,
            input_names=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"],
            output_names=["logits"],
            # Fără dynamic_axes pentru forme complet statice
        )
        print("Export ONNX complet.")

        print("Configurez OpenVINO...")
        ie = Core()
        print("Core OpenVINO inițializat.")
        model_openvino = ie.read_model(model="model.onnx")
        print("Model ONNX citit.")
        compiled_model = ie.compile_model(model_openvino, device)
        print(f"Model {model_name} încărcat pe {device} via OpenVINO.")
        return tokenizer, compiled_model, max_length_source, max_length_target
    except Exception as e:
        print(f"Eroare la configurarea modelului {model_name}: {str(e)}")
        raise  # Ridicăm eroarea pentru a vedea stiva completă

def select_device():
    """Detectează și selectează dispozitivul (NPU -> GPU -> CPU)"""
    ie = Core()
    devices = ie.available_devices
    print(f"Dispozitive disponibile: {devices}")

    if "NPU" in devices:
        device = "NPU"
        device_name = "NPU (Intel AI Boost)"
        print("NPU detectat, folosesc NPU")
    elif "GPU" in devices:
        device = "GPU"
        device_name = "GPU (Intel Arc)"
        print("GPU detectat, folosesc GPU")
    else:
        device = "CPU"
        device_name = "CPU"
        print("Niciun NPU sau GPU detectat, folosesc CPU")
    return device, device_name

def choose_language():
    """Alege limbile sursă și destinație"""
    print("Vrei să folosești traducerea implicită EN -> RO? (da/nu)")
    choice = input().strip().lower()
    if choice in ["da", "d", "yes", "y"]:
        return "en", "ro"
    if choice not in ["nu", "n", "no"]:
        print("Răspuns invalid. Folosesc implicit EN -> RO.")
        return "en", "ro"
    languages = [
        ("en", "Engleză"), ("es", "Spaniolă"), ("fr", "Franceză"), ("de", "Germană"),
        ("zh", "Chineză"), ("ru", "Rusă"), ("ar", "Arabă"), ("ro", "Română")
    ]
    print("\nSelectează limba sursă (introduce numărul):")
    for i, (code, name) in enumerate(languages, 1):
        print(f"{i}. {name} ({code})")
    try:
        from_idx = int(input()) - 1
        if not 0 <= from_idx < len(languages):
            raise ValueError
        from_code = languages[from_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit Engleză (en).")
        from_code = "en"
    print("\nSelectează limba destinație (introduce numărul):")
    for i, (code, name) in enumerate(languages, 1):
        print(f"{i}. {name} ({code})")
    try:
        to_idx = int(input()) - 1
        if not 0 <= to_idx < len(languages) or to_idx == from_idx:
            raise ValueError
        to_code = languages[to_idx][0]
    except (ValueError, IndexError):
        print("Selecție invalidă. Folosesc implicit Română (ro).")
        to_code = "ro"
    return from_code, to_code

def translate_file(args):
    """Tradu un fișier folosind OpenVINO"""
    file_path, from_code, to_code, output_dir, device_name, device = args
    file_name = os.path.basename(file_path)

    try:
        tokenizer, model, max_length_source, max_length_target = setup_translation(from_code, to_code, device)
        if not tokenizer or not model:
            return

        print(f"Procesez: {file_name} cu {device_name}")
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()

        sentences = [s.strip() + '.' for s in text.split('.') if s.strip()]
        translated_sentences = []

        with tqdm(total=len(sentences), desc=f"Traduc {file_name}", leave=False) as pbar:
            for sentence in sentences:
                inputs = tokenizer(sentence, return_tensors="np", padding="max_length", max_length=max_length_source)
                input_ids = inputs["input_ids"].astype("int32")
                attention_mask = inputs["attention_mask"].astype("int32")
                decoder_input_ids = np.ones((1, max_length_target), dtype="int32") * tokenizer.pad_token_id
                decoder_input_ids[0, 0] = tokenizer.bos_token_id
                decoder_attention_mask = np.zeros((1, max_length_target), dtype="int32")
                decoder_attention_mask[0, 0] = 1

                for i in range(1, max_length_target):
                    outputs = model.infer_new_request({
                        "input_ids": input_ids,
                        "attention_mask": attention_mask,
                        "decoder_input_ids": decoder_input_ids,
                        "decoder_attention_mask": decoder_attention_mask
                    })
                    next_token_logits = outputs["logits"][0, i-1, :]
                    next_token_id = np.argmax(next_token_logits)
                    if next_token_id == tokenizer.eos_token_id or i == max_length_target - 1:
                        break
                    decoder_input_ids[0, i] = next_token_id
                    decoder_attention_mask[0, i] = 1

                translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
                translated_sentences.append(translated_text)
                pbar.update(1)

        translated_text = " ".join(translated_sentences)
        output_path = os.path.join(output_dir, file_name)
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(translated_text)

        print(f"Complet: {file_name}")
    except Exception as e:
        print(f"Eroare la procesarea {file_name}: {str(e)}")

def translate_files():
    """Procesează traducerea mai multor fișiere"""
    from_code, to_code = choose_language()
    print(f"\nTraducere selectată: {from_code} -> {to_code}")
    device, device_name = select_device()

    root = Tk()
    root.withdraw()
    files = filedialog.askopenfilenames(
        title="Selectează fișierele de tradus",
        filetypes=[("Text files", "*.txt")]
    )
    root.destroy()

    if not files:
        print("Niciun fișier selectat. Program terminat.")
        return

    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(script_dir, to_code)
    os.makedirs(output_dir, exist_ok=True)

    num_processes = cpu_count()
    print(f"\nDetectate {num_processes} nuclee CPU. Procesez {len(files)} fișiere în paralel folosind {device_name}...")

    tasks = [(file_path, from_code, to_code, output_dir, device_name, device) for file_path in files if file_path.endswith('.txt')]
    with Pool(processes=num_processes) as pool:
        pool.map(translate_file, tasks)

    print(f"\nTraducere completă! Fișierele sunt în folderul '{to_code}'.")

if __name__ == "__main__":
    translate_files()
