#!/usr/bin/env python import numpy def decode_packet(packet): # import pdb; pdb.set_trace() # get version v = int(packet[:3], 2) packet = packet[3:] # get type t = int(packet[:3], 2) packet = packet[3:] # literal, no other sub packets if t == 4: value = '' while True: l = packet[:1] value = value + packet[1:5] packet = packet[5:] if int(l[0],2) == 0: break return (int(value, 2), packet) values = [] i = int(packet[:1], 2) packet = packet[1:] if i == 0: l = int(packet[:15], 2) packet = packet[15:] n_packet = packet[:l] packet = packet[l:] while (len(n_packet)): (tmp_r, tmp_p) = decode_packet(n_packet) n_packet = tmp_p values = values + [tmp_r] if i == 1: l = int(packet[:11], 2) packet = packet[11:] for x in range(l): (tmp_r, tmp_p) = decode_packet(packet) values = values + [tmp_r] packet = tmp_p if t == 0: return(sum(values), packet) if t == 1: return(numpy.prod(values), packet) if t == 2: return(min(values), packet) if t == 3: return(max(values), packet) if t == 5: return(int(values[0] > values[1]), packet) if t == 6: return(int(values[0] < values[1]), packet) if t == 7: return(int(values[0] == values[1]), packet) return None with open("input01.txt","r") as f: for packet in f: packet = bin(int('1'+packet.strip(), 16))[3:] (result, packet) = decode_packet(packet) print(result)