import strawberryfields as sf
from strawberryfields.ops import Sgate, BSgate, Rgate, MeasureFock, Xgate
import matplotlib.pyplot as plt
import numpy as np
import time


def ascii_to_binary(message):
    """ Conversia unui mesaj ASCII în binar (8 biți per caracter). """
    return ''.join(format(ord(char), '08b') for char in message)


def binary_to_ascii(binary_message):
    """ Conversia unui mesaj binar înapoi în ASCII. """
    chars = [binary_message[i:i + 8] for i in range(0, len(binary_message), 8)]
    return ''.join(chr(int(char, 2)) for char in chars)


def simulate_quantum_communication(message, photons_per_bit, delay_time, fiber_length):
    """
    Simulează experimentul de semnalizare retrocausală cu fotoni înlănțuiți.
    
    Parametri:
    - message (str): Mesajul binar de transmis.
    - photons_per_bit (int): Numărul de fotoni folosiți pentru fiecare bit.
    - delay_time (float): Întârzierea temporală (în microsecunde) pentru fiecare bit.
    - fiber_length (float): Lungimea fibrei optice (în km).
    """
    
    total_bits = len(message)
    accuracies = []  # Pentru a urmări acuratețea transmisiei
    received_message = ''  # Mesajul binar primit
    
    for i, bit in enumerate(message):
        successful_transmissions = 0
        
        for photon in range(photons_per_bit):
            prog = sf.Program(2)

            with prog.context as q:
                # Stoarcerea stării de vid (echivalentul generării de perechi de fotoni înlănțuiți)
                Sgate(0.54) | q[0]
                
                # Divizarea fasciculului de lumină
                BSgate(0.5) | (q[0], q[1])
                
                # Întârzierea fotonului prin fibră optică
                delay_phase = (2 * np.pi * delay_time * 1e-6)  # Întârziere în microsecunde
                fiber_phase = (2 * np.pi * fiber_length * 1.5 / 299792)  # Lungime de fibră (refracție 1.5)
                
                Rgate(delay_phase + fiber_phase) | q[0]
                
                # Polarizare pentru a reprezenta bitul de transmis
                if bit == '1':
                    Xgate(1) | q[0]  # Aplica un semnal de schimbare
                else:
                    Xgate(0) | q[0]  # Nu face nimic pentru 0
                
                # Măsurarea semnalului
                MeasureFock() | q
            
            eng = sf.Engine("fock", backend_options={"cutoff_dim": 10})
            result = eng.run(prog)
            detected = result.samples[0][0]  # Prima măsurătoare pentru firul 0
            
            if (bit == '1' and detected > 0) or (bit == '0' and detected == 0):
                successful_transmissions += 1
        
        accuracy = successful_transmissions / photons_per_bit * 100
        accuracies.append(accuracy)
        
        # Decide ce bit a fost transmis corect (folosind majoritatea voturilor)
        received_bit = '1' if successful_transmissions > photons_per_bit / 2 else '0'
        received_message += received_bit
        
        print(f"Bit {i+1}/{total_bits}: transmisia a fost {accuracy:.2f}% precisă")
        plot_accuracy(accuracies)
    
    print("\nTransmisie completă!")
    print(f"Mesajul binar transmis: {message}")
    print(f"Mesajul binar primit:   {received_message}")
    
    # Convertirea mesajului binar primit în ASCII
    received_ascii = binary_to_ascii(received_message)
    print(f"Mesajul ASCII primit: {received_ascii}")
    

def plot_accuracy(accuracies):
    """ Graficul care arată procentul de transmisie corectă în timp real. """
    plt.clf()
    plt.plot(accuracies, marker='o', linestyle='-', color='b', label='Acuratețe transmisie (%)')
    plt.axhline(y=100, color='r', linestyle='--', label='Acuratețe ideală')
    plt.xlabel('Bit')
    plt.ylabel('Acuratețe (%)')
    plt.ylim([0, 110])
    plt.title('Acuratețea transmisiei cuantice')
    plt.legend(loc="lower right")
    plt.pause(0.5)  # Actualizează graficul la fiecare bit
    plt.show(block=False)


def user_interface():
    """ Interfața interactivă pentru utilizator. """
    print("\nExperimentul de semnalizare retrocausală cu fotoni înlănțuiți\n")
    
    ascii_message = input("Introduceți mesajul ASCII de transmis (ex: Salut): ")
    binary_message = ascii_to_binary(ascii_message)
    print(f"Mesajul binar corespunzător este: {binary_message}")
    
    photons_per_bit = int(input("Introduceți numărul de fotoni folosiți pentru fiecare bit: "))
    delay_time = float(input("Introduceți întârzierea temporală (în microsecunde): "))
    fiber_length = float(input("Introduceți lungimea fibrei optice (în km): "))
    
    print("\nSimularea a început...\n")
    simulate_quantum_communication(binary_message, photons_per_bit, delay_time, fiber_length)


if __name__ == "__main__":
    plt.ion()  # Modul interactiv
    user_interface()
