from matplotlib import pyplot as plt

from vortex import Range
from vortex.scan import RasterScanConfig, RasterScan, inactive_policy
from vortex_tools.scan import plot_annotated_waveforms_space, partition_segments_by_activity

fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, constrained_layout=True, subplot_kw=dict(adjustable='box', aspect='equal'))
cfgs = []
names = []

cfg = RasterScanConfig()
cfg.segment_extent = Range.symmetric(1)
cfg.volume_extent = Range.symmetric(2)
cfg.segments_per_volume = 6
cfg.samples_per_segment = 50
for limit in cfg.limits:
    limit.acceleration *= 5
cfg.loop = True

names.append('Minimum Dynamic Limited')
cfgs.append(cfg.copy())
cfgs[-1].inactive_policy = inactive_policy.MinimumDynamicLimited()

names.append('Fixed Dynamic Limited')
cfgs.append(cfg.copy())
cfgs[-1].inactive_policy = inactive_policy.FixedDynamicLimited(200, 200)

names.append('Fixed Linear')
cfgs.append(cfg.copy())
cfgs[-1].inactive_policy = inactive_policy.FixedLinear()

for (name, cfg, ax) in zip(names, cfgs, axs.flat):
    scan = RasterScan()
    scan.initialize(cfg)
    plot_annotated_waveforms_space(scan.scan_buffer(), scan.scan_markers(), inactive_marker=None, scan_line='w-', axes=ax)
    ax.set_title(name)

for ax in axs.flat[len(names):]:
    fig.delaxes(ax)