#include <assert.h>
#include <fec.h>
#include <stddef.h>
#include <endian.h>
#include <string.h>
#include <float.h>

#include "scf.h"
#include "scf_fec.h"

static void *rs_code = NULL;

static uint32_t crc24(uint8_t *input, size_t len)
{
    uint32_t crc = 0xB704CE;
    while (len--) {
        crc ^= (*input++) << 16;
        for (size_t i = 0; i < 8; i++) {
            crc <<= 1;
            if (crc & 0x1000000) {
                crc &= 0XFFFFFF;
                crc ^= 0x864CFB;
            }
        }
    }
    return crc & 0xFFFFFF;
}

void scf_fec_init(void)
{
    if (!rs_code) {
        rs_code = init_rs_char(8, 0x11d, 1, 1, SCF_PARITY_LEN, 255 - SCF_PKT_LEN);
        assert(rs_code);
    }
}

void scf_fec_encode(uint8_t *packet, uint8_t *message)
{
    struct scf_packet *p = (void*) packet;

    memcpy(p->message, message, SCF_MSG_LEN);
    uint32_t crc = crc24(p->message, SCF_MSG_LEN);
    p->crc[0] = 0xFF & (crc >> 16);
    p->crc[1] = 0xFF & (crc >> 8);
    p->crc[2] = 0xFF & crc;
    encode_rs_char(rs_code, p->message, p->parity);
}

bool scf_fec_decode(uint8_t *message, struct scf_soft_symbol *packet)
{
    struct scf_packet p;
    uint8_t *pb = (void*) &p;

    for (size_t i = 0; i < SCF_PKT_LEN; i++) {
        pb[i] = packet[i].symbol;
    }

    int eras_pos[SCF_PARITY_LEN];
    int eras_no_max = SCF_PARITY_LEN;
    int eras_mark[SCF_PKT_LEN] = {0};

    for (size_t i = 0; i < eras_no_max; i++) {
        float min_weight = FLT_MAX;
        int min_pos = 0;

        for (size_t j = 0; j < SCF_PKT_LEN; j++) {
            if (!eras_mark[j] && packet[j].weight < min_weight) {
                min_weight = packet[j].weight;
                min_pos = j;
            }
        }

        eras_mark[min_pos] = 1;
        eras_pos[i] = min_pos;
    }

    for (size_t eras_no = 0; eras_no <= eras_no_max; eras_no += 2) {
        int eras_pos_tmp[SCF_PARITY_LEN];
        memcpy(eras_pos_tmp, eras_pos, sizeof(eras_pos_tmp));

        int rs_symbol_error_count = decode_rs_char(rs_code, pb, eras_pos_tmp, eras_no);

        if (rs_symbol_error_count >= 0) {
            uint32_t crc = crc24(p.message, SCF_MSG_LEN);
            uint32_t crc_received = (p.crc[0] << 16) | (p.crc[1] << 8) | p.crc[2];
            if (crc == crc_received) {
                memcpy(message, p.message, SCF_MSG_LEN);
                return true;
            }
        }
    }

    return false;
}
