import cirq
import numpy as np
import matplotlib.pyplot as plt
from time import sleep
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
import uuid
import json

# === CONFIGURARE SIMULARE ===
NOISE_LEVEL = 0.05
REPETITIONS = 300  # redus pentru viteza
BITSTREAM_INTERVAL = 0.0  # eliminat delay-ul
TRAINING_WINDOW = 100
MIN_NOISE = 0.01
MAX_NOISE = 0.1
UPDATE_GRAPH_EVERY = 10

# === INITIALIZARE MODEL ML ===
model = RandomForestClassifier(n_estimators=50, max_depth=5, n_jobs=-1)
scaler = StandardScaler()
X_train = []
y_train = []
is_model_trained = False

recent_bits = []

# === FUNCTII ===
def test_trc(bit):
    q0, q1 = cirq.LineQubit.range(2)
    circuit = cirq.Circuit([
        cirq.H(q0),
        cirq.CNOT(q0, q1),
    ])

    if bit == 1:
        circuit.append(cirq.X(q1))

    noise = cirq.DepolarizingChannel(p=NOISE_LEVEL)
    circuit = circuit.with_noise(noise)

    circuit.append([
        cirq.measure(q0, key='result'),
        cirq.measure(q1, key='check')
    ])

    simulator = cirq.Simulator()
    result = simulator.run(circuit, repetitions=REPETITIONS)
    return result.histogram(key='result'), result.histogram(key='check')

def extract_features(counts):
    total = sum(counts.values())
    return np.array([counts.get(0, 0)/total, counts.get(1, 0)/total])

def error_correction(bits, repeat=3):
    corrected = []
    for i in range(0, len(bits), repeat):
        chunk = bits[i:i+repeat]
        corrected.append(1 if sum(chunk) > len(chunk)/2 else 0)
    return corrected

def collect_calibration_data():
    counts_0, _ = test_trc(0)
    counts_1, _ = test_trc(1)
    p0_g0 = counts_0.get(0, 0) / REPETITIONS
    p1_g1 = counts_1.get(1, 0) / REPETITIONS
    return p0_g0, p1_g1

def draw_quantum_circuit():
    q0, q1 = cirq.LineQubit.range(2)
    circuit = cirq.Circuit([
        cirq.H(q0),
        cirq.CNOT(q0, q1),
        cirq.measure(q0),
        cirq.measure(q1)
    ])
    print("\nCircuit cuantic folosit:")
    print(circuit)

def text_to_bitstream(text):
    bits = ''.join(format(ord(c), '08b') for c in text)
    return [int(b) for b in bits]

def bitstream_to_text(bitstream):
    bytes_list = [bitstream[i:i+8] for i in range(0, len(bitstream), 8) if len(bitstream[i:i+8]) == 8]
    chars = []
    for byte in bytes_list:
        char_code = sum([b << (7-i) for i, b in enumerate(byte)])
        try:
            chars.append(chr(char_code))
        except:
            chars.append('?')
    return ''.join(chars)

# === GRAFIC LIVE ===
plt.ion()
fig, ax = plt.subplots()
ax.set_title("Acuratețe în timp real")
ax.set_xlabel("Bit transmis")
ax.set_ylabel("Acuratețe (%)")
ax.set_ylim(0, 100)
ax.grid(True)
line, = ax.plot([], [], 'g-', marker='o', markersize=4)

# === MAIN LOOP ===
print("\U0001f4e1 Sistem activ: Transmitere Retrocauzala Cuantica cu Invatare Automata")
draw_quantum_circuit()

message = input("\n🌤 Introduceți un mesaj: ")
bitstream = text_to_bitstream(message)
length = len(bitstream)

print(f"\n💌 Mesajul '{message}' va fi transmis ca {length} biți.")
input("Press Enter to start transmission...")

accuracy_log = []
correct = 0
total = 0
received_bits = []
p0_g0, p1_g1 = collect_calibration_data()

log_file = open(f"received_messages_{uuid.uuid4()}.txt", "w", encoding="utf-8")
training_data_file = open(f"training_data_{uuid.uuid4()}.json", "w", encoding="utf-8")

try:
    round_idx = 0
    while True:
        round_bits = []
        noise_adjustment = NOISE_LEVEL
        round_idx += 1
        print(f"\n=== Runda {round_idx} ===")

        for idx, sent_bit in enumerate(bitstream):
            measured_counts, check_counts = test_trc(sent_bit)
            features = extract_features(measured_counts)

            context_bits = np.array([recent_bits[-(i+1)] if len(recent_bits) > i else -1 for i in range(8)])
            features = np.concatenate((features, context_bits))

            if is_model_trained:
                features_scaled = scaler.transform([features])[0]
                decoded_bit = model.predict([features_scaled])[0]
            else:
                features_scaled = scaler.fit_transform([features])[0]
                decoded_bit = 1 if features_scaled[1] > (p0_g0 + p1_g1) / 2 else 0

            is_correct = (decoded_bit == sent_bit)
            total += 1
            if is_correct:
                correct += 1
            accuracy = correct / total * 100
            accuracy_log.append(accuracy)

            if not is_correct and noise_adjustment < MAX_NOISE:
                noise_adjustment += 0.005
            elif is_correct and noise_adjustment > MIN_NOISE:
                noise_adjustment -= 0.005
            globals()['NOISE_LEVEL'] = noise_adjustment

            X_train.append(features_scaled.tolist())
            y_train.append(sent_bit)
            recent_bits.append(sent_bit)
            round_bits.append(decoded_bit)
            received_bits.append(decoded_bit)

            if len(X_train) > TRAINING_WINDOW:
                X_train.pop(0)
                y_train.pop(0)

            if len(X_train) >= 2:
                model.fit(X_train, y_train)
                is_model_trained = True

            if idx % UPDATE_GRAPH_EVERY == 0 or idx == len(bitstream) - 1:
                line.set_xdata(range(len(accuracy_log)))
                line.set_ydata(accuracy_log)
                ax.relim()
                ax.autoscale_view()
                fig.canvas.draw()
                fig.canvas.flush_events()

            print(f"Bit #{idx+1}/{length} | Trimis: {sent_bit} | Decodificat: {decoded_bit} | {'Corect' if is_correct else 'GRESIT'} | Acuratețe: {accuracy:.1f}% | Zgomot: {NOISE_LEVEL:.3f}")

            if BITSTREAM_INTERVAL > 0:
                sleep(BITSTREAM_INTERVAL)

        corrected_bits = error_correction(round_bits)
        reconstructed = bitstream_to_text(corrected_bits)
        print(f"\n📣 Mesaj reconstruit (cu corectare): {reconstructed}\n")
        log_file.write(f"Runda {round_idx} | Mesaj primit: {reconstructed} | Acuratețe: {accuracy:.1f}%\n")
        log_file.flush()

except KeyboardInterrupt:
    print("\nOprit simularea.")
    training_data = {'X_train': X_train, 'y_train': y_train}
    json.dump(training_data, training_data_file, indent=2)

    print("\n=== Statistici finale ===")
    print(f"Total biți transmiși: {total}")
    print(f"Biți corecți: {correct}")
    print(f"Acuratețe finală: {accuracy:.1f}%")
    print(f"Nivel zgomot final: {NOISE_LEVEL:.3f}")

    log_file.write(f"\nStatistici finale:\n")
    log_file.write(f"Total biți: {total}\n")
    log_file.write(f"Biți corecți: {correct}\n")
    log_file.write(f"Acuratețe: {accuracy:.1f}%\n")

    log_file.close()
    training_data_file.close()
