import cirq
import numpy as np
from typing import List, Tuple
import matplotlib.pyplot as plt

def create_bell_pair() -> Tuple[cirq.Qid, cirq.Qid]:
    """Creates a Bell pair in the Phi+ state."""
    q1, q2 = cirq.LineQubit.range(2)
    circuit = cirq.Circuit(
        cirq.H(q1),
        cirq.CNOT(q1, q2)
    )
    return q1, q2, circuit

def create_measurement_setup(qubit: cirq.Qid, is_eraser: bool = False) -> cirq.Circuit:
    """Creates measurement setup for either direct measurement or eraser configuration."""
    circuit = cirq.Circuit()
    
    if is_eraser:
        # Quantum eraser configuration (Hadamard before measurement)
        circuit.append([
            cirq.H(qubit),
            cirq.measure(qubit, key=f'eraser_{qubit.x}')
        ])
    else:
        # Direct measurement
        circuit.append(cirq.measure(qubit, key=f'direct_{qubit.x}'))
    
    return circuit

def delayed_choice_quantum_eraser(num_shots: int = 1000) -> dict:
    """
    Implements the delayed choice quantum eraser experiment.
    
    Args:
        num_shots: Number of times to repeat the experiment
        
    Returns:
        Dictionary containing measurement results and correlations
    """
    # Create qubits and initial Bell pair
    signal_qubit, idler_qubit, bell_circuit = create_bell_pair()
    
    # Create superposition of measurement choices
    choice_qubit = cirq.LineQubit(2)
    
    # Full circuit
    circuit = cirq.Circuit()
    circuit.append(bell_circuit)
    
    # Add random choice
    circuit.append([
        cirq.H(choice_qubit),
        cirq.measure(choice_qubit, key='choice')
    ])
    
    # Add measurements with quantum control
    circuit.append([
        cirq.measure(signal_qubit, key='signal'),
        # Controlled measurement type on idler qubit
        cirq.H(idler_qubit).controlled_by(choice_qubit),
        cirq.measure(idler_qubit, key='idler')
    ])
    
    # Run simulation
    simulator = cirq.Simulator()
    results = simulator.run(circuit, repetitions=num_shots)
    
    # Process results
    choices = results.measurements['choice']
    signals = results.measurements['signal']
    idlers = results.measurements['idler']
    
    # Analyze correlations
    eraser_correlations = []
    direct_correlations = []
    
    for i in range(num_shots):
        if choices[i][0]:  # Eraser mode
            eraser_correlations.append((signals[i][0], idlers[i][0]))
        else:  # Direct measurement mode
            direct_correlations.append((signals[i][0], idlers[i][0]))
    
    return {
        'eraser_correlations': eraser_correlations,
        'direct_correlations': direct_correlations,
        'total_shots': num_shots
    }

def analyze_results(results: dict) -> None:
    """Analyzes and plots the results of the quantum eraser experiment."""
    eraser_corr = np.array(results['eraser_correlations'])
    direct_corr = np.array(results['direct_correlations'])
    
    plt.figure(figsize=(12, 5))
    
    # Plot eraser correlations
    if len(eraser_corr) > 0:
        plt.subplot(121)
        plt.hist2d(eraser_corr[:, 0], eraser_corr[:, 1], bins=2)
        plt.title('Eraser Mode Correlations')
        plt.xlabel('Signal Qubit')
        plt.ylabel('Idler Qubit')
    
    # Plot direct correlations
    if len(direct_corr) > 0:
        plt.subplot(122)
        plt.hist2d(direct_corr[:, 0], direct_corr[:, 1], bins=2)
        plt.title('Direct Measurement Correlations')
        plt.xlabel('Signal Qubit')
        plt.ylabel('Idler Qubit')
    
    plt.tight_layout()
    plt.show()

# Run the experiment
if __name__ == "__main__":
    # Run experiment with 1000 shots
    results = delayed_choice_quantum_eraser(1000)
    
    # Analyze and visualize results
    analyze_results(results)
    
    # Print some statistics
    print(f"Total shots: {results['total_shots']}")
    print(f"Eraser mode measurements: {len(results['eraser_correlations'])}")
    print(f"Direct mode measurements: {len(results['direct_correlations'])}")
