import cirq
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple
from matplotlib.animation import FuncAnimation
import random
import time

class QuantumTransmitter:
    def __init__(self, buffer_size=50):
        self.buffer_size = buffer_size
        self.sent_bits = []
        self.received_bits = []
        self.accuracy_history = []
        self.time_points = []
        self.start_time = time.time()
        
        # Setup real-time plotting
        plt.ion()
        self.fig, (self.ax1, self.ax2) = plt.subplots(2, 1, figsize=(10, 8))
        self.transmission_line, = self.ax1.plot([], [], 'b-', label='Sent bits')
        self.received_line, = self.ax1.plot([], [], 'r--', label='Received bits')
        self.accuracy_line, = self.ax2.plot([], [], 'g-', label='Accuracy')
        
        # Configure plots
        self.ax1.set_title('Bit Transmission')
        self.ax1.set_ylim(-0.5, 1.5)
        self.ax1.legend()
        self.ax2.set_title('Transmission Accuracy')
        self.ax2.set_ylim(0, 1.1)
        self.ax2.set_ylabel('Accuracy')
        self.ax2.legend()

    def create_quantum_channel(self) -> Tuple[cirq.Qid, cirq.Qid, cirq.Circuit]:
        """Creates entangled qubits for transmission."""
        q1, q2 = cirq.LineQubit.range(2)
        circuit = cirq.Circuit(
            cirq.H(q1),
            cirq.CNOT(q1, q2)
        )
        return q1, q2, circuit

    def encode_bit(self, circuit: cirq.Circuit, qubit: cirq.Qid, bit: int) -> None:
        """Encodes a classical bit into quantum state."""
        if bit == 1:
            circuit.append(cirq.X(qubit))

    def measure_qubit(self, circuit: cirq.Circuit, qubit: cirq.Qid, key: str) -> None:
        """Measures qubit in computational basis."""
        circuit.append(cirq.measure(qubit, key=key))

    def transmit_bit(self, bit: int) -> int:
        """Transmits a single bit using quantum channel."""
        q1, q2, circuit = self.create_quantum_channel()
        
        # Encode bit
        self.encode_bit(circuit, q1, bit)
        
        # Measure
        self.measure_qubit(circuit, q2, 'received')
        
        # Simulate
        simulator = cirq.Simulator()
        result = simulator.run(circuit, repetitions=1)
        
        return int(result.measurements['received'][0][0])

    def update_plots(self):
        """Updates real-time plots."""
        # Update transmission plot
        if len(self.sent_bits) > self.buffer_size:
            display_sent = self.sent_bits[-self.buffer_size:]
            display_received = self.received_bits[-self.buffer_size:]
            x_data = range(len(display_sent))
        else:
            display_sent = self.sent_bits
            display_received = self.received_bits
            x_data = range(len(display_sent))

        self.transmission_line.set_data(x_data, display_sent)
        self.received_line.set_data(x_data, display_received)
        self.ax1.set_xlim(-1, len(x_data))
        
        # Update accuracy plot
        self.accuracy_line.set_data(self.time_points, self.accuracy_history)
        self.ax2.set_xlim(0, max(self.time_points) + 1 if self.time_points else 1)
        
        plt.pause(0.1)

    def calculate_accuracy(self) -> float:
        """Calculates current transmission accuracy."""
        if not self.sent_bits:
            return 0
        correct = sum(s == r for s, r in zip(self.sent_bits, self.received_bits))
        return correct / len(self.sent_bits)

    def transmit_stream(self, num_bits: int, delay: float = 0.5):
        """Transmits a stream of random bits with real-time visualization."""
        for i in range(num_bits):
            # Generate and transmit random bit
            bit_to_send = random.randint(0, 1)
            received_bit = self.transmit_bit(bit_to_send)
            
            # Update data
            self.sent_bits.append(bit_to_send)
            self.received_bits.append(received_bit)
            current_time = time.time() - self.start_time
            self.time_points.append(current_time)
            self.accuracy_history.append(self.calculate_accuracy())
            
            # Update visualization
            self.update_plots()
            
            # Add delay for visualization
            time.sleep(delay)

def main():
    transmitter = QuantumTransmitter(buffer_size=50)
    print("Starting quantum transmission...")
    transmitter.transmit_stream(num_bits=100, delay=0.1)
    
    # Final accuracy
    final_accuracy = transmitter.calculate_accuracy()
    print(f"\nTransmission completed!")
    print(f"Final accuracy: {final_accuracy:.2%}")
    
    # Keep plots open
    plt.ioff()
    plt.show()

if __name__ == "__main__":
    main()
