import os
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import padding

def generate_key():
    """Generează o cheie aleatoare de 256 biti pentru criptare/descriptare."""
    return os.urandom(32)

def encrypt_file(input_filename, output_filename, key):
    """Criptează un fișier folosind AES cu modul CBC și o cheie dată."""
    iv = os.urandom(16)
    cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
    encryptor = cipher.encryptor()
    
    with open(input_filename, 'rb') as f:
        plaintext = f.read()

    # Aplicăm padding-ul pentru a se potrivi dimensiunea blocului
    padder = padding.PKCS7(algorithms.AES.block_size).padder()
    padded_data = padder.update(plaintext) + padder.finalize()

    ciphertext = encryptor.update(padded_data) + encryptor.finalize()

    with open(output_filename, 'wb') as f:
        f.write(iv + ciphertext)

def decrypt_file(input_filename, output_filename, key):
    """Decriptează un fișier folosind AES cu modul CBC și o cheie dată."""
    with open(input_filename, 'rb') as f:
        iv = f.read(16)
        ciphertext = f.read()

    cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
    decryptor = cipher.decryptor()
    
    padded_plaintext = decryptor.update(ciphertext) + decryptor.finalize()

    # Înlăturăm padding-ul
    unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
    plaintext = unpadder.update(padded_plaintext) + unpadder.finalize()

    with open(output_filename, 'wb') as f:
        f.write(plaintext)

def main():
    key = generate_key()
    print(f"Cheie generată: {key.hex()}")

    try:
        encrypt_file('input.txt', 'encrypted.bin', key)
        print("Fișierul a fost criptat cu succes.")
        
        decrypt_file('encrypted.bin', 'decrypted.txt', key)
        print("Fișierul a fost decriptat cu succes.")

    except Exception as e:
        print(f"A apărut o eroare: {e}")

if __name__ == '__main__':
    main()