import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import HistGradientBoostingClassifier

# === SETĂRI GENERALE ===
BITSTREAM_LENGTH = 1000     # Numărul de biți de transmis
REPETITIONS = 1             # Simulăm o singură rundă
UPDATE_GRAPH_EVERY = 50     # Actualizare grafic la fiecare N biți
TRAINING_WINDOW = 200       # Câți biți păstrăm pentru antrenarea ML
BITSTREAM_INTERVAL = 0.0    # Fără întârziere artificială între biți

# === ZGOMOT ===
MIN_NOISE = 0.01
MAX_NOISE = 0.3
NOISE_LEVEL = 0.1           # Valoare inițială a zgomotului

# === MODEL ML ===
scaler = StandardScaler()
model = HistGradientBoostingClassifier(max_iter=100, max_depth=5, random_state=42)
is_model_trained = False

# === DATE GLOBALE ===
bitstream = np.random.randint(0, 2, BITSTREAM_LENGTH)  # Flux binar aleatoriu
received_bits = []
recent_bits = []              # Context: ultimii biți decodați
X_train = []                  # Date de antrenament
y_train = []                  # Etichete de antrenament
correct = 0                   # Biți corect decodați
total = 0                     # Total biți procesați
accuracy_log = []             # Jurnal acuratețe

# === FUNCȚII AUXILIARE ===
def test_trc(bit):
    """Simulează zgomot pe bitul trimis cu multiple măsurători."""
    num_samples = 10  # Număr de măsurători per bit
    noise = np.random.normal(0, NOISE_LEVEL, num_samples)
    measured = bit + noise
    return measured, [1 if abs(m - bit) < 0.5 else 0 for m in measured]

def extract_features(measured_counts):
    """Extrage caracteristici din semnal."""
    return np.array([np.mean(measured_counts), np.std(measured_counts)])

def error_correction(bits):
    """Corecție simplificată - majoritate pe 3 biți cu padding."""
    bits = bits.copy()
    # Adaugă padding dacă este necesar
    while len(bits) % 3 != 0:
        bits.append(0)
    
    corrected = []
    for i in range(0, len(bits), 3):
        window = bits[i:i+3]
        majority = int(round(np.mean(window)))
        corrected.append(majority)
    return corrected

def bitstream_to_text(bits):
    """Convertește biții în text (ASCII simplu)."""
    bytes_list = []
    for i in range(0, len(bits), 8):
        byte_bits = bits[i:i+8]
        if len(byte_bits) < 8:
            break
        byte_str = ''.join(str(b) for b in byte_bits)
        bytes_list.append(int(byte_str, 2))
    return bytes(bytes_list).decode('ascii', errors='ignore')

# === GRAFIC LIVE ===
plt.ion()
fig, ax = plt.subplots()
line, = ax.plot([], [], 'b-', label='Acuratețe (%)')
ax.set_xlabel('Biți procesați')
ax.set_ylabel('Acuratețe (%)')
ax.legend()
ax.grid(True)

# === MAIN LOOP ===
print("🚀 Începe simularea...")

try:
    round_idx = 0
    while True:
        round_idx += 1
        print(f"\n=== Runda {round_idx} ===")
        round_bits = []
        noise_adjustment = NOISE_LEVEL

        X_batch, y_batch = [], []

        for idx, sent_bit in enumerate(bitstream):
            measured_counts, check_counts = test_trc(sent_bit)
            features = extract_features(measured_counts)

            # Context din ultimii biți decodați
            context_bits = np.zeros(8)
            if len(recent_bits) > 0:
                valid = min(len(recent_bits), 8)
                context_bits[-valid:] = recent_bits[-valid:]

            features = np.concatenate((features, context_bits))

            # Scalare și predicție
            if is_model_trained:
                features_scaled = scaler.transform([features])[0]
                decoded_bit = model.predict([features_scaled])[0]
            else:
                # Folosește caracteristici brute pentru primele predicții
                decoded_bit = 1 if np.mean(measured_counts) > 0.5 else 0
                features_scaled = features  # Nu scalează încă

            # Verificare corectitudine
            is_correct = (decoded_bit == sent_bit)
            total += 1
            correct += is_correct
            accuracy = correct / total * 100
            accuracy_log.append(accuracy)

            # Ajustare adaptivă zgomot
            if not is_correct and noise_adjustment < MAX_NOISE:
                noise_adjustment += 0.005
            elif is_correct and noise_adjustment > MIN_NOISE:
                noise_adjustment -= 0.005
            globals()['NOISE_LEVEL'] = noise_adjustment

            # Antrenare ML
            X_train.append(features)
            y_train.append(sent_bit)
            
            # Menține dimensiunea ferestrei de antrenament
            if len(X_train) > TRAINING_WINDOW:
                X_train.pop(0)
                y_train.pop(0)

            # Antrenează periodic modelul
            if len(X_train) >= 100 and not is_model_trained:
                scaler.fit(X_train)
                X_train_scaled = scaler.transform(X_train)
                model.fit(X_train_scaled, y_train)
                is_model_trained = True

            # Actualizează contextul
            recent_bits.append(decoded_bit)
            if len(recent_bits) > 8:
                recent_bits.pop(0)
            
            round_bits.append(decoded_bit)
            received_bits.append(decoded_bit)

            # Actualizare grafic
            if idx % UPDATE_GRAPH_EVERY == 0 or idx == len(bitstream) - 1:
                line.set_xdata(range(len(accuracy_log)))
                line.set_ydata(accuracy_log)
                ax.relim()
                ax.autoscale_view()
                fig.canvas.draw()
                fig.canvas.flush_events()

            # Logging
            if idx % 100 == 0 or idx == len(bitstream) - 1:
                print(f"Bit #{idx+1}/{len(bitstream)} | Trimis: {sent_bit} | {'Corect' if is_correct else 'GRESIT'} | Acuratețe: {accuracy:.1f}% | Zgomot: {NOISE_LEVEL:.3f}")

        # Reconstrucție mesaj
        corrected_bits = error_correction(round_bits)
        reconstructed = bitstream_to_text(corrected_bits)
        print(f"\n📣 Mesaj reconstruit (cu corectare): {reconstructed}\n")

except KeyboardInterrupt:
    print("\n🛑 Simulare întreruptă manual.")
    plt.ioff()
    plt.show()
