# Suprimăm mesajele TensorFlow
import os
import warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings('ignore')

import strawberryfields as sf
from strawberryfields.ops import Sgate, BSgate, Rgate, MeasureFock, Xgate, Dgate
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple
import time

class GhostInterferenceExperiment:
    def __init__(self, fiber_length: float, detection_rate: float, photons_per_signal: int, wavelength: float):
        self.FIBER_LENGTH = fiber_length
        self.REFRACTIVE_INDEX = 1.5
        self.DETECTION_RATE = detection_rate
        self.WAVELENGTH = wavelength
        self.PHOTONS_PER_SIGNAL = photons_per_signal
        self.eng = sf.Engine("fock", backend_options={"cutoff_dim": 15})

    def simulate_single_photon(self, bit: int) -> float:
        """Simulează comportamentul unui singur foton și returnează puterea măsurată"""
        prog = sf.Program(2)
        
        with prog.context as q:
            # Generare stare cuantică cu parametri optimizați
            Sgate(0.1) | q[0]  # Compresie redusă pentru zgomot minim
            BSgate(np.pi/3) | (q[0], q[1])  # Divizor de fascicul optimizat
            
            # Calculăm și aplicăm întârzierea prin fibră
            delay_time = (self.FIBER_LENGTH * self.REFRACTIVE_INDEX * 1000) / 299.792
            phase = 2 * np.pi * delay_time * 1e-6 * self.DETECTION_RATE
            Rgate(phase) | q[0]
            
            # Codificare bit cu contrast mărit
            if bit == 1:
                Xgate(2.0) | q[0]  # Deplasare mare pentru bit 1
            else:
                Dgate(0.01) | q[0]  # Zgomot minim pentru bit 0
            
            # Măsurători cuantice
            MeasureFock() | q[0]
            MeasureFock() | q[1]
        
        # Rulăm simularea și returnăm rezultatul
        result = self.eng.run(prog)
        return float(result.samples[0][0])

    def simulate_transmission(self, message: str) -> Tuple[str, List[float]]:
        """Simulează transmisia completă a mesajului"""
        received_bits = []
        accuracies = []
        
        print("\nSimulare transmisie retrocauzală - evenimentele apar în ordine inversă:")
        print("----------------------------------------------------------------")
        
        for idx, bit in enumerate(message):
            bit_val = int(bit)
            measurements = []
            time_before = -50  # microsecunde
            
            # Efectuăm măsurători multiple pentru fiecare bit
            for _ in range(self.PHOTONS_PER_SIGNAL):
                power = self.simulate_single_photon(bit_val)
                measurements.append(power)
            
            # Calculăm statistici pentru decizie
            avg_power = np.mean(measurements)
            max_power = np.max(measurements)
            
            # Calculăm acuratețea pentru acest bit
            if bit_val == 1:
                success_rate = sum(m > 0.5 for m in measurements) / self.PHOTONS_PER_SIGNAL
            else:
                success_rate = sum(m < 0.5 for m in measurements) / self.PHOTONS_PER_SIGNAL
            
            accuracy = success_rate * 100
            
            # Determinăm bitul primit
            received_bit = '1' if avg_power > 0.5 else '0'
            
            # Afișăm informații despre transmisie
            print(f"\nBit {idx+1}:")
            print(f"  t = {time_before}µs: Bit detectat: {received_bit}")
            print(f"  t = 0µs: Bit transmis: {bit}")
            print(f"  Putere medie: {avg_power:.3f}")
            print(f"  Putere maximă: {max_power:.3f}")
            print(f"  Acuratețe: {accuracy:.1f}%")
            
            received_bits.append(received_bit)
            accuracies.append(accuracy)
            
        return ''.join(received_bits), accuracies

    def plot_results(self, accuracies: List[float], message: str, received: str):
        """Vizualizează rezultatele transmisiei"""
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
        
        # Grafic acuratețe
        ax1.plot(accuracies, 'bo-', label='Acuratețe măsurată')
        ax1.axhline(y=100, color='r', linestyle='--', label='Ideal')
        ax1.set_xlabel('Număr bit')
        ax1.set_ylabel('Acuratețe (%)')
        ax1.set_title('Acuratețe Transmisie Retrocauzală')
        ax1.grid(True)
        ax1.legend()
        
        # Grafic comparație biți
        x = np.arange(len(message))
        width = 0.35
        ax2.bar(x - width/2, [int(b) for b in message], width, label='Biți transmiși', color='blue', alpha=0.5)
        ax2.bar(x + width/2, [int(b) for b in received], width, label='Biți primiți', color='red', alpha=0.5)
        ax2.set_xlabel('Poziție bit')
        ax2.set_ylabel('Valoare bit')
        ax2.set_title('Comparație Biți Transmiși vs. Primiți')
        ax2.set_xticks(x)
        ax2.legend()
        
        plt.tight_layout()
        plt.show()

def get_validated_input(prompt: str, min_val: float, default_val: float, unit: str = "") -> float:
    """Obține și validează input-ul utilizatorului"""
    while True:
        try:
            val = input(f"{prompt} [{unit}] (implicit {default_val}): ").strip()
            if val == "":
                return default_val
            val = float(val)
            if val >= min_val:
                return val
            print(f"Eroare: Valoarea trebuie să fie cel puțin {min_val}")
        except ValueError:
            print("Eroare: Introduceți un număr valid")

def main():
    """Funcția principală a programului"""
    print("Experiment de Comunicare Cuantică Retrocauzală")
    print("---------------------------------------------")
    print("\nConfigurare parametri experiment:")
    print("(Apăsați ENTER pentru a folosi valorile implicite din experimentul original)\n")

    # Obținem parametrii de la utilizator
    fiber_length = get_validated_input(
        "Lungimea fibrei optice (în experiment: 10 km)", 
        min_val=0.1, 
        default_val=10.0,
        unit="km"
    )

    detection_rate = get_validated_input(
        "Rata de detecție (în experiment: 10 MHz = 10e6 Hz)", 
        min_val=1e5, 
        default_val=1e7,
        unit="Hz"
    )

    photons_signal = get_validated_input(
        "Numărul de fotoni per bit (în experiment: 100)", 
        min_val=10, 
        default_val=100,
        unit="fotoni"
    )

    wavelength = get_validated_input(
        "Lungimea de undă (în experiment: 702 nm)", 
        min_val=100, 
        default_val=702,
        unit="nm"
    )

    # Inițializăm experimentul
    experiment = GhostInterferenceExperiment(
        fiber_length=fiber_length,
        detection_rate=detection_rate,
        photons_per_signal=int(photons_signal),
        wavelength=wavelength
    )
    
    # Afișăm configurația
    print("\nConfigurație finală:")
    print(f"- Lungime fibră: {experiment.FIBER_LENGTH} km")
    print(f"- Rată detecție: {experiment.DETECTION_RATE/1e6} MHz")
    print(f"- Fotoni per semnal: {experiment.PHOTONS_PER_SIGNAL}")
    print(f"- Lungime de undă: {experiment.WAVELENGTH} nm")
    
    # Obținem mesajul de la utilizator
    while True:
        message = input("\nIntroduceți mesajul binar (secvență de 0 și 1): ")
        if all(bit in '01' for bit in message):
            break
        print("Eroare: Introduceți doar 0 și 1")
    
    # Rulăm simularea
    print("\nÎncepere simulare transmisie retrocauzală...")
    received_message, accuracies = experiment.simulate_transmission(message)
    
    # Afișăm rezultatele
    print("\nRezultate finale:")
    print(f"Mesaj transmis:  {message}")
    print(f"Mesaj primit:    {received_message}")
    print(f"Acuratețe medie: {np.mean(accuracies):.2f}%")
    
    # Afișăm graficele
    experiment.plot_results(accuracies, message, received_message)

if __name__ == "__main__":
    main()
