
#ifndef BITSTREAM_H
#define BITSTREAM_H

#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#include <inttypes.h>

#ifdef __APPLE__
#include <machine/endian.h>
#else
#include <endian.h>
#endif


#ifdef __cplusplus
extern "C" {
#endif

typedef struct {
    uint32_t current_word;
    uint32_t next_word;
    uint8_t pos;
    uint8_t end;

    FILE *fi;
} bitstream_t;

/**************************************************
 API
***************************************************/
bitstream_t* bs_open( const char *path );
int bs_close( bitstream_t *bs );
uint8_t bs_getbits( bitstream_t *bs, uint8_t num_bits, uint32_t *result );
uint8_t bs_showbits( bitstream_t *bs, uint8_t num_bits, uint32_t *result );
void bs_bytealign( bitstream_t *bs );
void bs_printbits( uint32_t bits );





/**************************************************
 IMPLEMENTATION
***************************************************/

// Byte swabbing routine taken from the Linux kernel
#if (__BYTE_ORDER == __BIG_ENDIAN)
#   define swab32(x) (x)
#elif (__BYTE_ORDER == __LITTLE_ENDIAN)
#   define swab32(x)\
((((uint8_t*)&x)[0] << 24) | (((uint8_t*)&x)[1] << 16) |  \
 (((uint8_t*)&x)[2] << 8)  | (((uint8_t*)&x)[3]) )
#else
"Endian-ness could not be discovered"
#endif

/* Opens file for reading (mode r) 
 * if file could not be opened, returns NULL and global variable errno shall contain the reason
 */
bitstream_t* bs_open( const char *path )
{
    FILE *fi = fopen( path, "r" );
    if (!fi)
        return NULL;

    bitstream_t *bs = (bitstream_t*) malloc(sizeof(bitstream_t));
    bs->fi = fi;
    bs->pos = 0;

    bs->current_word = 0;
    size_t num = fread( &(bs->current_word), 1, 4, bs->fi );
    bs->current_word = swab32(bs->current_word);
    bs->end = 8*num;

    bs->next_word = 0;
    num = fread( &(bs->next_word), 1, 4, bs->fi );
    bs->next_word = swab32(bs->next_word);
    bs->end += 8*num;

    return bs;
}

/* Closes the bitstream.
 * if successful, bitstream will be freed, set to null, and function will return 0,
 * otherwise EOF is returned, and errno will contain the reason.
 */
int bs_close( bitstream_t *bs )
{
    int res = fclose(bs->fi);
    if ( 0 == res ) {
        free( bs );
        bs = NULL;
    }
    return res;
}

/* Try to read the next set of bits from stream, if necessary
 * meant for internal use only
 */
void bs_fillbuffer( bitstream_t *bs )
{
    // if we used all the bits in current_word, read next set of bits from file
    if ( bs->pos >= 32 ) {
        bs->current_word = bs->next_word;
        bs->end -= 32;
        bs->pos -= 32;
        
        bs->next_word = 0;
        size_t num = fread( &(bs->next_word), 1, 4, bs->fi );
        bs->next_word = swab32(bs->next_word);
        bs->end += 8*num;
    }
}

/* Read some bits from the stream and put them in result.
 * if EOF is reached or read error encountered then the actual number of bits read
 * is returned and will be less than that requested in num_bits.
 */
uint8_t bs_getbits( bitstream_t *bs, uint8_t num_bits, uint32_t *result )
{
    assert( 1 <= num_bits && num_bits <= 32 );

    *result = 0;
    if ( bs->pos + num_bits <=32 ) { // case 1: no overlap
        *result = (bs->current_word << bs->pos) >> (32 - num_bits);
    } else { // case 2: overlaps current_word and next_word
        *result = (bs->current_word << bs->pos) >> ( bs->pos );
        *result = *result << (bs->pos + num_bits - 32 );
        *result |= (bs->next_word >> (64 - bs->pos - num_bits));
    }
    bs->pos += num_bits;

    // if we reached the end of our buffer, then return result
    if (bs->pos > bs->end) {
        *result = *result >> (bs->pos - bs->end);
        return bs->end - (bs->pos - num_bits);
    }

    bs_fillbuffer( bs );
    return num_bits;
}

/* Show some bits from the stream and put them in result, does not advance the stream.
 * if EOF is reached or read error encountered then the actual number of bits read
 * is returned and will be less than that requested in num_bits.
 */
uint8_t bs_showbits( bitstream_t *bs, uint8_t num_bits, uint32_t *result )
{
    assert( 1 <= num_bits && num_bits <= 32 );

    *result = 0;
    if ( bs->pos + num_bits <=32 ) { // case 1: no overlap
        *result = (bs->current_word << bs->pos) >> (32 - num_bits);
    } else { // case 2: overlaps current_word and next_word
        *result = (bs->current_word << bs->pos) >> ( bs->pos );
        *result = *result << (bs->pos + num_bits - 32 );
        *result |= (bs->next_word >> (64 - bs->pos - num_bits));
    }

    // if we reached the end of our buffer, then return result
    if (bs->pos + num_bits > bs->end) {
        *result = *result >> (bs->pos + num_bits - bs->end);
        return (bs->end - bs->pos);
    }

    return num_bits;
}

/* Advance the stream to the next byte
 */
void bs_bytealign( bitstream_t *bs )
{
    uint8_t adv = bs->pos % 8;
    bs->pos += 8 - adv;

    bs_fillbuffer( bs );
}

/* Useful debugging function, prints all the bits of a uint32_t
 */
void bs_printbits( uint32_t bits )
{
    int i;
    for (i=31; i>=0; i--) {
        int x = bits >> i;
        x &= 0x1;
        printf( "%c", x ? '1' : '0' );
    }
}

#ifdef __cplusplus
}
#endif

#endif

