from roundfunc import round_func

BLOCK_SIZE = 108

def ror(i):
    # rotates a bytestring one to the right
    # b'abc' -> b'cab'
    return bytes( bytes([i[-1]]) + i[:-1] )

def xor(a,b):
    assert len(a) == len(b)
    return bytes([x^y for x,y in zip(a,b)])

def feistel(roundkeys, inp, rounds) -> bytes:
    l = bytes(inp[:54])
    r = bytes(inp[54:])
    
    assert len(l) == len(r)

    for round in range(rounds):
        k = roundkeys[round]
        r_orig = r
        r = xor(l,bytes(round_func(k, r)))
        l = r_orig

    return r + l

def keyschedule(k, n):
    assert len(k) == 16
    
    keys = [k]
    
    for _ in range(n-1):
        keys += [ror(keys[-1])]

    return keys
    

def pad(data):
    pad_len = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
    pad_len = BLOCK_SIZE if pad_len == 0 else pad_len
    return data + bytes([pad_len] * pad_len)

def unpad(data):
    pad_len = data[-1]
    assert pad_len <= BLOCK_SIZE
    return data[:-pad_len]

def encrypt(key, plain, rounds=1):
    assert len(key) == 16
    roundkeys = keyschedule(key, rounds)
    
    plain = pad(plain)
    ciphertext = b''
    for i in range(0, len(plain), BLOCK_SIZE):
        block = plain[i:i + BLOCK_SIZE]
        encrypted_block = feistel(roundkeys, block, rounds)
        ciphertext += encrypted_block
        
    return ciphertext

def decrypt(key, cipher, rounds=1):
    assert len(key) == 16
    roundkeys = keyschedule(key, rounds)

    decrypted = b''
    for i in range(0, len(cipher), BLOCK_SIZE):
        block = cipher[i:i + BLOCK_SIZE]
        decrypted_block = feistel(roundkeys[::-1], block, rounds)
        decrypted += decrypted_block
    return unpad(decrypted)