import os
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from mlkem.ml_kem import ML_KEM
from mlkem.parameter_set import ML_KEM_1024

# Unlike RSA, which acts as a general purpose asymetric encryption/decryption mechanism, ML-KEM is a
# Key Encapsulation Mechanism (KEM), which does is not able to encrypt/decrypt a given payload!

# For this example, lets pretend we are Alice...
# Alice will set up ML-KEM, and generate the keypair (as we would with RSA)
alice = ML_KEM(ML_KEM_1024, fast=False)
# The keypair is an encapsulation key and a decapsulation key (public and private)
enc_key, dec_key = alice.key_gen()
# Alice can send the encapsulation key to Bob, and keep the decapsulation key safe!

# Now we'll pretend to be Bob, and setup ML-KEM with the same paramater set.
bob = ML_KEM(ML_KEM_1024, fast=False)
# Bob consumes the encapsulation key from Alice and uses this to generate a SHARED KEY and some ciphertext.
shared_key_bob, ciphertext = bob.encaps(enc_key)
# Bob will send the ciphertext back to alice, which she will use to derive the same SHARED KEY bob just created

# Notice how bob didn't decide what the SHARED KEY was before generating the ciphertext...
# The random input data to encaps gets modified during the creation of the ciphertext, making it impossible to encrypt
# a given piece of data. Instead, the output of the encaps func (the shared key) will be the same as the output of the
# decaps function that you will see Alice run now...

# As Alice, I load the decapsulation (private) key I created earlier, along with the ciphertext Bob sent me, and I use
# this to derive the same SHARED KEY that Bob created!
shared_key_alice = alice.decaps(dec_key, ciphertext)

# Alice and Bob now have a SHARED KEY, so lets prove it...
assert shared_key_alice == shared_key_bob

# Now we just add asymmetric encryption/decryption algorithm like AES to exchange data securely between Alice and Bob
# Note that the asymmetric algorithm chosen ofcourse also needs to be "quantum-safe".

# As Alice, we'll encrypt our secret data
secret_data_alice = b"Hey Bob, here's my secret cookie ingredient. It's maple syrup - don't tell anyone! Bye"
# For AES, we'll also need to generate an initiialisation vector. (Also called a nonce)
iv = os.urandom(12)
encrypted_message = AESGCM(shared_key_alice).encrypt(iv, secret_data_alice, None)

# As Bob, let's decrypt this using our shared key
secret_data_bob = AESGCM(shared_key_bob).decrypt(iv, encrypted_message, None)

# Alice and Bob now have shared secret data, so let's prove it...
assert secret_data_alice == secret_data_bob
