/*  AOS Tool, decrypt and decompress Archos Gmini firmware files
    Copyright (C) 2004  Michael R. Donat <michael@donat.org>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

/*
   To do:
     * determine why about one in three file seems to have errors with
       checksum and compressed size though the decompressed file seems correct
     * determine the use of the contents of the SIGN chunk
*/
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <stdio.h>
#include <io.h>
#include <stdlib.h>
#include <string.h>

#ifdef _MSDOS_
#define DELIM '\\'
#else
#define DELIM '/'
#endif

typedef unsigned int  u32;
typedef unsigned char u8;

u8* key;
char *filename;

u32 loadFile( char* filename, u8 **buf )
{
   int aosfile;
   int length;
   u8* buffer = *buf;
   int bytesread = 0;
   struct stat filestats;
   if( !(aosfile = open( filename, O_RDONLY | O_BINARY ))) {
      fprintf( stderr, "Error opening input file: %s\n", filename );
      return -1;
   }
   if( fstat( aosfile, &filestats ) ) {
      fprintf( stderr, "Error accessing input file\n" );
      close( aosfile );
      return -1;
   }
   length = filestats.st_size;
   buffer = (u8*)malloc( length );
   while( bytesread < length )
      bytesread += read( aosfile, buffer + bytesread, length - bytesread );
   close( aosfile );
   *buf = buffer;
   return length;
}

int writeFile( char* filename, u8* buffer, u32 length ) {
   int binfile;
   struct stat filestats;
   if( !(binfile = open( filename, O_WRONLY | O_BINARY | O_CREAT | O_TRUNC,
                                   S_IWRITE | S_IREAD	))) {
      fprintf( stderr, "Error opening output file: %s\n", filename );
      return -1;
   }
   if( fstat( binfile, &filestats ) ) {
      fprintf( stderr, "Error accessing output file\n" );
      close( binfile );
      return -1;
   }
   if( write( binfile, buffer, length ) != length )
      fprintf( stderr, "Error writing to file\n" );
   close( binfile );
   return 0;
}

char *cstr( char* dest, const char* src, u32 len ) {
   memcpy( dest, src, len );
   dest[len] = 0;
   return dest;
   }

u32 lu32( char* src ) {
   return (u32)(((u8)*(src + 0) <<  0 ) |
                ((u8)*(src + 1) <<  8 ) |
                ((u8)*(src + 2) << 16 ) |
                ((u8)*(src + 3) << 24 ));
   }

u32 bu32( char* src ) {
   return (u32)(((u8)*(src + 3) <<  0 ) |
                ((u8)*(src + 2) <<  8 ) |
                ((u8)*(src + 1) << 16 ) |
                ((u8)*(src + 0) << 24 ));
   }

u32 chksum32( u8* src, u32 len ) {
   u32 sum = 0;
   u8* end = src + len;
   while( src < end )
      sum = ( sum + *(src++) );
   return sum;
}

u8* getText( u8* src, u32 slen, u8* dest, u32 dlen) {
   u8 flag = 0;
   u8 used = 7;
   u32 si = 0;
   u32 di = 0;
   while( di < dlen && si < slen )
   {
      flag >>= 1;
      if( ++used == 8 ) {
         flag = src[si++];
         used = 0;
      }
      if( flag & 0x01 )
         dest[di++]=src[si++];
      else
      {
         u32 offset = (di & 0xfffff000) + 18 +
                      ((u32)src[si] | (((u32)src[si+1] & 0xf0) << 4));
         u8 count = (src[si+1] & 0x0f) + 3;
         si +=2;
         if( offset > di )
            offset -= 4096;
         for(;count > 0; --count )
            dest[di++] = dest[offset++];
      }
   }
   return dest;
}

u32 getByteCount( u8* data, u32 length ) {
   u8 flag = 0;
   u8 used = 7;
   u32 count = 0;
   u8* end = data + length;
   while( data < end ) {
      flag >>= 1;
      if( ++used == 8 ) {
         flag = *(data++);
         used = 0;
      }
      if( flag & 0x01 )
         ++count;
      else
         count += (*(++data) & 0x0f) + 3;
      ++data;
   }
   return count;
}

u8* guessXORkey( u8* source, u32 length, u32 keylength ) {
   key = (u8*)malloc( keylength + 1);
   key[keylength] = '\0';
   int i;
   u32 bytecount[256];
   u32 c;
   u8 v;
   for( i = 0; i < keylength; ++i )
   {
      for( c = 0; c < 256; ++c )
         bytecount[c] = '\0';
      /* gather byte frequency statistics */
      for( c = i; c < length; c += keylength )
      {
         v = source[c];
         if( (v >= 'A' && v <= 'Z') || (v >='z' && v <='z') )
         ++bytecount[v];
      }
      u32 topval = 0;
      u8 topchar = 0;
      for( c = 0; c < 256; ++c )
         if( bytecount[c] > topval )
         {
            topval = bytecount[c];
            topchar = c;
         }
      key[i] = topchar;
   }
   return key;
}

u8* decryptXOR( u8* data, u32 length, u32 keylength ) {
   u32 i;
   for( i = 0; i < length; ++i )
      data[i] ^= key[ i % keylength ];
   return data;
}

void displayAJZ( char* type, u8* ajz, u32 length ) {
   // AJZ header ID field
   char head[5];
   cstr( head, ajz, 4 );
   printf( "%s: SIGNATURE: %s\n", type, head );

   // Uncompressed size stored in AJZ header
   printf( "%s: Decompressed Size: %u\n", type, lu32( ajz + 4 ) );

   // Compressed size stored in AJZ header
   u32 csize = lu32( ajz + 8 );
   if( csize > length - 30) {
      printf( "%s: Compressed size BUG: %u\n", type, csize );
      csize = length - 30;
   }
   else
      printf( "%s: Compressed size: %u\n", type, csize );

   // Checksum stored in AJZ header
   printf( "%s: Checksum: %8X\n", type, lu32( ajz + 12) );

   // Checksum of AJZ payload
   u32 sumdata = chksum32( ajz + 16, csize );
   if( lu32( ajz + 12) != sumdata ) {
         printf( "%s: Checksum of data (%8X) doesn't match header\n", type, sumdata );
   }
   // Decrypt firmware Image
   if( key == NULL )
   {
      key = guessXORkey( ajz + 16, csize, 32 );
      printf( "%s: Guessing key: %s\n", type, key );
   }
   else
      printf( "%s: Using key: %s\n", type, key );
   u8* compressed = decryptXOR(  ajz + 16, csize, strlen(key));

   // Decompress firmware Image
   u32 actual = getByteCount( compressed, csize );
   u8* firmware = (u8*)malloc( actual );
   getText( compressed, csize, firmware, actual );

   // Write unencrypted/decompressed file to disk
   printf( "%s: Writing decompressed firmware to %s\n", type, filename );
   writeFile( filename, firmware, actual );
}

void displayChunk( char* name, u8 *chunk, u32 leng ) {
   char type[5];
   u32 cleng;
   u8 *end = chunk + leng;
   while( chunk + 8 < end ) {
      cstr( type, chunk, 4 );
      cleng = bu32( chunk + 4 );
      if( strcmp( type, "RIFF" ) == 0 || strcmp( type, "LIST" ) == 0 )
         displayChunk( type, chunk + 8, cleng );
      else if( strcmp( type, "SWID" ) == 0 ) {
         char vid[5]; vid[4] = 0;
         strncpy( vid, chunk + 8, 4 );
         printf( "%s: Version ID %s\n", type, vid);
         printf( "%s: Version string: %s\n", type, chunk + 12 );
      }
      else if( strcmp( type, "FLSH" ) == 0 ) {
         if( bu32( chunk + 8 ) != 24576 )
            printf( "%s: %u != 24576\n", type, bu32( chunk + 8 ) );
         printf( "%s: Decompressed size: %u\n", type, bu32( chunk + 12 ) );
      }
      else if( strcmp( type, "HWID" ) == 0 )
         {if(bu32(chunk + 8)!= 0)printf( "%s: %u != 0\n", type, bu32( chunk + 8 ) );}
      else if( strcmp( type, "CCOD" ) == 0 )
         displayAJZ( type, chunk + 12, cleng );
      else if( strcmp( type, "CHOS" ) == 0 )
         {if( cleng != 0 ) printf( "%s: length %u != 0\n", type, cleng );}
      else if( strcmp( type, "SIGN" ) == 0 )
         {if( cleng != 108 ) printf( "%s: length %u != 0\n", type, cleng );}
      else
         printf( "%s: length: %u\n", type, cleng );
      chunk += ((cleng==0)? 8: cleng);
   }
}

int processAOS( char* file ) {
   u8 *aos;
   u32 length;
   if( !( length = loadFile( file, &aos )))
      return -1;
   printf( "File: %s (%u bytes)\n", file, length );
   displayChunk( file, aos, length );
   return 1;
}

int strrpos( char* str, char c ) {
   int i = strlen( str );
   while( i > 0 && str[i] != c ) --i;
   return i;
}

int main( int argc, char** argv ) {
   int n = 1;
   if( argc < 2 )
      printf("Usage: %s file.aos [file2.aos ...]", argv[0] + strrpos( argv[0], DELIM ));
   else for(;n < argc; ++n ) {
      int len = strlen(argv[n]);
      int pre = strrpos( argv[n], DELIM );
      len -= pre;
      filename = (char*)malloc( len );
      strncpy( filename, argv[n] + pre + 1, len - 4 );
      strcpy( filename + len - 4, "o" );
      processAOS( argv[n] );
   }
   return 0;
}

