275 lines
12 KiB
Python
275 lines
12 KiB
Python
"""SPI Mode 3 byte-oriented slave for the EXI bus.
|
||
|
||
CPOL=1, CPHA=1: CLK idles HIGH.
|
||
Slave samples MOSI on the FALLING CLK edge.
|
||
Slave drives MISO on the RISING CLK edge (master samples on next falling edge).
|
||
All three raw inputs are run through a 2-stage FFSynchronizer before use.
|
||
"""
|
||
|
||
from amaranth import *
|
||
from amaranth.lib.cdc import FFSynchronizer
|
||
|
||
|
||
# ── public re-export for import convenience ─────────────────────────────────
|
||
__all__ = ["SPIMode3Slave"]
|
||
|
||
|
||
class SPIMode3Slave(Elaboratable):
|
||
"""Byte-oriented SPI Mode 3 slave.
|
||
|
||
Ports
|
||
-----
|
||
spi_clk / spi_mosi / spi_cs_n : raw async inputs from GC (synchronized internally)
|
||
spi_miso : output to GC; idles HIGH when CS deasserted
|
||
rx_byte : last complete received byte (valid when rx_valid pulses)
|
||
rx_valid : 1-cycle pulse in exi domain when rx_byte contains a new byte
|
||
tx_byte : upstream loads this before or within one exi clock of tx_load pulsing
|
||
tx_load : 1-cycle pulse requesting the next TX byte from upstream
|
||
"""
|
||
|
||
def __init__(self, domain="capture"):
|
||
# Clock domain this byte engine runs in. Split-domain design puts the
|
||
# bit engine in a fast `capture` domain (54 MHz) so it can oversample
|
||
# a 27 MHz EXI clock ~3×; the register file lives in a slower domain.
|
||
self._domain = domain
|
||
|
||
self.spi_clk = Signal(init=1) # idles HIGH
|
||
self.spi_mosi = Signal()
|
||
self.spi_cs_n = Signal(init=1) # active LOW
|
||
|
||
self.spi_miso = Signal() # combinatorial output
|
||
|
||
self.rx_byte = Signal(8)
|
||
self.rx_valid = Signal()
|
||
self.tx_byte = Signal(8)
|
||
self.tx_load = Signal()
|
||
|
||
# 1-cycle pulse on CS assertion (transaction start). The capture
|
||
# wrapper uses it to reset its per-transaction TX byte counter.
|
||
self.frame_start = Signal()
|
||
|
||
# Level: high while CS is asserted (a transaction is in progress).
|
||
# Lets downstream logic detect variable-length (DMA) transaction ends.
|
||
self.cs_active = Signal()
|
||
|
||
def elaborate(self, platform):
|
||
m = Module()
|
||
d = self._domain
|
||
|
||
# ── Input synchronization (async → exi, 2 stages) ──────────────────
|
||
clk_s = Signal(init=1)
|
||
mosi_s = Signal()
|
||
cs_s = Signal(init=1)
|
||
|
||
m.submodules.sync_clk = FFSynchronizer(self.spi_clk, clk_s, o_domain=d, init=1)
|
||
m.submodules.sync_mosi = FFSynchronizer(self.spi_mosi, mosi_s, o_domain=d)
|
||
m.submodules.sync_cs = FFSynchronizer(self.spi_cs_n, cs_s, o_domain=d, init=1)
|
||
|
||
# ── Edge detection ──────────────────────────────────────────────────
|
||
clk_prev = Signal(init=1)
|
||
cs_prev = Signal(init=1)
|
||
m.d[d] += clk_prev.eq(clk_s)
|
||
m.d[d] += cs_prev.eq(cs_s)
|
||
|
||
falling_clk = Signal()
|
||
rising_clk = Signal()
|
||
cs_fall = Signal()
|
||
cs_rise = Signal()
|
||
m.d.comb += falling_clk.eq(~clk_s & clk_prev)
|
||
m.d.comb += rising_clk .eq( clk_s & ~clk_prev)
|
||
m.d.comb += cs_fall .eq(~cs_s & cs_prev)
|
||
m.d.comb += cs_rise .eq( cs_s & ~cs_prev)
|
||
m.d.comb += self.frame_start.eq(cs_fall)
|
||
m.d.comb += self.cs_active.eq(~cs_s)
|
||
|
||
# ── Shift registers ─────────────────────────────────────────────────
|
||
rx_shift = Signal(8)
|
||
tx_shift = Signal(8)
|
||
bit_ctr = Signal(4) # counts 0..7; 7 means "8th (last) bit"
|
||
armed = Signal(init=1) # between bytes: drive the LIVE tx_byte MSB
|
||
rearm = Signal() # arm for next byte on the next rising edge
|
||
|
||
# MISO: idle HIGH when CS deasserted. While "armed" — i.e. at the start
|
||
# of a byte, including the inter-byte / clock-idle gap before the first
|
||
# falling edge — drive the LIVE tx_byte MSB. This is what lets a
|
||
# response that upstream pushes DURING the EXI clock-idle gap reach MISO
|
||
# in time: there is no clock edge during the gap to latch it, so MISO
|
||
# must be combinational on tx_byte until the byte actually starts. Once
|
||
# shifting (after the first falling edge) drive the latched shift reg.
|
||
m.d.comb += self.spi_miso.eq(
|
||
Mux(cs_s, 1, Mux(armed, self.tx_byte[7], tx_shift[7]))
|
||
)
|
||
|
||
# Default: deassert single-cycle pulses every cycle
|
||
m.d[d] += self.rx_valid.eq(0)
|
||
m.d[d] += self.tx_load.eq(0)
|
||
|
||
with m.If(cs_fall):
|
||
# Transaction start: first byte drives its MSB live (armed).
|
||
m.d[d] += bit_ctr.eq(0)
|
||
m.d[d] += armed.eq(1)
|
||
|
||
with m.Elif(cs_rise | cs_s):
|
||
# CS deasserted / idle: reset state
|
||
m.d[d] += bit_ctr.eq(0)
|
||
m.d[d] += armed.eq(1)
|
||
|
||
with m.Else():
|
||
# CS asserted: run bit engine
|
||
with m.If(falling_clk):
|
||
# Sample MOSI (MSB first: left-shift, new bit enters at LSB)
|
||
# Cat(a, b) → a at lower bits; so Cat(mosi, rx[6:0]) = {rx[6:0], mosi}
|
||
m.d[d] += rx_shift.eq(Cat(mosi_s, rx_shift[:-1]))
|
||
|
||
with m.If(armed):
|
||
# First falling edge of this byte: master has just sampled
|
||
# the MSB (driven live above). Latch tx_byte so the
|
||
# remaining 7 bits shift out of a stable register.
|
||
m.d[d] += tx_shift.eq(self.tx_byte)
|
||
m.d[d] += armed.eq(0)
|
||
|
||
with m.If(bit_ctr == 7):
|
||
# 8th falling edge: byte complete. The master samples the
|
||
# LSB on THIS edge, so MISO must still hold tx_shift[7].
|
||
# Defer arming to the next rising edge (rearm) so MISO is
|
||
# not switched to the next byte's live MSB too early.
|
||
m.d[d] += self.rx_byte.eq(Cat(mosi_s, rx_shift[:-1]))
|
||
m.d[d] += self.rx_valid.eq(1)
|
||
m.d[d] += bit_ctr.eq(0)
|
||
m.d[d] += self.tx_load.eq(1) # advance source to next byte
|
||
m.d[d] += rearm.eq(1) # arm on the next rising edge
|
||
with m.Else():
|
||
m.d[d] += bit_ctr.eq(bit_ctr + 1)
|
||
|
||
with m.If(rising_clk):
|
||
with m.If(rearm):
|
||
# Byte boundary: arm for the next byte (live MSB drive).
|
||
m.d[d] += armed.eq(1)
|
||
m.d[d] += rearm.eq(0)
|
||
with m.Elif(~armed):
|
||
# Shift left: next bit into MSB position
|
||
# Cat(0, tx[6:0]) = {tx[6:0], 0} — left shift
|
||
m.d[d] += tx_shift.eq(Cat(0, tx_shift[:-1]))
|
||
|
||
return m
|
||
|
||
|
||
# ── Testbench ───────────────────────────────────────────────────────────────
|
||
|
||
if __name__ == "__main__":
|
||
from amaranth.sim import Simulator, Period
|
||
|
||
dut = SPIMode3Slave()
|
||
|
||
# 4 exi ticks per SPI half-period → well above the 3-cycle (2 sync + 1 edge) latency.
|
||
HALF = 4
|
||
|
||
async def spi_send_byte(ctx, mosi_val, next_tx_byte=None):
|
||
"""Drive one SPI Mode 3 byte on MOSI; return the MISO byte assembled.
|
||
|
||
next_tx_byte: if given, written to tx_byte after the LAST falling edge
|
||
(before the last rising edge) so need_reload picks it up in time.
|
||
"""
|
||
miso_byte = 0
|
||
for bit in range(7, -1, -1):
|
||
ctx.set(dut.spi_mosi, (mosi_val >> bit) & 1)
|
||
ctx.set(dut.spi_clk, 0) # falling edge
|
||
await ctx.tick("capture").repeat(HALF)
|
||
miso_byte = (miso_byte << 1) | ctx.get(dut.spi_miso)
|
||
# Set next TX byte here — after last fall, before rising edge.
|
||
# The rising edge is detected 3 cycles after we assert clk=1,
|
||
# so we have HALF ticks of margin.
|
||
if bit == 0 and next_tx_byte is not None:
|
||
ctx.set(dut.tx_byte, next_tx_byte)
|
||
ctx.set(dut.spi_clk, 1) # rising edge
|
||
await ctx.tick("capture").repeat(HALF)
|
||
return miso_byte
|
||
|
||
errors = []
|
||
|
||
async def testbench(ctx):
|
||
# ── Test 1: Single byte TX/RX ──────────────────────────────────────
|
||
ctx.set(dut.spi_cs_n, 0)
|
||
ctx.set(dut.spi_clk, 1)
|
||
ctx.set(dut.tx_byte, 0xA5) # pre-load before CS fall is detected
|
||
await ctx.tick("capture").repeat(HALF)
|
||
|
||
miso = await spi_send_byte(ctx, 0x37)
|
||
await ctx.tick("capture").repeat(2)
|
||
rx = ctx.get(dut.rx_byte)
|
||
|
||
ctx.set(dut.spi_cs_n, 1)
|
||
await ctx.tick("capture").repeat(HALF)
|
||
|
||
if rx != 0x37:
|
||
errors.append(f"Test1 rx_byte: expected 0x37, got 0x{rx:02X}")
|
||
if miso != 0xA5:
|
||
errors.append(f"Test1 miso: expected 0xA5, got 0x{miso:02X}")
|
||
print(f"Test1 – MOSI→rx_byte: 0x{rx:02X} MISO←tx_byte: 0x{miso:02X}")
|
||
|
||
await ctx.tick("capture").repeat(HALF)
|
||
|
||
# ── Test 2: Two-byte transaction; second byte loaded via need_reload ─
|
||
ctx.set(dut.spi_cs_n, 0)
|
||
ctx.set(dut.tx_byte, 0xBE) # first response byte
|
||
await ctx.tick("capture").repeat(HALF)
|
||
|
||
# Pass next_tx_byte=0xEF so it's set after last falling edge of byte 0,
|
||
# giving need_reload time to load it on the subsequent rising edge.
|
||
miso0 = await spi_send_byte(ctx, 0x00, next_tx_byte=0xEF)
|
||
miso1 = await spi_send_byte(ctx, 0xFF)
|
||
|
||
await ctx.tick("capture").repeat(2)
|
||
rx1 = ctx.get(dut.rx_byte)
|
||
|
||
ctx.set(dut.spi_cs_n, 1)
|
||
await ctx.tick("capture").repeat(HALF)
|
||
|
||
if miso0 != 0xBE:
|
||
errors.append(f"Test2 miso0: expected 0xBE, got 0x{miso0:02X}")
|
||
if miso1 != 0xEF:
|
||
errors.append(f"Test2 miso1: expected 0xEF, got 0x{miso1:02X}")
|
||
if rx1 != 0xFF:
|
||
errors.append(f"Test2 rx1: expected 0xFF, got 0x{rx1:02X}")
|
||
print(f"Test2 – byte0 MISO: 0x{miso0:02X} byte1 MISO: 0x{miso1:02X} rx1: 0x{rx1:02X}")
|
||
|
||
await ctx.tick("capture").repeat(HALF)
|
||
|
||
# ── Test 3: MISO idles HIGH when CS deasserted ─────────────────────
|
||
miso_idle = ctx.get(dut.spi_miso)
|
||
if miso_idle != 1:
|
||
errors.append(f"Test3 MISO idle: expected 1, got {miso_idle}")
|
||
print(f"Test3 – MISO idle (CS=1): {miso_idle}")
|
||
|
||
# ── Test 4: All-zeros byte (0x00) TX and RX ────────────────────────
|
||
ctx.set(dut.spi_cs_n, 0)
|
||
ctx.set(dut.tx_byte, 0x00)
|
||
await ctx.tick("capture").repeat(HALF)
|
||
|
||
miso = await spi_send_byte(ctx, 0xFF)
|
||
await ctx.tick("capture").repeat(2)
|
||
rx = ctx.get(dut.rx_byte)
|
||
ctx.set(dut.spi_cs_n, 1)
|
||
await ctx.tick("capture").repeat(HALF)
|
||
|
||
if miso != 0x00:
|
||
errors.append(f"Test4 miso: expected 0x00, got 0x{miso:02X}")
|
||
if rx != 0xFF:
|
||
errors.append(f"Test4 rx: expected 0xFF, got 0x{rx:02X}")
|
||
print(f"Test4 – 0x00 TX / 0xFF RX: MISO=0x{miso:02X} rx=0x{rx:02X}")
|
||
|
||
sim = Simulator(dut)
|
||
sim.add_clock(Period(MHz=54), domain="capture")
|
||
sim.add_testbench(testbench)
|
||
|
||
with sim.write_vcd("SPIMode3Slave.vcd"):
|
||
sim.run()
|
||
|
||
if errors:
|
||
print("\nFAILURES:")
|
||
for e in errors:
|
||
print(" ", e)
|
||
raise SystemExit(1)
|
||
else:
|
||
print("\nAll tests passed.")
|