#include <complex.h>
#include <fftw3.h>
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <stdio.h>

#include "scf.h"
#include "scf_code.h"
#include "scf_filter.h"
#include "scf_rx.h"

#define DEC_RATIO 4
#define CHIP_PHASES 8
#define CHIP_FREQS 9
#define FREQ_STEP_PER_BIN 4

#define CHIP_LEN (SCF_CHIP_LEN / DEC_RATIO)
#define TRACKER_LEN (2 * SCF_RX_LATENCY + 1)

struct rx_chain {
    float waterfall[SCF_CHIPS][SCF_FREQS];
    struct scf_soft_symbol tracker[SCF_CHIPS * TRACKER_LEN];
    struct scf_soft_symbol decoded_symbol;
    float decoded_group_weight;
    size_t decoded_phase;
};

struct rx_worker {
    complex float *source;
    fftwf_complex *sa_fft_buf;
    struct rx_chain rx_chains[CHIP_FREQS];
    struct scf_soft_symbol *decoded_symbol;
};

static struct rx_worker *rx_workers;
static unsigned int waterfall_idx;
static unsigned int tracker_idx;
static unsigned int chip_cnt;
static fftwf_plan sa_fft;
static complex float input_chips[CHIP_LEN * 2];
static float carrier_freq;
static float carrier_phase;
static complex float fir_tail[SCF_FIR_LEN_RX];
static complex float shifter_wavetable[CHIP_FREQS][CHIP_LEN];

static void shift_input_chips(float *chip)
{
    memcpy(&input_chips[0], &input_chips[CHIP_LEN], CHIP_LEN * sizeof(input_chips[0]));

    complex float baseband[SCF_CHIP_LEN];
    complex float baseband_filtered[SCF_CHIP_LEN];
    for (size_t i = 0; i < SCF_CHIP_LEN; i++) {
        float carrier_i = sinf(carrier_phase);
        float carrier_q = cosf(carrier_phase);
        baseband[i] = chip[i] * carrier_i + chip[i] * I * carrier_q;

        carrier_phase += 2.0f * M_PI * carrier_freq * (1.0f / (float) SCF_SRATE);
        while (carrier_phase > 2.0f * M_PI) {
            carrier_phase -= 2.0f * M_PI;
        }
    }
    scf_filter_rx(baseband_filtered, baseband, fir_tail);
    for (size_t i = 0; i < CHIP_LEN; i++) {
        input_chips[CHIP_LEN + i] = baseband_filtered[i * DEC_RATIO];
    }
}

static void update_waterfall(struct rx_worker *w, struct rx_chain *c, size_t freq_i)
{
    for (size_t i = 0; i < CHIP_LEN; i++) {
        w->sa_fft_buf[i] = w->source[i] * shifter_wavetable[freq_i][i];
    }

    fftwf_execute_dft(sa_fft, w->sa_fft_buf, w->sa_fft_buf);

    for (size_t i = 0; i < SCF_FREQS; i++) {
        size_t bin_idx = (i - SCF_FREQSPAN + CHIP_LEN) % CHIP_LEN;
        complex float bin = w->sa_fft_buf[bin_idx];
        float power = crealf(bin) * crealf(bin) + cimagf(bin) * cimagf(bin);
        c->waterfall[waterfall_idx][i] = power;
    }
}

static void update_tracker(struct rx_chain *c)
{
    int *code_linear_ptr = &scf_code_linear[waterfall_idx][0];
    float *waterfall_linear = &c->waterfall[0][0];
    float hopping_pattern_powers[SCF_CHIPS];
    float max_weight = 0.0f;
    float sum_weight = 0.0f;
    uint8_t max_symbol = 0;

    for (size_t s = 0; s < SCF_SYMBOL_M; s++) {

        for (size_t t = 0; t < SCF_CHIPS; t++) {
            size_t waterfall_linear_idx = *code_linear_ptr;
            hopping_pattern_powers[t] = waterfall_linear[waterfall_linear_idx];
            code_linear_ptr++;
        }

        float weight = 0.0f;
        for (size_t t = 0; t < SCF_CHIPS; t++) {
            weight += hopping_pattern_powers[t];
        }
        sum_weight += weight;

        if (weight > max_weight) {
            max_weight = weight;
            max_symbol = s;
        }
    }

    c->tracker[tracker_idx].weight = max_weight / sum_weight;
    c->tracker[tracker_idx].symbol = max_symbol;

}

static void update_decoded_symbol(struct rx_chain *c)
{
    float max_group_weight = 0.0f;
    struct scf_soft_symbol *max_symbol = NULL;
    size_t max_phase = 0;

    for (size_t phase = 0; phase < SCF_CHIPS; phase++) {
        size_t tracker_pos = (tracker_idx + 1 + phase + SCF_RX_LATENCY * SCF_CHIPS) % (SCF_CHIPS * TRACKER_LEN);
        float group_weight = 0.0f;
        for (size_t ahead_cnt = 0; ahead_cnt < TRACKER_LEN; ahead_cnt++) {
            size_t ahead_pos = (tracker_pos + ahead_cnt * SCF_CHIPS) % (SCF_CHIPS * TRACKER_LEN);
            group_weight += c->tracker[ahead_pos].weight;
        }

        if (group_weight > max_group_weight) {
            max_group_weight = group_weight;
            max_phase = phase;
            max_symbol = &c->tracker[tracker_pos];
        }
    }

    c->decoded_symbol = *max_symbol;
    c->decoded_group_weight = max_group_weight;
    c->decoded_phase = max_phase;
}

static void rx_worker_process_chip(struct rx_worker *w)
{
    for (size_t i = 0; i < CHIP_FREQS; i++) {
        struct rx_chain *c = &w->rx_chains[i];
        update_waterfall(w, c, i);
        update_tracker(c);
        if (chip_cnt == 1) {
            update_decoded_symbol(c);
        }
    }
}

static void shifter_wavetable_init(void)
{
    float freq_step = 1.0f / ((float) FREQ_STEP_PER_BIN * SCF_CHIP_LEN / SCF_SRATE);

    for (int freq_i = 0; freq_i < CHIP_FREQS; freq_i++) {
        float shift_freq = freq_step * (freq_i - CHIP_FREQS / 2);
        float shift_phase = 0.0f;
        for (size_t i = 0; i < CHIP_LEN; i++) {
            shifter_wavetable[freq_i][i] = sinf(shift_phase) + cosf(shift_phase) * I;
            shift_phase += 2.0f * M_PI * shift_freq * (DEC_RATIO / (float) SCF_SRATE);
            while (shift_phase > 2.0f * M_PI) {
                shift_phase -= 2.0f * M_PI;
            }
            while (shift_phase < -2.0f * M_PI) {
                shift_phase += 2.0f * M_PI;
            }
        }
    }
}

void scf_rx_init(float freq)
{
    carrier_freq = freq;

    if (!rx_workers) {
        rx_workers = calloc(CHIP_PHASES, sizeof(*rx_workers));
        assert(rx_workers);

        for (size_t i = 0; i < CHIP_PHASES; i++) {
            struct rx_worker *w = &rx_workers[i];

            w->source = &input_chips[CHIP_LEN - i * CHIP_LEN / CHIP_PHASES];
            w->sa_fft_buf = fftwf_alloc_complex(CHIP_LEN);
            assert(w->sa_fft_buf);
        }

        sa_fft = fftwf_plan_dft_1d(
            CHIP_LEN,
            rx_workers[0].sa_fft_buf,
            rx_workers[0].sa_fft_buf,
            FFTW_FORWARD,
            FFTW_ESTIMATE
        );
        assert(sa_fft);

        waterfall_idx = 0;
        tracker_idx = 0;
        chip_cnt = SCF_CHIPS;
        shifter_wavetable_init();
    }
}

bool scf_rx_chip(struct scf_soft_symbol *symbol, float *chip)
{
    shift_input_chips(chip);

    for (size_t i = 0; i < CHIP_PHASES; i++) {
        struct rx_worker *w = &rx_workers[i];
        rx_worker_process_chip(w);
    }

    waterfall_idx = (waterfall_idx + 1) % SCF_CHIPS;
    tracker_idx = (tracker_idx + 1) % (SCF_CHIPS * TRACKER_LEN);
    chip_cnt--;

    if (!chip_cnt) {
        float max_group_weight = 0.0f;
        size_t best_phase = 0;

        for (size_t i = 0; i < CHIP_PHASES; i++) {
            struct rx_worker *w = &rx_workers[i];
            for (size_t i = 0; i < CHIP_FREQS; i++) {
                struct rx_chain *c = &w->rx_chains[i];
                if (c->decoded_group_weight > max_group_weight) {
                    max_group_weight = c->decoded_group_weight;
                    *symbol = c->decoded_symbol;
                    best_phase = c->decoded_phase;
                }
            }
        }

        if (best_phase > SCF_CHIPS / 2) {
            chip_cnt = SCF_CHIPS + 1;
        } else if (best_phase < SCF_CHIPS / 2) {
            chip_cnt = SCF_CHIPS - 1;
        } else {
            chip_cnt = SCF_CHIPS;
        }
        return true;
    } else {
        return false;
    }
}
