Post

ISITDTU Quals 2024

I wasn’t able to solve it during the event, so I took notes to capture the knowledge for later review.

Sharemixer1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import random   # TODO: heard that this is unsafe but nvm
from Crypto.Util.number import getPrime, bytes_to_long

flag = bytes_to_long(open("flag.txt", "rb").read())
p = getPrime(256)
assert flag < p
l = 32


def share_mixer(xs):
    cs = [random.randint(1, p - 1) for _ in range(l - 1)]
    cs.append(flag)

    # mixy mix
    random.shuffle(xs)
    random.shuffle(cs)

    shares = [sum((c * pow(x, i, p)) %
                  p for i, c in enumerate(cs)) % p for x in xs]
    return shares


if __name__ == "__main__":
    try:
        print(f"{p = }")
        queries = input("Gib me the queries: ")
        xs = list(map(lambda x: int(x) % p, queries.split()))

        if 0 in xs or len(xs) > 256:
            print("GUH")
            exit(1)

        shares = share_mixer(xs)
        print(f"{shares = }")
    except:
        exit(1)

We get $shares_j = \displaystyle \sum_{i=1}^{n}c_i*x_j^i \% p$ with random $cs$, $xs$, we can send $xs$ list before doing random and the limit is 255. But the random is just shuffle and sum does not change much which means if we send 1, 2, 2, … in the shares there will be only one $shares_j$ which we know it from $x = 1$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from pwn import*
from sage.all import*
import random
from Crypto.Util.number import*

share = ...
p = ...

def count_occurrences(lst):
    counts = {k: lst.count(k) for k in set(lst)}
    return counts

def get_values_with_count(counts, num):
    return [k for k, v in counts.items() if v ==num]

counts = count_occurrences(share)

def gen_vector(vector, a, b, row):
    vt1 = vector.copy()
    vt2 = vector.copy()
    vt1[row] = vt2[31-row] = a
    vt1[31-row] = vt2[row] = b
    return [vt1, vt2]

def gen_vector2(vector, a, row):
    vector[row] = a
    return vector

vts = [[0 for _ in range(32)]]

rs = []
# print(f"{vts = }")
for cnt in range(2, 15):
    vls = get_values_with_count(counts, cnt)
    rs = []
    for i in range(len(vts)):
        vt = vts[i]
        a, b = vls
        rs+=gen_vector(vt, a, b, cnt-1)
    vts = rs

vls = get_values_with_count(counts, 16)[0]
rs = []
for vt in vts:
    vt[16] = vls

vls = get_values_with_count(counts, 15)[0]
rs = []
for vt in vts:
    vt[17] = vls

a = []
from itertools import permutations 

from Crypto.Util.number import*

def get_vt_cand(vector, per):
    vt_can = vector.copy()
    lst = get_values_with_count(counts, 1)
    a, b, c, d = per
    vt_can[0] = lst[a]
    vt_can[14] = lst[b]
    vt_can[15] = lst[c]
    vt_can[31] = lst[d]
    return vt_can

rs = []
for vt_rd in vts:
    perm =  permutations([0, 1, 2, 3]) 
    for per in list(perm):
        vt_can = get_vt_cand(vt_rd, per)
        rs+=[vt_can]
n = 32
M = matrix(GF(p), [[x**i for i in range(n)] for x in range(1, n+1)])
print(f"solving matrix")
for r in rs:
    rr = vector(GF(p), r)
    res = M**(-1)*rr
    for re in res:
        flag = long_to_bytes(int(re))
        if b"ISIT" in flag:
            print(flag)
            break
# ISITDTU{Mix1_a5850c98ad583157f0}

This code runs really long!!!

Sharemixer2

The only difference here is the limit of $xs$ array is 32. The solution is to use root of unity

We can see the relation of sum of shares:

\[sum = c_0 + c_1*\displaystyle \sum_{i=1}^{n}x_j^i + ... + c_n*\displaystyle \sum_{i=1}^{n}x_j^i\] \[S = \displaystyle \sum_{i=1}^{n}x_j^i = \displaystyle \sum_{i=1}^{n}e^{2*\pi*i*k/n} = \frac{1-q^n}{1-q} (q = e^{2*\pi*i/n}) \to q^n = 1 \to S = 0\]

So we can repeat until the flag is on the first place of the list so that $sum = 31*flag$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from Crypto.Util.number import *
from pwn import *
import random
import sys
import ast

context.log_level = 'warn'

def solve(r, p):
    assert (p - 1) % 32 == 0

    print("Found (p - 1) % 32 == 0")

    for _ in range(10000):
        e = random.randint(2, p - 1)
        base = pow(e, (p - 1) // 32, p)
        if pow(base, 32, p) == 1 and pow(base, 16, p) != 1:
            break
    else:
        # Failed to find
        return

    print("Found base")

    xs = [ pow(base, i, p) for i in range(32) ]
    print(sum(xs)%p)
    r.sendline(' '.join(str(x) for x in xs).encode())

    r.recvuntil(b'shares = ')
    shares = ast.literal_eval(r.recvline().decode().strip() )

    const = sum(shares) * pow(32, -1, p) % p

    try:
        msg = long_to_bytes(const)
        print(msg.decode())
        print(f"{xs = }")
        
        # print(f"{ks = }")
        exit(0)
    except UnicodeDecodeError:
        pass
    
    try:
        msg = long_to_bytes(p - const)
        print(msg.decode())
        print(f"{xs = }")
        # print(f"{ks = }")
        exit(0)
    except UnicodeDecodeError:
        pass

    r.close()

while True:
    if len(sys.argv) < 3:
        r = process(['python3', 'chall.py'])
    else:
        r = remote(sys.argv[1], sys.argv[2])

    r.recvuntil(b'p = ')
    p = int(r.recvline().decode().strip())

    if (p - 1) % 32 != 0:
        r.close()
        continue
    
    solve(r, p)

Sign

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python3

import os

from Crypto.Util.number import *
from Crypto.Signature import PKCS1_v1_5
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA256

flag = b'ISITDTU{aaaaaaaaaaaaaaaaaaaaaaaaaa}'
flag = os.urandom(255 - len(flag)) + flag


def genkey(e=11):
    while True:
        p = getPrime(1024)
        q = getPrime(1024)
        if GCD(p-1, e) == 1 and GCD(q-1, e) == 1:
            break
    n = p*q
    d = pow(e, -1, (p-1)*(q-1))
    return RSA.construct((n, e, d))


def gensig(key: RSA.RsaKey) -> bytes:
    m = os.urandom(256)
    h = SHA256.new(m)
    s = PKCS1_v1_5.new(key).sign(h)
    return s


def getflagsig(key: RSA.RsaKey) -> bytes:
    return long_to_bytes(pow(bytes_to_long(flag), key.d, key.n))


key = genkey()

while True:
    print(
        """=================
1. Generate random signature
2. Get flag signature
================="""
    )

    try:
        choice = int(input('> '))
        if choice == 1:
            sig = gensig(key)
            print('sig =', sig.hex())
        elif choice == 2:
            sig = getflagsig(key)
            print('sig =', sig.hex())
    except Exception as e:
        print('huh')
        exit(-1)

This oracle give us signatures of 32-byte random messages and use pkcs v1.5 padding which in form 0x00 || 0x01 || 0xFF...FF || 0x00 || ASN.1(SHA-256) || H(m) so we can use the agcd to solve it because the noise in range 32 bytes

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from sage.all import *
from pwn import remote

from Crypto.Util.number import long_to_bytes, bytes_to_long

import hashlib

difficulty = 6
zeros = '0' * difficulty

def is_valid(digest):
    return digest[:3] == b'\x00'*3

io = remote('35.187.238.100', 5003)
io.recvuntil(b'sha256("')
prefix = io.recvuntil(b'"')[:-1].decode()

i = 0
while True:
    i += 1
    s = prefix + str(i)
    if is_valid(hashlib.sha256(s.encode()).digest()):
        io.sendline(str(i).encode())
        break

sigs = []
for _ in range(30):
    io.sendline(b'1')
    sigs.append(int(io.recvline_contains(b'sig = ').split()[-1], 16))

io.sendline(b'2')
flag = int(io.recvline_contains(b'sig = ').split()[-1], 16)


pad = bytes_to_long(b'\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00010\r\x06\t`\x86H\x01e\x03\x04\x02\x01\x05\x00\x04 ' + b'\x00' * 32)
sigs = [x**11 - pad for x in sigs]
M = diagonal_matrix(ZZ, [sigs[0]] * len(sigs), sparse = False)
M[0] = [1<<256] + sigs[1:]


M = M.LLL()
for row in M:
    k0 = row[0] // (1<<256)
    n = sigs[0] // k0
    flag = pow(flag, 11, n)
    print(long_to_bytes(int(flag)))
    break
This post is licensed under CC BY 4.0 by the author.